Compare commits

...

2 Commits

22 changed files with 313 additions and 124 deletions

1
.gitignore vendored
View File

@ -35,4 +35,5 @@ suppressions
release release
.gofuzz .gofuzz
*-fuzz.zip *-fuzz.zip
*.test

View File

@ -43,7 +43,7 @@ func main() {
log.Fatalf(err) log.Fatalf(err)
} }
sg, err = core.NewSuperGraph(conf, db) sg, err := core.NewSuperGraph(conf, db)
if err != nil { if err != nil {
log.Fatalf(err) log.Fatalf(err)
} }

View File

@ -24,7 +24,7 @@
log.Fatalf(err) log.Fatalf(err)
} }
sg, err = core.NewSuperGraph(conf, db) sg, err := core.NewSuperGraph(conf, db)
if err != nil { if err != nil {
log.Fatalf(err) log.Fatalf(err)
} }
@ -82,6 +82,7 @@ type SuperGraph struct {
conf *Config conf *Config
db *sql.DB db *sql.DB
log *_log.Logger log *_log.Logger
dbinfo *psql.DBInfo
schema *psql.DBSchema schema *psql.DBSchema
allowList *allow.List allowList *allow.List
encKey [32]byte encKey [32]byte
@ -93,14 +94,25 @@ type SuperGraph struct {
anonExists bool anonExists bool
qc *qcode.Compiler qc *qcode.Compiler
pc *psql.Compiler pc *psql.Compiler
ge *graphql.Engine
} }
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its // NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
// schemas and relationships // schemas and relationships
func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) { func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
return newSuperGraph(conf, db, nil)
}
// newSuperGraph helps with writing tests and benchmarks
func newSuperGraph(conf *Config, db *sql.DB, dbinfo *psql.DBInfo) (*SuperGraph, error) {
if conf == nil {
conf = &Config{}
}
sg := &SuperGraph{ sg := &SuperGraph{
conf: conf, conf: conf,
db: db, db: db,
dbinfo: dbinfo,
log: _log.New(os.Stdout, "", 0), log: _log.New(os.Stdout, "", 0),
} }
@ -124,6 +136,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
return nil, err return nil, err
} }
if err := sg.initGraphQLEgine(); err != nil {
return nil, err
}
if len(conf.SecretKey) != 0 { if len(conf.SecretKey) != 0 {
sk := sha256.Sum256([]byte(conf.SecretKey)) sk := sha256.Sum256([]byte(conf.SecretKey))
conf.SecretKey = "" conf.SecretKey = ""
@ -163,14 +179,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
// use the chirino/graphql library for introspection queries // use the chirino/graphql library for introspection queries
// disabled when allow list is enforced // disabled when allow list is enforced
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" { if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
engine, err := sg.createGraphQLEgine() r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query})
if err != nil {
res.Error = err.Error()
return &res, err
}
r := engine.ExecuteOne(&graphql.EngineRequest{Query: query})
res.Data = r.Data res.Data = r.Data
if r.Error() != nil { if r.Error() != nil {
res.Error = r.Error().Error() res.Error = r.Error().Error()
} }
@ -199,10 +210,8 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
return &ct.res, nil return &ct.res, nil
} }
// GraphQLSchema function return the GraphQL schema for the underlying database connected
// to this instance of Super Graph
func (sg *SuperGraph) GraphQLSchema() (string, error) { func (sg *SuperGraph) GraphQLSchema() (string, error) {
engine, err := sg.createGraphQLEgine() return sg.ge.Schema.String(), nil
if err != nil {
return "", err
}
return engine.Schema.String(), nil
} }

62
core/api_test.go Normal file
View File

