Add nested where clause to filter based on related tables

This commit is contained in:
Vikram Rangnekar
2019-11-04 23:44:42 -05:00
parent 77a51924a7
commit 89bc93e159
13 changed files with 358 additions and 206 deletions

View File

@ -8,19 +8,19 @@ import (
"github.com/dosco/super-graph/jsn"
)
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string))
return stringArg(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
return stringArg(w, v.(string))
}
io.WriteString(w, "null")
@ -28,7 +28,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringVar(w, v.(string))
return stringArg(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
@ -50,7 +50,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
}
if is {
return stringVarB(w, fields[0].Value)
return stringArgB(w, fields[0].Value)
}
w.Write(fields[0].Value)
@ -58,7 +58,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
}
}
func varList(ctx *coreContext, args [][]byte) []interface{} {
func argList(ctx *coreContext, args [][]byte) []interface{} {
vars := make([]interface{}, len(args))
var fields map[string]interface{}
@ -86,6 +86,11 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
vars[i] = v.(string)
}
case bytes.Equal(av, []byte("user_role")):
if v := ctx.Value(userRoleKey); v != nil {
vars[i] = v.(string)
}
default:
if v, ok := fields[string(av)]; ok {
vars[i] = v
@ -96,7 +101,7 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
return vars
}
func stringVar(w io.Writer, v string) (int, error) {
func stringArg(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
@ -106,7 +111,7 @@ func stringVar(w io.Writer, v string) (int, error) {
return w.Write([]byte(`'`))
}
func stringVarB(w io.Writer, v []byte) (int, error) {
func stringArgB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}

View File

@ -66,6 +66,7 @@ type config struct {
PoolSize int32 `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
SetUserID bool `mapstructure:"set_user_id"`
Vars map[string]string `mapstructure:"variables"`

View File

@ -122,10 +122,8 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, nil, err
}
}
@ -153,7 +151,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
var root []byte
vars := varList(c, ps.args)
vars := argList(c, ps.args)
if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
@ -206,7 +204,7 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c))
_, err = t.ExecuteFunc(buf, argMap(c))
if err == errNoUserID {
logger.Warn().Msg("no user id found. query requires an authenicated request")
@ -224,10 +222,8 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
stime = time.Now()
}
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, 0, err
}
}
@ -425,6 +421,15 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
return role, nil
}
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func (c *coreContext) render(w io.Writer, data []byte) error {
c.res.Data = json.RawMessage(data)
return json.NewEncoder(w).Encode(c.res)