super-graph/serv/http.go
2019-04-08 02:47:59 -04:00

240 lines
4.8 KiB
Go

package serv
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/gorilla/websocket"
"github.com/valyala/fasttemplate"
)
const (
introspectionQuery = "IntrospectionQuery"
openVar = "{{"
closeVar = "}}"
)
var (
upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available")
)
type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Variables map[string]string `json:"variables"`
}
type gqlResp struct {
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Extensions *extensions `json:"extensions,omitempty"`
}
type extensions struct {
Tracing *trace `json:"tracing,omitempty"`
}
type trace struct {
Version int `json:"version"`
StartTime time.Time `json:"startTime"`
EndTime time.Time `json:"endTime"`
Duration time.Duration `json:"duration"`
Execution execution `json:"execution"`
}
type execution struct {
Resolvers []resolver `json:"resolvers"`
}
type resolver struct {
Path []string `json:"path"`
ParentType string `json:"parentType"`
FieldName string `json:"fieldName"`
ReturnType string `json:"returnType"`
StartOffset int `json:"startOffset"`
Duration time.Duration `json:"duration"`
}
func apiv1Http(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if authFailBlock == authFailBlockAlways && authCheck(ctx) == false {
http.Error(w, "Not authorized", 401)
return
}
b, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
errorResp(w, err)
return
}
req := &gqlReq{}
if err := json.Unmarshal(b, req); err != nil {
errorResp(w, err)
return
}
if strings.EqualFold(req.OpName, introspectionQuery) {
dat, err := ioutil.ReadFile("test.schema")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(dat)
return
}
qc, err := qcompile.CompileQuery(req.Query)
if err != nil {
errorResp(w, err)
return
}
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
errorResp(w, err)
return
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varValues(ctx))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
http.Error(w, "Not authorized", 401)
return
}
if err != nil {
errorResp(w, err)
return
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
errorResp(w, err)
return
}
et := time.Now()
resp := gqlResp{}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}
resp.Data = json.RawMessage(root)
json.NewEncoder(w).Encode(resp)
}
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
}
}
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func varValues(ctx context.Context) map[string]interface{} {
uidFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
uidpFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
return map[string]interface{}{
"USER_ID": uidFn,
"user_id": uidFn,
"USER_ID_PROVIDER": uidpFn,
"user_id_provider": uidpFn,
}
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}