Add support for GraphQL variables
This commit is contained in:
parent
0755ecf6bd
commit
652b31ce38
|
@ -12,4 +12,4 @@ targets:
|
|||
package: github.com/dosco/super-graph/qcode
|
||||
# the repository will be cloned to
|
||||
# $GOPATH/src/github.com/fuzzbuzz/tutorial
|
||||
checkout: github.com/dosco/super-graph/
|
||||
checkout: github.com/dosco/super-graph
|
||||
|
|
|
@ -539,8 +539,7 @@ func (v *selectBlock) renderWhere(w io.Writer) error {
|
|||
if len(v.ti.PrimaryCol) == 0 {
|
||||
return fmt.Errorf("no primary key column defined for %s", v.sel.Table)
|
||||
}
|
||||
fmt.Fprintf(w, `(("%s") = ('%s'))`, v.ti.PrimaryCol, val.Val)
|
||||
valExists = false
|
||||
fmt.Fprintf(w, `(("%s") =`, v.ti.PrimaryCol)
|
||||
case qcode.OpTsQuery:
|
||||
if len(v.ti.TSVCol) == 0 {
|
||||
return fmt.Errorf("no tsv column defined for %s", v.sel.Table)
|
||||
|
@ -632,7 +631,11 @@ func renderVal(w io.Writer, ex *qcode.Exp, vars map[string]string) {
|
|||
io.WriteString(w, ` (`)
|
||||
switch ex.Type {
|
||||
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
|
||||
io.WriteString(w, ex.Val)
|
||||
if len(ex.Val) != 0 {
|
||||
fmt.Fprintf(w, `%s`, ex.Val)
|
||||
} else {
|
||||
io.WriteString(w, `''`)
|
||||
}
|
||||
case qcode.ValStr:
|
||||
fmt.Fprintf(w, `'%s'`, ex.Val)
|
||||
case qcode.ValVar:
|
||||
|
|
|
@ -353,7 +353,7 @@ func fetchByID(t *testing.T) {
|
|||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
if err != nil {
|
||||
|
|
|
@ -14,6 +14,8 @@ var (
|
|||
type parserType int16
|
||||
|
||||
const (
|
||||
maxNested = 50
|
||||
|
||||
parserError parserType = iota
|
||||
parserEOF
|
||||
opQuery
|
||||
|
@ -50,6 +52,8 @@ func (t parserType) String() string {
|
|||
v = "node-float"
|
||||
case nodeBool:
|
||||
v = "node-bool"
|
||||
case nodeVar:
|
||||
v = "node-var"
|
||||
case nodeObj:
|
||||
v = "node-obj"
|
||||
case nodeList:
|
||||
|
@ -253,7 +257,7 @@ func (p *Parser) parseFields() ([]*Field, int16, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
if i > 500 {
|
||||
if i > maxNested {
|
||||
return nil, 0, errors.New("too many fields")
|
||||
}
|
||||
|
||||
|
|
|
@ -465,8 +465,10 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
|
|||
ex.Type = ValInt
|
||||
case nodeFloat:
|
||||
ex.Type = ValFloat
|
||||
case nodeVar:
|
||||
ex.Type = ValVar
|
||||
default:
|
||||
fmt.Errorf("expecting an string, int or float")
|
||||
fmt.Errorf("expecting a string, int, float or variable")
|
||||
}
|
||||
|
||||
sel.Where = ex
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
|
||||
qc, err := qcompile.CompileQuery(req.Query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sqlStmt strings.Builder
|
||||
|
||||
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
||||
sqlStmt.Reset()
|
||||
|
||||
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
|
||||
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(ctx) == false {
|
||||
return errUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finalSQL := sqlStmt.String()
|
||||
if conf.DebugLevel > 0 {
|
||||
fmt.Println(finalSQL)
|
||||
}
|
||||
st := time.Now()
|
||||
|
||||
var root json.RawMessage
|
||||
_, err = db.Query(pg.Scan(&root), finalSQL)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
et := time.Now()
|
||||
resp := gqlResp{Data: json.RawMessage(root)}
|
||||
|
||||
if conf.EnableTracing {
|
||||
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
return nil
|
||||
}
|
150
serv/http.go
150
serv/http.go
|
@ -1,23 +1,19 @@
|
|||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
const (
|
||||
maxReadBytes = 100000 // 100Kb
|
||||
introspectionQuery = "IntrospectionQuery"
|
||||
openVar = "{{"
|
||||
closeVar = "}}"
|
||||
|
@ -26,17 +22,20 @@ const (
|
|||
var (
|
||||
upgrader = websocket.Upgrader{}
|
||||
errNoUserID = errors.New("no user_id available")
|
||||
errUnauthorized = errors.New("not authorized")
|
||||
)
|
||||
|
||||
type gqlReq struct {
|
||||
OpName string `json:"operationName"`
|
||||
Query string `json:"query"`
|
||||
Variables map[string]string `json:"variables"`
|
||||
Vars variables `json:"variables"`
|
||||
}
|
||||
|
||||
type variables map[string]interface{}
|
||||
|
||||
type gqlResp struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Extensions *extensions `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -73,7 +72,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
|
||||
defer r.Body.Close()
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
|
@ -96,144 +95,13 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
qc, err := qcompile.CompileQuery(req.Query)
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
err = handleReq(ctx, w, req)
|
||||
|
||||
var sqlStmt strings.Builder
|
||||
|
||||
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
||||
sqlStmt.Reset()
|
||||
|
||||
_, err = t.Execute(&sqlStmt, varValues(ctx))
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(ctx) == false {
|
||||
if err == errUnauthorized {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
finalSQL := sqlStmt.String()
|
||||
if conf.DebugLevel > 0 {
|
||||
fmt.Println(finalSQL)
|
||||
}
|
||||
st := time.Now()
|
||||
|
||||
var root json.RawMessage
|
||||
_, err = db.Query(pg.Scan(&root), finalSQL)
|
||||
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
et := time.Now()
|
||||
resp := gqlResp{}
|
||||
|
||||
if conf.EnableTracing {
|
||||
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
||||
}
|
||||
|
||||
resp.Data = json.RawMessage(root)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
/*
|
||||
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Println("read:", err)
|
||||
break
|
||||
}
|
||||
fmt.Printf("recv: %s", message)
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
fmt.Println("write:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func serve(w http.ResponseWriter, r *http.Request) {
|
||||
// if websocket.IsWebSocketUpgrade(r) {
|
||||
// apiv1Ws(w, r)
|
||||
// return
|
||||
// }
|
||||
apiv1Http(w, r)
|
||||
}
|
||||
*/
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
||||
http.Error(w, string(b), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func authCheck(ctx context.Context) bool {
|
||||
return (ctx.Value(userIDKey) != nil)
|
||||
}
|
||||
|
||||
func varValues(ctx context.Context) map[string]interface{} {
|
||||
uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
})
|
||||
|
||||
uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
})
|
||||
|
||||
return map[string]interface{}{
|
||||
"USER_ID": uidFn,
|
||||
"user_id": uidFn,
|
||||
"USER_ID_PROVIDER": uidpFn,
|
||||
"user_id_provider": uidpFn,
|
||||
}
|
||||
}
|
||||
|
||||
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
||||
du := et.Sub(et)
|
||||
|
||||
t := &trace{
|
||||
Version: 1,
|
||||
StartTime: st,
|
||||
EndTime: et,
|
||||
Duration: du,
|
||||
Execution: execution{
|
||||
[]resolver{
|
||||
resolver{
|
||||
Path: []string{qc.Query.Select.Table},
|
||||
ParentType: "Query",
|
||||
FieldName: qc.Query.Select.Table,
|
||||
ReturnType: "object",
|
||||
StartOffset: 1,
|
||||
Duration: du,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
||||
http.Error(w, string(b), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func authCheck(ctx context.Context) bool {
|
||||
return (ctx.Value(userIDKey) != nil)
|
||||
}
|
||||
|
||||
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
||||
du := et.Sub(et)
|
||||
|
||||
t := &trace{
|
||||
Version: 1,
|
||||
StartTime: st,
|
||||
EndTime: et,
|
||||
Duration: du,
|
||||
Execution: execution{
|
||||
[]resolver{
|
||||
resolver{
|
||||
Path: []string{qc.Query.Select.Table},
|
||||
ParentType: "Query",
|
||||
FieldName: qc.Query.Select.Table,
|
||||
ReturnType: "object",
|
||||
StartOffset: 1,
|
||||
Duration: du,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
func varMap(ctx context.Context, vars variables) variables {
|
||||
userIDFn := func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
}
|
||||
|
||||
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
}
|
||||
|
||||
userIDTag := fasttemplate.TagFunc(userIDFn)
|
||||
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
|
||||
|
||||
vm := variables{
|
||||
"user_id": userIDTag,
|
||||
"user_id_provider": userIDProviderTag,
|
||||
"USER_ID": userIDTag,
|
||||
"USER_ID_PROVIDER": userIDProviderTag,
|
||||
}
|
||||
|
||||
for k, v := range vars {
|
||||
if _, ok := vm[k]; ok {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
vm[k] = val
|
||||
case int:
|
||||
vm[k] = strconv.Itoa(val)
|
||||
case int64:
|
||||
vm[k] = strconv.FormatInt(val, 64)
|
||||
case float64:
|
||||
vm[k] = fmt.Sprintf("%.0f", val)
|
||||
}
|
||||
}
|
||||
return vm
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package serv
|
||||
|
||||
/*
|
||||
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Println("read:", err)
|
||||
break
|
||||
}
|
||||
fmt.Printf("recv: %s", message)
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
fmt.Println("write:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func serve(w http.ResponseWriter, r *http.Request) {
|
||||
// if websocket.IsWebSocketUpgrade(r) {
|
||||
// apiv1Ws(w, r)
|
||||
// return
|
||||
// }
|
||||
apiv1Http(w, r)
|
||||
}
|
||||
*/
|
Loading…
Reference in New Issue