super-graph/serv/core.go

403 lines
7.2 KiB
Go
Raw Normal View History

2019-04-19 07:55:03 +02:00
package serv
import (
2019-05-13 01:27:26 +02:00
"bytes"
2019-04-19 07:55:03 +02:00
"context"
"encoding/json"
2019-06-02 01:48:42 +02:00
"errors"
2019-06-04 16:54:51 +02:00
"fmt"
2019-04-19 07:55:03 +02:00
"io"
2019-05-13 01:27:26 +02:00
"net/http"
2019-04-19 07:55:03 +02:00
"time"
2019-05-13 01:27:26 +02:00
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/allow"
2019-04-20 06:35:57 +02:00
"github.com/dosco/super-graph/qcode"
2020-02-10 07:45:37 +01:00
"github.com/jackc/pgx/v4"
2019-04-19 07:55:03 +02:00
"github.com/valyala/fasttemplate"
)
2019-05-13 01:27:26 +02:00
type coreContext struct {
req gqlReq
res gqlResp
context.Context
}
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
2019-09-20 06:19:11 +02:00
c.req.ref = req.Referer()
c.req.hdr = req.Header
2019-11-25 08:22:33 +01:00
if len(c.req.Vars) == 2 {
c.req.Vars = nil
}
2019-10-15 08:30:19 +02:00
if authCheck(c) {
c.req.role = "user"
2019-10-15 08:30:19 +02:00
} else {
c.req.role = "anon"
2019-10-15 08:30:19 +02:00
}
b, err := c.execQuery()
2019-09-20 06:19:11 +02:00
if err != nil {
return err
}
return c.render(w, b)
}
func (c *coreContext) execQuery() ([]byte, error) {
2019-07-29 07:13:33 +02:00
var data []byte
2019-11-25 08:22:33 +01:00
var st *stmt
var err error
2019-04-19 07:55:03 +02:00
2019-11-07 08:37:24 +01:00
if conf.Production {
2019-11-25 08:22:33 +01:00
data, st, err = c.resolvePreparedSQL()
2019-07-29 07:13:33 +02:00
if err != nil {
2019-11-25 08:22:33 +01:00
logger.Error().
Err(err).
Str("default_role", c.req.role).
Msg(c.req.Query)
2019-07-29 07:13:33 +02:00
2019-11-25 08:22:33 +01:00
return nil, errors.New("query failed. check logs for error")
}
2019-07-29 07:13:33 +02:00
} else {
2019-11-25 08:22:33 +01:00
if data, st, err = c.resolveSQL(); err != nil {
2019-09-20 06:19:11 +02:00
return nil, err
2019-07-29 07:13:33 +02:00
}
2019-05-13 01:27:26 +02:00
}
2019-04-19 07:55:03 +02:00
2019-11-25 08:22:33 +01:00
return execRemoteJoin(st, data, c.req.hdr)
}
2019-11-25 08:22:33 +01:00
func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
var tx pgx.Tx
var err error
2019-11-25 08:22:33 +01:00
qt := qcode.GetQType(c.req.Query)
mutation := (qt == qcode.QTMutation)
2019-12-10 06:03:44 +01:00
useRoleQuery := conf.isABACEnabled() && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
2019-12-03 05:08:35 +01:00
if tx, err = db.Begin(c.Context); err != nil {
return nil, nil, err
}
defer tx.Rollback(c) //nolint: errcheck
}
if conf.DB.SetUserID {
2019-12-03 05:08:35 +01:00
if err := setLocalUserID(c.Context, tx); err != nil {
return nil, nil, err
}
}
var role string
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
role = v.(string)
2019-11-25 08:22:33 +01:00
} else {
role = c.req.role
}
ps, ok := _preparedList[stmtHash(allow.QueryName(c.req.Query), role)]
if !ok {
return nil, nil, errUnauthorized
}
var root []byte
var row pgx.Row
2019-11-25 08:22:33 +01:00
vars, err := argList(c, ps.args)
if err != nil {
return nil, nil, err
}
if useTx {
2019-12-03 05:08:35 +01:00
row = tx.QueryRow(c.Context, ps.sd.SQL, vars...)
} else {
2019-12-03 05:08:35 +01:00
row = db.QueryRow(c.Context, ps.sd.SQL, vars...)
}
2019-12-10 06:03:44 +01:00
if ps.roleArg {
err = row.Scan(&role, &root)
2019-12-10 06:03:44 +01:00
} else {
err = row.Scan(&root)
}
2019-11-19 06:47:55 +01:00
2019-11-25 08:22:33 +01:00
if len(role) == 0 {
logger.Debug().Str("default_role", c.req.role).Msg(c.req.Query)
} else {
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
}
2019-11-19 06:47:55 +01:00
if err != nil {
return nil, nil, err
}
2019-11-19 06:47:55 +01:00
c.req.role = role
if useTx {
2019-12-03 05:08:35 +01:00
if err := tx.Commit(c.Context); err != nil {
return nil, nil, err
}
}
2019-12-10 06:03:44 +01:00
return root, &ps.st, nil
}
2019-11-25 08:22:33 +01:00
func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
var tx pgx.Tx
var err error
2019-11-25 08:22:33 +01:00
qt := qcode.GetQType(c.req.Query)
mutation := (qt == qcode.QTMutation)
2019-12-10 06:03:44 +01:00
useRoleQuery := conf.isABACEnabled() && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
2019-12-03 05:08:35 +01:00
if tx, err = db.Begin(c.Context); err != nil {
2019-11-25 08:22:33 +01:00
return nil, nil, err
}
2019-12-03 05:08:35 +01:00
defer tx.Rollback(c.Context) //nolint: errcheck
}
if conf.DB.SetUserID {
2019-12-03 05:08:35 +01:00
if err := setLocalUserID(c.Context, tx); err != nil {
2019-11-25 08:22:33 +01:00
return nil, nil, err
}
}
if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
2019-11-25 08:22:33 +01:00
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
c.req.role = v.(string)
}
2019-11-25 08:22:33 +01:00
stmts, err := buildStmt(qt, []byte(c.req.Query), c.req.Vars, c.req.role)
if err != nil {
2019-11-25 08:22:33 +01:00
return nil, nil, err
}
2019-11-25 08:22:33 +01:00
st := &stmts[0]
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
2019-11-25 08:22:33 +01:00
_, err = t.ExecuteFunc(buf, argMap(c, c.req.Vars))
if err != nil {
2019-11-25 08:22:33 +01:00
return nil, nil, err
}
finalSQL := buf.String()
var stime time.Time
if conf.EnableTracing {
stime = time.Now()
}
var root []byte
2019-11-25 08:22:33 +01:00
var role string
var row pgx.Row
2019-11-19 06:47:55 +01:00
2019-11-25 08:22:33 +01:00
defaultRole := c.req.role
if useTx {
2019-12-03 05:08:35 +01:00
row = tx.QueryRow(c.Context, finalSQL)
} else {
2019-12-03 05:08:35 +01:00
row = db.QueryRow(c.Context, finalSQL)
}
2019-12-10 06:03:44 +01:00
if len(stmts) > 1 {
err = row.Scan(&role, &root)
2019-12-10 06:03:44 +01:00
} else {
err = row.Scan(&root)
}
2019-11-19 06:47:55 +01:00
2019-11-25 08:22:33 +01:00
if len(role) == 0 {
logger.Debug().Str("default_role", defaultRole).Msg(c.req.Query)
} else {
logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query)
}
2019-11-19 06:47:55 +01:00
if err != nil {
2019-11-25 08:22:33 +01:00
return nil, nil, err
}
if useTx {
2019-12-03 05:08:35 +01:00
if err := tx.Commit(c.Context); err != nil {
2019-11-25 08:22:33 +01:00
return nil, nil, err
2019-11-19 06:47:55 +01:00
}
}
2020-02-10 07:45:37 +01:00
if root, err = encryptCursor(st.qc, root); err != nil {
return nil, nil, err
}
if allowList.IsPersist() {
if err := allowList.Set(c.req.Vars, c.req.Query, c.req.ref); err != nil {
return nil, nil, err
}
}
2019-06-02 01:48:42 +02:00
2019-11-25 08:22:33 +01:00
if len(stmts) > 1 {
if st = findStmt(role, stmts); st == nil {
return nil, nil, fmt.Errorf("invalid role '%s' returned", role)
2019-06-02 01:48:42 +02:00
}
}
2019-11-25 08:22:33 +01:00
if conf.EnableTracing {
for _, id := range st.qc.Roots {
c.addTrace(st.qc.Selects, id, stime)
2019-04-20 06:35:57 +02:00
}
2019-05-13 01:27:26 +02:00
}
2019-11-25 08:22:33 +01:00
return root, st, nil
2019-05-13 01:27:26 +02:00
}
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
2019-12-25 07:24:30 +01:00
userID := c.Value(userIDKey)
if userID == nil {
return "anon", nil
}
var role string
2019-12-25 07:24:30 +01:00
row := tx.QueryRow(c.Context, "_sg_get_role", userID, c.req.role)
2019-04-19 07:55:03 +02:00
if err := row.Scan(&role); err != nil {
return "", err
2019-07-29 07:13:33 +02:00
}
return role, nil
2019-05-13 01:27:26 +02:00
}
func (c *coreContext) render(w io.Writer, data []byte) error {
c.res.Data = json.RawMessage(data)
return json.NewEncoder(w).Encode(c.res)
}
2019-06-08 02:53:08 +02:00
func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
2019-05-13 06:05:08 +02:00
et := time.Now()
du := et.Sub(st)
if c.res.Extensions == nil {
c.res.Extensions = &extensions{&trace{
Version: 1,
StartTime: st,
Execution: execution{},
}}
}
c.res.Extensions.Tracing.EndTime = et
c.res.Extensions.Tracing.Duration = du
2019-06-08 02:53:08 +02:00
n := 1
2019-11-19 06:47:55 +01:00
for i := id; i != -1; i = sel[i].ParentID {
n++
}
path := make([]string, n)
2019-06-08 02:53:08 +02:00
n--
2019-06-08 02:53:08 +02:00
for i := id; ; i = sel[i].ParentID {
path[n] = sel[i].Name
2019-11-19 06:47:55 +01:00
if sel[i].ParentID == -1 {
2019-06-08 02:53:08 +02:00
break
}
n--
}
2019-05-13 06:05:08 +02:00
tr := resolver{
Path: path,
2019-05-13 06:05:08 +02:00
ParentType: "Query",
FieldName: sel[id].Name,
2019-05-13 06:05:08 +02:00
ReturnType: "object",
StartOffset: 1,
Duration: du,
}
c.res.Extensions.Tracing.Execution.Resolvers =
append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
}
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(context.Background(), fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
2019-05-13 01:27:26 +02:00
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
[][]byte,
map[uint64]*qcode.Select) {
c := 0
for i := range sel {
s := &sel[i]
2019-06-08 02:53:08 +02:00
if isSkipped(skipped, uint32(s.ID)) {
2019-05-13 01:27:26 +02:00
c++
2019-04-20 06:35:57 +02:00
}
}
2019-05-13 01:27:26 +02:00
// list of keys (and it's related value) to extract from
// the db json response
fm := make([][]byte, c)
// mapping between the above extracted key and a Select
// object
sm := make(map[uint64]*qcode.Select, c)
n := 0
for i := range sel {
s := &sel[i]
if !isSkipped(skipped, uint32(s.ID)) {
2019-05-13 01:27:26 +02:00
continue
}
p := sel[s.ParentID]
k := mkkey(h, s.Name, p.Name)
2019-05-13 01:27:26 +02:00
if r, ok := rmap[k]; ok {
fm[n] = r.IDField
n++
k := xxhash.Sum64(r.IDField)
sm[k] = s
}
}
return fm, sm
}
2019-06-08 02:53:08 +02:00
func isSkipped(n uint32, pos uint32) bool {
2019-05-13 01:27:26 +02:00
return ((n & (1 << pos)) != 0)
}
func authCheck(ctx *coreContext) bool {
return (ctx.Value(userIDKey) != nil)
}
2019-05-13 06:05:08 +02:00
func colsToList(cols []qcode.Column) []string {
var f []string
2019-05-13 01:27:26 +02:00
2019-05-13 06:05:08 +02:00
for i := range cols {
f = append(f, cols[i].Name)
2019-04-19 07:55:03 +02:00
}
2019-05-13 06:05:08 +02:00
return f
2019-04-19 07:55:03 +02:00
}