Add support for GraphQL variables

This commit is contained in:
Vikram Rangnekar 2019-04-19 01:55:03 -04:00
parent 0755ecf6bd
commit 652b31ce38
10 changed files with 223 additions and 152 deletions

View File

@ -12,4 +12,4 @@ targets:
package: github.com/dosco/super-graph/qcode package: github.com/dosco/super-graph/qcode
# the repository will be cloned to # the repository will be cloned to
# $GOPATH/src/github.com/fuzzbuzz/tutorial # $GOPATH/src/github.com/fuzzbuzz/tutorial
checkout: github.com/dosco/super-graph/ checkout: github.com/dosco/super-graph

View File

@ -539,8 +539,7 @@ func (v *selectBlock) renderWhere(w io.Writer) error {
if len(v.ti.PrimaryCol) == 0 { if len(v.ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", v.sel.Table) return fmt.Errorf("no primary key column defined for %s", v.sel.Table)
} }
fmt.Fprintf(w, `(("%s") = ('%s'))`, v.ti.PrimaryCol, val.Val) fmt.Fprintf(w, `(("%s") =`, v.ti.PrimaryCol)
valExists = false
case qcode.OpTsQuery: case qcode.OpTsQuery:
if len(v.ti.TSVCol) == 0 { if len(v.ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", v.sel.Table) return fmt.Errorf("no tsv column defined for %s", v.sel.Table)
@ -632,7 +631,11 @@ func renderVal(w io.Writer, ex *qcode.Exp, vars map[string]string) {
io.WriteString(w, ` (`) io.WriteString(w, ` (`)
switch ex.Type { switch ex.Type {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat: case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
io.WriteString(w, ex.Val) if len(ex.Val) != 0 {
fmt.Fprintf(w, `%s`, ex.Val)
} else {
io.WriteString(w, `''`)
}
case qcode.ValStr: case qcode.ValStr:
fmt.Fprintf(w, `'%s'`, ex.Val) fmt.Fprintf(w, `'%s'`, ex.Val)
case qcode.ValVar: case qcode.ValVar:

View File

@ -353,7 +353,7 @@ func fetchByID(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {

View File

@ -14,6 +14,8 @@ var (
type parserType int16 type parserType int16
const ( const (
maxNested = 50
parserError parserType = iota parserError parserType = iota
parserEOF parserEOF
opQuery opQuery
@ -50,6 +52,8 @@ func (t parserType) String() string {
v = "node-float" v = "node-float"
case nodeBool: case nodeBool:
v = "node-bool" v = "node-bool"
case nodeVar:
v = "node-var"
case nodeObj: case nodeObj:
v = "node-obj" v = "node-obj"
case nodeList: case nodeList:
@ -253,7 +257,7 @@ func (p *Parser) parseFields() ([]*Field, int16, error) {
continue continue
} }
if i > 500 { if i > maxNested {
return nil, 0, errors.New("too many fields") return nil, 0, errors.New("too many fields")
} }

View File

@ -465,8 +465,10 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
ex.Type = ValInt ex.Type = ValInt
case nodeFloat: case nodeFloat:
ex.Type = ValFloat ex.Type = ValFloat
case nodeVar:
ex.Type = ValVar
default: default:
fmt.Errorf("expecting an string, int or float") fmt.Errorf("expecting a string, int, float or variable")
} }
sel.Where = ex sel.Where = ex

64
serv/core.go Normal file
View File

@ -0,0 +1,64 @@
package serv
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
)
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
qc, err := qcompile.CompileQuery(req.Query)
if err != nil {
return err
}
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
return err
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
return errUnauthorized
}
if err != nil {
return err
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
return err
}
et := time.Now()
resp := gqlResp{Data: json.RawMessage(root)}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}
json.NewEncoder(w).Encode(resp)
return nil
}

View File

@ -1,23 +1,19 @@
package serv package serv
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/valyala/fasttemplate"
) )
const ( const (
maxReadBytes = 100000 // 100Kb
introspectionQuery = "IntrospectionQuery" introspectionQuery = "IntrospectionQuery"
openVar = "{{" openVar = "{{"
closeVar = "}}" closeVar = "}}"
@ -26,17 +22,20 @@ const (
var ( var (
upgrader = websocket.Upgrader{} upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available") errNoUserID = errors.New("no user_id available")
errUnauthorized = errors.New("not authorized")
) )
type gqlReq struct { type gqlReq struct {
OpName string `json:"operationName"` OpName string `json:"operationName"`
Query string `json:"query"` Query string `json:"query"`
Variables map[string]string `json:"variables"` Vars variables `json:"variables"`
} }
type variables map[string]interface{}
type gqlResp struct { type gqlResp struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data,omitempty"` Data json.RawMessage `json:"data"`
Extensions *extensions `json:"extensions,omitempty"` Extensions *extensions `json:"extensions,omitempty"`
} }
@ -73,7 +72,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return return
} }
b, err := ioutil.ReadAll(r.Body) b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
defer r.Body.Close() defer r.Body.Close()
if err != nil { if err != nil {
errorResp(w, err) errorResp(w, err)
@ -96,144 +95,13 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return return
} }
qc, err := qcompile.CompileQuery(req.Query) err = handleReq(ctx, w, req)
if err != nil {
errorResp(w, err)
return
}
var sqlStmt strings.Builder if err == errUnauthorized {
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
errorResp(w, err)
return
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varValues(ctx))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
http.Error(w, "Not authorized", 401) http.Error(w, "Not authorized", 401)
return
} }
if err != nil {
errorResp(w, err)
return
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil { if err != nil {
errorResp(w, err) errorResp(w, err)
return
}
et := time.Now()
resp := gqlResp{}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}
resp.Data = json.RawMessage(root)
json.NewEncoder(w).Encode(resp)
}
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
} }
} }
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func varValues(ctx context.Context) map[string]interface{} {
uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
return map[string]interface{}{
"USER_ID": uidFn,
"user_id": uidFn,
"USER_ID_PROVIDER": uidpFn,
"user_id_provider": uidpFn,
}
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}

44
serv/utils.go Normal file
View File

@ -0,0 +1,44 @@
package serv
import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/dosco/super-graph/qcode"
)
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}

53
serv/vars.go Normal file
View File

@ -0,0 +1,53 @@
package serv
import (
"context"
"fmt"
"io"
"strconv"
"github.com/valyala/fasttemplate"
)
func varMap(ctx context.Context, vars variables) variables {
userIDFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDTag := fasttemplate.TagFunc(userIDFn)
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
vm := variables{
"user_id": userIDTag,
"user_id_provider": userIDProviderTag,
"USER_ID": userIDTag,
"USER_ID_PROVIDER": userIDProviderTag,
}
for k, v := range vars {
if _, ok := vm[k]; ok {
continue
}
switch val := v.(type) {
case string:
vm[k] = val
case int:
vm[k] = strconv.Itoa(val)
case int64:
vm[k] = strconv.FormatInt(val, 64)
case float64:
vm[k] = fmt.Sprintf("%.0f", val)
}
}
return vm
}

33
serv/ws.go Normal file
View File

@ -0,0 +1,33 @@
package serv
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
}
}
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/