Compare commits

..

2 Commits

24 changed files with 225 additions and 160 deletions

View File

@ -85,10 +85,10 @@ type SuperGraph struct {
allowList *allow.List
encKey [32]byte
hashSeed maphash.Seed
queries map[uint64]*query
queries map[uint64]query
roles map[string]*Role
getRole *sql.Stmt
rmap map[uint64]*resolvFn
rmap map[uint64]resolvFn
abacEnabled bool
anonExists bool
qc *qcode.Compiler

View File

@ -12,7 +12,8 @@ import (
// to a prepared statement.
func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
vars := make([]interface{}, len(md.Params))
params := md.Params()
vars := make([]interface{}, len(params))
var fields map[string]json.RawMessage
var err error
@ -25,7 +26,7 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
}
}
for i, p := range md.Params {
for i, p := range params {
switch p.Name {
case "user_id":
if v := c.Value(UserIDKey); v != nil {

View File

@ -88,6 +88,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
stmts := make([]stmt, 0, len(sg.conf.Roles))
w := &bytes.Buffer{}
md := psql.Metadata{}
for i := 0; i < len(sg.conf.Roles); i++ {
role := &sg.conf.Roles[i]
@ -105,16 +106,18 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
stmts = append(stmts, stmt{role: role, qc: qc})
s := &stmts[len(stmts)-1]
s.md, err = sg.pc.Compile(w, qc, psql.Variables(vm))
md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
if err != nil {
return nil, err
}
s.sql = w.String()
s.md = md
w.Reset()
}
sql, err := sg.renderUserQuery(stmts)
sql, err := sg.renderUserQuery(md, stmts)
if err != nil {
return nil, err
}
@ -124,7 +127,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
}
//nolint: errcheck
func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
func (sg *SuperGraph) renderUserQuery(md psql.Metadata, stmts []stmt) (string, error) {
w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
@ -142,7 +145,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
}
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, sg.conf.RolesQuery)
md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
@ -158,7 +161,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
}
io.WriteString(w, ` ELSE 'user' END) FROM (`)
io.WriteString(w, sg.conf.RolesQuery)
md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)

View File

