Compare commits

...

2 Commits

22 changed files with 313 additions and 124 deletions

1
.gitignore vendored
View File

@ -35,4 +35,5 @@ suppressions
release
.gofuzz
*-fuzz.zip
*.test

View File

@ -43,7 +43,7 @@ func main() {
log.Fatalf(err)
}
sg, err = core.NewSuperGraph(conf, db)
sg, err := core.NewSuperGraph(conf, db)
if err != nil {
log.Fatalf(err)
}

View File

@ -24,7 +24,7 @@
log.Fatalf(err)
}
sg, err = core.NewSuperGraph(conf, db)
sg, err := core.NewSuperGraph(conf, db)
if err != nil {
log.Fatalf(err)
}
@ -82,6 +82,7 @@ type SuperGraph struct {
conf *Config
db *sql.DB
log *_log.Logger
dbinfo *psql.DBInfo
schema *psql.DBSchema
allowList *allow.List
encKey [32]byte
@ -93,15 +94,26 @@ type SuperGraph struct {
anonExists bool
qc *qcode.Compiler
pc *psql.Compiler
ge *graphql.Engine
}
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
// schemas and relationships
func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
return newSuperGraph(conf, db, nil)
}
// newSuperGraph helps with writing tests and benchmarks
func newSuperGraph(conf *Config, db *sql.DB, dbinfo *psql.DBInfo) (*SuperGraph, error) {
if conf == nil {
conf = &Config{}
}
sg := &SuperGraph{
conf: conf,
db: db,
log: _log.New(os.Stdout, "", 0),
conf: conf,
db: db,
dbinfo: dbinfo,
log: _log.New(os.Stdout, "", 0),
}
if err := sg.initConfig(); err != nil {
@ -124,6 +136,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
return nil, err
}
if err := sg.initGraphQLEgine(); err != nil {
return nil, err
}
if len(conf.SecretKey) != 0 {
sk := sha256.Sum256([]byte(conf.SecretKey))
conf.SecretKey = ""
@ -163,14 +179,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
// use the chirino/graphql library for introspection queries
// disabled when allow list is enforced
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
engine, err := sg.createGraphQLEgine()
if err != nil {
res.Error = err.Error()
return &res, err
}
r := engine.ExecuteOne(&graphql.EngineRequest{Query: query})
r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query})
res.Data = r.Data
if r.Error() != nil {
res.Error = r.Error().Error()
}
@ -199,10 +210,8 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
return &ct.res, nil
}
// GraphQLSchema function return the GraphQL schema for the underlying database connected
// to this instance of Super Graph
func (sg *SuperGraph) GraphQLSchema() (string, error) {
engine, err := sg.createGraphQLEgine()
if err != nil {
return "", err
}
return engine.Schema.String(), nil
return sg.ge.Schema.String(), nil
}

62
core/api_test.go Normal file
View File

@ -0,0 +1,62 @@
package core
import (
"context"
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/dosco/super-graph/core/internal/psql"
)
func BenchmarkGraphQL(b *testing.B) {
ct := context.WithValue(context.Background(), UserIDKey, "1")
db, _, err := sqlmock.New()
if err != nil {
b.Fatal(err)
}
defer db.Close()
// mock.ExpectQuery(`^SELECT jsonb_build_object`).WithArgs()
sg, err := newSuperGraph(nil, db, psql.GetTestDBInfo())
if err != nil {
b.Fatal(err)
}
query := `
query {
products {
id
name
user {
full_name
phone
email
}
customers {
id
email
}
}
users {
id
name
}
}`
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err = sg.GraphQL(ct, query, nil)
}
})
fmt.Println(err)
//fmt.Println(mock.ExpectationsWereMet())
}

View File

