Add nested mutations
This commit is contained in:
16
serv/args.go
16
serv/args.go
@ -3,6 +3,7 @@ package serv
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -48,7 +49,7 @@ func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int
|
||||
func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
|
||||
vars := make([]interface{}, len(args))
|
||||
|
||||
var fields map[string]interface{}
|
||||
var fields map[string]json.RawMessage
|
||||
var err error
|
||||
|
||||
if len(ctx.req.Vars) != 0 {
|
||||
@ -86,10 +87,19 @@ func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
|
||||
|
||||
default:
|
||||
if v, ok := fields[string(av)]; ok {
|
||||
vars[i] = v
|
||||
switch v[0] {
|
||||
case '[', '{':
|
||||
vars[i] = v
|
||||
default:
|
||||
var val interface{}
|
||||
if err := json.Unmarshal(v, &val); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vars[i] = val
|
||||
}
|
||||
|
||||
} else {
|
||||
return nil, fmt.Errorf("query requires variable $%s", string(av))
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -260,8 +260,14 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||
}
|
||||
|
||||
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||
userID := c.Value(userIDKey)
|
||||
|
||||
if userID == nil {
|
||||
return "anon", nil
|
||||
}
|
||||
|
||||
var role string
|
||||
row := tx.QueryRow(c.Context, "_sg_get_role", c.req.role, 1)
|
||||
row := tx.QueryRow(c.Context, "_sg_get_role", userID, c.req.role)
|
||||
|
||||
if err := row.Scan(&role); err != nil {
|
||||
return "", err
|
||||
|
@ -15,7 +15,9 @@ func health(w http.ResponseWriter, _ *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, _ := context.WithTimeout(context.Background(), conf.DB.PingTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), conf.DB.PingTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := conn.Conn().Ping(ctx); err != nil {
|
||||
errlog.Error().Err(err).Msg("error pinging database")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
@ -70,6 +70,12 @@ func prepareStmt(gql string, vars []byte) error {
|
||||
qt := qcode.GetQType(gql)
|
||||
q := []byte(gql)
|
||||
|
||||
if len(vars) == 0 {
|
||||
logger.Debug().Msgf("Prepared statement:\n%s\n", gql)
|
||||
} else {
|
||||
logger.Debug().Msgf("Prepared statement:\n%s\n%s\n", vars, gql)
|
||||
}
|
||||
|
||||
tx, err := db.Begin(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
@ -91,12 +97,16 @@ func prepareStmt(gql string, vars []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Prepared statement role: user")
|
||||
|
||||
err = prepare(tx, stmts1, gqlHash(gql, vars, "user"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if conf.isAnonRoleDefined() {
|
||||
logger.Debug().Msg("Prepared statement for role: anon")
|
||||
|
||||
stmts2, err := buildRoleStmt(q, vars, "anon")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -110,6 +120,8 @@ func prepareStmt(gql string, vars []byte) error {
|
||||
|
||||
case qcode.QTMutation:
|
||||
for _, role := range conf.Roles {
|
||||
logger.Debug().Msgf("Prepared statement for role: %s", role.Name)
|
||||
|
||||
stmts, err := buildRoleStmt(q, vars, role.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -122,12 +134,6 @@ func prepareStmt(gql string, vars []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
if len(vars) == 0 {
|
||||
logger.Debug().Msgf("Building prepared statement for:\n %s", gql)
|
||||
} else {
|
||||
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
|
||||
}
|
||||
|
||||
if err := tx.Commit(context.Background()); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -160,7 +166,11 @@ func prepareRoleStmt(tx pgx.Tx) error {
|
||||
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
io.WriteString(w, `SELECT (CASE`)
|
||||
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) THEN `)
|
||||
|
||||
io.WriteString(w, `(SELECT (CASE`)
|
||||
for _, role := range conf.Roles {
|
||||
if len(role.Match) == 0 {
|
||||
continue
|
||||
@ -174,7 +184,8 @@ func prepareRoleStmt(tx pgx.Tx) error {
|
||||
|
||||
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
|
||||
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
||||
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
|
||||
|
||||
roleSQL, _ := processTemplate(w.String())
|
||||
|
||||
|
Reference in New Issue
Block a user