super-graph/core/args.go

209 lines
3.9 KiB
Go
Raw Normal View History

package core
2019-04-19 07:55:03 +02:00
import (
2019-09-05 06:09:56 +02:00
"bytes"
2019-12-25 07:24:30 +01:00
"encoding/json"
2019-09-05 06:09:56 +02:00
"fmt"
2019-04-19 07:55:03 +02:00
"io"
2019-09-05 06:09:56 +02:00
"github.com/dosco/super-graph/jsn"
2019-04-19 07:55:03 +02:00
)
// argMap function is used to string replace variables with values by
// the fasttemplate code
func (c *scontext) argMap() func(w io.Writer, tag string) (int, error) {
2019-09-05 06:09:56 +02:00
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id_provider":
if v := c.Value(UserIDProviderKey); v != nil {
return io.WriteString(w, v.(string))
}
return 0, argErr("user_id_provider")
2019-09-05 06:09:56 +02:00
case "user_id":
if v := c.Value(UserIDKey); v != nil {
return io.WriteString(w, v.(string))
2019-09-05 06:09:56 +02:00
}
return 0, argErr("user_id")
case "user_role":
if v := c.Value(UserRoleKey); v != nil {
return io.WriteString(w, v.(string))
2019-09-05 06:09:56 +02:00
}
return 0, argErr("user_role")
2019-04-19 07:55:03 +02:00
}
fields := jsn.Get(c.vars, [][]byte{[]byte(tag)})
2020-01-15 05:16:55 +01:00
2019-09-05 06:09:56 +02:00
if len(fields) == 0 {
return 0, argErr(tag)
2019-09-05 06:09:56 +02:00
}
v := fields[0].Value
2020-02-23 21:29:50 +01:00
2020-05-24 23:43:54 +02:00
if isJsonScalarArray(v) {
return w.Write(jsonListToValues(v))
}
2020-02-23 21:29:50 +01:00
// Open and close quotes
if len(v) >= 2 && v[0] == '"' && v[len(v)-1] == '"' {
fields[0].Value = v[1 : len(v)-1]
}
2019-04-19 07:55:03 +02:00
if tag == "cursor" {
2020-02-23 21:29:50 +01:00
if bytes.EqualFold(v, []byte("null")) {
return io.WriteString(w, ``)
}
v1, err := c.sg.decrypt(string(fields[0].Value))
if err != nil {
return 0, err
}
return w.Write(v1)
}
return w.Write(escSQuote(fields[0].Value))
2019-04-19 07:55:03 +02:00
}
}
2019-07-29 07:13:33 +02:00
// argList function is used to create a list of arguments to pass
// to a prepared statement. FYI no escaping of single quotes is
// needed here
func (c *scontext) argList(args [][]byte) ([]interface{}, error) {
2019-09-05 06:09:56 +02:00
vars := make([]interface{}, len(args))
2019-12-25 07:24:30 +01:00
var fields map[string]json.RawMessage
2019-09-05 06:09:56 +02:00
var err error
2019-07-29 07:13:33 +02:00
if len(c.vars) != 0 {
fields, _, err = jsn.Tree(c.vars)
2019-09-05 06:09:56 +02:00
if err != nil {
2019-11-25 08:22:33 +01:00
return nil, err
2019-09-05 06:09:56 +02:00
}
2019-07-29 07:13:33 +02:00
}
for i := range args {
2019-09-05 06:09:56 +02:00
av := args[i]
switch {
case bytes.Equal(av, []byte("user_id")):
if v := c.Value(UserIDKey); v != nil {
2019-09-05 06:09:56 +02:00
vars[i] = v.(string)
2019-11-25 08:22:33 +01:00
} else {
return nil, argErr("user_id")
2019-07-29 07:13:33 +02:00
}
2019-09-05 06:09:56 +02:00
case bytes.Equal(av, []byte("user_id_provider")):
if v := c.Value(UserIDProviderKey); v != nil {
2019-09-05 06:09:56 +02:00
vars[i] = v.(string)
2019-11-25 08:22:33 +01:00
} else {
return nil, argErr("user_id_provider")
2019-07-29 07:13:33 +02:00
}
case bytes.Equal(av, []byte("user_role")):
if v := c.Value(UserRoleKey); v != nil {
vars[i] = v.(string)
2019-11-25 08:22:33 +01:00
} else {
return nil, argErr("user_role")
}
case bytes.Equal(av, []byte("cursor")):
if v, ok := fields["cursor"]; ok && v[0] == '"' {
v1, err := c.sg.decrypt(string(v[1 : len(v)-1]))
if err != nil {
return nil, err
}
vars[i] = v1
} else {
return nil, argErr("cursor")
}
2019-09-05 06:09:56 +02:00
default:
if v, ok := fields[string(av)]; ok {
2019-12-25 07:24:30 +01:00
switch v[0] {
case '[', '{':
2020-05-24 23:43:54 +02:00
if isJsonScalarArray(v) {
vars[i] = jsonListToValues(v)
} else {
vars[i] = v
}
2019-12-25 07:24:30 +01:00
default:
var val interface{}
if err := json.Unmarshal(v, &val); err != nil {
return nil, err
}
vars[i] = val
}
2019-11-25 08:22:33 +01:00
} else {
return nil, argErr(string(av))
2019-07-29 07:13:33 +02:00
}
}
}
2019-11-25 08:22:33 +01:00
return vars, nil
2019-07-29 07:13:33 +02:00
}
2020-01-14 07:02:12 +01:00
//
func escSQuote(b []byte) []byte {
var buf *bytes.Buffer
2020-01-14 07:02:12 +01:00
s := 0
for i := range b {
if b[i] == '\'' {
if buf == nil {
buf = &bytes.Buffer{}
}
2020-01-14 07:02:12 +01:00
buf.Write(b[s:i])
buf.WriteString(`''`)
s = i + 1
}
}
if buf == nil {
return b
}
2020-01-14 07:02:12 +01:00
l := len(b)
if s < (l - 1) {
buf.Write(b[s:l])
}
return buf.Bytes()
}
2020-05-24 23:43:54 +02:00
func isJsonScalarArray(b []byte) bool {
if b[0] != '[' || b[len(b)-1] != ']' {
return false
}
for i := range b {
switch b[i] {
case '{':
return false
case '[', ' ', '\t', '\n':
continue
default:
return true
}
}
return true
}
func jsonListToValues(b []byte) []byte {
s := 0
for i := 1; i < len(b)-1; i++ {
if b[i] == '"' && s%2 == 0 {
b[i] = '\''
}
if b[i] == '\\' {
s++
} else {
s = 0
}
}
return b[1 : len(b)-1]
}
func argErr(name string) error {
return fmt.Errorf("query requires variable '%s' to be set", name)
}