super-graph/core/core.go

399 lines
7.1 KiB
Go

package core
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode"
"github.com/valyala/fasttemplate"
)
type extensions struct {
Tracing *trace `json:"tracing,omitempty"`
}
type trace struct {
Version int `json:"version"`
StartTime time.Time `json:"startTime"`
EndTime time.Time `json:"endTime"`
Duration time.Duration `json:"duration"`
Execution execution `json:"execution"`
}
type execution struct {
Resolvers []resolver `json:"resolvers"`
}
type resolver struct {
Path []string `json:"path"`
ParentType string `json:"parentType"`
FieldName string `json:"fieldName"`
ReturnType string `json:"returnType"`
StartOffset int `json:"startOffset"`
Duration time.Duration `json:"duration"`
}
type scontext struct {
context.Context
sg *SuperGraph
query string
vars json.RawMessage
role string
res Result
}
func (sg *SuperGraph) initCompilers() error {
var err error
// If sg.di is not null then it's probably set
// for tests
if sg.dbinfo == nil {
sg.dbinfo, err = psql.GetDBInfo(sg.db)
if err != nil {
return err
}
}
if err = addTables(sg.conf, sg.dbinfo); err != nil {
return err
}
if err = addForeignKeys(sg.conf, sg.dbinfo); err != nil {
return err
}
sg.schema, err = psql.NewDBSchema(sg.dbinfo, getDBTableAliases(sg.conf))
if err != nil {
return err
}
sg.qc, err = qcode.NewCompiler(qcode.Config{
Blocklist: sg.conf.Blocklist,
})
if err != nil {
return err
}
if err := addRoles(sg.conf, sg.qc); err != nil {
return err
}
sg.pc = psql.NewCompiler(psql.Config{
Schema: sg.schema,
Vars: sg.conf.Vars,
})
return nil
}
func (c *scontext) execQuery() ([]byte, error) {
var data []byte
var st *stmt
var err error
if c.sg.conf.UseAllowList {
data, st, err = c.resolvePreparedSQL()
} else {
data, st, err = c.resolveSQL()
}
if err != nil {
return nil, err
}
if len(data) == 0 || st.skipped == 0 {
return data, nil
}
// return c.sg.execRemoteJoin(st, data, c.req.hdr)
return c.sg.execRemoteJoin(st, data, nil)
}
func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
var tx *sql.Tx
var err error
mutation := (c.res.op == qcode.QTMutation)
useRoleQuery := c.sg.abacEnabled && mutation
useTx := useRoleQuery || c.sg.conf.SetUserID
if useTx {
if tx, err = c.sg.db.BeginTx(c, nil); err != nil {
return nil, nil, err
}
defer tx.Rollback() //nolint: errcheck
}
if c.sg.conf.SetUserID {
if err := setLocalUserID(c, tx); err != nil {
return nil, nil, err
}
}
var role string
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(UserRoleKey); v != nil {
role = v.(string)
} else {
role = c.role
}
c.res.role = role
ps, ok := c.sg.prepared[stmtHash(c.res.name, role)]
if !ok {
return nil, nil, errNotFound
}
c.res.sql = ps.st.sql
var root []byte
var row *sql.Row
varsList, err := c.argList(ps.args)
if err != nil {
return nil, nil, err
}
if useTx {
row = tx.Stmt(ps.sd).QueryRow(varsList...)
} else {
row = ps.sd.QueryRow(varsList...)
}
if ps.roleArg {
err = row.Scan(&role, &root)
} else {
err = row.Scan(&root)
}
if err != nil {
return nil, nil, err
}
c.role = role
if useTx {
if err := tx.Commit(); err != nil {
return nil, nil, err
}
}
if root, err = c.sg.encryptCursor(ps.st.qc, root); err != nil {
return nil, nil, err
}
return root, &ps.st, nil
}
func (c *scontext) resolveSQL() ([]byte, *stmt, error) {
var tx *sql.Tx
var err error
mutation := (c.res.op == qcode.QTMutation)
useRoleQuery := c.sg.abacEnabled && mutation
useTx := useRoleQuery || c.sg.conf.SetUserID
if useTx {
if tx, err = c.sg.db.BeginTx(c, nil); err != nil {
return nil, nil, err
}
defer tx.Rollback() //nolint: errcheck
}
if c.sg.conf.SetUserID {
if err := setLocalUserID(c, tx); err != nil {
return nil, nil, err
}
}
if useRoleQuery {
if c.role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(UserRoleKey); v != nil {
c.role = v.(string)
}
stmts, err := c.sg.buildStmt(c.res.op, []byte(c.query), c.vars, c.role)
if err != nil {
return nil, nil, err
}
st := &stmts[0]
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, c.argMap())
if err != nil {
return nil, nil, err
}
finalSQL := buf.String()
// var stime time.Time
// if c.sg.conf.EnableTracing {
// stime = time.Now()
// }
var root []byte
var role string
var row *sql.Row
// defaultRole := c.role
if useTx {
row = tx.QueryRow(finalSQL)
} else {
row = c.sg.db.QueryRow(finalSQL)
}
if len(stmts) > 1 {
err = row.Scan(&role, &root)
} else {
err = row.Scan(&root)
}
c.res.sql = finalSQL
if len(role) == 0 {
c.res.role = c.role
} else {
c.res.role = role
}
if err != nil {
return nil, nil, err
}
if useTx {
if err := tx.Commit(); err != nil {
return nil, nil, err
}
}
if root, err = c.sg.encryptCursor(st.qc, root); err != nil {
return nil, nil, err
}
if c.sg.allowList.IsPersist() {
if err := c.sg.allowList.Set(c.vars, c.query, ""); err != nil {
return nil, nil, err
}
}
if len(stmts) > 1 {
if st = findStmt(role, stmts); st == nil {
return nil, nil, fmt.Errorf("invalid role '%s' returned", role)
}
}
// if c.sg.conf.EnableTracing {
// for _, id := range st.qc.Roots {
// c.addTrace(st.qc.Selects, id, stime)
// }
// }
return root, st, nil
}
func (c *scontext) executeRoleQuery(tx *sql.Tx) (string, error) {
userID := c.Value(UserIDKey)
if userID == nil {
return "anon", nil
}
var role string
row := c.sg.getRole.QueryRow(userID, c.role)
if err := row.Scan(&role); err != nil {
return "", err
}
return role, nil
}
func (r *Result) Operation() string {
return r.op.String()
}
func (r *Result) QueryName() string {
return r.name
}
func (r *Result) Role() string {
return r.role
}
func (r *Result) SQL() string {
return r.sql
}
// func (c *scontext) addTrace(sel []qcode.Select, id int32, st time.Time) {
// et := time.Now()
// du := et.Sub(st)
// if c.res.Extensions == nil {
// c.res.Extensions = &extensions{&trace{
// Version: 1,
// StartTime: st,
// Execution: execution{},
// }}
// }
// c.res.Extensions.Tracing.EndTime = et
// c.res.Extensions.Tracing.Duration = du
// n := 1
// for i := id; i != -1; i = sel[i].ParentID {
// n++
// }
// path := make([]string, n)
// n--
// for i := id; ; i = sel[i].ParentID {
// path[n] = sel[i].Name
// if sel[i].ParentID == -1 {
// break
// }
// n--
// }
// tr := resolver{
// Path: path,
// ParentType: "Query",
// FieldName: sel[id].Name,
// ReturnType: "object",
// StartOffset: 1,
// Duration: du,
// }
// c.res.Extensions.Tracing.Execution.Resolvers =
// append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
// }
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {
continue
}
return &stmts[i]
}
return nil
}