Add insert mutation with bulk insert

This commit is contained in:
Vikram Rangnekar
2019-09-05 00:09:56 -04:00
parent 5b9105ff0c
commit c0a21e448f
30 changed files with 1080 additions and 265 deletions

View File

@ -1,6 +1,7 @@
package serv
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
@ -9,9 +10,15 @@ import (
"strings"
)
const (
AL_QUERY int = iota + 1
AL_VARS
)
type allowItem struct {
uri string
gql string
uri string
gql string
vars json.RawMessage
}
var _allowList allowList
@ -77,8 +84,9 @@ func (al *allowList) add(req *gqlReq) {
}
al.saveChan <- &allowItem{
uri: req.ref,
gql: req.Query,
uri: req.ref,
gql: req.Query,
vars: req.Vars,
}
}
@ -93,32 +101,62 @@ func (al *allowList) load() {
}
var uri string
var varBytes []byte
s, e, c := 0, 0, 0
ty := 0
for {
if c == 0 && b[e] == '#' {
s = e
for b[e] != '\n' && e < len(b) {
for e < len(b) && b[e] != '\n' {
e++
}
if (e - s) > 2 {
uri = strings.TrimSpace(string(b[(s + 1):e]))
}
}
if b[e] == '{' {
if e >= len(b) {
break
}
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
if c == 0 {
s = e
}
ty = AL_QUERY
} else if matchPrefix(b, e, "variables") {
if c == 0 {
s = e + len("variables") + 1
}
ty = AL_VARS
} else if b[e] == '{' {
c++
} else if b[e] == '}' {
c--
if c == 0 {
q := b[s:(e + 1)]
al.list[gqlHash(q)] = &allowItem{
uri: uri,
gql: string(q),
if ty == AL_QUERY {
q := string(b[s:(e + 1)])
item := &allowItem{
uri: uri,
gql: q,
}
if len(varBytes) != 0 {
item.vars = varBytes
}
al.list[gqlHash(q, varBytes)] = item
varBytes = nil
} else if ty == AL_VARS {
varBytes = b[s:(e + 1)]
}
ty = 0
}
}
@ -130,7 +168,7 @@ func (al *allowList) load() {
}
func (al *allowList) save(item *allowItem) {
al.list[gqlHash([]byte(item.gql))] = item
al.list[gqlHash(item.gql, item.vars)] = item
f, err := os.Create(al.filepath)
if err != nil {
@ -141,10 +179,10 @@ func (al *allowList) save(item *allowItem) {
defer f.Close()
keys := []string{}
urlMap := make(map[string][]string)
urlMap := make(map[string][]*allowItem)
for _, v := range al.list {
urlMap[v.uri] = append(urlMap[v.uri], v.gql)
urlMap[v.uri] = append(urlMap[v.uri], v)
}
for k := range urlMap {
@ -159,7 +197,28 @@ func (al *allowList) save(item *allowItem) {
f.WriteString(fmt.Sprintf("# %s\n\n", k))
for i := range v {
f.WriteString(fmt.Sprintf("query %s\n\n", v[i]))
if len(v[i].vars) != 0 {
vj, err := json.MarshalIndent(v[i].vars, "", "\t")
if err != nil {
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
continue
}
f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
}
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
}
}
}
func matchPrefix(b []byte, i int, s string) bool {
if (len(b) - i) < len(s) {
return false
}
for n := 0; n < len(s); n++ {
if b[(i+n)] != s[n] {
return false
}
}
return true
}

View File

@ -14,6 +14,7 @@ import (
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
@ -42,7 +43,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
if conf.UseAllowList {
var ps *preparedItem
data, ps, err = c.resolvePreparedSQL([]byte(c.req.Query))
data, ps, err = c.resolvePreparedSQL(c.req.Query)
if err != nil {
return err
}
@ -52,7 +53,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
} else {
qc, err = qcompile.CompileQuery([]byte(c.req.Query))
qc, err = qcompile.Compile([]byte(c.req.Query))
if err != nil {
return err
}
@ -67,7 +68,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
return c.render(w, data)
}
sel := qc.Query.Selects
sel := qc.Selects
h := xxhash.New()
// fetch the field name used within the db response json
@ -252,8 +253,8 @@ func (c *coreContext) resolveRemotes(
return to, cerr
}
func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql)]
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
if !ok {
return nil, nil, errUnauthorized
}
@ -266,17 +267,22 @@ func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, err
return nil, nil, err
}
fmt.Printf("PRE: %#v %#v\n", ps.stmt, vars)
fmt.Printf("PRE: %v\n", ps.stmt)
return []byte(root), ps, nil
}
func (c *coreContext) resolveSQL(qc *qcode.QCode) (
[]byte, uint32, error) {
stmt := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, stmt)
vars := make(map[string]json.RawMessage)
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
if err != nil {
return nil, 0, err
}
@ -284,7 +290,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset()
_, err = t.Execute(stmt, varMap(c))
_, err = t.ExecuteFunc(stmt, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
@ -317,10 +323,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
return nil, 0, err
}
if conf.EnableTracing && len(qc.Query.Selects) != 0 {
if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace(
qc.Query.Selects,
qc.Query.Selects[0].ID,
qc.Selects,
qc.Selects[0].ID,
st)
}

35
serv/core_test.go Normal file
View File

@ -0,0 +1,35 @@
package serv
/*
func simpleMutation(t *testing.T) {
gql := `mutation {
product(id: 15, insert: { name: "Test", price: 20.5 }) {
id
name
}
}`
sql := `test`
backgroundCtx := context.Background()
ctx := &coreContext{Context: backgroundCtx}
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
t.Fatal(err)
}
fmt.Println(">", string(resSQL))
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func TestCompileGQL(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("simpleMutation", simpleMutation)
}
*/

View File

@ -26,13 +26,13 @@ var (
)
type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Vars variables `json:"variables"`
OpName string `json:"operationName"`
Query string `json:"query"`
Vars json.RawMessage `json:"variables"`
ref string
}
type variables map[string]interface{}
type variables map[string]json.RawMessage
type gqlResp struct {
Error string `json:"error,omitempty"`

View File

@ -2,9 +2,11 @@ package serv
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
@ -12,7 +14,7 @@ import (
type preparedItem struct {
stmt *pg.Stmt
args []string
args [][]byte
skipped uint32
qc *qcode.QCode
}
@ -25,36 +27,46 @@ func initPreparedList() {
_preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list {
err := prepareStmt(k, v.gql)
err := prepareStmt(k, v.gql, v.vars)
if err != nil {
panic(err)
}
}
}
func prepareStmt(key, gql string) error {
func prepareStmt(key, gql string, varBytes json.RawMessage) error {
if len(gql) == 0 || len(key) == 0 {
return nil
}
qc, err := qcompile.CompileQuery([]byte(gql))
qc, err := qcompile.Compile([]byte(gql))
if err != nil {
return err
}
var vars map[string]json.RawMessage
if len(varBytes) != 0 {
vars = make(map[string]json.RawMessage)
if err := json.Unmarshal(varBytes, &vars); err != nil {
return err
}
}
buf := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, buf)
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
if err != nil {
return err
}
t := fasttemplate.New(buf.String(), `('{{`, `}}')`)
am := make([]string, 0, 5)
t := fasttemplate.New(buf.String(), `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
am = append(am, tag)
am = append(am, []byte(tag))
i++
return w.Write([]byte(fmt.Sprintf("$%d", i)))
})

View File

@ -4,8 +4,12 @@ import (
"bytes"
"crypto/sha1"
"encoding/hex"
"io"
"sort"
"strings"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
)
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
@ -17,8 +21,8 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v
}
func gqlHash(b []byte) string {
b = bytes.TrimSpace(b)
func gqlHash(b string, vars []byte) string {
b = strings.TrimSpace(b)
h := sha1.New()
s, e := 0, 0
@ -45,13 +49,27 @@ func gqlHash(b []byte) string {
if e != 0 {
b0 = b[(e - 1)]
}
h.Write(bytes.ToLower(b[s:e]))
io.WriteString(h, strings.ToLower(b[s:e]))
}
if e >= len(b) {
break
}
}
if vars == nil {
return hex.EncodeToString(h.Sum(nil))
}
fields := jsn.Keys([]byte(vars))
sort.Slice(fields, func(i, j int) bool {
return bytes.Compare(fields[i], fields[j]) == -1
})
for i := range fields {
h.Write(fields[i])
}
return hex.EncodeToString(h.Sum(nil))
}

View File

@ -6,7 +6,7 @@ import (
)
func TestRelaxHash1(t *testing.T) {
var v1 = []byte(`
var v1 = `
products(
limit: 30,
@ -14,18 +14,18 @@ func TestRelaxHash1(t *testing.T) {
id
name
price
}`)
}`
var v2 = []byte(`
var v2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `)
} `
h1 := gqlHash(v1)
h2 := gqlHash(v2)
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
@ -33,7 +33,7 @@ func TestRelaxHash1(t *testing.T) {
}
func TestRelaxHash2(t *testing.T) {
var v1 = []byte(`
var v1 = `
{
products(
limit: 30
@ -49,12 +49,119 @@ func TestRelaxHash2(t *testing.T) {
email
}
}
}`)
}`
var v2 = []byte(` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `)
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1)
h2 := gqlHash(v2)
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars1(t *testing.T) {
var q1 = `
products(
limit: 30,
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
}`
var v1 = `
{
"insert": {
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
},
"user": 123
}`
var q2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `
var v2 = `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
h1 := gqlHash(q1, []byte(va1))
h2 := gqlHash(q2, []byte(va2))
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars2(t *testing.T) {
var q1 = `
products(
limit: 30,
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
}`
var v1 = `
{
"insert": [{
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
},
{
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
}],
"user": 123
}`
var q2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `
var v2 = `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
h1 := gqlHash(q1, []byte(va1))
h2 := gqlHash(q2, []byte(va2))
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")

View File

@ -1,95 +1,107 @@
package serv
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
"github.com/valyala/fasttemplate"
"github.com/dosco/super-graph/jsn"
)
func varMap(ctx *coreContext) variables {
userIDFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDTag := fasttemplate.TagFunc(userIDFn)
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
vm := variables{
"user_id": userIDTag,
"user_id_provider": userIDProviderTag,
"USER_ID": userIDTag,
"USER_ID_PROVIDER": userIDProviderTag,
}
for k, v := range ctx.req.Vars {
var buf []byte
k = strings.ToLower(k)
if _, ok := vm[k]; ok {
continue
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
}
switch val := v.(type) {
case string:
vm[k] = val
case int:
vm[k] = strconv.AppendInt(buf, int64(val), 10)
case int64:
vm[k] = strconv.AppendInt(buf, val, 10)
case float64:
vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64)
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
if len(fields) == 0 {
return 0, fmt.Errorf("variable '%s' not found", tag)
}
is := false
for i := range fields[0].Value {
c := fields[0].Value[i]
if c != ' ' {
is = (c == '"') || (c == '{') || (c == '[')
break
}
}
if is {
return stringVarB(w, fields[0].Value)
}
w.Write(fields[0].Value)
return 0, nil
}
return vm
}
func varList(ctx *coreContext, args []string) []interface{} {
vars := make([]interface{}, 0, len(args))
func varList(ctx *coreContext, args [][]byte) []interface{} {
vars := make([]interface{}, len(args))
for k, v := range ctx.req.Vars {
ctx.req.Vars[strings.ToLower(k)] = v
var fields map[string]interface{}
var err error
if len(ctx.req.Vars) != 0 {
fields, _, err = jsn.Tree(ctx.req.Vars)
if err != nil {
logger.Warn().Err(err).Msg("Failed to parse variables")
}
}
for i := range args {
arg := strings.ToLower(args[i])
av := args[i]
if arg == "user_id" {
switch {
case bytes.Equal(av, []byte("user_id")):
if v := ctx.Value(userIDKey); v != nil {
vars = append(vars, v.(string))
vars[i] = v.(string)
}
}
if arg == "user_id_provider" {
case bytes.Equal(av, []byte("user_id_provider")):
if v := ctx.Value(userIDProviderKey); v != nil {
vars = append(vars, v.(string))
vars[i] = v.(string)
}
}
if v, ok := ctx.req.Vars[arg]; ok {
switch val := v.(type) {
case string:
vars = append(vars, val)
case int:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case int64:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case float64:
vars = append(vars, strconv.FormatFloat(val, 'f', -1, 64))
default:
if v, ok := fields[string(av)]; ok {
vars[i] = v
}
}
}
return vars
}
func stringVar(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write([]byte(v)); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}
func stringVarB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write(v); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}