@ -50,20 +50,26 @@ type scontext struct {
}
func (sg *SuperGraph) initCompilers() error {
di, err := psql.GetDBInfo(sg.db)
if err != nil {
var err error
// If sg.di is not null then it's probably set
// for tests
if sg.dbinfo == nil {
sg.dbinfo, err = psql.GetDBInfo(sg.db)
if err != nil {
return err
}
}
if err = addTables(sg.conf, sg.dbinfo); err != nil {
return err
}
if err = addTables(sg.conf, di); err != nil {
if err = addForeignKeys(sg.conf, sg.dbinfo); err != nil {
return err
}
if err = addForeignKeys(sg.conf, di); err != nil {
return err
}
sg.schema, err = psql.NewDBSchema(di, getDBTableAliases(sg.conf))
sg.schema, err = psql.NewDBSchema(sg.dbinfo, getDBTableAliases(sg.conf))
if err != nil {
return err
}

View File

@ -1,8 +1,6 @@
package core
import (
"errors"
"regexp"
"strings"
"github.com/chirino/graphql"
@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{
"boolean": "Boolean",
}
func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) {
func (sg *SuperGraph) initGraphQLEgine() error {
engine := graphql.New()
engineSchema := engine.Schema
dbSchema := sg.schema
@ -63,15 +61,16 @@ enum OrderDirection {
engineSchema.EntryPoints[schema.Query] = query
engineSchema.EntryPoints[schema.Mutation] = mutation
validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
//validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
scalarExpressionTypesNeeded := map[string]bool{}
tableNames := dbSchema.GetTableNames()
for _, table := range tableNames {
funcs := dbSchema.GetFunctions()
for _, table := range tableNames {
ti, err := dbSchema.GetTable(table)
if err != nil {
return nil, err
return err
}
if !ti.IsSingular {
@ -79,13 +78,13 @@ enum OrderDirection {
}
singularName := ti.Singular
if !validGraphQLIdentifierRegex.MatchString(singularName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName)
}
// if !validGraphQLIdentifierRegex.MatchString(singularName) {
// return errors.New("table name is not a valid GraphQL identifier: " + singularName)
// }
pluralName := ti.Plural
if !validGraphQLIdentifierRegex.MatchString(pluralName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName)
}
// if !validGraphQLIdentifierRegex.MatchString(pluralName) {
// return errors.New("table name is not a valid GraphQL identifier: " + pluralName)
// }
outputType := &schema.Object{
Name: singularName + "Output",
@ -127,9 +126,9 @@ enum OrderDirection {
for _, col := range ti.Columns {
colName := col.Name
if !validGraphQLIdentifierRegex.MatchString(colName) {
return nil, errors.New("column name is not a valid GraphQL identifier: " + colName)
}
// if !validGraphQLIdentifierRegex.MatchString(colName) {
// return errors.New("column name is not a valid GraphQL identifier: " + colName)
// }
colType := gqltype(col)
nullableColType := ""
@ -144,6 +143,16 @@ enum OrderDirection {
Type: colType,
})
for _, f := range funcs {
if col.Type != f.Params[0].Type {
continue
}
outputType.Fields = append(outputType.Fields, &schema.Field{
Name: f.Name + "_" + colName,
Type: colType,
})
}
// If it's a numeric type...
if nullableColType == "Float" || nullableColType == "Int" {
outputType.Fields = append(outputType.Fields, &schema.Field{
@ -464,7 +473,7 @@ enum OrderDirection {
err := engineSchema.ResolveTypes()
if err != nil {
return nil, err
return err
}
engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution {
@ -479,5 +488,7 @@ enum OrderDirection {
return nil
})
return engine, nil
sg.ge = engine
return nil
}

View File

@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
}
func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error {
pl := funcPrefixLen(col.Name)
pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`)

View File

@ -10,7 +10,7 @@ import (
var (
qcompileTest, _ = qcode.NewCompiler(qcode.Config{})
schema = getTestSchema()
schema = GetTestSchema()
vars = NewVariables(map[string]string{
"admin_account_id": "5",

View File

@ -1,4 +1,4 @@
package psql
package psql_test
import (
"encoding/json"

View File

@ -1,4 +1,4 @@
package psql
package psql_test
import (
"encoding/json"

View File

@ -1,4 +1,4 @@
package psql
package psql_test
import (
"fmt"
@ -8,6 +8,7 @@ import (
"strings"
"testing"
"github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode"
)
@ -19,7 +20,7 @@ const (
var (
qcompile *qcode.Compiler
pcompile *Compiler
pcompile *psql.Compiler
expected map[string][]string
)
@ -133,13 +134,16 @@ func TestMain(m *testing.M) {
log.Fatal(err)
}
schema := getTestSchema()
schema, err := psql.GetTestSchema()
if err != nil {
log.Fatal(err)
}
vars := NewVariables(map[string]string{
vars := psql.NewVariables(map[string]string{
"admin_account_id": "5",
})
pcompile = NewCompiler(Config{
pcompile = psql.NewCompiler(psql.Config{
Schema: schema,
Vars: vars,
})
@ -173,7 +177,7 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func compileGQLToPSQL(t *testing.T, gql string, vars Variables, role string) {
func compileGQLToPSQL(t *testing.T, gql string, vars psql.Variables, role string) {
generateTestFile := false
if generateTestFile {

View File

@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip
var cn string
for _, col := range sel.Cols {
if n := funcPrefixLen(col.Name); n != 0 {
if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 {
if !sel.Functions {
continue
}
@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type)
}
func funcPrefixLen(fn string) int {
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch {
case strings.HasPrefix(fn, "avg_"):
return 4
@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int {
case strings.HasPrefix(fn, "var_samp_"):
return 9
}
fnLen := len(fn)
for k := range fm {
kLen := len(k)
if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' {
return kLen + 1
}
}
return 0
}

View File

@ -1,4 +1,4 @@
package psql
package psql_test
import (
"bytes"

View File

@ -11,6 +11,7 @@ type DBSchema struct {
ver int
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
fm map[string]*DBFunction
}
type DBTableInfo struct {
@ -56,8 +57,10 @@ type DBRel struct {
func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
schema := &DBSchema{
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
ver: info.Version,
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
fm: make(map[string]*DBFunction, len(info.Functions)),
}
for i, t := range info.Tables {
@ -81,6 +84,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
}
}
for k, f := range info.Functions {
if len(f.Params) == 1 {
schema.fm[strings.ToLower(f.Name)] = &info.Functions[k]
}
}
return schema, nil
}
@ -439,3 +448,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
}
return rel, nil
}
func (s *DBSchema) GetFunctions() []*DBFunction {
var funcs []*DBFunction
for _, f := range s.fm {
funcs = append(funcs, f)
}
return funcs
}

View File

@ -10,10 +10,11 @@ import (
)
type DBInfo struct {
Version int
Tables []DBTable
Columns [][]DBColumn
colmap map[string]map[string]*DBColumn
Version int
Tables []DBTable
Columns [][]DBColumn
Functions []DBFunction
colMap map[string]map[string]*DBColumn
}
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
@ -35,41 +36,56 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
return nil, err
}
di.colmap = make(map[string]map[string]*DBColumn, len(di.Tables))
for i, t := range di.Tables {
for _, t := range di.Tables {
cols, err := GetColumns(db, "public", t.Name)
if err != nil {
return nil, err
}
di.Columns = append(di.Columns, cols)
di.colmap[t.Key] = make(map[string]*DBColumn, len(cols))
}
for n, c := range di.Columns[i] {
di.colmap[t.Key][c.Key] = &di.Columns[i][n]
}
di.colMap = newColMap(di.Tables, di.Columns)
di.Functions, err = GetFunctions(db)
if err != nil {
return nil, err
}
return di, nil
}
func newColMap(tables []DBTable, columns [][]DBColumn) map[string]map[string]*DBColumn {
cm := make(map[string]map[string]*DBColumn, len(tables))
for i, t := range tables {
cols := columns[i]
cm[t.Key] = make(map[string]*DBColumn, len(cols))
for n, c := range cols {
cm[t.Key][c.Key] = &columns[i][n]
}
}
return cm
}
func (di *DBInfo) AddTable(t DBTable, cols []DBColumn) {
t.ID = di.Tables[len(di.Tables)-1].ID
di.Tables = append(di.Tables, t)
di.colmap[t.Key] = make(map[string]*DBColumn, len(cols))
di.colMap[t.Key] = make(map[string]*DBColumn, len(cols))
for i := range cols {
cols[i].ID = int16(i)
c := &cols[i]
di.colmap[t.Key][c.Key] = c
di.colMap[t.Key][c.Key] = c
}
di.Columns = append(di.Columns, cols)
}
func (di *DBInfo) GetColumn(table, column string) (*DBColumn, bool) {
v, ok := di.colmap[strings.ToLower(table)][strings.ToLower(column)]
v, ok := di.colMap[strings.ToLower(table)][strings.ToLower(column)]
return v, ok
}
@ -237,6 +253,64 @@ ORDER BY id;`
return cols, nil
}
type DBFunction struct {
Name string
Params []DBFuncParam
}
type DBFuncParam struct {
ID int
Name string
Type string
}
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
sqlStmt := `
SELECT
routines.routine_name,
parameters.specific_name,
parameters.data_type,
parameters.parameter_name,
parameters.ordinal_position
FROM
information_schema.routines
RIGHT JOIN
information_schema.parameters
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
WHERE
routines.specific_schema = 'public'
ORDER BY
routines.routine_name, parameters.ordinal_position;`
rows, err := db.Query(sqlStmt)
if err != nil {
return nil, fmt.Errorf("Error fetching functions: %s", err)
}
defer rows.Close()
var funcs []DBFunction
fm := make(map[string]int)
for rows.Next() {
var fn, fid string
fp := DBFuncParam{}
err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
if err != nil {
return nil, err
}
if i, ok := fm[fid]; ok {
funcs[i].Params = append(funcs[i].Params, fp)
} else {
funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
fm[fid] = len(funcs) - 1
}
}
return funcs, nil
}
// func GetValType(type string) qcode.ValType {
// switch {
// case "bigint", "integer", "smallint", "numeric", "bigserial":

View File

@ -1,11 +1,10 @@
package psql
import (
"log"
"strings"
)
func getTestSchema() *DBSchema {
func GetTestDBInfo() *DBInfo {
tables := []DBTable{
DBTable{Name: "customers", Type: "table"},
DBTable{Name: "users", Type: "table"},
@ -74,36 +73,19 @@ func getTestSchema() *DBSchema {
}
}
schema := &DBSchema{
ver: 110000,
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
return &DBInfo{
Version: 110000,
Tables: tables,
Columns: columns,
Functions: []DBFunction{},
colMap: newColMap(tables, columns),
}
}
func GetTestSchema() (*DBSchema, error) {
aliases := map[string][]string{
"users": []string{"mes"},
}
for i, t := range tables {
err := schema.addTable(t, columns[i], aliases)
if err != nil {
log.Fatal(err)
}
}
for i, t := range tables {
err := schema.firstDegreeRels(t, columns[i])
if err != nil {
log.Fatal(err)
}
}
for i, t := range tables {
err := schema.secondDegreeRels(t, columns[i])
if err != nil {
log.Fatal(err)
}
}
return schema
return NewDBSchema(GetTestDBInfo(), aliases)
}

View File

@ -1,4 +1,4 @@
package psql
package psql_test
import (
"encoding/json"

View File

@ -730,6 +730,32 @@ query {
}
```
### Custom Functions
Any function defined in the database like the below `add_five` that adds 5 to any number given to it can be used
within your query. The one limitation is that it should be a function that only accepts a single argument. The function is used within you're GraphQL in similar way to how aggregrations are used above. Example below
```grahql
query {
thread(id: 5) {
id
total_votes
add_five_total_votes
}
}
```
Postgres user-defined function `add_five`
```
CREATE OR REPLACE FUNCTION add_five(a integer) RETURNS integer AS $$
BEGIN
RETURN a + 5;
END;
$$ LANGUAGE plpgsql;
```
In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates.
When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments.

1
go.mod
View File

@ -1,6 +1,7 @@
module github.com/dosco/super-graph
require (
github.com/DATA-DOG/go-sqlmock v1.4.1
github.com/GeertJohan/go.rice v1.0.0
github.com/NYTimes/gziphandler v1.1.1
github.com/adjust/gorails v0.0.0-20171013043634-2786ed0c03d3

2
go.sum
View File

@ -1,6 +1,8 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1MRDJM=
github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/GeertJohan/go.incremental v1.0.0 h1:7AH+pY1XUgQE4Y1HcXYaMqAI0m9yrFqo/jt0CW30vsg=
github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0=
github.com/GeertJohan/go.rice v1.0.0 h1:KkI6O9uMaQU3VEKaj01ulavtF7o1fWT7+pk/4voiMLQ=

View File

@ -23,7 +23,7 @@ func newAction(a *Action) (http.Handler, error) {
httpFn := func(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil {
renderErr(w, err, nil)
renderErr(w, err)
}
}

View File

@ -7,7 +7,6 @@ import (
"io/ioutil"
"net/http"
"github.com/dosco/super-graph/core"
"github.com/dosco/super-graph/internal/serv/internal/auth"
"github.com/rs/cors"
"go.uber.org/zap"
@ -29,7 +28,7 @@ type gqlReq struct {
}
type errorResp struct {
Error error `json:"error"`
Error string `json:"error"`
}
func apiV1Handler() http.Handler {
@ -55,13 +54,13 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck
if conf.AuthFailBlock && !auth.IsAuth(ct) {
renderErr(w, errUnauthorized, nil)
renderErr(w, errUnauthorized)
return
}
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
if err != nil {
renderErr(w, err, nil)
renderErr(w, err)
return
}
defer r.Body.Close()
@ -70,7 +69,7 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
err = json.Unmarshal(b, &req)
if err != nil {
renderErr(w, err, nil)
renderErr(w, err)
return
}
@ -86,12 +85,11 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
}
if err != nil {
renderErr(w, err, res)
return
renderErr(w, err)
} else {
json.NewEncoder(w).Encode(res)
}
json.NewEncoder(w).Encode(res)
if doLog && logLevel >= LogLevelInfo {
zlog.Info("success",
zap.String("op", res.Operation()),
@ -102,22 +100,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
}
//nolint: errcheck
func renderErr(w http.ResponseWriter, err error, res *core.Result) {
func renderErr(w http.ResponseWriter, err error) {
if err == errUnauthorized {
w.WriteHeader(http.StatusUnauthorized)
}
json.NewEncoder(w).Encode(&errorResp{err})
if logLevel >= LogLevelError {
if res != nil {
zlog.Error(err.Error(),
zap.String("op", res.Operation()),
zap.String("name", res.QueryName()),
zap.String("role", res.Role()),
)
} else {
zlog.Error(err.Error())
}
}
json.NewEncoder(w).Encode(errorResp{err.Error()})
}