super-graph/serv/core.go

490 lines
8.7 KiB
Go
Raw Normal View History

2019-04-19 07:55:03 +02:00
package serv
import (
2019-05-13 01:27:26 +02:00
"bytes"
2019-04-19 07:55:03 +02:00
"context"
"encoding/json"
2019-06-02 01:48:42 +02:00
"errors"
2019-06-04 16:54:51 +02:00
"fmt"
2019-04-19 07:55:03 +02:00
"io"
2019-05-13 01:27:26 +02:00
"net/http"
"sync"
2019-04-19 07:55:03 +02:00
"time"
2019-05-13 01:27:26 +02:00
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
2019-09-05 06:09:56 +02:00
"github.com/dosco/super-graph/psql"
2019-04-20 06:35:57 +02:00
"github.com/dosco/super-graph/qcode"
2019-04-19 07:55:03 +02:00
"github.com/valyala/fasttemplate"
)
2019-05-13 01:27:26 +02:00
const (
empty = ""
2019-04-20 06:35:57 +02:00
)
2019-05-13 01:27:26 +02:00
type coreContext struct {
req gqlReq
res gqlResp
context.Context
}
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
2019-09-20 06:19:11 +02:00
c.req.ref = req.Referer()
c.req.hdr = req.Header
b, err := c.execQuery()
if err != nil {
return err
}
return c.render(w, b)
}
func (c *coreContext) execQuery() ([]byte, error) {
2019-04-20 06:35:57 +02:00
var err error
2019-07-29 07:13:33 +02:00
var skipped uint32
var qc *qcode.QCode
var data []byte
2019-04-19 07:55:03 +02:00
2019-07-29 07:13:33 +02:00
if conf.UseAllowList {
var ps *preparedItem
2019-09-05 06:09:56 +02:00
data, ps, err = c.resolvePreparedSQL(c.req.Query)
2019-07-29 07:13:33 +02:00
if err != nil {
2019-09-20 06:19:11 +02:00
return nil, err
2019-07-29 07:13:33 +02:00
}
skipped = ps.skipped
qc = ps.qc
} else {
2019-09-05 06:09:56 +02:00
qc, err = qcompile.Compile([]byte(c.req.Query))
2019-07-29 07:13:33 +02:00
if err != nil {
2019-09-20 06:19:11 +02:00
return nil, err
2019-07-29 07:13:33 +02:00
}
data, skipped, err = c.resolveSQL(qc)
if err != nil {
2019-09-20 06:19:11 +02:00
return nil, err
2019-07-29 07:13:33 +02:00
}
2019-05-13 01:27:26 +02:00
}
2019-04-19 07:55:03 +02:00
2019-05-13 01:27:26 +02:00
if len(data) == 0 || skipped == 0 {
2019-09-20 06:19:11 +02:00
return data, nil
2019-05-13 01:27:26 +02:00
}
2019-04-19 07:55:03 +02:00
2019-09-05 06:09:56 +02:00
sel := qc.Selects
2019-05-13 01:27:26 +02:00
h := xxhash.New()
// fetch the field name used within the db response json
// that are used to mark insertion points and the mapping between
// those field names and their select objects
fids, sfmap := parentFieldIds(h, sel, skipped)
// fetch the field values of the marked insertion points
// these values contain the id to be used with fetching remote data
from := jsn.Get(data, fids)
2019-06-02 01:48:42 +02:00
var to []jsn.Field
switch {
case len(from) == 1:
2019-09-20 06:19:11 +02:00
to, err = c.resolveRemote(c.req.hdr, h, from[0], sel, sfmap)
2019-06-02 01:48:42 +02:00
case len(from) > 1:
2019-09-20 06:19:11 +02:00
to, err = c.resolveRemotes(c.req.hdr, h, from, sel, sfmap)
2019-06-02 01:48:42 +02:00
default:
2019-09-20 06:19:11 +02:00
return nil, errors.New("something wrong no remote ids found in db response")
2019-06-02 01:48:42 +02:00
}
2019-06-04 16:54:51 +02:00
if err != nil {
2019-09-20 06:19:11 +02:00
return nil, err
2019-06-04 16:54:51 +02:00
}
2019-06-02 01:48:42 +02:00
var ob bytes.Buffer
err = jsn.Replace(&ob, data, from, to)
if err != nil {
2019-09-20 06:19:11 +02:00
return nil, err
2019-06-02 01:48:42 +02:00
}
2019-09-20 06:19:11 +02:00
return ob.Bytes(), nil
2019-06-02 01:48:42 +02:00
}
func (c *coreContext) resolveRemote(
2019-09-20 06:19:11 +02:00
hdr http.Header,
2019-06-02 01:48:42 +02:00
h *xxhash.Digest,
field jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
// replacement data for the marked insertion points
// key and value will be replaced by whats below
toA := [1]jsn.Field{}
to := toA[:1]
2019-06-02 01:48:42 +02:00
// use the json key to find the related Select object
k1 := xxhash.Sum64(field.Key)
s, ok := sfmap[k1]
if !ok {
return nil, nil
}
p := sel[s.ParentID]
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
r, ok := rmap[k2]
if !ok {
return nil, nil
}
id := jsn.Value(field.Value)
if len(id) == 0 {
return nil, nil
}
st := time.Now()
2019-09-20 06:19:11 +02:00
b, err := r.Fn(hdr, id)
2019-06-02 01:48:42 +02:00
if err != nil {
return nil, err
}
if conf.EnableTracing {
c.addTrace(sel, s.ID, st)
2019-06-02 01:48:42 +02:00
}
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
}
var ob bytes.Buffer
if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil {
return nil, err
}
} else {
ob.WriteString("null")
}
to[0] = jsn.Field{[]byte(s.FieldName), ob.Bytes()}
return to, nil
}
func (c *coreContext) resolveRemotes(
2019-09-20 06:19:11 +02:00
hdr http.Header,
2019-06-02 01:48:42 +02:00
h *xxhash.Digest,
from []jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
2019-05-13 01:27:26 +02:00
// replacement data for the marked insertion points
// key and value will be replaced by whats below
to := make([]jsn.Field, len(from))
2019-05-13 01:27:26 +02:00
var wg sync.WaitGroup
wg.Add(len(from))
var cerr error
for i, id := range from {
2019-06-02 01:48:42 +02:00
2019-05-13 01:27:26 +02:00
// use the json key to find the related Select object
k1 := xxhash.Sum64(id.Key)
s, ok := sfmap[k1]
if !ok {
2019-06-02 01:48:42 +02:00
return nil, nil
2019-04-20 06:35:57 +02:00
}
2019-05-13 01:27:26 +02:00
p := sel[s.ParentID]
2019-04-19 07:55:03 +02:00
2019-05-13 01:27:26 +02:00
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
2019-04-20 06:35:57 +02:00
2019-05-13 01:27:26 +02:00
r, ok := rmap[k2]
if !ok {
2019-06-02 01:48:42 +02:00
return nil, nil
2019-05-13 01:27:26 +02:00
}
2019-04-20 06:35:57 +02:00
2019-05-13 01:27:26 +02:00
id := jsn.Value(id.Value)
if len(id) == 0 {
2019-06-02 01:48:42 +02:00
return nil, nil
2019-04-20 06:35:57 +02:00
}
go func(n int, id []byte, s *qcode.Select) {
defer wg.Done()
2019-04-20 06:35:57 +02:00
st := time.Now()
2019-04-20 06:35:57 +02:00
2019-09-20 06:19:11 +02:00
b, err := r.Fn(hdr, id)
if err != nil {
2019-06-04 16:54:51 +02:00
cerr = fmt.Errorf("%s: %s", s.Table, err)
return
}
2019-05-13 01:27:26 +02:00
if conf.EnableTracing {
c.addTrace(sel, s.ID, st)
}
2019-05-13 01:27:26 +02:00
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
2019-05-13 06:05:08 +02:00
}
var ob bytes.Buffer
if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil {
2019-06-04 16:54:51 +02:00
cerr = fmt.Errorf("%s: %s", s.Table, err)
return
}
2019-05-13 01:27:26 +02:00
} else {
ob.WriteString("null")
}
to[n] = jsn.Field{[]byte(s.FieldName), ob.Bytes()}
}(i, id, s)
2019-05-13 01:27:26 +02:00
}
wg.Wait()
return to, cerr
2019-05-13 01:27:26 +02:00
}
2019-09-05 06:09:56 +02:00
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
2019-07-29 07:13:33 +02:00
if !ok {
return nil, nil, errUnauthorized
}
2019-05-13 01:27:26 +02:00
2019-09-26 06:35:31 +02:00
var root []byte
2019-07-29 07:13:33 +02:00
vars := varList(c, ps.args)
2019-09-26 06:35:31 +02:00
tx, err := db.Begin(c)
2019-07-29 07:13:33 +02:00
if err != nil {
return nil, nil, err
}
2019-09-26 06:35:31 +02:00
defer tx.Rollback(c)
2019-09-08 07:54:38 +02:00
if v := c.Value(userIDKey); v != nil {
2019-09-26 06:35:31 +02:00
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
2019-09-08 07:54:38 +02:00
if err != nil {
return nil, nil, err
}
}
2019-09-26 06:35:31 +02:00
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
2019-09-08 07:54:38 +02:00
if err != nil {
return nil, nil, err
}
2019-09-26 06:35:31 +02:00
if err := tx.Commit(c); err != nil {
2019-09-08 07:54:38 +02:00
return nil, nil, err
}
2019-09-05 06:09:56 +02:00
fmt.Printf("PRE: %v\n", ps.stmt)
2019-05-13 01:27:26 +02:00
2019-09-26 06:35:31 +02:00
return root, ps, nil
2019-07-29 07:13:33 +02:00
}
2019-05-13 01:27:26 +02:00
2019-09-26 06:35:31 +02:00
func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) {
var vars map[string]json.RawMessage
2019-06-02 01:48:42 +02:00
stmt := &bytes.Buffer{}
2019-09-26 06:35:31 +02:00
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
2019-09-05 06:09:56 +02:00
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
2019-05-13 01:27:26 +02:00
if err != nil {
return nil, 0, err
}
2019-06-02 01:48:42 +02:00
t := fasttemplate.New(stmt.String(), openVar, closeVar)
2019-05-13 01:27:26 +02:00
2019-06-02 01:48:42 +02:00
stmt.Reset()
2019-09-05 06:09:56 +02:00
_, err = t.ExecuteFunc(stmt, varMap(c))
2019-05-13 01:27:26 +02:00
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
2019-04-19 07:55:03 +02:00
}
2019-05-13 01:27:26 +02:00
if err != nil {
return nil, 0, err
}
2019-06-02 01:48:42 +02:00
finalSQL := stmt.String()
2019-05-13 01:27:26 +02:00
2019-09-26 06:35:31 +02:00
// if conf.LogLevel == "debug" {
// os.Stdout.WriteString(finalSQL)
// os.Stdout.WriteString("\n\n")
// }
2019-05-13 01:27:26 +02:00
var st time.Time
if conf.EnableTracing {
st = time.Now()
}
2019-04-19 07:55:03 +02:00
2019-09-26 06:35:31 +02:00
tx, err := db.Begin(c)
2019-09-08 07:54:38 +02:00
if err != nil {
return nil, 0, err
}
2019-09-26 06:35:31 +02:00
defer tx.Rollback(c)
2019-09-08 07:54:38 +02:00
if v := c.Value(userIDKey); v != nil {
2019-09-26 06:35:31 +02:00
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
2019-09-08 07:54:38 +02:00
if err != nil {
return nil, 0, err
}
}
2019-09-20 06:19:11 +02:00
//fmt.Printf("\nRAW: %#v\n", finalSQL)
2019-07-29 07:13:33 +02:00
2019-09-26 06:35:31 +02:00
var root []byte
2019-04-19 07:55:03 +02:00
2019-09-26 06:35:31 +02:00
err = tx.QueryRow(c, finalSQL).Scan(&root)
2019-04-19 07:55:03 +02:00
if err != nil {
2019-05-13 01:27:26 +02:00
return nil, 0, err
2019-04-19 07:55:03 +02:00
}
2019-09-26 06:35:31 +02:00
if err := tx.Commit(c); err != nil {
2019-09-08 07:54:38 +02:00
return nil, 0, err
}
2019-09-05 06:09:56 +02:00
if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace(
2019-09-05 06:09:56 +02:00
qc.Selects,
qc.Selects[0].ID,
st)
2019-05-13 01:27:26 +02:00
}
2019-04-19 07:55:03 +02:00
2019-07-29 07:13:33 +02:00
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
2019-09-26 06:35:31 +02:00
return root, skipped, nil
2019-05-13 01:27:26 +02:00
}
func (c *coreContext) render(w io.Writer, data []byte) error {
c.res.Data = json.RawMessage(data)
return json.NewEncoder(w).Encode(c.res)
}
2019-06-08 02:53:08 +02:00
func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
2019-05-13 06:05:08 +02:00
et := time.Now()
du := et.Sub(st)
if c.res.Extensions == nil {
c.res.Extensions = &extensions{&trace{
Version: 1,
StartTime: st,
Execution: execution{},
}}
}
c.res.Extensions.Tracing.EndTime = et
c.res.Extensions.Tracing.Duration = du
2019-06-08 02:53:08 +02:00
n := 1
for i := id; i != 0; i = sel[i].ParentID {
n++
}
path := make([]string, n)
2019-06-08 02:53:08 +02:00
n--
2019-06-08 02:53:08 +02:00
for i := id; ; i = sel[i].ParentID {
path[n] = sel[i].Table
2019-06-08 02:53:08 +02:00
if sel[i].ID == 0 {
break
}
n--
}
2019-05-13 06:05:08 +02:00
tr := resolver{
Path: path,
2019-05-13 06:05:08 +02:00
ParentType: "Query",
FieldName: sel[id].Table,
2019-05-13 06:05:08 +02:00
ReturnType: "object",
StartOffset: 1,
Duration: du,
}
c.res.Extensions.Tracing.Execution.Resolvers =
append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
}
2019-05-13 01:27:26 +02:00
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
[][]byte,
map[uint64]*qcode.Select) {
c := 0
for i := range sel {
s := &sel[i]
2019-06-08 02:53:08 +02:00
if isSkipped(skipped, uint32(s.ID)) {
2019-05-13 01:27:26 +02:00
c++
2019-04-20 06:35:57 +02:00
}
}
2019-05-13 01:27:26 +02:00
// list of keys (and it's related value) to extract from
// the db json response
fm := make([][]byte, c)
// mapping between the above extracted key and a Select
// object
sm := make(map[uint64]*qcode.Select, c)
n := 0
for i := range sel {
s := &sel[i]
2019-06-08 02:53:08 +02:00
if isSkipped(skipped, uint32(s.ID)) == false {
2019-05-13 01:27:26 +02:00
continue
}
p := sel[s.ParentID]
k := mkkey(h, s.Table, p.Table)
if r, ok := rmap[k]; ok {
fm[n] = r.IDField
n++
k := xxhash.Sum64(r.IDField)
sm[k] = s
}
}
return fm, sm
}
2019-06-08 02:53:08 +02:00
func isSkipped(n uint32, pos uint32) bool {
2019-05-13 01:27:26 +02:00
return ((n & (1 << pos)) != 0)
}
func authCheck(ctx *coreContext) bool {
return (ctx.Value(userIDKey) != nil)
}
2019-05-13 06:05:08 +02:00
func colsToList(cols []qcode.Column) []string {
var f []string
2019-05-13 01:27:26 +02:00
2019-05-13 06:05:08 +02:00
for i := range cols {
f = append(f, cols[i].Name)
2019-04-19 07:55:03 +02:00
}
2019-05-13 06:05:08 +02:00
return f
2019-04-19 07:55:03 +02:00
}