Add support for GraphQL variables
This commit is contained in:
64
serv/core.go
Normal file
64
serv/core.go
Normal file
@ -0,0 +1,64 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
|
||||
qc, err := qcompile.CompileQuery(req.Query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sqlStmt strings.Builder
|
||||
|
||||
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
||||
sqlStmt.Reset()
|
||||
|
||||
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
|
||||
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(ctx) == false {
|
||||
return errUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finalSQL := sqlStmt.String()
|
||||
if conf.DebugLevel > 0 {
|
||||
fmt.Println(finalSQL)
|
||||
}
|
||||
st := time.Now()
|
||||
|
||||
var root json.RawMessage
|
||||
_, err = db.Query(pg.Scan(&root), finalSQL)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
et := time.Now()
|
||||
resp := gqlResp{Data: json.RawMessage(root)}
|
||||
|
||||
if conf.EnableTracing {
|
||||
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
return nil
|
||||
}
|
158
serv/http.go
158
serv/http.go
@ -1,42 +1,41 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
const (
|
||||
maxReadBytes = 100000 // 100Kb
|
||||
introspectionQuery = "IntrospectionQuery"
|
||||
openVar = "{{"
|
||||
closeVar = "}}"
|
||||
)
|
||||
|
||||
var (
|
||||
upgrader = websocket.Upgrader{}
|
||||
errNoUserID = errors.New("no user_id available")
|
||||
upgrader = websocket.Upgrader{}
|
||||
errNoUserID = errors.New("no user_id available")
|
||||
errUnauthorized = errors.New("not authorized")
|
||||
)
|
||||
|
||||
type gqlReq struct {
|
||||
OpName string `json:"operationName"`
|
||||
Query string `json:"query"`
|
||||
Variables map[string]string `json:"variables"`
|
||||
OpName string `json:"operationName"`
|
||||
Query string `json:"query"`
|
||||
Vars variables `json:"variables"`
|
||||
}
|
||||
|
||||
type variables map[string]interface{}
|
||||
|
||||
type gqlResp struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Extensions *extensions `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
@ -73,7 +72,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
|
||||
defer r.Body.Close()
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
@ -96,144 +95,13 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
qc, err := qcompile.CompileQuery(req.Query)
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
err = handleReq(ctx, w, req)
|
||||
|
||||
var sqlStmt strings.Builder
|
||||
|
||||
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
||||
sqlStmt.Reset()
|
||||
|
||||
_, err = t.Execute(&sqlStmt, varValues(ctx))
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(ctx) == false {
|
||||
if err == errUnauthorized {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
finalSQL := sqlStmt.String()
|
||||
if conf.DebugLevel > 0 {
|
||||
fmt.Println(finalSQL)
|
||||
}
|
||||
st := time.Now()
|
||||
|
||||
var root json.RawMessage
|
||||
_, err = db.Query(pg.Scan(&root), finalSQL)
|
||||
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
et := time.Now()
|
||||
resp := gqlResp{}
|
||||
|
||||
if conf.EnableTracing {
|
||||
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
||||
}
|
||||
|
||||
resp.Data = json.RawMessage(root)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
/*
|
||||
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Println("read:", err)
|
||||
break
|
||||
}
|
||||
fmt.Printf("recv: %s", message)
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
fmt.Println("write:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func serve(w http.ResponseWriter, r *http.Request) {
|
||||
// if websocket.IsWebSocketUpgrade(r) {
|
||||
// apiv1Ws(w, r)
|
||||
// return
|
||||
// }
|
||||
apiv1Http(w, r)
|
||||
}
|
||||
*/
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
||||
http.Error(w, string(b), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func authCheck(ctx context.Context) bool {
|
||||
return (ctx.Value(userIDKey) != nil)
|
||||
}
|
||||
|
||||
func varValues(ctx context.Context) map[string]interface{} {
|
||||
uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
})
|
||||
|
||||
uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
})
|
||||
|
||||
return map[string]interface{}{
|
||||
"USER_ID": uidFn,
|
||||
"user_id": uidFn,
|
||||
"USER_ID_PROVIDER": uidpFn,
|
||||
"user_id_provider": uidpFn,
|
||||
}
|
||||
}
|
||||
|
||||
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
||||
du := et.Sub(et)
|
||||
|
||||
t := &trace{
|
||||
Version: 1,
|
||||
StartTime: st,
|
||||
EndTime: et,
|
||||
Duration: du,
|
||||
Execution: execution{
|
||||
[]resolver{
|
||||
resolver{
|
||||
Path: []string{qc.Query.Select.Table},
|
||||
ParentType: "Query",
|
||||
FieldName: qc.Query.Select.Table,
|
||||
ReturnType: "object",
|
||||
StartOffset: 1,
|
||||
Duration: du,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
44
serv/utils.go
Normal file
44
serv/utils.go
Normal file
@ -0,0 +1,44 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
||||
http.Error(w, string(b), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func authCheck(ctx context.Context) bool {
|
||||
return (ctx.Value(userIDKey) != nil)
|
||||
}
|
||||
|
||||
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
||||
du := et.Sub(et)
|
||||
|
||||
t := &trace{
|
||||
Version: 1,
|
||||
StartTime: st,
|
||||
EndTime: et,
|
||||
Duration: du,
|
||||
Execution: execution{
|
||||
[]resolver{
|
||||
resolver{
|
||||
Path: []string{qc.Query.Select.Table},
|
||||
ParentType: "Query",
|
||||
FieldName: qc.Query.Select.Table,
|
||||
ReturnType: "object",
|
||||
StartOffset: 1,
|
||||
Duration: du,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
53
serv/vars.go
Normal file
53
serv/vars.go
Normal file
@ -0,0 +1,53 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
func varMap(ctx context.Context, vars variables) variables {
|
||||
userIDFn := func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
}
|
||||
|
||||
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
}
|
||||
|
||||
userIDTag := fasttemplate.TagFunc(userIDFn)
|
||||
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
|
||||
|
||||
vm := variables{
|
||||
"user_id": userIDTag,
|
||||
"user_id_provider": userIDProviderTag,
|
||||
"USER_ID": userIDTag,
|
||||
"USER_ID_PROVIDER": userIDProviderTag,
|
||||
}
|
||||
|
||||
for k, v := range vars {
|
||||
if _, ok := vm[k]; ok {
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
vm[k] = val
|
||||
case int:
|
||||
vm[k] = strconv.Itoa(val)
|
||||
case int64:
|
||||
vm[k] = strconv.FormatInt(val, 64)
|
||||
case float64:
|
||||
vm[k] = fmt.Sprintf("%.0f", val)
|
||||
}
|
||||
}
|
||||
return vm
|
||||
}
|
33
serv/ws.go
Normal file
33
serv/ws.go
Normal file
@ -0,0 +1,33 @@
|
||||
package serv
|
||||
|
||||
/*
|
||||
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Println("read:", err)
|
||||
break
|
||||
}
|
||||
fmt.Printf("recv: %s", message)
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
fmt.Println("write:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func serve(w http.ResponseWriter, r *http.Request) {
|
||||
// if websocket.IsWebSocketUpgrade(r) {
|
||||
// apiv1Ws(w, r)
|
||||
// return
|
||||
// }
|
||||
apiv1Http(w, r)
|
||||
}
|
||||
*/
|
Reference in New Issue
Block a user