Add support for GraphQL variables

This commit is contained in:
Vikram Rangnekar
2019-04-19 01:55:03 -04:00
parent 0755ecf6bd
commit 652b31ce38
10 changed files with 223 additions and 152 deletions

64
serv/core.go Normal file
View File

@ -0,0 +1,64 @@
package serv
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
)
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
qc, err := qcompile.CompileQuery(req.Query)
if err != nil {
return err
}
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
return err
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
return errUnauthorized
}
if err != nil {
return err
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
return err
}
et := time.Now()
resp := gqlResp{Data: json.RawMessage(root)}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}
json.NewEncoder(w).Encode(resp)
return nil
}

View File

@ -1,42 +1,41 @@
package serv
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/gorilla/websocket"
"github.com/valyala/fasttemplate"
)
const (
maxReadBytes = 100000 // 100Kb
introspectionQuery = "IntrospectionQuery"
openVar = "{{"
closeVar = "}}"
)
var (
upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available")
upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available")
errUnauthorized = errors.New("not authorized")
)
type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Variables map[string]string `json:"variables"`
OpName string `json:"operationName"`
Query string `json:"query"`
Vars variables `json:"variables"`
}
type variables map[string]interface{}
type gqlResp struct {
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Data json.RawMessage `json:"data"`
Extensions *extensions `json:"extensions,omitempty"`
}
@ -73,7 +72,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return
}
b, err := ioutil.ReadAll(r.Body)
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
defer r.Body.Close()
if err != nil {
errorResp(w, err)
@ -96,144 +95,13 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return
}
qc, err := qcompile.CompileQuery(req.Query)
if err != nil {
errorResp(w, err)
return
}
err = handleReq(ctx, w, req)
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
errorResp(w, err)
return
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varValues(ctx))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
if err == errUnauthorized {
http.Error(w, "Not authorized", 401)
return
}
if err != nil {
errorResp(w, err)
return
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
errorResp(w, err)
return
}
et := time.Now()
resp := gqlResp{}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}
resp.Data = json.RawMessage(root)
json.NewEncoder(w).Encode(resp)
}
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
}
}
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func varValues(ctx context.Context) map[string]interface{} {
uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
return map[string]interface{}{
"USER_ID": uidFn,
"user_id": uidFn,
"USER_ID_PROVIDER": uidpFn,
"user_id_provider": uidpFn,
}
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}

44
serv/utils.go Normal file
View File

@ -0,0 +1,44 @@
package serv
import (
"context"
"encoding/json"
"net/http"
"time"
"github.com/dosco/super-graph/qcode"
)
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}

53
serv/vars.go Normal file
View File

@ -0,0 +1,53 @@
package serv
import (
"context"
"fmt"
"io"
"strconv"
"github.com/valyala/fasttemplate"
)
func varMap(ctx context.Context, vars variables) variables {
userIDFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDTag := fasttemplate.TagFunc(userIDFn)
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
vm := variables{
"user_id": userIDTag,
"user_id_provider": userIDProviderTag,
"USER_ID": userIDTag,
"USER_ID_PROVIDER": userIDProviderTag,
}
for k, v := range vars {
if _, ok := vm[k]; ok {
continue
}
switch val := v.(type) {
case string:
vm[k] = val
case int:
vm[k] = strconv.Itoa(val)
case int64:
vm[k] = strconv.FormatInt(val, 64)
case float64:
vm[k] = fmt.Sprintf("%.0f", val)
}
}
return vm
}

33
serv/ws.go Normal file
View File

@ -0,0 +1,33 @@
package serv
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
}
}
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/