diff --git a/.gitignore b/.gitignore index a286ab1..a2f2017 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,5 @@ suppressions release .gofuzz *-fuzz.zip +*.test diff --git a/README.md b/README.md index 0e28ecb..a7ac43d 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ func main() { log.Fatalf(err) } - sg, err = core.NewSuperGraph(conf, db) + sg, err := core.NewSuperGraph(conf, db) if err != nil { log.Fatalf(err) } diff --git a/core/api.go b/core/api.go index e33d1d7..6e2a727 100644 --- a/core/api.go +++ b/core/api.go @@ -24,7 +24,7 @@ log.Fatalf(err) } - sg, err = core.NewSuperGraph(conf, db) + sg, err := core.NewSuperGraph(conf, db) if err != nil { log.Fatalf(err) } @@ -82,6 +82,7 @@ type SuperGraph struct { conf *Config db *sql.DB log *_log.Logger + dbinfo *psql.DBInfo schema *psql.DBSchema allowList *allow.List encKey [32]byte @@ -99,10 +100,20 @@ type SuperGraph struct { // NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its // schemas and relationships func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) { + return newSuperGraph(conf, db, nil) +} + +// newSuperGraph helps with writing tests and benchmarks +func newSuperGraph(conf *Config, db *sql.DB, dbinfo *psql.DBInfo) (*SuperGraph, error) { + if conf == nil { + conf = &Config{} + } + sg := &SuperGraph{ - conf: conf, - db: db, - log: _log.New(os.Stdout, "", 0), + conf: conf, + db: db, + dbinfo: dbinfo, + log: _log.New(os.Stdout, "", 0), } if err := sg.initConfig(); err != nil { @@ -199,6 +210,8 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess return &ct.res, nil } +// GraphQLSchema function return the GraphQL schema for the underlying database connected +// to this instance of Super Graph func (sg *SuperGraph) GraphQLSchema() (string, error) { return sg.ge.Schema.String(), nil } diff --git a/core/api_test.go b/core/api_test.go new file mode 100644 index 0000000..2cf793a --- /dev/null +++ b/core/api_test.go @@ -0,0 +1,62 @@ +package core + +import ( + "context" + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/dosco/super-graph/core/internal/psql" +) + +func BenchmarkGraphQL(b *testing.B) { + ct := context.WithValue(context.Background(), UserIDKey, "1") + + db, _, err := sqlmock.New() + if err != nil { + b.Fatal(err) + } + defer db.Close() + + // mock.ExpectQuery(`^SELECT jsonb_build_object`).WithArgs() + + sg, err := newSuperGraph(nil, db, psql.GetTestDBInfo()) + if err != nil { + b.Fatal(err) + } + + query := ` + query { + products { + id + name + user { + full_name + phone + email + } + customers { + id + email + } + } + users { + id + name + } + }` + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err = sg.GraphQL(ct, query, nil) + } + }) + + fmt.Println(err) + + //fmt.Println(mock.ExpectationsWereMet()) + +} diff --git a/core/core.go b/core/core.go index 9d544b2..7c945c8 100644 --- a/core/core.go +++ b/core/core.go @@ -50,20 +50,26 @@ type scontext struct { } func (sg *SuperGraph) initCompilers() error { - di, err := psql.GetDBInfo(sg.db) - if err != nil { + var err error + + // If sg.di is not null then it's probably set + // for tests + if sg.dbinfo == nil { + sg.dbinfo, err = psql.GetDBInfo(sg.db) + if err != nil { + return err + } + } + + if err = addTables(sg.conf, sg.dbinfo); err != nil { return err } - if err = addTables(sg.conf, di); err != nil { + if err = addForeignKeys(sg.conf, sg.dbinfo); err != nil { return err } - if err = addForeignKeys(sg.conf, di); err != nil { - return err - } - - sg.schema, err = psql.NewDBSchema(di, getDBTableAliases(sg.conf)) + sg.schema, err = psql.NewDBSchema(sg.dbinfo, getDBTableAliases(sg.conf)) if err != nil { return err } diff --git a/core/internal/psql/fuzz.go b/core/internal/psql/fuzz.go index a7e41d9..12d63c5 100644 --- a/core/internal/psql/fuzz.go +++ b/core/internal/psql/fuzz.go @@ -10,7 +10,7 @@ import ( var ( qcompileTest, _ = qcode.NewCompiler(qcode.Config{}) - schema = getTestSchema() + schema = GetTestSchema() vars = NewVariables(map[string]string{ "admin_account_id": "5", diff --git a/core/internal/psql/insert_test.go b/core/internal/psql/insert_test.go index 95c3579..76f03e4 100644 --- a/core/internal/psql/insert_test.go +++ b/core/internal/psql/insert_test.go @@ -1,4 +1,4 @@ -package psql +package psql_test import ( "encoding/json" diff --git a/core/internal/psql/mutate_test.go b/core/internal/psql/mutate_test.go index c0a7fc2..ffd99d5 100644 --- a/core/internal/psql/mutate_test.go +++ b/core/internal/psql/mutate_test.go @@ -1,4 +1,4 @@ -package psql +package psql_test import ( "encoding/json" diff --git a/core/internal/psql/psql_test.go b/core/internal/psql/psql_test.go index a536644..bad3f1f 100644 --- a/core/internal/psql/psql_test.go +++ b/core/internal/psql/psql_test.go @@ -1,4 +1,4 @@ -package psql +package psql_test import ( "fmt" @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/dosco/super-graph/core/internal/psql" "github.com/dosco/super-graph/core/internal/qcode" ) @@ -19,7 +20,7 @@ const ( var ( qcompile *qcode.Compiler - pcompile *Compiler + pcompile *psql.Compiler expected map[string][]string ) @@ -133,13 +134,16 @@ func TestMain(m *testing.M) { log.Fatal(err) } - schema := getTestSchema() + schema, err := psql.GetTestSchema() + if err != nil { + log.Fatal(err) + } - vars := NewVariables(map[string]string{ + vars := psql.NewVariables(map[string]string{ "admin_account_id": "5", }) - pcompile = NewCompiler(Config{ + pcompile = psql.NewCompiler(psql.Config{ Schema: schema, Vars: vars, }) @@ -173,7 +177,7 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func compileGQLToPSQL(t *testing.T, gql string, vars Variables, role string) { +func compileGQLToPSQL(t *testing.T, gql string, vars psql.Variables, role string) { generateTestFile := false if generateTestFile { diff --git a/core/internal/psql/query_test.go b/core/internal/psql/query_test.go index be7a653..5e88cf9 100644 --- a/core/internal/psql/query_test.go +++ b/core/internal/psql/query_test.go @@ -1,4 +1,4 @@ -package psql +package psql_test import ( "bytes" diff --git a/core/internal/psql/schema.go b/core/internal/psql/schema.go index d5a88e6..5ff57cb 100644 --- a/core/internal/psql/schema.go +++ b/core/internal/psql/schema.go @@ -57,9 +57,10 @@ type DBRel struct { func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) { schema := &DBSchema{ - t: make(map[string]*DBTableInfo), - rm: make(map[string]map[string]*DBRel), - fm: make(map[string]*DBFunction, len(info.Functions)), + ver: info.Version, + t: make(map[string]*DBTableInfo), + rm: make(map[string]map[string]*DBRel), + fm: make(map[string]*DBFunction, len(info.Functions)), } for i, t := range info.Tables { diff --git a/core/internal/psql/tables.go b/core/internal/psql/tables.go index f7aa573..0464138 100644 --- a/core/internal/psql/tables.go +++ b/core/internal/psql/tables.go @@ -14,7 +14,7 @@ type DBInfo struct { Tables []DBTable Columns [][]DBColumn Functions []DBFunction - colmap map[string]map[string]*DBColumn + colMap map[string]map[string]*DBColumn } func GetDBInfo(db *sql.DB) (*DBInfo, error) { @@ -36,22 +36,17 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) { return nil, err } - di.colmap = make(map[string]map[string]*DBColumn, len(di.Tables)) - - for i, t := range di.Tables { + for _, t := range di.Tables { cols, err := GetColumns(db, "public", t.Name) if err != nil { return nil, err } di.Columns = append(di.Columns, cols) - di.colmap[t.Key] = make(map[string]*DBColumn, len(cols)) - - for n, c := range di.Columns[i] { - di.colmap[t.Key][c.Key] = &di.Columns[i][n] - } } + di.colMap = newColMap(di.Tables, di.Columns) + di.Functions, err = GetFunctions(db) if err != nil { return nil, err @@ -60,22 +55,37 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) { return di, nil } +func newColMap(tables []DBTable, columns [][]DBColumn) map[string]map[string]*DBColumn { + cm := make(map[string]map[string]*DBColumn, len(tables)) + + for i, t := range tables { + cols := columns[i] + cm[t.Key] = make(map[string]*DBColumn, len(cols)) + + for n, c := range cols { + cm[t.Key][c.Key] = &columns[i][n] + } + } + + return cm +} + func (di *DBInfo) AddTable(t DBTable, cols []DBColumn) { t.ID = di.Tables[len(di.Tables)-1].ID di.Tables = append(di.Tables, t) - di.colmap[t.Key] = make(map[string]*DBColumn, len(cols)) + di.colMap[t.Key] = make(map[string]*DBColumn, len(cols)) for i := range cols { cols[i].ID = int16(i) c := &cols[i] - di.colmap[t.Key][c.Key] = c + di.colMap[t.Key][c.Key] = c } di.Columns = append(di.Columns, cols) } func (di *DBInfo) GetColumn(table, column string) (*DBColumn, bool) { - v, ok := di.colmap[strings.ToLower(table)][strings.ToLower(column)] + v, ok := di.colMap[strings.ToLower(table)][strings.ToLower(column)] return v, ok } diff --git a/core/internal/psql/test_schema.go b/core/internal/psql/test_dbinfo.go similarity index 91% rename from core/internal/psql/test_schema.go rename to core/internal/psql/test_dbinfo.go index 32da488..6379d32 100644 --- a/core/internal/psql/test_schema.go +++ b/core/internal/psql/test_dbinfo.go @@ -1,11 +1,10 @@ package psql import ( - "log" "strings" ) -func getTestSchema() *DBSchema { +func GetTestDBInfo() *DBInfo { tables := []DBTable{ DBTable{Name: "customers", Type: "table"}, DBTable{Name: "users", Type: "table"}, @@ -74,36 +73,19 @@ func getTestSchema() *DBSchema { } } - schema := &DBSchema{ - ver: 110000, - t: make(map[string]*DBTableInfo), - rm: make(map[string]map[string]*DBRel), + return &DBInfo{ + Version: 110000, + Tables: tables, + Columns: columns, + Functions: []DBFunction{}, + colMap: newColMap(tables, columns), } +} +func GetTestSchema() (*DBSchema, error) { aliases := map[string][]string{ "users": []string{"mes"}, } - for i, t := range tables { - err := schema.addTable(t, columns[i], aliases) - if err != nil { - log.Fatal(err) - } - } - - for i, t := range tables { - err := schema.firstDegreeRels(t, columns[i]) - if err != nil { - log.Fatal(err) - } - } - - for i, t := range tables { - err := schema.secondDegreeRels(t, columns[i]) - if err != nil { - log.Fatal(err) - } - } - - return schema + return NewDBSchema(GetTestDBInfo(), aliases) } diff --git a/core/internal/psql/update_test.go b/core/internal/psql/update_test.go index 87e1fe0..f410a18 100644 --- a/core/internal/psql/update_test.go +++ b/core/internal/psql/update_test.go @@ -1,4 +1,4 @@ -package psql +package psql_test import ( "encoding/json" diff --git a/go.mod b/go.mod index 7e70b30..de4cad1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/dosco/super-graph require ( + github.com/DATA-DOG/go-sqlmock v1.4.1 github.com/GeertJohan/go.rice v1.0.0 github.com/NYTimes/gziphandler v1.1.1 github.com/adjust/gorails v0.0.0-20171013043634-2786ed0c03d3 diff --git a/go.sum b/go.sum index 7f74369..323b23c 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1MRDJM= +github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/GeertJohan/go.incremental v1.0.0 h1:7AH+pY1XUgQE4Y1HcXYaMqAI0m9yrFqo/jt0CW30vsg= github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0= github.com/GeertJohan/go.rice v1.0.0 h1:KkI6O9uMaQU3VEKaj01ulavtF7o1fWT7+pk/4voiMLQ=