From 652b31ce38c572f1a23abf71f84f2e71426015a4 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Fri, 19 Apr 2019 01:55:03 -0400 Subject: [PATCH] Add support for GraphQL variables --- fuzz.yaml | 2 +- psql/psql.go | 9 ++- psql/psql_test.go | 2 +- qcode/parse.go | 6 +- qcode/qcode.go | 4 +- serv/core.go | 64 +++++++++++++++++++ serv/http.go | 158 ++++------------------------------------------ serv/utils.go | 44 +++++++++++++ serv/vars.go | 53 ++++++++++++++++ serv/ws.go | 33 ++++++++++ 10 files changed, 223 insertions(+), 152 deletions(-) create mode 100644 serv/core.go create mode 100644 serv/utils.go create mode 100644 serv/vars.go create mode 100644 serv/ws.go diff --git a/fuzz.yaml b/fuzz.yaml index 84a886c..1297620 100644 --- a/fuzz.yaml +++ b/fuzz.yaml @@ -12,4 +12,4 @@ targets: package: github.com/dosco/super-graph/qcode # the repository will be cloned to # $GOPATH/src/github.com/fuzzbuzz/tutorial - checkout: github.com/dosco/super-graph/ + checkout: github.com/dosco/super-graph diff --git a/psql/psql.go b/psql/psql.go index 2d80df4..59326b1 100644 --- a/psql/psql.go +++ b/psql/psql.go @@ -539,8 +539,7 @@ func (v *selectBlock) renderWhere(w io.Writer) error { if len(v.ti.PrimaryCol) == 0 { return fmt.Errorf("no primary key column defined for %s", v.sel.Table) } - fmt.Fprintf(w, `(("%s") = ('%s'))`, v.ti.PrimaryCol, val.Val) - valExists = false + fmt.Fprintf(w, `(("%s") =`, v.ti.PrimaryCol) case qcode.OpTsQuery: if len(v.ti.TSVCol) == 0 { return fmt.Errorf("no tsv column defined for %s", v.sel.Table) @@ -632,7 +631,11 @@ func renderVal(w io.Writer, ex *qcode.Exp, vars map[string]string) { io.WriteString(w, ` (`) switch ex.Type { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: - io.WriteString(w, ex.Val) + if len(ex.Val) != 0 { + fmt.Fprintf(w, `%s`, ex.Val) + } else { + io.WriteString(w, `''`) + } case qcode.ValStr: fmt.Fprintf(w, `'%s'`, ex.Val) case qcode.ValVar: diff --git a/psql/psql_test.go b/psql/psql_test.go index 628bd70..a409956 100644 --- a/psql/psql_test.go +++ b/psql/psql_test.go @@ -353,7 +353,7 @@ func fetchByID(t *testing.T) { } }` - sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { diff --git a/qcode/parse.go b/qcode/parse.go index 0d1970d..f008f1b 100644 --- a/qcode/parse.go +++ b/qcode/parse.go @@ -14,6 +14,8 @@ var ( type parserType int16 const ( + maxNested = 50 + parserError parserType = iota parserEOF opQuery @@ -50,6 +52,8 @@ func (t parserType) String() string { v = "node-float" case nodeBool: v = "node-bool" + case nodeVar: + v = "node-var" case nodeObj: v = "node-obj" case nodeList: @@ -253,7 +257,7 @@ func (p *Parser) parseFields() ([]*Field, int16, error) { continue } - if i > 500 { + if i > maxNested { return nil, 0, errors.New("too many fields") } diff --git a/qcode/qcode.go b/qcode/qcode.go index 3ac6f64..3674a66 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -465,8 +465,10 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { ex.Type = ValInt case nodeFloat: ex.Type = ValFloat + case nodeVar: + ex.Type = ValVar default: - fmt.Errorf("expecting an string, int or float") + fmt.Errorf("expecting a string, int, float or variable") } sel.Where = ex diff --git a/serv/core.go b/serv/core.go new file mode 100644 index 0000000..a26d6ce --- /dev/null +++ b/serv/core.go @@ -0,0 +1,64 @@ +package serv + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + "time" + + "github.com/go-pg/pg" + "github.com/valyala/fasttemplate" +) + +func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error { + qc, err := qcompile.CompileQuery(req.Query) + if err != nil { + return err + } + + var sqlStmt strings.Builder + + if err := pcompile.Compile(&sqlStmt, qc); err != nil { + return err + } + + t := fasttemplate.New(sqlStmt.String(), openVar, closeVar) + sqlStmt.Reset() + + _, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars)) + + if err == errNoUserID && + authFailBlock == authFailBlockPerQuery && + authCheck(ctx) == false { + return errUnauthorized + } + + if err != nil { + return err + } + + finalSQL := sqlStmt.String() + if conf.DebugLevel > 0 { + fmt.Println(finalSQL) + } + st := time.Now() + + var root json.RawMessage + _, err = db.Query(pg.Scan(&root), finalSQL) + + if err != nil { + return err + } + + et := time.Now() + resp := gqlResp{Data: json.RawMessage(root)} + + if conf.EnableTracing { + resp.Extensions = &extensions{newTrace(st, et, qc)} + } + + json.NewEncoder(w).Encode(resp) + return nil +} diff --git a/serv/http.go b/serv/http.go index 4deab71..262aaf4 100644 --- a/serv/http.go +++ b/serv/http.go @@ -1,42 +1,41 @@ package serv import ( - "context" "encoding/json" "errors" - "fmt" "io" "io/ioutil" "net/http" "strings" "time" - "github.com/dosco/super-graph/qcode" - "github.com/go-pg/pg" "github.com/gorilla/websocket" - "github.com/valyala/fasttemplate" ) const ( + maxReadBytes = 100000 // 100Kb introspectionQuery = "IntrospectionQuery" openVar = "{{" closeVar = "}}" ) var ( - upgrader = websocket.Upgrader{} - errNoUserID = errors.New("no user_id available") + upgrader = websocket.Upgrader{} + errNoUserID = errors.New("no user_id available") + errUnauthorized = errors.New("not authorized") ) type gqlReq struct { - OpName string `json:"operationName"` - Query string `json:"query"` - Variables map[string]string `json:"variables"` + OpName string `json:"operationName"` + Query string `json:"query"` + Vars variables `json:"variables"` } +type variables map[string]interface{} + type gqlResp struct { Error string `json:"error,omitempty"` - Data json.RawMessage `json:"data,omitempty"` + Data json.RawMessage `json:"data"` Extensions *extensions `json:"extensions,omitempty"` } @@ -73,7 +72,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { return } - b, err := ioutil.ReadAll(r.Body) + b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes)) defer r.Body.Close() if err != nil { errorResp(w, err) @@ -96,144 +95,13 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { return } - qc, err := qcompile.CompileQuery(req.Query) - if err != nil { - errorResp(w, err) - return - } + err = handleReq(ctx, w, req) - var sqlStmt strings.Builder - - if err := pcompile.Compile(&sqlStmt, qc); err != nil { - errorResp(w, err) - return - } - - t := fasttemplate.New(sqlStmt.String(), openVar, closeVar) - sqlStmt.Reset() - - _, err = t.Execute(&sqlStmt, varValues(ctx)) - if err == errNoUserID && - authFailBlock == authFailBlockPerQuery && - authCheck(ctx) == false { + if err == errUnauthorized { http.Error(w, "Not authorized", 401) - return } - if err != nil { - errorResp(w, err) - return - } - - finalSQL := sqlStmt.String() - if conf.DebugLevel > 0 { - fmt.Println(finalSQL) - } - st := time.Now() - - var root json.RawMessage - _, err = db.Query(pg.Scan(&root), finalSQL) if err != nil { errorResp(w, err) - return - } - - et := time.Now() - resp := gqlResp{} - - if conf.EnableTracing { - resp.Extensions = &extensions{newTrace(st, et, qc)} - } - - resp.Data = json.RawMessage(root) - json.NewEncoder(w).Encode(resp) -} - -/* -func apiv1Ws(w http.ResponseWriter, r *http.Request) { - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - fmt.Println(err) - return - } - defer c.Close() - for { - mt, message, err := c.ReadMessage() - if err != nil { - fmt.Println("read:", err) - break - } - fmt.Printf("recv: %s", message) - err = c.WriteMessage(mt, message) - if err != nil { - fmt.Println("write:", err) - break - } } } - -func serve(w http.ResponseWriter, r *http.Request) { - // if websocket.IsWebSocketUpgrade(r) { - // apiv1Ws(w, r) - // return - // } - apiv1Http(w, r) -} -*/ - -func errorResp(w http.ResponseWriter, err error) { - b, _ := json.Marshal(gqlResp{Error: err.Error()}) - http.Error(w, string(b), http.StatusBadRequest) -} - -func authCheck(ctx context.Context) bool { - return (ctx.Value(userIDKey) != nil) -} - -func varValues(ctx context.Context) map[string]interface{} { - uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) { - if v := ctx.Value(userIDKey); v != nil { - return w.Write([]byte(v.(string))) - } - return 0, errNoUserID - }) - - uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) { - if v := ctx.Value(userIDProviderKey); v != nil { - return w.Write([]byte(v.(string))) - } - return 0, errNoUserID - }) - - return map[string]interface{}{ - "USER_ID": uidFn, - "user_id": uidFn, - "USER_ID_PROVIDER": uidpFn, - "user_id_provider": uidpFn, - } -} - -func newTrace(st, et time.Time, qc *qcode.QCode) *trace { - du := et.Sub(et) - - t := &trace{ - Version: 1, - StartTime: st, - EndTime: et, - Duration: du, - Execution: execution{ - []resolver{ - resolver{ - Path: []string{qc.Query.Select.Table}, - ParentType: "Query", - FieldName: qc.Query.Select.Table, - ReturnType: "object", - StartOffset: 1, - Duration: du, - }, - }, - }, - } - - return t -} diff --git a/serv/utils.go b/serv/utils.go new file mode 100644 index 0000000..650154c --- /dev/null +++ b/serv/utils.go @@ -0,0 +1,44 @@ +package serv + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "github.com/dosco/super-graph/qcode" +) + +func errorResp(w http.ResponseWriter, err error) { + b, _ := json.Marshal(gqlResp{Error: err.Error()}) + http.Error(w, string(b), http.StatusBadRequest) +} + +func authCheck(ctx context.Context) bool { + return (ctx.Value(userIDKey) != nil) +} + +func newTrace(st, et time.Time, qc *qcode.QCode) *trace { + du := et.Sub(et) + + t := &trace{ + Version: 1, + StartTime: st, + EndTime: et, + Duration: du, + Execution: execution{ + []resolver{ + resolver{ + Path: []string{qc.Query.Select.Table}, + ParentType: "Query", + FieldName: qc.Query.Select.Table, + ReturnType: "object", + StartOffset: 1, + Duration: du, + }, + }, + }, + } + + return t +} diff --git a/serv/vars.go b/serv/vars.go new file mode 100644 index 0000000..4544483 --- /dev/null +++ b/serv/vars.go @@ -0,0 +1,53 @@ +package serv + +import ( + "context" + "fmt" + "io" + "strconv" + + "github.com/valyala/fasttemplate" +) + +func varMap(ctx context.Context, vars variables) variables { + userIDFn := func(w io.Writer, _ string) (int, error) { + if v := ctx.Value(userIDKey); v != nil { + return w.Write([]byte(v.(string))) + } + return 0, errNoUserID + } + + userIDProviderFn := func(w io.Writer, _ string) (int, error) { + if v := ctx.Value(userIDProviderKey); v != nil { + return w.Write([]byte(v.(string))) + } + return 0, errNoUserID + } + + userIDTag := fasttemplate.TagFunc(userIDFn) + userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn) + + vm := variables{ + "user_id": userIDTag, + "user_id_provider": userIDProviderTag, + "USER_ID": userIDTag, + "USER_ID_PROVIDER": userIDProviderTag, + } + + for k, v := range vars { + if _, ok := vm[k]; ok { + continue + } + switch val := v.(type) { + case string: + vm[k] = val + case int: + vm[k] = strconv.Itoa(val) + case int64: + vm[k] = strconv.FormatInt(val, 64) + case float64: + vm[k] = fmt.Sprintf("%.0f", val) + } + } + return vm +} diff --git a/serv/ws.go b/serv/ws.go new file mode 100644 index 0000000..12f628a --- /dev/null +++ b/serv/ws.go @@ -0,0 +1,33 @@ +package serv + +/* +func apiv1Ws(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + fmt.Println("read:", err) + break + } + fmt.Printf("recv: %s", message) + err = c.WriteMessage(mt, message) + if err != nil { + fmt.Println("write:", err) + break + } + } +} + +func serve(w http.ResponseWriter, r *http.Request) { + // if websocket.IsWebSocketUpgrade(r) { + // apiv1Ws(w, r) + // return + // } + apiv1Http(w, r) +} +*/