@ -125,7 +125,7 @@ func (c *scontext) execQuery() ([]byte, error) {
return nil, err
}
if len(data) == 0 || st.md.Skipped == 0 {
if len(data) == 0 || st.md.Skipped() == 0 {
return data, nil
}
@ -179,7 +179,7 @@ func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
}
if q.sd == nil {
q.Do(func() { c.sg.prepare(q, role) })
q.Do(func() { c.sg.prepare(&q, role) })
if q.err != nil {
return nil, nil, err
@ -196,6 +196,8 @@ func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
return nil, nil, err
}
fmt.Println(">>", varsList)
if useTx {
row = tx.Stmt(q.sd).QueryRow(varsList...)
} else {

View File

@ -75,13 +75,22 @@ func (sg *SuperGraph) initConfig() error {
if c.RolesQuery == "" {
sg.log.Printf("INF roles_query not defined: attribute based access control disabled")
} else {
n := 0
for k, v := range sg.roles {
if k == "user" || k == "anon" {
n++
} else if v.Match != "" {
n++
}
}
sg.abacEnabled = (n > 2)
if !sg.abacEnabled {
sg.log.Printf("WRN attribute based access control disabled: no custom roles found (with 'match' defined)")
}
}
_, userExists := sg.roles["user"]
_, sg.anonExists = sg.roles["anon"]
sg.abacEnabled = userExists && c.RolesQuery != ""
return nil
}

View File

@ -1,4 +1,3 @@
//nolint:errcheck
package psql
import (
@ -112,15 +111,15 @@ func (c *compilerContext) renderColumnSearchRank(sel *qcode.Select, ti *DBTableI
c.renderComma(columnsRendered)
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_rank(`)
_, _ = io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery(`)
_, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
} else {
io.WriteString(c.w, `, to_tsquery(`)
_, _ = io.WriteString(c.w, `, to_tsquery(`)
}
c.renderValueExp(Param{Name: arg.Val, Type: "string"})
io.WriteString(c.w, `))`)
c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
_, _ = io.WriteString(c.w, `))`)
alias(c.w, col.Name)
return nil
@ -137,15 +136,15 @@ func (c *compilerContext) renderColumnSearchHeadline(sel *qcode.Select, ti *DBTa
c.renderComma(columnsRendered)
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_headline(`)
_, _ = io.WriteString(c.w, `ts_headline(`)
colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery(`)
_, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
} else {
io.WriteString(c.w, `, to_tsquery(`)
_, _ = io.WriteString(c.w, `, to_tsquery(`)
}
c.renderValueExp(Param{Name: arg.Val, Type: "string"})
io.WriteString(c.w, `))`)
c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
_, _ = io.WriteString(c.w, `))`)
alias(c.w, col.Name)
return nil
@ -157,9 +156,9 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
}
c.renderComma(columnsRendered)
io.WriteString(c.w, `(`)
_, _ = io.WriteString(c.w, `(`)
squoted(c.w, ti.Name)
io.WriteString(c.w, ` :: text)`)
_, _ = io.WriteString(c.w, ` :: text)`)
alias(c.w, col.Name)
return nil
@ -169,9 +168,9 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`)
// io.WriteString(c.w, col.Name)
// io.WriteString(c.w, ` not defined'`)
// _, _ = io.WriteString(c.w, `'`)
// _, _ = io.WriteString(c.w, col.Name)
// _, _ = io.WriteString(c.w, ` not defined'`)
// alias(c.w, col.Name)
// }
@ -190,10 +189,10 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
c.renderComma(columnsRendered)
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Name, cn, col.Name)
io.WriteString(c.w, fn)
io.WriteString(c.w, `(`)
_, _ = io.WriteString(c.w, fn)
_, _ = io.WriteString(c.w, `(`)
colWithTable(c.w, ti.Name, cn)
io.WriteString(c.w, `)`)
_, _ = io.WriteString(c.w, `)`)
alias(c.w, col.Name)
return nil
@ -201,7 +200,7 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
func (c *compilerContext) renderComma(columnsRendered int) {
if columnsRendered != 0 {
io.WriteString(c.w, `, `)
_, _ = io.WriteString(c.w, `, `)
}
}

View File

@ -25,7 +25,7 @@ func (c *compilerContext) renderInsert(
if insert[0] == '[' {
io.WriteString(c.w, `json_array_elements(`)
}
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"})
c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
io.WriteString(c.w, ` :: json`)
if insert[0] == '[' {
io.WriteString(c.w, `)`)

View File

@ -0,0 +1,61 @@
package psql
import (
"io"
)
func (md *Metadata) RenderVar(w io.Writer, vv string) {
f, s := -1, 0
for i := range vv {
v := vv[i]
switch {
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
if (i - s) > 0 {
_, _ = io.WriteString(w, vv[s:i])
}
f = i
case (v < 'a' && v > 'z') &&
(v < 'A' && v > 'Z') &&
(v < '0' && v > '9') &&
v != '_' &&
f != -1 &&
(i-f) > 1:
md.renderValueExp(w, Param{Name: vv[f+1 : i]})
s = i
f = -1
}
}
if f != -1 && (len(vv)-f) > 1 {
md.renderValueExp(w, Param{Name: vv[f+1:]})
} else {
_, _ = io.WriteString(w, vv[s:])
}
}
func (md *Metadata) renderValueExp(w io.Writer, p Param) {
_, _ = io.WriteString(w, `$`)
if v, ok := md.pindex[p.Name]; ok {
int32String(w, int32(v))
} else {
md.params = append(md.params, p)
n := len(md.params)
if md.pindex == nil {
md.pindex = make(map[string]int)
}
md.pindex[p.Name] = n
int32String(w, int32(n))
}
}
func (md Metadata) Skipped() uint32 {
return md.skipped
}
func (md Metadata) Params() []Param {
return md.params
}

View File

@ -432,11 +432,11 @@ func (c *compilerContext) renderInsertUpdateColumns(
val := root.PresetMap[cn]
switch {
case ok && len(val) > 1 && val[0] == '$':
c.renderValueExp(Param{Name: val[1:], Type: col.Type})
c.md.renderValueExp(c.w, Param{Name: val[1:], Type: col.Type})
case ok && strings.HasPrefix(val, "sql:"):
io.WriteString(c.w, `(`)
c.renderVar(val[4:], c.renderValueExp)
c.md.RenderVar(c.w, val[4:])
io.WriteString(c.w, `)`)
case ok:

View File

@ -25,8 +25,8 @@ type Param struct {
}
type Metadata struct {
Skipped uint32
Params []Param
skipped uint32
params []Param
pindex map[string]int
}
@ -80,26 +80,30 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (Metadata, []byte
}
func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
return co.CompileWithMetadata(w, qc, vars, Metadata{})
}
func (co *Compiler) CompileWithMetadata(w io.Writer, qc *qcode.QCode, vars Variables, md Metadata) (Metadata, error) {
md.skipped = 0
if qc == nil {
return Metadata{}, fmt.Errorf("qcode is nil")
return md, fmt.Errorf("qcode is nil")
}
switch qc.Type {
case qcode.QTQuery:
return co.compileQuery(w, qc, vars)
return co.compileQueryWithMetadata(w, qc, vars, md)
case qcode.QTInsert,
qcode.QTUpdate,
qcode.QTDelete,
qcode.QTUpsert:
return co.compileMutation(w, qc, vars)
default:
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
}
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
}
func (co *Compiler) compileQuery(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
return co.compileQueryWithMetadata(w, qc, vars, Metadata{})
}
func (co *Compiler) compileQueryWithMetadata(
@ -176,7 +180,7 @@ func (co *Compiler) compileQueryWithMetadata(
}
for _, cid := range sel.Children {
if hasBit(c.md.Skipped, uint32(cid)) {
if hasBit(c.md.skipped, uint32(cid)) {
continue
}
child := &c.s[cid]
@ -354,7 +358,7 @@ func (c *compilerContext) initSelect(sel *qcode.Select, ti *DBTableInfo, vars Va
if _, ok := colmap[rel.Left.Col]; !ok {
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col})
colmap[rel.Left.Col] = struct{}{}
c.md.Skipped |= (1 << uint(id))
c.md.skipped |= (1 << uint(id))
}
default:
@ -622,7 +626,7 @@ func (c *compilerContext) renderJoinColumns(sel *qcode.Select, ti *DBTableInfo,
i := colsRendered
for _, id := range sel.Children {
if hasBit(c.md.Skipped, uint32(id)) {
if hasBit(c.md.skipped, uint32(id)) {
continue
}
childSel := &c.s[id]
@ -804,7 +808,7 @@ func (c *compilerContext) renderCursorCTE(sel *qcode.Select) error {
quoted(c.w, ob.Col)
}
io.WriteString(c.w, ` FROM string_to_array(`)
c.renderValueExp(Param{Name: "cursor", Type: "json"})
c.md.renderValueExp(c.w, Param{Name: "cursor", Type: "json"})
io.WriteString(c.w, `, ',') as a) `)
return nil
}
@ -1102,7 +1106,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, ti *DBTableInfo) error {
} else {
io.WriteString(c.w, `) @@ to_tsquery(`)
}
c.renderValueExp(Param{Name: ex.Val, Type: "string"})
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: "string"})
io.WriteString(c.w, `))`)
return nil
@ -1191,7 +1195,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
switch {
case ok && strings.HasPrefix(val, "sql:"):
io.WriteString(c.w, `(`)
c.renderVar(val[4:], c.renderValueExp)
c.md.RenderVar(c.w, val[4:])
io.WriteString(c.w, `)`)
case ok:
@ -1199,7 +1203,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
case ex.Op == qcode.OpIn || ex.Op == qcode.OpNotIn:
io.WriteString(c.w, `(ARRAY(SELECT json_array_elements_text(`)
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: true})
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: true})
io.WriteString(c.w, `))`)
io.WriteString(c.w, ` :: `)
@ -1208,7 +1212,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
return
default:
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: false})
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: false})
}
case qcode.ValRef:
@ -1222,54 +1226,6 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type)
}
func (c *compilerContext) renderValueExp(p Param) {
io.WriteString(c.w, `$`)
if v, ok := c.md.pindex[p.Name]; ok {
int32String(c.w, int32(v))
} else {
c.md.Params = append(c.md.Params, p)
n := len(c.md.Params)
if c.md.pindex == nil {
c.md.pindex = make(map[string]int)
}
c.md.pindex[p.Name] = n
int32String(c.w, int32(n))
}
}
func (c *compilerContext) renderVar(vv string, fn func(Param)) {
f, s := -1, 0
for i := range vv {
v := vv[i]
switch {
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
if (i - s) > 0 {
io.WriteString(c.w, vv[s:i])
}
f = i
case (v < 'a' && v > 'z') &&
(v < 'A' && v > 'Z') &&
(v < '0' && v > '9') &&
v != '_' &&
f != -1 &&
(i-f) > 1:
fn(Param{Name: vv[f+1 : i]})
s = i
f = -1
}
}
if f != -1 && (len(vv)-f) > 1 {
fn(Param{Name: vv[f+1:]})
} else {
io.WriteString(c.w, vv[s:])
}
}
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch {
case strings.HasPrefix(fn, "avg_"):

View File

@ -22,7 +22,7 @@ func (c *compilerContext) renderUpdate(
}
io.WriteString(c.w, `WITH "_sg_input" AS (SELECT `)
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"})
c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
// io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, ` :: json AS j)`)

View File

@ -2,8 +2,9 @@ package qcode
import (
"errors"
"github.com/chirino/graphql/schema"
"testing"
"github.com/chirino/graphql/schema"
)
func TestCompile1(t *testing.T) {
@ -130,6 +131,22 @@ updateThread {
}
func TestFragmentsCompile(t *testing.T) {
gql := `
fragment userFields on user {
name
email
}
query { users { ...userFields } }`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "anon")
if err == nil {
t.Fatal(errors.New("expecting an error"))
}
}
var gql = []byte(`
{products(
# returns only 30 items

View File

@ -63,7 +63,7 @@ func (sg *SuperGraph) initPrepared() error {
return fmt.Errorf("role query: %w", err)
}
sg.queries = make(map[uint64]*query)
sg.queries = make(map[uint64]query)
list, err := sg.allowList.Load()
if err != nil {
@ -77,22 +77,19 @@ func (sg *SuperGraph) initPrepared() error {
if len(v.Query) == 0 {
continue
}
q := &query{ai: v, qt: qcode.GetQType(v.Query)}
qt := qcode.GetQType(v.Query)
switch q.qt {
switch qt {
case qcode.QTQuery:
sg.queries[queryID(&h, v.Name, "user")] = q
h.Reset()
sg.queries[queryID(&h, v.Name, "user")] = query{ai: v, qt: qt}
if sg.anonExists {
sg.queries[queryID(&h, v.Name, "anon")] = q
h.Reset()
sg.queries[queryID(&h, v.Name, "anon")] = query{ai: v, qt: qt}
}
case qcode.QTMutation:
for _, role := range sg.conf.Roles {
sg.queries[queryID(&h, v.Name, role.Name)] = q
h.Reset()
sg.queries[queryID(&h, v.Name, role.Name)] = query{ai: v, qt: qt}
}
}
}
@ -128,7 +125,7 @@ func (sg *SuperGraph) prepareRoleStmt() error {
}
io.WriteString(w, ` ELSE $2 END) FROM (`)
io.WriteString(w, sg.conf.RolesQuery)
io.WriteString(w, rq)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
@ -166,5 +163,8 @@ func (sg *SuperGraph) initAllowList() error {
func queryID(h *maphash.Hash, name string, role string) uint64 {
h.WriteString(name)
h.WriteString(role)
return h.Sum64()
v := h.Sum64()
h.Reset()
return v
}

View File

@ -4,10 +4,10 @@ import (
"bytes"
"errors"
"fmt"
"hash/maphash"
"net/http"
"sync"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/core/internal/qcode"
"github.com/dosco/super-graph/jsn"
)
@ -16,12 +16,13 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
var err error
sel := st.qc.Selects
h := xxhash.New()
h := maphash.Hash{}
h.SetSeed(sg.hashSeed)
// fetch the field name used within the db response json
// that are used to mark insertion points and the mapping between
// those field names and their select objects
fids, sfmap := sg.parentFieldIds(h, sel, st.md.Skipped)
fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped())
// fetch the field values of the marked insertion points
// these values contain the id to be used with fetching remote data
@ -30,10 +31,10 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
switch {
case len(from) == 1:
to, err = sg.resolveRemote(hdr, h, from[0], sel, sfmap)
to, err = sg.resolveRemote(hdr, &h, from[0], sel, sfmap)
case len(from) > 1:
to, err = sg.resolveRemotes(hdr, h, from, sel, sfmap)
to, err = sg.resolveRemotes(hdr, &h, from, sel, sfmap)
default:
return nil, errors.New("something wrong no remote ids found in db response")
@ -55,7 +56,7 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
func (sg *SuperGraph) resolveRemote(
hdr http.Header,
h *xxhash.Digest,
h *maphash.Hash,
field jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
@ -66,7 +67,8 @@ func (sg *SuperGraph) resolveRemote(
to := toA[:1]
// use the json key to find the related Select object
k1 := xxhash.Sum64(field.Key)
_, _ = h.Write(field.Key)
k1 := h.Sum64()
s, ok := sfmap[k1]
if !ok {
@ -117,7 +119,7 @@ func (sg *SuperGraph) resolveRemote(
func (sg *SuperGraph) resolveRemotes(
hdr http.Header,
h *xxhash.Digest,
h *maphash.Hash,
from []jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
@ -134,7 +136,8 @@ func (sg *SuperGraph) resolveRemotes(
for i, id := range from {
// use the json key to find the related Select object
k1 := xxhash.Sum64(id.Key)
_, _ = h.Write(id.Key)
k1 := h.Sum64()
s, ok := sfmap[k1]
if !ok {
@ -192,7 +195,7 @@ func (sg *SuperGraph) resolveRemotes(
return to, cerr
}
func (sg *SuperGraph) parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
func (sg *SuperGraph) parentFieldIds(h *maphash.Hash, sel []qcode.Select, skipped uint32) (
[][]byte,
map[uint64]*qcode.Select) {
@ -227,8 +230,8 @@ func (sg *SuperGraph) parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipp
fm[n] = r.IDField
n++
k := xxhash.Sum64(r.IDField)
sm[k] = s
_, _ = h.Write(r.IDField)
sm[h.Sum64()] = s
}
}

View File

@ -2,11 +2,11 @@ package core
import (
"fmt"
"hash/maphash"
"io/ioutil"
"net/http"
"strings"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/jsn"
)
@ -19,7 +19,7 @@ type resolvFn struct {
func (sg *SuperGraph) initResolvers() error {
var err error
sg.rmap = make(map[uint64]*resolvFn)
sg.rmap = make(map[uint64]resolvFn)
for _, t := range sg.conf.Tables {
err = sg.initRemotes(t)
@ -36,7 +36,8 @@ func (sg *SuperGraph) initResolvers() error {
}
func (sg *SuperGraph) initRemotes(t Table) error {
h := xxhash.New()
h := maphash.Hash{}
h.SetSeed(sg.hashSeed)
for _, r := range t.Remotes {
// defines the table column to be used as an id in the
@ -75,17 +76,18 @@ func (sg *SuperGraph) initRemotes(t Table) error {
path = append(path, []byte(p))
}
rf := &resolvFn{
rf := resolvFn{
IDField: []byte(idk),
Path: path,
Fn: fn,
}
// index resolver obj by parent and child names
sg.rmap[mkkey(h, r.Name, t.Name)] = rf
sg.rmap[mkkey(&h, r.Name, t.Name)] = rf
// index resolver obj by IDField
sg.rmap[xxhash.Sum64(rf.IDField)] = rf
_, _ = h.Write(rf.IDField)
sg.rmap[h.Sum64()] = rf
}
return nil

View File

@ -1,11 +1,9 @@
package core
import (
"github.com/cespare/xxhash/v2"
)
import "hash/maphash"
// nolint: errcheck
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
func mkkey(h *maphash.Hash, k1 string, k2 string) uint64 {
h.WriteString(k1)
h.WriteString(k2)
v := h.Sum64()

View File

@ -36,8 +36,8 @@ module.exports = {
position: "left",
},
{
label: "Art Compute",
href: "https://artcompute.com/s/super-graph",
label: "AbtCode",
href: "https://abtcode.com/s/super-graph",
position: "left",
},
],

3
go.mod
View File

@ -12,13 +12,11 @@ require (
github.com/adjust/gorails v0.0.0-20171013043634-2786ed0c03d3
github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b
github.com/brianvoe/gofakeit/v5 v5.2.0
github.com/cespare/xxhash/v2 v2.1.1
github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a
github.com/daaku/go.zipexe v1.0.1 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dlclark/regexp2 v1.2.0 // indirect
github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0
github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813 // indirect
github.com/fsnotify/fsnotify v1.4.9
github.com/garyburd/redigo v1.6.0
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
@ -30,7 +28,6 @@ require (
github.com/openzipkin/zipkin-go v0.2.2
github.com/pelletier/go-toml v1.7.0 // indirect
github.com/pkg/errors v0.9.1
github.com/prometheus/common v0.4.0
github.com/rs/cors v1.7.0
github.com/spf13/afero v1.2.2 // indirect
github.com/spf13/cast v1.3.1 // indirect

4
go.sum
View File

@ -55,8 +55,6 @@ github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a h1:WVu7r2vwlrBVmunbSSU+9/3M3AgsQyhE49CKDjHiFq4=
github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a/go.mod h1:wQjjxFMFyMlsWh4Z3nMuHQtevD4Ul9UVQSnz1JOLuP8=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
@ -87,8 +85,6 @@ github.com/dlclark/regexp2 v1.2.0 h1:8sAhBGEM0dRWogWqWyQeIJnxjWO6oIjl8FKqREDsGfk
github.com/dlclark/regexp2 v1.2.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0 h1:EfFAcaAwGai/wlDCWwIObHBm3T2C2CCPX/SaS0fpOJ4=
github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0/go.mod h1:Mw6PkjjMXWbTj+nnj4s3QPXq1jaT0s5pC0iFD4+BOAA=
github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813 h1:NgO45/5mBLRVfiXerEFzH6ikcZ7DNRPS639xFg3ENzU=
github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=

View File

@ -82,8 +82,6 @@ func graphQLFunc(sg *core.SuperGraph, query string, data interface{}, opt map[st
if v, ok := opt["user_id"]; ok && len(v) != 0 {
ct = context.WithValue(ct, core.UserIDKey, v)
} else {
ct = context.WithValue(ct, core.UserIDKey, "-1")
}
// var role string

13
jsn/bench.1 Normal file
View File

@ -0,0 +1,13 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/jsn
BenchmarkGet
BenchmarkGet-16 13898 85293 ns/op 3328 B/op 2 allocs/op
BenchmarkFilter
BenchmarkFilter-16 189328 6341 ns/op 448 B/op 1 allocs/op
BenchmarkStrip
BenchmarkStrip-16 219765 5543 ns/op 224 B/op 1 allocs/op
BenchmarkReplace
BenchmarkReplace-16 100899 12022 ns/op 416 B/op 1 allocs/op
PASS
ok github.com/dosco/super-graph/jsn 6.029s

View File

@ -2,17 +2,19 @@ package jsn
import (
"bytes"
"github.com/cespare/xxhash/v2"
"hash/maphash"
)
// Filter function filters the JSON keeping only the provided keys and removing all others
func Filter(w *bytes.Buffer, b []byte, keys []string) error {
var err error
kmap := make(map[uint64]struct{}, len(keys))
h := maphash.Hash{}
for i := range keys {
kmap[xxhash.Sum64String(keys[i])] = struct{}{}
_, _ = h.WriteString(keys[i])
kmap[h.Sum64()] = struct{}{}
h.Reset()
}
// is an list
@ -132,7 +134,11 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error {
cb := b[s:(e + 1)]
e = 0
if _, ok := kmap[xxhash.Sum64(k)]; !ok {
_, _ = h.Write(k)
_, ok := kmap[h.Sum64()]
h.Reset()
if !ok {
continue
}

View File

@ -1,7 +1,7 @@
package jsn
import (
"github.com/cespare/xxhash/v2"
"hash/maphash"
)
const (
@ -41,9 +41,12 @@ func Value(b []byte) []byte {
// Keys function fetches values for the provided keys
func Get(b []byte, keys [][]byte) []Field {
kmap := make(map[uint64]struct{}, len(keys))
h := maphash.Hash{}
for i := range keys {
kmap[xxhash.Sum64(keys[i])] = struct{}{}
_, _ = h.Write(keys[i])
kmap[h.Sum64()] = struct{}{}
h.Reset()
}
res := make([]Field, 0, 20)
@ -141,7 +144,9 @@ func Get(b []byte, keys [][]byte) []Field {
}
if e != 0 {
_, ok := kmap[xxhash.Sum64(k)]
_, _ = h.Write(k)
_, ok := kmap[h.Sum64()]
h.Reset()
if ok {
res = append(res, Field{k, b[s:(e + 1)]})

View File

@ -3,8 +3,7 @@ package jsn
import (
"bytes"
"errors"
"github.com/cespare/xxhash/v2"
"hash/maphash"
)
// Replace function replaces key-value pairs provided in the `from` argument with those in the `to` argument
@ -18,7 +17,7 @@ func Replace(w *bytes.Buffer, b []byte, from, to []Field) error {
return err
}
h := xxhash.New()
h := maphash.Hash{}
tmap := make(map[uint64]int, len(from))
for i, f := range from {