Add nested mutations

This commit is contained in:
Vikram Rangnekar
2019-12-25 01:24:30 -05:00
parent 96ed3413fc
commit 6831d3f56f
23 changed files with 1617 additions and 404 deletions

View File

@ -3,6 +3,7 @@ package serv
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@ -48,7 +49,7 @@ func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int
func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
vars := make([]interface{}, len(args))
var fields map[string]interface{}
var fields map[string]json.RawMessage
var err error
if len(ctx.req.Vars) != 0 {
@ -86,10 +87,19 @@ func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
default:
if v, ok := fields[string(av)]; ok {
vars[i] = v
switch v[0] {
case '[', '{':
vars[i] = v
default:
var val interface{}
if err := json.Unmarshal(v, &val); err != nil {
return nil, err
}
vars[i] = val
}
} else {
return nil, fmt.Errorf("query requires variable $%s", string(av))
}
}
}

View File

@ -260,8 +260,14 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
}
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
userID := c.Value(userIDKey)
if userID == nil {
return "anon", nil
}
var role string
row := tx.QueryRow(c.Context, "_sg_get_role", c.req.role, 1)
row := tx.QueryRow(c.Context, "_sg_get_role", userID, c.req.role)
if err := row.Scan(&role); err != nil {
return "", err

View File

@ -15,7 +15,9 @@ func health(w http.ResponseWriter, _ *http.Request) {
return
}
ctx, _ := context.WithTimeout(context.Background(), conf.DB.PingTimeout)
ctx, cancel := context.WithTimeout(context.Background(), conf.DB.PingTimeout)
defer cancel()
if err := conn.Conn().Ping(ctx); err != nil {
errlog.Error().Err(err).Msg("error pinging database")
w.WriteHeader(http.StatusInternalServerError)

View File

@ -70,6 +70,12 @@ func prepareStmt(gql string, vars []byte) error {
qt := qcode.GetQType(gql)
q := []byte(gql)
if len(vars) == 0 {
logger.Debug().Msgf("Prepared statement:\n%s\n", gql)
} else {
logger.Debug().Msgf("Prepared statement:\n%s\n%s\n", vars, gql)
}
tx, err := db.Begin(context.Background())
if err != nil {
return err
@ -91,12 +97,16 @@ func prepareStmt(gql string, vars []byte) error {
return err
}
logger.Debug().Msg("Prepared statement role: user")
err = prepare(tx, stmts1, gqlHash(gql, vars, "user"))
if err != nil {
return err
}
if conf.isAnonRoleDefined() {
logger.Debug().Msg("Prepared statement for role: anon")
stmts2, err := buildRoleStmt(q, vars, "anon")
if err != nil {
return err
@ -110,6 +120,8 @@ func prepareStmt(gql string, vars []byte) error {
case qcode.QTMutation:
for _, role := range conf.Roles {
logger.Debug().Msgf("Prepared statement for role: %s", role.Name)
stmts, err := buildRoleStmt(q, vars, role.Name)
if err != nil {
return err
@ -122,12 +134,6 @@ func prepareStmt(gql string, vars []byte) error {
}
}
if len(vars) == 0 {
logger.Debug().Msgf("Building prepared statement for:\n %s", gql)
} else {
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
}
if err := tx.Commit(context.Background()); err != nil {
return err
}
@ -160,7 +166,11 @@ func prepareRoleStmt(tx pgx.Tx) error {
w := &bytes.Buffer{}
io.WriteString(w, `SELECT (CASE`)
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, role := range conf.Roles {
if len(role.Match) == 0 {
continue
@ -174,7 +184,8 @@ func prepareRoleStmt(tx pgx.Tx) error {
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
roleSQL, _ := processTemplate(w.String())