@ -0,0 +1,62 @@
package core
import (
"context"
"fmt"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/dosco/super-graph/core/internal/psql"
)
func BenchmarkGraphQL(b *testing.B) {
ct := context.WithValue(context.Background(), UserIDKey, "1")
db, _, err := sqlmock.New()
if err != nil {
b.Fatal(err)
}
defer db.Close()
// mock.ExpectQuery(`^SELECT jsonb_build_object`).WithArgs()
sg, err := newSuperGraph(nil, db, psql.GetTestDBInfo())
if err != nil {
b.Fatal(err)
}
query := `
query {
products {
id
name
user {
full_name
phone
email
}
customers {
id
email
}
}
users {
id
name
}
}`
b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err = sg.GraphQL(ct, query, nil)
}
})
fmt.Println(err)
//fmt.Println(mock.ExpectationsWereMet())
}

View File

@ -50,20 +50,26 @@ type scontext struct {
} }
func (sg *SuperGraph) initCompilers() error { func (sg *SuperGraph) initCompilers() error {
di, err := psql.GetDBInfo(sg.db) var err error
// If sg.di is not null then it's probably set
// for tests
if sg.dbinfo == nil {
sg.dbinfo, err = psql.GetDBInfo(sg.db)
if err != nil { if err != nil {
return err return err
} }
}
if err = addTables(sg.conf, di); err != nil { if err = addTables(sg.conf, sg.dbinfo); err != nil {
return err return err
} }
if err = addForeignKeys(sg.conf, di); err != nil { if err = addForeignKeys(sg.conf, sg.dbinfo); err != nil {
return err return err
} }
sg.schema, err = psql.NewDBSchema(di, getDBTableAliases(sg.conf)) sg.schema, err = psql.NewDBSchema(sg.dbinfo, getDBTableAliases(sg.conf))
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,8 +1,6 @@
package core package core
import ( import (
"errors"
"regexp"
"strings" "strings"
"github.com/chirino/graphql" "github.com/chirino/graphql"
@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{
"boolean": "Boolean", "boolean": "Boolean",
} }
func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) { func (sg *SuperGraph) initGraphQLEgine() error {
engine := graphql.New() engine := graphql.New()
engineSchema := engine.Schema engineSchema := engine.Schema
dbSchema := sg.schema dbSchema := sg.schema
@ -63,15 +61,16 @@ enum OrderDirection {
engineSchema.EntryPoints[schema.Query] = query engineSchema.EntryPoints[schema.Query] = query
engineSchema.EntryPoints[schema.Mutation] = mutation engineSchema.EntryPoints[schema.Mutation] = mutation
validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`) //validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
scalarExpressionTypesNeeded := map[string]bool{} scalarExpressionTypesNeeded := map[string]bool{}
tableNames := dbSchema.GetTableNames() tableNames := dbSchema.GetTableNames()
for _, table := range tableNames { funcs := dbSchema.GetFunctions()
for _, table := range tableNames {
ti, err := dbSchema.GetTable(table) ti, err := dbSchema.GetTable(table)
if err != nil { if err != nil {
return nil, err return err
} }
if !ti.IsSingular { if !ti.IsSingular {
@ -79,13 +78,13 @@ enum OrderDirection {
} }
singularName := ti.Singular singularName := ti.Singular
if !validGraphQLIdentifierRegex.MatchString(singularName) { // if !validGraphQLIdentifierRegex.MatchString(singularName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName) // return errors.New("table name is not a valid GraphQL identifier: " + singularName)
} // }
pluralName := ti.Plural pluralName := ti.Plural
if !validGraphQLIdentifierRegex.MatchString(pluralName) { // if !validGraphQLIdentifierRegex.MatchString(pluralName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName) // return errors.New("table name is not a valid GraphQL identifier: " + pluralName)
} // }
outputType := &schema.Object{ outputType := &schema.Object{
Name: singularName + "Output", Name: singularName + "Output",
@ -127,9 +126,9 @@ enum OrderDirection {
for _, col := range ti.Columns { for _, col := range ti.Columns {
colName := col.Name colName := col.Name
if !validGraphQLIdentifierRegex.MatchString(colName) { // if !validGraphQLIdentifierRegex.MatchString(colName) {
return nil, errors.New("column name is not a valid GraphQL identifier: " + colName) // return errors.New("column name is not a valid GraphQL identifier: " + colName)
} // }
colType := gqltype(col) colType := gqltype(col)
nullableColType := "" nullableColType := ""
@ -144,6 +143,16 @@ enum OrderDirection {
Type: colType, Type: colType,
}) })
for _, f := range funcs {
if col.Type != f.Params[0].Type {
continue
}
outputType.Fields = append(outputType.Fields, &schema.Field{
Name: f.Name + "_" + colName,
Type: colType,
})
}
// If it's a numeric type... // If it's a numeric type...
if nullableColType == "Float" || nullableColType == "Int" { if nullableColType == "Float" || nullableColType == "Int" {
outputType.Fields = append(outputType.Fields, &schema.Field{ outputType.Fields = append(outputType.Fields, &schema.Field{
@ -464,7 +473,7 @@ enum OrderDirection {
err := engineSchema.ResolveTypes() err := engineSchema.ResolveTypes()
if err != nil { if err != nil {
return nil, err return err
} }
engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution { engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution {
@ -479,5 +488,7 @@ enum OrderDirection {
return nil return nil
}) })
return engine, nil
sg.ge = engine
return nil
} }

View File

@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
} }
func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error { func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error {
pl := funcPrefixLen(col.Name) pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 { // if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) // //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`) // io.WriteString(c.w, `'`)

View File

@ -10,7 +10,7 @@ import (
var ( var (
qcompileTest, _ = qcode.NewCompiler(qcode.Config{}) qcompileTest, _ = qcode.NewCompiler(qcode.Config{})
schema = getTestSchema() schema = GetTestSchema()
vars = NewVariables(map[string]string{ vars = NewVariables(map[string]string{
"admin_account_id": "5", "admin_account_id": "5",

View File

@ -1,4 +1,4 @@
package psql package psql_test
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package psql package psql_test
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package psql package psql_test
import ( import (
"fmt" "fmt"
@ -8,6 +8,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode" "github.com/dosco/super-graph/core/internal/qcode"
) )
@ -19,7 +20,7 @@ const (
var ( var (
qcompile *qcode.Compiler qcompile *qcode.Compiler
pcompile *Compiler pcompile *psql.Compiler
expected map[string][]string expected map[string][]string
) )
@ -133,13 +134,16 @@ func TestMain(m *testing.M) {
log.Fatal(err) log.Fatal(err)
} }
schema := getTestSchema() schema, err := psql.GetTestSchema()
if err != nil {
log.Fatal(err)
}
vars := NewVariables(map[string]string{ vars := psql.NewVariables(map[string]string{
"admin_account_id": "5", "admin_account_id": "5",
}) })
pcompile = NewCompiler(Config{ pcompile = psql.NewCompiler(psql.Config{
Schema: schema, Schema: schema,
Vars: vars, Vars: vars,
}) })
@ -173,7 +177,7 @@ func TestMain(m *testing.M) {
os.Exit(m.Run()) os.Exit(m.Run())
} }
func compileGQLToPSQL(t *testing.T, gql string, vars Variables, role string) { func compileGQLToPSQL(t *testing.T, gql string, vars psql.Variables, role string) {
generateTestFile := false generateTestFile := false
if generateTestFile { if generateTestFile {

View File

@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip
var cn string var cn string
for _, col := range sel.Cols { for _, col := range sel.Cols {
if n := funcPrefixLen(col.Name); n != 0 { if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 {
if !sel.Functions { if !sel.Functions {
continue continue
} }
@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type) io.WriteString(c.w, col.Type)
} }
func funcPrefixLen(fn string) int { func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch { switch {
case strings.HasPrefix(fn, "avg_"): case strings.HasPrefix(fn, "avg_"):
return 4 return 4
@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int {
case strings.HasPrefix(fn, "var_samp_"): case strings.HasPrefix(fn, "var_samp_"):
return 9 return 9
} }
fnLen := len(fn)
for k := range fm {
kLen := len(k)
if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' {
return kLen + 1
}
}
return 0 return 0
} }

View File

@ -1,4 +1,4 @@
package psql package psql_test
import ( import (
"bytes" "bytes"

View File

@ -11,6 +11,7 @@ type DBSchema struct {
ver int ver int
t map[string]*DBTableInfo t map[string]*DBTableInfo
rm map[string]map[string]*DBRel rm map[string]map[string]*DBRel
fm map[string]*DBFunction
} }
type DBTableInfo struct { type DBTableInfo struct {
@ -56,8 +57,10 @@ type DBRel struct {
func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) { func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
schema := &DBSchema{ schema := &DBSchema{
ver: info.Version,
t: make(map[string]*DBTableInfo), t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel), rm: make(map[string]map[string]*DBRel),
fm: make(map[string]*DBFunction, len(info.Functions)),
} }
for i, t := range info.Tables { for i, t := range info.Tables {
@ -81,6 +84,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
} }
} }
for k, f := range info.Functions {
if len(f.Params) == 1 {
schema.fm[strings.ToLower(f.Name)] = &info.Functions[k]
}
}
return schema, nil return schema, nil
} }
@ -439,3 +448,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
} }
return rel, nil return rel, nil
} }
func (s *DBSchema) GetFunctions() []*DBFunction {
var funcs []*DBFunction
for _, f := range s.fm {
funcs = append(funcs, f)
}
return funcs
}

View File

@ -13,7 +13,8 @@ type DBInfo struct {
Version int Version int
Tables []DBTable Tables []DBTable
Columns [][]DBColumn Columns [][]DBColumn
colmap map[string]map[string]*DBColumn Functions []DBFunction
colMap map[string]map[string]*DBColumn
} }
func GetDBInfo(db *sql.DB) (*DBInfo, error) { func GetDBInfo(db *sql.DB) (*DBInfo, error) {
@ -35,41 +36,56 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
return nil, err return nil, err
} }
di.colmap = make(map[string]map[string]*DBColumn, len(di.Tables)) for _, t := range di.Tables {
for i, t := range di.Tables {
cols, err := GetColumns(db, "public", t.Name) cols, err := GetColumns(db, "public", t.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
di.Columns = append(di.Columns, cols) di.Columns = append(di.Columns, cols)
di.colmap[t.Key] = make(map[string]*DBColumn, len(cols))
for n, c := range di.Columns[i] {
di.colmap[t.Key][c.Key] = &di.Columns[i][n]
} }
di.colMap = newColMap(di.Tables, di.Columns)
di.Functions, err = GetFunctions(db)
if err != nil {
return nil, err
} }
return di, nil return di, nil
} }
func newColMap(tables []DBTable, columns [][]DBColumn) map[string]map[string]*DBColumn {
cm := make(map[string]map[string]*DBColumn, len(tables))
for i, t := range tables {
cols := columns[i]
cm[t.Key] = make(map[string]*DBColumn, len(cols))
for n, c := range cols {
cm[t.Key][c.Key] = &columns[i][n]
}
}
return cm
}
func (di *DBInfo) AddTable(t DBTable, cols []DBColumn) { func (di *DBInfo) AddTable(t DBTable, cols []DBColumn) {
t.ID = di.Tables[len(di.Tables)-1].ID t.ID = di.Tables[len(di.Tables)-1].ID
di.Tables = append(di.Tables, t) di.Tables = append(di.Tables, t)
di.colmap[t.Key] = make(map[string]*DBColumn, len(cols)) di.colMap[t.Key] = make(map[string]*DBColumn, len(cols))
for i := range cols { for i := range cols {
cols[i].ID = int16(i) cols[i].ID = int16(i)
c := &cols[i] c := &cols[i]
di.colmap[t.Key][c.Key] = c di.colMap[t.Key][c.Key] = c
} }
di.Columns = append(di.Columns, cols) di.Columns = append(di.Columns, cols)
} }
func (di *DBInfo) GetColumn(table, column string) (*DBColumn, bool) { func (di *DBInfo) GetColumn(table, column string) (*DBColumn, bool) {
v, ok := di.colmap[strings.ToLower(table)][strings.ToLower(column)] v, ok := di.colMap[strings.ToLower(table)][strings.ToLower(column)]
return v, ok return v, ok
} }
@ -237,6 +253,64 @@ ORDER BY id;`
return cols, nil return cols, nil
} }
type DBFunction struct {
Name string
Params []DBFuncParam
}
type DBFuncParam struct {
ID int
Name string
Type string
}
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
sqlStmt := `
SELECT
routines.routine_name,
parameters.specific_name,
parameters.data_type,
parameters.parameter_name,
parameters.ordinal_position
FROM
information_schema.routines
RIGHT JOIN
information_schema.parameters
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
WHERE
routines.specific_schema = 'public'
ORDER BY
routines.routine_name, parameters.ordinal_position;`
rows, err := db.Query(sqlStmt)
if err != nil {
return nil, fmt.Errorf("Error fetching functions: %s", err)
}
defer rows.Close()
var funcs []DBFunction
fm := make(map[string]int)
for rows.Next() {
var fn, fid string
fp := DBFuncParam{}
err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
if err != nil {
return nil, err
}
if i, ok := fm[fid]; ok {
funcs[i].Params = append(funcs[i].Params, fp)
} else {
funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
fm[fid] = len(funcs) - 1
}
}
return funcs, nil
}
// func GetValType(type string) qcode.ValType { // func GetValType(type string) qcode.ValType {
// switch { // switch {
// case "bigint", "integer", "smallint", "numeric", "bigserial": // case "bigint", "integer", "smallint", "numeric", "bigserial":

View File

@ -1,11 +1,10 @@
package psql package psql
import ( import (
"log"
"strings" "strings"
) )
func getTestSchema() *DBSchema { func GetTestDBInfo() *DBInfo {
tables := []DBTable{ tables := []DBTable{
DBTable{Name: "customers", Type: "table"}, DBTable{Name: "customers", Type: "table"},
DBTable{Name: "users", Type: "table"}, DBTable{Name: "users", Type: "table"},
@ -74,36 +73,19 @@ func getTestSchema() *DBSchema {
} }
} }
schema := &DBSchema{ return &DBInfo{
ver: 110000, Version: 110000,
t: make(map[string]*DBTableInfo), Tables: tables,
rm: make(map[string]map[string]*DBRel), Columns: columns,
Functions: []DBFunction{},
colMap: newColMap(tables, columns),
}
} }
func GetTestSchema() (*DBSchema, error) {
aliases := map[string][]string{ aliases := map[string][]string{
"users": []string{"mes"}, "users": []string{"mes"},
} }
for i, t := range tables { return NewDBSchema(GetTestDBInfo(), aliases)
err := schema.addTable(t, columns[i], aliases)
if err != nil {
log.Fatal(err)
}
}
for i, t := range tables {
err := schema.firstDegreeRels(t, columns[i])
if err != nil {
log.Fatal(err)
}
}
for i, t := range tables {
err := schema.secondDegreeRels(t, columns[i])
if err != nil {
log.Fatal(err)
}
}
return schema
} }

View File

@ -1,4 +1,4 @@
package psql package psql_test
import ( import (
"encoding/json" "encoding/json"

View File

@ -730,6 +730,32 @@ query {
} }
``` ```
### Custom Functions
Any function defined in the database like the below `add_five` that adds 5 to any number given to it can be used
within your query. The one limitation is that it should be a function that only accepts a single argument. The function is used within you're GraphQL in similar way to how aggregrations are used above. Example below
```grahql
query {
thread(id: 5) {
id
total_votes
add_five_total_votes
}
}
```
Postgres user-defined function `add_five`
```
CREATE OR REPLACE FUNCTION add_five(a integer) RETURNS integer AS $$
BEGIN
RETURN a + 5;
END;
$$ LANGUAGE plpgsql;
```
In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates. In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates.
When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments. When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments.

1
go.mod
View File

@ -1,6 +1,7 @@
module github.com/dosco/super-graph module github.com/dosco/super-graph
require ( require (
github.com/DATA-DOG/go-sqlmock v1.4.1
github.com/GeertJohan/go.rice v1.0.0 github.com/GeertJohan/go.rice v1.0.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/adjust/gorails v0.0.0-20171013043634-2786ed0c03d3 github.com/adjust/gorails v0.0.0-20171013043634-2786ed0c03d3

2
go.sum
View File

@ -1,6 +1,8 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1MRDJM=
github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/GeertJohan/go.incremental v1.0.0 h1:7AH+pY1XUgQE4Y1HcXYaMqAI0m9yrFqo/jt0CW30vsg= github.com/GeertJohan/go.incremental v1.0.0 h1:7AH+pY1XUgQE4Y1HcXYaMqAI0m9yrFqo/jt0CW30vsg=
github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0= github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0=
github.com/GeertJohan/go.rice v1.0.0 h1:KkI6O9uMaQU3VEKaj01ulavtF7o1fWT7+pk/4voiMLQ= github.com/GeertJohan/go.rice v1.0.0 h1:KkI6O9uMaQU3VEKaj01ulavtF7o1fWT7+pk/4voiMLQ=

View File

@ -23,7 +23,7 @@ func newAction(a *Action) (http.Handler, error) {
httpFn := func(w http.ResponseWriter, r *http.Request) { httpFn := func(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil { if err := fn(w, r); err != nil {
renderErr(w, err, nil) renderErr(w, err)
} }
} }

View File

@ -7,7 +7,6 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"github.com/dosco/super-graph/core"
"github.com/dosco/super-graph/internal/serv/internal/auth" "github.com/dosco/super-graph/internal/serv/internal/auth"
"github.com/rs/cors" "github.com/rs/cors"
"go.uber.org/zap" "go.uber.org/zap"
@ -29,7 +28,7 @@ type gqlReq struct {
} }
type errorResp struct { type errorResp struct {
Error error `json:"error"` Error string `json:"error"`
} }
func apiV1Handler() http.Handler { func apiV1Handler() http.Handler {
@ -55,13 +54,13 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck //nolint: errcheck
if conf.AuthFailBlock && !auth.IsAuth(ct) { if conf.AuthFailBlock && !auth.IsAuth(ct) {
renderErr(w, errUnauthorized, nil) renderErr(w, errUnauthorized)
return return
} }
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes)) b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
if err != nil { if err != nil {
renderErr(w, err, nil) renderErr(w, err)
return return
} }
defer r.Body.Close() defer r.Body.Close()
@ -70,7 +69,7 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
err = json.Unmarshal(b, &req) err = json.Unmarshal(b, &req)
if err != nil { if err != nil {
renderErr(w, err, nil) renderErr(w, err)
return return
} }
@ -86,11 +85,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
} }
if err != nil { if err != nil {
renderErr(w, err, res) renderErr(w, err)
return } else {
}
json.NewEncoder(w).Encode(res) json.NewEncoder(w).Encode(res)
}
if doLog && logLevel >= LogLevelInfo { if doLog && logLevel >= LogLevelInfo {
zlog.Info("success", zlog.Info("success",
@ -102,22 +100,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
} }
//nolint: errcheck //nolint: errcheck
func renderErr(w http.ResponseWriter, err error, res *core.Result) { func renderErr(w http.ResponseWriter, err error) {
if err == errUnauthorized { if err == errUnauthorized {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
} }
json.NewEncoder(w).Encode(&errorResp{err}) json.NewEncoder(w).Encode(errorResp{err.Error()})
if logLevel >= LogLevelError {
if res != nil {
zlog.Error(err.Error(),
zap.String("op", res.Operation()),
zap.String("name", res.QueryName()),
zap.String("role", res.Role()),
)
} else {
zlog.Error(err.Error())
}
}
} }