Add support for GraphQL variables
This commit is contained in:
parent
0755ecf6bd
commit
652b31ce38
|
@ -12,4 +12,4 @@ targets:
|
||||||
package: github.com/dosco/super-graph/qcode
|
package: github.com/dosco/super-graph/qcode
|
||||||
# the repository will be cloned to
|
# the repository will be cloned to
|
||||||
# $GOPATH/src/github.com/fuzzbuzz/tutorial
|
# $GOPATH/src/github.com/fuzzbuzz/tutorial
|
||||||
checkout: github.com/dosco/super-graph/
|
checkout: github.com/dosco/super-graph
|
||||||
|
|
|
@ -539,8 +539,7 @@ func (v *selectBlock) renderWhere(w io.Writer) error {
|
||||||
if len(v.ti.PrimaryCol) == 0 {
|
if len(v.ti.PrimaryCol) == 0 {
|
||||||
return fmt.Errorf("no primary key column defined for %s", v.sel.Table)
|
return fmt.Errorf("no primary key column defined for %s", v.sel.Table)
|
||||||
}
|
}
|
||||||
fmt.Fprintf(w, `(("%s") = ('%s'))`, v.ti.PrimaryCol, val.Val)
|
fmt.Fprintf(w, `(("%s") =`, v.ti.PrimaryCol)
|
||||||
valExists = false
|
|
||||||
case qcode.OpTsQuery:
|
case qcode.OpTsQuery:
|
||||||
if len(v.ti.TSVCol) == 0 {
|
if len(v.ti.TSVCol) == 0 {
|
||||||
return fmt.Errorf("no tsv column defined for %s", v.sel.Table)
|
return fmt.Errorf("no tsv column defined for %s", v.sel.Table)
|
||||||
|
@ -632,7 +631,11 @@ func renderVal(w io.Writer, ex *qcode.Exp, vars map[string]string) {
|
||||||
io.WriteString(w, ` (`)
|
io.WriteString(w, ` (`)
|
||||||
switch ex.Type {
|
switch ex.Type {
|
||||||
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
|
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
|
||||||
io.WriteString(w, ex.Val)
|
if len(ex.Val) != 0 {
|
||||||
|
fmt.Fprintf(w, `%s`, ex.Val)
|
||||||
|
} else {
|
||||||
|
io.WriteString(w, `''`)
|
||||||
|
}
|
||||||
case qcode.ValStr:
|
case qcode.ValStr:
|
||||||
fmt.Fprintf(w, `'%s'`, ex.Val)
|
fmt.Fprintf(w, `'%s'`, ex.Val)
|
||||||
case qcode.ValVar:
|
case qcode.ValVar:
|
||||||
|
|
|
@ -353,7 +353,7 @@ func fetchByID(t *testing.T) {
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -14,6 +14,8 @@ var (
|
||||||
type parserType int16
|
type parserType int16
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
maxNested = 50
|
||||||
|
|
||||||
parserError parserType = iota
|
parserError parserType = iota
|
||||||
parserEOF
|
parserEOF
|
||||||
opQuery
|
opQuery
|
||||||
|
@ -50,6 +52,8 @@ func (t parserType) String() string {
|
||||||
v = "node-float"
|
v = "node-float"
|
||||||
case nodeBool:
|
case nodeBool:
|
||||||
v = "node-bool"
|
v = "node-bool"
|
||||||
|
case nodeVar:
|
||||||
|
v = "node-var"
|
||||||
case nodeObj:
|
case nodeObj:
|
||||||
v = "node-obj"
|
v = "node-obj"
|
||||||
case nodeList:
|
case nodeList:
|
||||||
|
@ -253,7 +257,7 @@ func (p *Parser) parseFields() ([]*Field, int16, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if i > 500 {
|
if i > maxNested {
|
||||||
return nil, 0, errors.New("too many fields")
|
return nil, 0, errors.New("too many fields")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -465,8 +465,10 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
|
||||||
ex.Type = ValInt
|
ex.Type = ValInt
|
||||||
case nodeFloat:
|
case nodeFloat:
|
||||||
ex.Type = ValFloat
|
ex.Type = ValFloat
|
||||||
|
case nodeVar:
|
||||||
|
ex.Type = ValVar
|
||||||
default:
|
default:
|
||||||
fmt.Errorf("expecting an string, int or float")
|
fmt.Errorf("expecting a string, int, float or variable")
|
||||||
}
|
}
|
||||||
|
|
||||||
sel.Where = ex
|
sel.Where = ex
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
package serv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-pg/pg"
|
||||||
|
"github.com/valyala/fasttemplate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
|
||||||
|
qc, err := qcompile.CompileQuery(req.Query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var sqlStmt strings.Builder
|
||||||
|
|
||||||
|
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
||||||
|
sqlStmt.Reset()
|
||||||
|
|
||||||
|
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
|
||||||
|
|
||||||
|
if err == errNoUserID &&
|
||||||
|
authFailBlock == authFailBlockPerQuery &&
|
||||||
|
authCheck(ctx) == false {
|
||||||
|
return errUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
finalSQL := sqlStmt.String()
|
||||||
|
if conf.DebugLevel > 0 {
|
||||||
|
fmt.Println(finalSQL)
|
||||||
|
}
|
||||||
|
st := time.Now()
|
||||||
|
|
||||||
|
var root json.RawMessage
|
||||||
|
_, err = db.Query(pg.Scan(&root), finalSQL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
et := time.Now()
|
||||||
|
resp := gqlResp{Data: json.RawMessage(root)}
|
||||||
|
|
||||||
|
if conf.EnableTracing {
|
||||||
|
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
||||||
|
}
|
||||||
|
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
return nil
|
||||||
|
}
|
158
serv/http.go
158
serv/http.go
|
@ -1,42 +1,41 @@
|
||||||
package serv
|
package serv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dosco/super-graph/qcode"
|
|
||||||
"github.com/go-pg/pg"
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/valyala/fasttemplate"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
maxReadBytes = 100000 // 100Kb
|
||||||
introspectionQuery = "IntrospectionQuery"
|
introspectionQuery = "IntrospectionQuery"
|
||||||
openVar = "{{"
|
openVar = "{{"
|
||||||
closeVar = "}}"
|
closeVar = "}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
upgrader = websocket.Upgrader{}
|
upgrader = websocket.Upgrader{}
|
||||||
errNoUserID = errors.New("no user_id available")
|
errNoUserID = errors.New("no user_id available")
|
||||||
|
errUnauthorized = errors.New("not authorized")
|
||||||
)
|
)
|
||||||
|
|
||||||
type gqlReq struct {
|
type gqlReq struct {
|
||||||
OpName string `json:"operationName"`
|
OpName string `json:"operationName"`
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
Variables map[string]string `json:"variables"`
|
Vars variables `json:"variables"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type variables map[string]interface{}
|
||||||
|
|
||||||
type gqlResp struct {
|
type gqlResp struct {
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
Data json.RawMessage `json:"data,omitempty"`
|
Data json.RawMessage `json:"data"`
|
||||||
Extensions *extensions `json:"extensions,omitempty"`
|
Extensions *extensions `json:"extensions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +72,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := ioutil.ReadAll(r.Body)
|
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
|
||||||
defer r.Body.Close()
|
defer r.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorResp(w, err)
|
errorResp(w, err)
|
||||||
|
@ -96,144 +95,13 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
qc, err := qcompile.CompileQuery(req.Query)
|
err = handleReq(ctx, w, req)
|
||||||
if err != nil {
|
|
||||||
errorResp(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var sqlStmt strings.Builder
|
if err == errUnauthorized {
|
||||||
|
|
||||||
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
|
||||||
errorResp(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
|
||||||
sqlStmt.Reset()
|
|
||||||
|
|
||||||
_, err = t.Execute(&sqlStmt, varValues(ctx))
|
|
||||||
if err == errNoUserID &&
|
|
||||||
authFailBlock == authFailBlockPerQuery &&
|
|
||||||
authCheck(ctx) == false {
|
|
||||||
http.Error(w, "Not authorized", 401)
|
http.Error(w, "Not authorized", 401)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
errorResp(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
finalSQL := sqlStmt.String()
|
|
||||||
if conf.DebugLevel > 0 {
|
|
||||||
fmt.Println(finalSQL)
|
|
||||||
}
|
|
||||||
st := time.Now()
|
|
||||||
|
|
||||||
var root json.RawMessage
|
|
||||||
_, err = db.Query(pg.Scan(&root), finalSQL)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorResp(w, err)
|
errorResp(w, err)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
et := time.Now()
|
|
||||||
resp := gqlResp{}
|
|
||||||
|
|
||||||
if conf.EnableTracing {
|
|
||||||
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Data = json.RawMessage(root)
|
|
||||||
json.NewEncoder(w).Encode(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
|
|
||||||
c, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
for {
|
|
||||||
mt, message, err := c.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("read:", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
fmt.Printf("recv: %s", message)
|
|
||||||
err = c.WriteMessage(mt, message)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("write:", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func serve(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// if websocket.IsWebSocketUpgrade(r) {
|
|
||||||
// apiv1Ws(w, r)
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
apiv1Http(w, r)
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func errorResp(w http.ResponseWriter, err error) {
|
|
||||||
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
|
||||||
http.Error(w, string(b), http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
func authCheck(ctx context.Context) bool {
|
|
||||||
return (ctx.Value(userIDKey) != nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func varValues(ctx context.Context) map[string]interface{} {
|
|
||||||
uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
|
|
||||||
if v := ctx.Value(userIDKey); v != nil {
|
|
||||||
return w.Write([]byte(v.(string)))
|
|
||||||
}
|
|
||||||
return 0, errNoUserID
|
|
||||||
})
|
|
||||||
|
|
||||||
uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
|
|
||||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
|
||||||
return w.Write([]byte(v.(string)))
|
|
||||||
}
|
|
||||||
return 0, errNoUserID
|
|
||||||
})
|
|
||||||
|
|
||||||
return map[string]interface{}{
|
|
||||||
"USER_ID": uidFn,
|
|
||||||
"user_id": uidFn,
|
|
||||||
"USER_ID_PROVIDER": uidpFn,
|
|
||||||
"user_id_provider": uidpFn,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
|
||||||
du := et.Sub(et)
|
|
||||||
|
|
||||||
t := &trace{
|
|
||||||
Version: 1,
|
|
||||||
StartTime: st,
|
|
||||||
EndTime: et,
|
|
||||||
Duration: du,
|
|
||||||
Execution: execution{
|
|
||||||
[]resolver{
|
|
||||||
resolver{
|
|
||||||
Path: []string{qc.Query.Select.Table},
|
|
||||||
ParentType: "Query",
|
|
||||||
FieldName: qc.Query.Select.Table,
|
|
||||||
ReturnType: "object",
|
|
||||||
StartOffset: 1,
|
|
||||||
Duration: du,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
package serv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dosco/super-graph/qcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errorResp(w http.ResponseWriter, err error) {
|
||||||
|
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
||||||
|
http.Error(w, string(b), http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authCheck(ctx context.Context) bool {
|
||||||
|
return (ctx.Value(userIDKey) != nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
||||||
|
du := et.Sub(et)
|
||||||
|
|
||||||
|
t := &trace{
|
||||||
|
Version: 1,
|
||||||
|
StartTime: st,
|
||||||
|
EndTime: et,
|
||||||
|
Duration: du,
|
||||||
|
Execution: execution{
|
||||||
|
[]resolver{
|
||||||
|
resolver{
|
||||||
|
Path: []string{qc.Query.Select.Table},
|
||||||
|
ParentType: "Query",
|
||||||
|
FieldName: qc.Query.Select.Table,
|
||||||
|
ReturnType: "object",
|
||||||
|
StartOffset: 1,
|
||||||
|
Duration: du,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return t
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
package serv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/valyala/fasttemplate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func varMap(ctx context.Context, vars variables) variables {
|
||||||
|
userIDFn := func(w io.Writer, _ string) (int, error) {
|
||||||
|
if v := ctx.Value(userIDKey); v != nil {
|
||||||
|
return w.Write([]byte(v.(string)))
|
||||||
|
}
|
||||||
|
return 0, errNoUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
|
||||||
|
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||||
|
return w.Write([]byte(v.(string)))
|
||||||
|
}
|
||||||
|
return 0, errNoUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDTag := fasttemplate.TagFunc(userIDFn)
|
||||||
|
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
|
||||||
|
|
||||||
|
vm := variables{
|
||||||
|
"user_id": userIDTag,
|
||||||
|
"user_id_provider": userIDProviderTag,
|
||||||
|
"USER_ID": userIDTag,
|
||||||
|
"USER_ID_PROVIDER": userIDProviderTag,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range vars {
|
||||||
|
if _, ok := vm[k]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
vm[k] = val
|
||||||
|
case int:
|
||||||
|
vm[k] = strconv.Itoa(val)
|
||||||
|
case int64:
|
||||||
|
vm[k] = strconv.FormatInt(val, 64)
|
||||||
|
case float64:
|
||||||
|
vm[k] = fmt.Sprintf("%.0f", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vm
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package serv
|
||||||
|
|
||||||
|
/*
|
||||||
|
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
|
||||||
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
for {
|
||||||
|
mt, message, err := c.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("read:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fmt.Printf("recv: %s", message)
|
||||||
|
err = c.WriteMessage(mt, message)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("write:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serve(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// if websocket.IsWebSocketUpgrade(r) {
|
||||||
|
// apiv1Ws(w, r)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
apiv1Http(w, r)
|
||||||
|
}
|
||||||
|
*/
|
Loading…
Reference in New Issue