Add support for GraphQL variables

This commit is contained in:
Vikram Rangnekar 2019-04-19 01:55:03 -04:00
parent 0755ecf6bd
commit 652b31ce38
10 changed files with 223 additions and 152 deletions

View File

@ -12,4 +12,4 @@ targets:
package: github.com/dosco/super-graph/qcode
# the repository will be cloned to
# $GOPATH/src/github.com/fuzzbuzz/tutorial
checkout: github.com/dosco/super-graph/
checkout: github.com/dosco/super-graph

View File

@ -539,8 +539,7 @@ func (v *selectBlock) renderWhere(w io.Writer) error {
if len(v.ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", v.sel.Table)
}
fmt.Fprintf(w, `(("%s") = ('%s'))`, v.ti.PrimaryCol, val.Val)
valExists = false
fmt.Fprintf(w, `(("%s") =`, v.ti.PrimaryCol)
case qcode.OpTsQuery:
if len(v.ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", v.sel.Table)
@ -632,7 +631,11 @@ func renderVal(w io.Writer, ex *qcode.Exp, vars map[string]string) {
io.WriteString(w, ` (`)
switch ex.Type {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
io.WriteString(w, ex.Val)
if len(ex.Val) != 0 {
fmt.Fprintf(w, `%s`, ex.Val)
} else {
io.WriteString(w, `''`)
}
case qcode.ValStr:
fmt.Fprintf(w, `'%s'`, ex.Val)
case qcode.ValVar:

View File

@ -353,7 +353,7 @@ func fetchByID(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {

View File

@ -14,6 +14,8 @@ var (
type parserType int16
const (
maxNested = 50
parserError parserType = iota
parserEOF
opQuery
@ -50,6 +52,8 @@ func (t parserType) String() string {
v = "node-float"
case nodeBool:
v = "node-bool"
case nodeVar:
v = "node-var"
case nodeObj:
v = "node-obj"
case nodeList:
@ -253,7 +257,7 @@ func (p *Parser) parseFields() ([]*Field, int16, error) {
continue
}
if i > 500 {
if i > maxNested {
return nil, 0, errors.New("too many fields")
}

View File

@ -465,8 +465,10 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
ex.Type = ValInt
case nodeFloat:
ex.Type = ValFloat
case nodeVar:
ex.Type = ValVar
default:
fmt.Errorf("expecting an string, int or float")
fmt.Errorf("expecting a string, int, float or variable")
}
sel.Where = ex

64
serv/core.go Normal file
View File

@ -0,0 +1,64 @@
package serv
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
)
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
qc, err := qcompile.CompileQuery(req.Query)
if err != nil {
return err
}
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
return err
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
return errUnauthorized
}
if err != nil {
return err
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
return err
}
et := time.Now()
resp := gqlResp{Data: json.RawMessage(root)}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}
json.NewEncoder(w).Encode(resp)
return nil
}

View File

@ -1,42 +1,41 @@
package serv
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/gorilla/websocket"
"github.com/valyala/fasttemplate"
)
const (
maxReadBytes = 100000 // 100Kb
introspectionQuery = "IntrospectionQuery"
openVar = "{{"
closeVar = "}}"
)
var (
upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available")
upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available")
errUnauthorized = errors.New("not authorized")
)
type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Variables map[string]string `json:"variables"`
OpName string `json:"operationName"`
Query string `json:"query"`
Vars variables `json:"variables"`
}
type variables map[string]interface{}
type gqlResp struct {
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Data json.RawMessage `json:"data"`
Extensions *extensions `json:"extensions,omitempty"`
}
@ -73,7 +72,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return
}
b, err := ioutil.ReadAll(r.Body)
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
defer r.Body.Close()
if err != nil {
errorResp(w, err)
@ -96,144 +95,13 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return
}
qc, err := qcompile.CompileQuery(req.Query)
if err != nil {
errorResp(w, err)
return
}
err = handleReq(ctx, w, req)
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
errorResp(w, err)
return
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varValues(ctx))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
if err == errUnauthorized {
http.Error(w, "Not authorized", 401)
return
}
if err != nil {
errorResp(w, err)
return
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
errorResp(w, err)
return
}
et := time.Now()
resp := gqlResp{}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}
resp.Data = json.RawMessage(root)
json.NewEncoder(w).Encode(resp)
}
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
}
}
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func varValues(ctx context.Context) map[string]interface{} {
uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
return map[string]interface{}{
"USER_ID": uidFn,
"user_id": uidFn,
"USER_ID_PROVIDER": uidpFn,
"user_id_provider": uidpFn,
}
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}

44
serv/utils.go Normal file
View File

@ -0,0 +1,44 @@
package serv
import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/dosco/super-graph/qcode"
)
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}

53
serv/vars.go Normal file
View File

@ -0,0 +1,53 @@
package serv
import (
"context"
"fmt"
"io"
"strconv"
"github.com/valyala/fasttemplate"
)
func varMap(ctx context.Context, vars variables) variables {
userIDFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDTag := fasttemplate.TagFunc(userIDFn)
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
vm := variables{
"user_id": userIDTag,
"user_id_provider": userIDProviderTag,
"USER_ID": userIDTag,
"USER_ID_PROVIDER": userIDProviderTag,
}
for k, v := range vars {
if _, ok := vm[k]; ok {
continue
}
switch val := v.(type) {
case string:
vm[k] = val
case int:
vm[k] = strconv.Itoa(val)
case int64:
vm[k] = strconv.FormatInt(val, 64)
case float64:
vm[k] = fmt.Sprintf("%.0f", val)
}
}
return vm
}

33
serv/ws.go Normal file
View File

@ -0,0 +1,33 @@
package serv
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
}
}
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/