Compare commits

..

9 Commits

44 changed files with 1426 additions and 687 deletions

View File

@ -12,8 +12,7 @@ FROM golang:1.14-alpine as go-build
RUN apk update && \ RUN apk update && \
apk add --no-cache make && \ apk add --no-cache make && \
apk add --no-cache git && \ apk add --no-cache git && \
apk add --no-cache jq && \ apk add --no-cache jq
apk add --no-cache upx=3.95-r2
RUN GO111MODULE=off go get -u github.com/rafaelsq/wtc RUN GO111MODULE=off go get -u github.com/rafaelsq/wtc

View File

@ -49,6 +49,7 @@ import (
"crypto/sha256" "crypto/sha256"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"hash/maphash"
_log "log" _log "log"
"os" "os"
@ -83,10 +84,11 @@ type SuperGraph struct {
schema *psql.DBSchema schema *psql.DBSchema
allowList *allow.List allowList *allow.List
encKey [32]byte encKey [32]byte
prepared map[string]*preparedItem hashSeed maphash.Seed
queries map[uint64]query
roles map[string]*Role roles map[string]*Role
getRole *sql.Stmt getRole *sql.Stmt
rmap map[uint64]*resolvFn rmap map[uint64]resolvFn
abacEnabled bool abacEnabled bool
anonExists bool anonExists bool
qc *qcode.Compiler qc *qcode.Compiler
@ -107,10 +109,11 @@ func newSuperGraph(conf *Config, db *sql.DB, dbinfo *psql.DBInfo) (*SuperGraph,
} }
sg := &SuperGraph{ sg := &SuperGraph{
conf: conf, conf: conf,
db: db, db: db,
dbinfo: dbinfo, dbinfo: dbinfo,
log: _log.New(os.Stdout, "", 0), log: _log.New(os.Stdout, "", 0),
hashSeed: maphash.MakeSeed(),
} }
if err := sg.initConfig(); err != nil { if err := sg.initConfig(); err != nil {

View File

@ -12,7 +12,8 @@ import (
// to a prepared statement. // to a prepared statement.
func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) { func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
vars := make([]interface{}, len(md.Params)) params := md.Params()
vars := make([]interface{}, len(params))
var fields map[string]json.RawMessage var fields map[string]json.RawMessage
var err error var err error
@ -25,7 +26,7 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
} }
} }
for i, p := range md.Params { for i, p := range params {
switch p.Name { switch p.Name {
case "user_id": case "user_id":
if v := c.Value(UserIDKey); v != nil { if v := c.Value(UserIDKey); v != nil {

41
core/bench.11 Normal file
View File

@ -0,0 +1,41 @@
INF roles_query not defined: attribute based access control disabled
all expectations were already fulfilled, call to Query 'SELECT jsonb_build_object('users', "__sj_0"."json", 'products', "__sj_1"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_1"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "products_1"."id" AS "id", "products_1"."name" AS "name", "__sj_2"."json" AS "customers", "__sj_3"."json" AS "user" FROM (SELECT "products"."id", "products"."name", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_1" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "users_3"."full_name" AS "full_name", "users_3"."phone" AS "phone", "users_3"."email" AS "email" FROM (SELECT "users"."full_name", "users"."phone", "users"."email" FROM "users" WHERE ((("users"."id") = ("products_1"."user_id"))) LIMIT ('1') :: integer) AS "users_3") AS "__sr_3") AS "__sj_3" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "customers_2"."id" AS "id", "customers_2"."email" AS "email" FROM (SELECT "customers"."id", "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_1"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1") AS "__sj_1", (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."name" AS "name" FROM (SELECT "users"."id" FROM "users" GROUP BY "users"."id" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"' with args [] was not expected
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/core
BenchmarkGraphQL-16 INF roles_query not defined: attribute based access control disabled
all expectations were already fulfilled, call to Query 'SELECT jsonb_build_object('users', "__sj_0"."json", 'products', "__sj_1"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_1"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "products_1"."id" AS "id", "products_1"."name" AS "name", "__sj_2"."json" AS "customers", "__sj_3"."json" AS "user" FROM (SELECT "products"."id", "products"."name", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_1" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "users_3"."full_name" AS "full_name", "users_3"."phone" AS "phone", "users_3"."email" AS "email" FROM (SELECT "users"."full_name", "users"."phone", "users"."email" FROM "users" WHERE ((("users"."id") = ("products_1"."user_id"))) LIMIT ('1') :: integer) AS "users_3") AS "__sr_3") AS "__sj_3" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "customers_2"."id" AS "id", "customers_2"."email" AS "email" FROM (SELECT "customers"."id", "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_1"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1") AS "__sj_1", (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."name" AS "name" FROM (SELECT "users"."id" FROM "users" GROUP BY "users"."id" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"' with args [] was not expected
INF roles_query not defined: attribute based access control disabled
all expectations were already fulfilled, call to Query 'SELECT jsonb_build_object('users', "__sj_0"."json", 'products', "__sj_1"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_1"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "products_1"."id" AS "id", "products_1"."name" AS "name", "__sj_2"."json" AS "customers", "__sj_3"."json" AS "user" FROM (SELECT "products"."id", "products"."name", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_1" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "users_3"."full_name" AS "full_name", "users_3"."phone" AS "phone", "users_3"."email" AS "email" FROM (SELECT "users"."full_name", "users"."phone", "users"."email" FROM "users" WHERE ((("users"."id") = ("products_1"."user_id"))) LIMIT ('1') :: integer) AS "users_3") AS "__sr_3") AS "__sj_3" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "customers_2"."id" AS "id", "customers_2"."email" AS "email" FROM (SELECT "customers"."id", "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_1"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1") AS "__sj_1", (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."name" AS "name" FROM (SELECT "users"."id" FROM "users" GROUP BY "users"."id" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"' with args [] was not expected
INF roles_query not defined: attribute based access control disabled
all expectations were already fulfilled, call to Query 'SELECT jsonb_build_object('users', "__sj_0"."json", 'products', "__sj_1"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_1"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "products_1"."id" AS "id", "products_1"."name" AS "name", "__sj_2"."json" AS "customers", "__sj_3"."json" AS "user" FROM (SELECT "products"."id", "products"."name", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_1" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "users_3"."full_name" AS "full_name", "users_3"."phone" AS "phone", "users_3"."email" AS "email" FROM (SELECT "users"."full_name", "users"."phone", "users"."email" FROM "users" WHERE ((("users"."id") = ("products_1"."user_id"))) LIMIT ('1') :: integer) AS "users_3") AS "__sr_3") AS "__sj_3" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "customers_2"."id" AS "id", "customers_2"."email" AS "email" FROM (SELECT "customers"."id", "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_1"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1") AS "__sj_1", (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."name" AS "name" FROM (SELECT "users"."id" FROM "users" GROUP BY "users"."id" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"' with args [] was not expected
105048 10398 ns/op 18342 B/op 55 allocs/op
PASS
ok github.com/dosco/super-graph/core 1.328s
PASS
ok github.com/dosco/super-graph/core/internal/allow 0.088s
? github.com/dosco/super-graph/core/internal/crypto [no test files]
? github.com/dosco/super-graph/core/internal/integration_tests [no test files]
PASS
ok github.com/dosco/super-graph/core/internal/integration_tests/cockroachdb 0.121s
PASS
ok github.com/dosco/super-graph/core/internal/integration_tests/postgresql 0.118s
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/core/internal/psql
BenchmarkCompile-16 79845 14428 ns/op 4584 B/op 39 allocs/op
BenchmarkCompileParallel-16 326205 3918 ns/op 4633 B/op 39 allocs/op
PASS
ok github.com/dosco/super-graph/core/internal/psql 2.696s
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/core/internal/qcode
BenchmarkQCompile-16 146953 8049 ns/op 3756 B/op 28 allocs/op
BenchmarkQCompileP-16 475936 2447 ns/op 3790 B/op 28 allocs/op
BenchmarkParse-16 140811 8163 ns/op 3902 B/op 18 allocs/op
BenchmarkParseP-16 571345 2041 ns/op 3903 B/op 18 allocs/op
BenchmarkSchemaParse-16 230715 5012 ns/op 3968 B/op 57 allocs/op
BenchmarkSchemaParseP-16 802426 1565 ns/op 3968 B/op 57 allocs/op
PASS
ok github.com/dosco/super-graph/core/internal/qcode 8.427s
? github.com/dosco/super-graph/core/internal/util [no test files]

View File

@ -88,6 +88,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
stmts := make([]stmt, 0, len(sg.conf.Roles)) stmts := make([]stmt, 0, len(sg.conf.Roles))
w := &bytes.Buffer{} w := &bytes.Buffer{}
md := psql.Metadata{}
for i := 0; i < len(sg.conf.Roles); i++ { for i := 0; i < len(sg.conf.Roles); i++ {
role := &sg.conf.Roles[i] role := &sg.conf.Roles[i]
@ -105,16 +106,18 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
stmts = append(stmts, stmt{role: role, qc: qc}) stmts = append(stmts, stmt{role: role, qc: qc})
s := &stmts[len(stmts)-1] s := &stmts[len(stmts)-1]
s.md, err = sg.pc.Compile(w, qc, psql.Variables(vm)) md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.sql = w.String() s.sql = w.String()
s.md = md
w.Reset() w.Reset()
} }
sql, err := sg.renderUserQuery(stmts) sql, err := sg.renderUserQuery(md, stmts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,7 +127,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
} }
//nolint: errcheck //nolint: errcheck
func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) { func (sg *SuperGraph) renderUserQuery(md psql.Metadata, stmts []stmt) (string, error) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
@ -142,7 +145,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
} }
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`) io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, sg.conf.RolesQuery) md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) THEN `) io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`) io.WriteString(w, `(SELECT (CASE`)
@ -158,7 +161,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
} }
io.WriteString(w, ` ELSE 'user' END) FROM (`) io.WriteString(w, ` ELSE 'user' END) FROM (`)
io.WriteString(w, sg.conf.RolesQuery) md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `) io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)

View File

@ -197,30 +197,26 @@ func (c *Config) AddRoleTable(role string, table string, conf interface{}) error
// ReadInConfig function reads in the config file for the environment specified in the GO_ENV // ReadInConfig function reads in the config file for the environment specified in the GO_ENV
// environment variable. This is the best way to create a new Super Graph config. // environment variable. This is the best way to create a new Super Graph config.
func ReadInConfig(configFile string) (*Config, error) { func ReadInConfig(configFile string) (*Config, error) {
cpath := path.Dir(configFile) cp := path.Dir(configFile)
cfile := path.Base(configFile) vi := newViper(cp, path.Base(configFile))
vi := newViper(cpath, cfile)
if err := vi.ReadInConfig(); err != nil { if err := vi.ReadInConfig(); err != nil {
return nil, err return nil, err
} }
inherits := vi.GetString("inherits") if pcf := vi.GetString("inherits"); pcf != "" {
cf := vi.ConfigFileUsed()
if inherits != "" { vi = newViper(cp, pcf)
vi = newViper(cpath, inherits)
if err := vi.ReadInConfig(); err != nil { if err := vi.ReadInConfig(); err != nil {
return nil, err return nil, err
} }
if vi.IsSet("inherits") { if v := vi.GetString("inherits"); v != "" {
return nil, fmt.Errorf("inherited config (%s) cannot itself inherit (%s)", return nil, fmt.Errorf("inherited config (%s) cannot itself inherit (%s)", pcf, v)
inherits,
vi.GetString("inherits"))
} }
vi.SetConfigName(cfile) vi.SetConfigFile(cf)
if err := vi.MergeInConfig(); err != nil { if err := vi.MergeInConfig(); err != nil {
return nil, err return nil, err
@ -234,7 +230,7 @@ func ReadInConfig(configFile string) (*Config, error) {
} }
if c.AllowListFile == "" { if c.AllowListFile == "" {
c.AllowListFile = path.Join(cpath, "allow.list") c.AllowListFile = path.Join(cp, "allow.list")
} }
return c, nil return c, nil
@ -248,7 +244,7 @@ func newViper(configPath, configFile string) *viper.Viper {
vi.AutomaticEnv() vi.AutomaticEnv()
if filepath.Ext(configFile) != "" { if filepath.Ext(configFile) != "" {
vi.SetConfigFile(configFile) vi.SetConfigFile(path.Join(configPath, configFile))
} else { } else {
vi.SetConfigName(configFile) vi.SetConfigName(configFile)
vi.AddConfigPath(configPath) vi.AddConfigPath(configPath)

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"hash/maphash"
"time" "time"
"github.com/dosco/super-graph/core/internal/psql" "github.com/dosco/super-graph/core/internal/psql"
@ -124,7 +125,7 @@ func (c *scontext) execQuery() ([]byte, error) {
return nil, err return nil, err
} }
if len(data) == 0 || st.md.Skipped == 0 { if len(data) == 0 || st.md.Skipped() == 0 {
return data, nil return data, nil
} }
@ -165,32 +166,43 @@ func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
} else { } else {
role = c.role role = c.role
} }
c.res.role = role c.res.role = role
ps, ok := c.sg.prepared[stmtHash(c.res.name, role)] h := maphash.Hash{}
h.SetSeed(c.sg.hashSeed)
q, ok := c.sg.queries[queryID(&h, c.res.name, role)]
if !ok { if !ok {
return nil, nil, errNotFound return nil, nil, errNotFound
} }
c.res.sql = ps.st.sql
if q.sd == nil {
q.Do(func() { c.sg.prepare(&q, role) })
if q.err != nil {
return nil, nil, err
}
}
c.res.sql = q.st.sql
var root []byte var root []byte
var row *sql.Row var row *sql.Row
varsList, err := c.argList(ps.st.md) varsList, err := c.argList(q.st.md)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if useTx { if useTx {
row = tx.Stmt(ps.sd).QueryRow(varsList...) row = tx.Stmt(q.sd).QueryRow(varsList...)
} else { } else {
row = ps.sd.QueryRow(varsList...) row = q.sd.QueryRow(varsList...)
} }
if ps.roleArg { if q.roleArg {
err = row.Scan(&role, &root) err = row.Scan(&role, &root)
} else { } else {
err = row.Scan(&root) err = row.Scan(&root)
@ -204,15 +216,15 @@ func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
if useTx { if useTx {
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return nil, nil, err return nil, nil, q.err
} }
} }
if root, err = c.sg.encryptCursor(ps.st.qc, root); err != nil { if root, err = c.sg.encryptCursor(q.st.qc, root); err != nil {
return nil, nil, err return nil, nil, err
} }
return root, &ps.st, nil return root, &q.st, nil
} }
func (c *scontext) resolveSQL() ([]byte, *stmt, error) { func (c *scontext) resolveSQL() ([]byte, *stmt, error) {

View File

@ -75,13 +75,22 @@ func (sg *SuperGraph) initConfig() error {
if c.RolesQuery == "" { if c.RolesQuery == "" {
sg.log.Printf("INF roles_query not defined: attribute based access control disabled") sg.log.Printf("INF roles_query not defined: attribute based access control disabled")
} else {
n := 0
for k, v := range sg.roles {
if k == "user" || k == "anon" {
n++
} else if v.Match != "" {
n++
}
}
sg.abacEnabled = (n > 2)
if !sg.abacEnabled {
sg.log.Printf("WRN attribute based access control disabled: no custom roles found (with 'match' defined)")
}
} }
_, userExists := sg.roles["user"]
_, sg.anonExists = sg.roles["anon"]
sg.abacEnabled = userExists && c.RolesQuery != ""
return nil return nil
} }

View File

@ -10,21 +10,23 @@ import (
"os" "os"
"sort" "sort"
"strings" "strings"
"text/scanner"
"github.com/chirino/graphql/schema" "github.com/chirino/graphql/schema"
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
) )
const ( const (
AL_QUERY int = iota + 1 expComment = iota + 1
AL_VARS expVar
expQuery
) )
type Item struct { type Item struct {
Name string Name string
key string key string
Query string Query string
Vars json.RawMessage Vars string
Comment string Comment string
} }
@ -126,121 +128,101 @@ func (al *List) Set(vars []byte, query, comment string) error {
return errors.New("empty query") return errors.New("empty query")
} }
var q string
for i := 0; i < len(query); i++ {
c := query[i]
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
q = query
break
} else if c == '{' {
q = "query " + query
break
}
}
al.saveChan <- Item{ al.saveChan <- Item{
Comment: comment, Comment: comment,
Query: q, Query: query,
Vars: vars, Vars: string(vars),
} }
return nil return nil
} }
func (al *List) Load() ([]Item, error) { func (al *List) Load() ([]Item, error) {
var list []Item
varString := "variables"
b, err := ioutil.ReadFile(al.filepath) b, err := ioutil.ReadFile(al.filepath)
if err != nil { if err != nil {
return list, err return nil, err
} }
if len(b) == 0 { return parse(string(b), al.filepath)
return list, nil }
func parse(b string, filename string) ([]Item, error) {
var items []Item
var s scanner.Scanner
s.Init(strings.NewReader(b))
s.Filename = filename
s.Mode ^= scanner.SkipComments
var op, sp scanner.Position
var item Item
newComment := false
st := expComment
for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() {
txt := s.TokenText()
switch {
case strings.HasPrefix(txt, "/*"):
if st == expQuery {
v := b[sp.Offset:s.Pos().Offset]
item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
items = append(items, item)
}
item = Item{Comment: strings.TrimSpace(txt[2 : len(txt)-2])}
sp = s.Pos()
st = expComment
newComment = true
case !newComment && strings.HasPrefix(txt, "#"):
if st == expQuery {
v := b[sp.Offset:s.Pos().Offset]
item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
items = append(items, item)
}
item = Item{}
sp = s.Pos()
st = expComment
case strings.HasPrefix(txt, "variables"):
if st == expComment {
v := b[sp.Offset:s.Pos().Offset]
item.Comment = strings.TrimSpace(v[:strings.IndexByte(v, '\n')])
}
sp = s.Pos()
st = expVar
case isGraphQL(txt):
if st == expVar {
v := b[sp.Offset:s.Pos().Offset]
item.Vars = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
}
sp = op
st = expQuery
}
op = s.Pos()
} }
var comment bytes.Buffer if st == expQuery {
var varBytes []byte v := b[sp.Offset:s.Pos().Offset]
item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
itemMap := make(map[string]struct{}) items = append(items, item)
s, e, c := 0, 0, 0
ty := 0
for {
fq := false
if c == 0 && b[e] == '#' {
s = e
for e < len(b) && b[e] != '\n' {
e++
}
if (e - s) > 2 {
comment.Write(b[(s + 1):(e + 1)])
}
}
if e >= len(b) {
break
}
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
if c == 0 {
s = e
}
ty = AL_QUERY
} else if matchPrefix(b, e, varString) {
if c == 0 {
s = e + len(varString) + 1
}
ty = AL_VARS
} else if b[e] == '{' {
c++
} else if b[e] == '}' {
c--
if c == 0 {
if ty == AL_QUERY {
fq = true
} else if ty == AL_VARS {
varBytes = b[s:(e + 1)]
}
ty = 0
}
}
if fq {
query := string(b[s:(e + 1)])
name := QueryName(query)
key := strings.ToLower(name)
if _, ok := itemMap[key]; !ok {
v := Item{
Name: name,
key: key,
Query: query,
Vars: varBytes,
Comment: comment.String(),
}
list = append(list, v)
comment.Reset()
}
varBytes = nil
}
e++
if e >= len(b) {
break
}
} }
return list, nil for i := range items {
items[i].Name = QueryName(items[i].Query)
items[i].key = strings.ToLower(items[i].Name)
}
return items, nil
}
func isGraphQL(s string) bool {
return strings.HasPrefix(s, "query") ||
strings.HasPrefix(s, "mutation") ||
strings.HasPrefix(s, "subscription")
} }
func (al *List) save(item Item) error { func (al *List) save(item Item) error {
@ -297,57 +279,39 @@ func (al *List) save(item Item) error {
return strings.Compare(list[i].key, list[j].key) == -1 return strings.Compare(list[i].key, list[j].key) == -1
}) })
for _, v := range list { for i, v := range list {
cmtLines := strings.Split(v.Comment, "\n") var vars string
if v.Vars != "" {
i := 0
for _, c := range cmtLines {
if c = strings.TrimSpace(c); c == "" {
continue
}
_, err := f.WriteString(fmt.Sprintf("# %s\n", c))
if err != nil {
return err
}
i++
}
if i != 0 {
if _, err := f.WriteString("\n"); err != nil {
return err
}
} else {
if _, err := f.WriteString(fmt.Sprintf("# Query named %s\n\n", v.Name)); err != nil {
return err
}
}
if len(v.Vars) != 0 && !bytes.Equal(v.Vars, []byte("{}")) {
buf.Reset() buf.Reset()
if err := jsn.Clear(&buf, []byte(v.Vars)); err != nil {
if err := jsn.Clear(&buf, v.Vars); err != nil { continue
return fmt.Errorf("failed to clean vars: %w", err)
} }
vj := json.RawMessage(buf.Bytes()) vj := json.RawMessage(buf.Bytes())
vj, err = json.MarshalIndent(vj, "", " ") if vj, err = json.MarshalIndent(vj, "", " "); err != nil {
if err != nil { continue
return fmt.Errorf("failed to marshal vars: %w", err)
} }
vars = string(vj)
}
list[i].Vars = vars
list[i].Comment = strings.TrimSpace(v.Comment)
}
_, err = f.WriteString(fmt.Sprintf("variables %s\n\n", vj)) for _, v := range list {
if v.Comment != "" {
f.WriteString(fmt.Sprintf("/* %s */\n\n", v.Comment))
} else {
f.WriteString(fmt.Sprintf("/* %s */\n\n", v.Name))
}
if v.Vars != "" {
_, err = f.WriteString(fmt.Sprintf("variables %s\n\n", v.Vars))
if err != nil { if err != nil {
return err return err
} }
} }
if v.Query[0] == '{' { _, err = f.WriteString(fmt.Sprintf("%s\n\n", v.Query))
_, err = f.WriteString(fmt.Sprintf("query %s\n\n", v.Query))
} else {
_, err = f.WriteString(fmt.Sprintf("%s\n\n", v.Query))
}
if err != nil { if err != nil {
return err return err
} }

View File

@ -82,3 +82,160 @@ func TestGQLName5(t *testing.T) {
t.Fatal("Name should be empty, not ", name) t.Fatal("Name should be empty, not ", name)
} }
} }
func TestParse1(t *testing.T) {
var al = `
# Hello world
variables {
"data": {
"slug": "",
"body": "",
"post": {
"connect": {
"slug": ""
}
}
}
}
mutation createComment {
comment(insert: $data) {
slug
body
createdAt: created_at
totalVotes: cached_votes_total
totalReplies: cached_replies_total
vote: comment_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}
# Query named createPost
query createPost {
post(insert: $data) {
slug
body
published
createdAt: created_at
totalVotes: cached_votes_total
totalComments: cached_comments_total
vote: post_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}`
_, err := parse(al, "allow.list")
if err != nil {
t.Fatal(err)
}
}
func TestParse2(t *testing.T) {
var al = `
/* Hello world */
variables {
"data": {
"slug": "",
"body": "",
"post": {
"connect": {
"slug": ""
}
}
}
}
mutation createComment {
comment(insert: $data) {
slug
body
createdAt: created_at
totalVotes: cached_votes_total
totalReplies: cached_replies_total
vote: comment_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}
/*
Query named createPost
*/
variables {
"data": {
"thread": {
"connect": {
"slug": ""
}
},
"slug": "",
"published": false,
"body": ""
}
}
query createPost {
post(insert: $data) {
slug
body
published
createdAt: created_at
totalVotes: cached_votes_total
totalComments: cached_comments_total
vote: post_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}`
_, err := parse(al, "allow.list")
if err != nil {
t.Fatal(err)
}
}

View File

@ -1,4 +1,3 @@
//nolint:errcheck
package psql package psql
import ( import (
@ -112,15 +111,15 @@ func (c *compilerContext) renderColumnSearchRank(sel *qcode.Select, ti *DBTableI
c.renderComma(columnsRendered) c.renderComma(columnsRendered)
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name) //c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_rank(`) _, _ = io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 { if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery(`) _, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
} else { } else {
io.WriteString(c.w, `, to_tsquery(`) _, _ = io.WriteString(c.w, `, to_tsquery(`)
} }
c.renderValueExp(Param{Name: arg.Val, Type: "string"}) c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
io.WriteString(c.w, `))`) _, _ = io.WriteString(c.w, `))`)
alias(c.w, col.Name) alias(c.w, col.Name)
return nil return nil
@ -137,15 +136,15 @@ func (c *compilerContext) renderColumnSearchHeadline(sel *qcode.Select, ti *DBTa
c.renderComma(columnsRendered) c.renderComma(columnsRendered)
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name) //c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_headline(`) _, _ = io.WriteString(c.w, `ts_headline(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 { if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery(`) _, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
} else { } else {
io.WriteString(c.w, `, to_tsquery(`) _, _ = io.WriteString(c.w, `, to_tsquery(`)
} }
c.renderValueExp(Param{Name: arg.Val, Type: "string"}) c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
io.WriteString(c.w, `))`) _, _ = io.WriteString(c.w, `))`)
alias(c.w, col.Name) alias(c.w, col.Name)
return nil return nil
@ -157,9 +156,9 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
} }
c.renderComma(columnsRendered) c.renderComma(columnsRendered)
io.WriteString(c.w, `(`) _, _ = io.WriteString(c.w, `(`)
squoted(c.w, ti.Name) squoted(c.w, ti.Name)
io.WriteString(c.w, ` :: text)`) _, _ = io.WriteString(c.w, ` :: text)`)
alias(c.w, col.Name) alias(c.w, col.Name)
return nil return nil
@ -169,9 +168,9 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
pl := funcPrefixLen(c.schema.fm, col.Name) pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 { // if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) // //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`) // _, _ = io.WriteString(c.w, `'`)
// io.WriteString(c.w, col.Name) // _, _ = io.WriteString(c.w, col.Name)
// io.WriteString(c.w, ` not defined'`) // _, _ = io.WriteString(c.w, ` not defined'`)
// alias(c.w, col.Name) // alias(c.w, col.Name)
// } // }
@ -190,10 +189,10 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
c.renderComma(columnsRendered) c.renderComma(columnsRendered)
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Name, cn, col.Name) //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Name, cn, col.Name)
io.WriteString(c.w, fn) _, _ = io.WriteString(c.w, fn)
io.WriteString(c.w, `(`) _, _ = io.WriteString(c.w, `(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
io.WriteString(c.w, `)`) _, _ = io.WriteString(c.w, `)`)
alias(c.w, col.Name) alias(c.w, col.Name)
return nil return nil
@ -201,7 +200,7 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
func (c *compilerContext) renderComma(columnsRendered int) { func (c *compilerContext) renderComma(columnsRendered int) {
if columnsRendered != 0 { if columnsRendered != 0 {
io.WriteString(c.w, `, `) _, _ = io.WriteString(c.w, `, `)
} }
} }

View File

@ -25,7 +25,7 @@ func (c *compilerContext) renderInsert(
if insert[0] == '[' { if insert[0] == '[' {
io.WriteString(c.w, `json_array_elements(`) io.WriteString(c.w, `json_array_elements(`)
} }
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"}) c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
io.WriteString(c.w, ` :: json`) io.WriteString(c.w, ` :: json`)
if insert[0] == '[' { if insert[0] == '[' {
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)

View File

@ -0,0 +1,61 @@
package psql
import (
"io"
)
func (md *Metadata) RenderVar(w io.Writer, vv string) {
f, s := -1, 0
for i := range vv {
v := vv[i]
switch {
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
if (i - s) > 0 {
_, _ = io.WriteString(w, vv[s:i])
}
f = i
case (v < 'a' && v > 'z') &&
(v < 'A' && v > 'Z') &&
(v < '0' && v > '9') &&
v != '_' &&
f != -1 &&
(i-f) > 1:
md.renderValueExp(w, Param{Name: vv[f+1 : i]})
s = i
f = -1
}
}
if f != -1 && (len(vv)-f) > 1 {
md.renderValueExp(w, Param{Name: vv[f+1:]})
} else {
_, _ = io.WriteString(w, vv[s:])
}
}
func (md *Metadata) renderValueExp(w io.Writer, p Param) {
_, _ = io.WriteString(w, `$`)
if v, ok := md.pindex[p.Name]; ok {
int32String(w, int32(v))
} else {
md.params = append(md.params, p)
n := len(md.params)
if md.pindex == nil {
md.pindex = make(map[string]int)
}
md.pindex[p.Name] = n
int32String(w, int32(n))
}
}
func (md Metadata) Skipped() uint32 {
return md.skipped
}
func (md Metadata) Params() []Param {
return md.params
}

View File

@ -432,11 +432,11 @@ func (c *compilerContext) renderInsertUpdateColumns(
val := root.PresetMap[cn] val := root.PresetMap[cn]
switch { switch {
case ok && len(val) > 1 && val[0] == '$': case ok && len(val) > 1 && val[0] == '$':
c.renderValueExp(Param{Name: val[1:], Type: col.Type}) c.md.renderValueExp(c.w, Param{Name: val[1:], Type: col.Type})
case ok && strings.HasPrefix(val, "sql:"): case ok && strings.HasPrefix(val, "sql:"):
io.WriteString(c.w, `(`) io.WriteString(c.w, `(`)
c.renderVar(val[4:], c.renderValueExp) c.md.RenderVar(c.w, val[4:])
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
case ok: case ok:

View File

@ -25,8 +25,8 @@ type Param struct {
} }
type Metadata struct { type Metadata struct {
Skipped uint32 skipped uint32
Params []Param params []Param
pindex map[string]int pindex map[string]int
} }
@ -80,26 +80,30 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (Metadata, []byte
} }
func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) { func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
return co.CompileWithMetadata(w, qc, vars, Metadata{})
}
func (co *Compiler) CompileWithMetadata(w io.Writer, qc *qcode.QCode, vars Variables, md Metadata) (Metadata, error) {
md.skipped = 0
if qc == nil { if qc == nil {
return Metadata{}, fmt.Errorf("qcode is nil") return md, fmt.Errorf("qcode is nil")
} }
switch qc.Type { switch qc.Type {
case qcode.QTQuery: case qcode.QTQuery:
return co.compileQuery(w, qc, vars) return co.compileQueryWithMetadata(w, qc, vars, md)
case qcode.QTInsert, case qcode.QTInsert,
qcode.QTUpdate, qcode.QTUpdate,
qcode.QTDelete, qcode.QTDelete,
qcode.QTUpsert: qcode.QTUpsert:
return co.compileMutation(w, qc, vars) return co.compileMutation(w, qc, vars)
default:
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
} }
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
}
func (co *Compiler) compileQuery(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
return co.compileQueryWithMetadata(w, qc, vars, Metadata{})
} }
func (co *Compiler) compileQueryWithMetadata( func (co *Compiler) compileQueryWithMetadata(
@ -176,7 +180,7 @@ func (co *Compiler) compileQueryWithMetadata(
} }
for _, cid := range sel.Children { for _, cid := range sel.Children {
if hasBit(c.md.Skipped, uint32(cid)) { if hasBit(c.md.skipped, uint32(cid)) {
continue continue
} }
child := &c.s[cid] child := &c.s[cid]
@ -354,7 +358,7 @@ func (c *compilerContext) initSelect(sel *qcode.Select, ti *DBTableInfo, vars Va
if _, ok := colmap[rel.Left.Col]; !ok { if _, ok := colmap[rel.Left.Col]; !ok {
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col}) cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col})
colmap[rel.Left.Col] = struct{}{} colmap[rel.Left.Col] = struct{}{}
c.md.Skipped |= (1 << uint(id)) c.md.skipped |= (1 << uint(id))
} }
default: default:
@ -622,7 +626,7 @@ func (c *compilerContext) renderJoinColumns(sel *qcode.Select, ti *DBTableInfo,
i := colsRendered i := colsRendered
for _, id := range sel.Children { for _, id := range sel.Children {
if hasBit(c.md.Skipped, uint32(id)) { if hasBit(c.md.skipped, uint32(id)) {
continue continue
} }
childSel := &c.s[id] childSel := &c.s[id]
@ -804,7 +808,7 @@ func (c *compilerContext) renderCursorCTE(sel *qcode.Select) error {
quoted(c.w, ob.Col) quoted(c.w, ob.Col)
} }
io.WriteString(c.w, ` FROM string_to_array(`) io.WriteString(c.w, ` FROM string_to_array(`)
c.renderValueExp(Param{Name: "cursor", Type: "json"}) c.md.renderValueExp(c.w, Param{Name: "cursor", Type: "json"})
io.WriteString(c.w, `, ',') as a) `) io.WriteString(c.w, `, ',') as a) `)
return nil return nil
} }
@ -1102,7 +1106,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, ti *DBTableInfo) error {
} else { } else {
io.WriteString(c.w, `) @@ to_tsquery(`) io.WriteString(c.w, `) @@ to_tsquery(`)
} }
c.renderValueExp(Param{Name: ex.Val, Type: "string"}) c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: "string"})
io.WriteString(c.w, `))`) io.WriteString(c.w, `))`)
return nil return nil
@ -1191,7 +1195,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
switch { switch {
case ok && strings.HasPrefix(val, "sql:"): case ok && strings.HasPrefix(val, "sql:"):
io.WriteString(c.w, `(`) io.WriteString(c.w, `(`)
c.renderVar(val[4:], c.renderValueExp) c.md.RenderVar(c.w, val[4:])
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
case ok: case ok:
@ -1199,7 +1203,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
case ex.Op == qcode.OpIn || ex.Op == qcode.OpNotIn: case ex.Op == qcode.OpIn || ex.Op == qcode.OpNotIn:
io.WriteString(c.w, `(ARRAY(SELECT json_array_elements_text(`) io.WriteString(c.w, `(ARRAY(SELECT json_array_elements_text(`)
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: true}) c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: true})
io.WriteString(c.w, `))`) io.WriteString(c.w, `))`)
io.WriteString(c.w, ` :: `) io.WriteString(c.w, ` :: `)
@ -1208,7 +1212,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
return return
default: default:
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: false}) c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: false})
} }
case qcode.ValRef: case qcode.ValRef:
@ -1222,54 +1226,6 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type) io.WriteString(c.w, col.Type)
} }
func (c *compilerContext) renderValueExp(p Param) {
io.WriteString(c.w, `$`)
if v, ok := c.md.pindex[p.Name]; ok {
int32String(c.w, int32(v))
} else {
c.md.Params = append(c.md.Params, p)
n := len(c.md.Params)
if c.md.pindex == nil {
c.md.pindex = make(map[string]int)
}
c.md.pindex[p.Name] = n
int32String(c.w, int32(n))
}
}
func (c *compilerContext) renderVar(vv string, fn func(Param)) {
f, s := -1, 0
for i := range vv {
v := vv[i]
switch {
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
if (i - s) > 0 {
io.WriteString(c.w, vv[s:i])
}
f = i
case (v < 'a' && v > 'z') &&
(v < 'A' && v > 'Z') &&
(v < '0' && v > '9') &&
v != '_' &&
f != -1 &&
(i-f) > 1:
fn(Param{Name: vv[f+1 : i]})
s = i
f = -1
}
}
if f != -1 && (len(vv)-f) > 1 {
fn(Param{Name: vv[f+1:]})
} else {
io.WriteString(c.w, vv[s:])
}
}
func funcPrefixLen(fm map[string]*DBFunction, fn string) int { func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch { switch {
case strings.HasPrefix(fn, "avg_"): case strings.HasPrefix(fn, "avg_"):

View File

@ -307,6 +307,80 @@ func multiRoot(t *testing.T) {
compileGQLToPSQL(t, gql, nil, "user") compileGQLToPSQL(t, gql, nil, "user")
} }
func withFragment1(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields2 on user {
first_name
last_name
}`
compileGQLToPSQL(t, gql, nil, "anon")
}
func withFragment2(t *testing.T) {
gql := `
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}`
compileGQLToPSQL(t, gql, nil, "anon")
}
func withFragment3(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}
query {
users {
...userFields2
created_at
...userFields1
}
}
`
compileGQLToPSQL(t, gql, nil, "anon")
}
func withCursor(t *testing.T) { func withCursor(t *testing.T) {
gql := `query { gql := `query {
Products( Products(
@ -400,6 +474,9 @@ func TestCompileQuery(t *testing.T) {
t.Run("queryWithVariables", queryWithVariables) t.Run("queryWithVariables", queryWithVariables)
t.Run("withWhereOnRelations", withWhereOnRelations) t.Run("withWhereOnRelations", withWhereOnRelations)
t.Run("multiRoot", multiRoot) t.Run("multiRoot", multiRoot)
t.Run("withFragment1", withFragment1)
t.Run("withFragment2", withFragment2)
t.Run("withFragment3", withFragment3)
t.Run("jsonColumnAsTable", jsonColumnAsTable) t.Run("jsonColumnAsTable", jsonColumnAsTable)
t.Run("withCursor", withCursor) t.Run("withCursor", withCursor)
t.Run("nullForAuthRequiredInAnon", nullForAuthRequiredInAnon) t.Run("nullForAuthRequiredInAnon", nullForAuthRequiredInAnon)

View File

@ -86,6 +86,12 @@ SELECT jsonb_build_object('product', "__sj_0"."json") as "__root" FROM (SELECT t
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")) AND ((("products"."price") > '3' :: numeric(7,2))))) LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0" SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")) AND ((("products"."price") > '3' :: numeric(7,2))))) LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/multiRoot === RUN TestCompileQuery/multiRoot
SELECT jsonb_build_object('customer', "__sj_0"."json", 'user', "__sj_1"."json", 'product', "__sj_2"."json") as "__root" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "products_2"."id" AS "id", "products_2"."name" AS "name", "__sj_3"."json" AS "customers", "__sj_4"."json" AS "customer" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE (((("products"."price") > '0' :: numeric(7,2)) AND (("products"."price") < '8' :: numeric(7,2)))) LIMIT ('1') :: integer) AS "products_2" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_4".*) AS "json"FROM (SELECT "customers_4"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('1') :: integer) AS "customers_4") AS "__sr_4") AS "__sj_4" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_3"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "customers_3"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_3") AS "__sr_3") AS "__sj_3") AS "__sj_3" ON ('true')) AS "__sr_2") AS "__sj_2", (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "users_1"."id" AS "id", "users_1"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_1") AS "__sr_1") AS "__sj_1", (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "customers_0"."id" AS "id" FROM (SELECT "customers"."id" FROM "customers" LIMIT ('1') :: integer) AS "customers_0") AS "__sr_0") AS "__sj_0" SELECT jsonb_build_object('customer', "__sj_0"."json", 'user', "__sj_1"."json", 'product', "__sj_2"."json") as "__root" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "products_2"."id" AS "id", "products_2"."name" AS "name", "__sj_3"."json" AS "customers", "__sj_4"."json" AS "customer" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE (((("products"."price") > '0' :: numeric(7,2)) AND (("products"."price") < '8' :: numeric(7,2)))) LIMIT ('1') :: integer) AS "products_2" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_4".*) AS "json"FROM (SELECT "customers_4"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('1') :: integer) AS "customers_4") AS "__sr_4") AS "__sj_4" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_3"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "customers_3"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_3") AS "__sr_3") AS "__sj_3") AS "__sj_3" ON ('true')) AS "__sr_2") AS "__sj_2", (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "users_1"."id" AS "id", "users_1"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_1") AS "__sr_1") AS "__sj_1", (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "customers_0"."id" AS "id" FROM (SELECT "customers"."id" FROM "customers" LIMIT ('1') :: integer) AS "customers_0") AS "__sr_0") AS "__sj_0"
=== RUN TestCompileQuery/withFragment1
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/withFragment2
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/withFragment3
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/jsonColumnAsTable === RUN TestCompileQuery/jsonColumnAsTable
SELECT jsonb_build_object('products', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "__sj_1"."json" AS "tag_count" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "tag_count_1"."count" AS "count", "__sj_2"."json" AS "tags" FROM (SELECT "tag_count"."count", "tag_count"."tag_id" FROM "products", json_to_recordset("products"."tag_count") AS "tag_count"(tag_id bigint, count int) WHERE ((("products"."id") = ("products_0"."id"))) LIMIT ('1') :: integer) AS "tag_count_1" LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "tags_2"."name" AS "name" FROM (SELECT "tags"."name" FROM "tags" WHERE ((("tags"."id") = ("tag_count_1"."tag_id"))) LIMIT ('20') :: integer) AS "tags_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1" ON ('true')) AS "__sr_0") AS "__sj_0") AS "__sj_0" SELECT jsonb_build_object('products', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "__sj_1"."json" AS "tag_count" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "tag_count_1"."count" AS "count", "__sj_2"."json" AS "tags" FROM (SELECT "tag_count"."count", "tag_count"."tag_id" FROM "products", json_to_recordset("products"."tag_count") AS "tag_count"(tag_id bigint, count int) WHERE ((("products"."id") = ("products_0"."id"))) LIMIT ('1') :: integer) AS "tag_count_1" LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "tags_2"."name" AS "name" FROM (SELECT "tags"."name" FROM "tags" WHERE ((("tags"."id") = ("tag_count_1"."tag_id"))) LIMIT ('20') :: integer) AS "tags_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1" ON ('true')) AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/withCursor === RUN TestCompileQuery/withCursor
@ -117,6 +123,9 @@ SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coa
--- PASS: TestCompileQuery/queryWithVariables (0.00s) --- PASS: TestCompileQuery/queryWithVariables (0.00s)
--- PASS: TestCompileQuery/withWhereOnRelations (0.00s) --- PASS: TestCompileQuery/withWhereOnRelations (0.00s)
--- PASS: TestCompileQuery/multiRoot (0.00s) --- PASS: TestCompileQuery/multiRoot (0.00s)
--- PASS: TestCompileQuery/withFragment1 (0.00s)
--- PASS: TestCompileQuery/withFragment2 (0.00s)
--- PASS: TestCompileQuery/withFragment3 (0.00s)
--- PASS: TestCompileQuery/jsonColumnAsTable (0.00s) --- PASS: TestCompileQuery/jsonColumnAsTable (0.00s)
--- PASS: TestCompileQuery/withCursor (0.00s) --- PASS: TestCompileQuery/withCursor (0.00s)
--- PASS: TestCompileQuery/nullForAuthRequiredInAnon (0.00s) --- PASS: TestCompileQuery/nullForAuthRequiredInAnon (0.00s)
@ -151,4 +160,4 @@ WITH "_sg_input" AS (SELECT $1 :: json AS j), "_x_users" AS (SELECT * FROM (VALU
--- PASS: TestCompileUpdate/nestedUpdateOneToOneWithConnect (0.00s) --- PASS: TestCompileUpdate/nestedUpdateOneToOneWithConnect (0.00s)
--- PASS: TestCompileUpdate/nestedUpdateOneToOneWithDisconnect (0.00s) --- PASS: TestCompileUpdate/nestedUpdateOneToOneWithDisconnect (0.00s)
PASS PASS
ok github.com/dosco/super-graph/core/internal/psql (cached) ok github.com/dosco/super-graph/core/internal/psql 0.374s

View File

@ -22,7 +22,7 @@ func (c *compilerContext) renderUpdate(
} }
io.WriteString(c.w, `WITH "_sg_input" AS (SELECT `) io.WriteString(c.w, `WITH "_sg_input" AS (SELECT `)
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"}) c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
// io.WriteString(c.w, qc.ActionVar) // io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, ` :: json AS j)`) io.WriteString(c.w, ` :: json AS j)`)

View File

@ -11,15 +11,18 @@ import (
var ( var (
queryToken = []byte("query") queryToken = []byte("query")
mutationToken = []byte("mutation") mutationToken = []byte("mutation")
fragmentToken = []byte("fragment")
subscriptionToken = []byte("subscription") subscriptionToken = []byte("subscription")
onToken = []byte("on")
trueToken = []byte("true") trueToken = []byte("true")
falseToken = []byte("false") falseToken = []byte("false")
quotesToken = []byte(`'"`) quotesToken = []byte(`'"`)
signsToken = []byte(`+-`) signsToken = []byte(`+-`)
punctuatorToken = []byte(`!():=[]{|}`)
spreadToken = []byte(`...`) spreadToken = []byte(`...`)
digitToken = []byte(`0123456789`) digitToken = []byte(`0123456789`)
dotToken = []byte(`.`) dotToken = []byte(`.`)
punctuatorToken = `!():=[]{|}`
) )
// Pos represents a byte position in the original input text from which // Pos represents a byte position in the original input text from which
@ -43,6 +46,8 @@ const (
itemName itemName
itemQuery itemQuery
itemMutation itemMutation
itemFragment
itemOn
itemSub itemSub
itemPunctuator itemPunctuator
itemArgsOpen itemArgsOpen
@ -263,11 +268,11 @@ func lexRoot(l *lexer) stateFn {
l.backup() l.backup()
return lexString return lexString
case r == '.': case r == '.':
if len(l.input) >= 3 { l.acceptRun(dotToken)
if equals(l.input, 0, 3, spreadToken) { s, e := l.current()
l.emit(itemSpread) if equals(l.input, s, e, spreadToken) {
return lexRoot l.emit(itemSpread)
} return lexRoot
} }
fallthrough // '.' can start a number. fallthrough // '.' can start a number.
case r == '+' || r == '-' || ('0' <= r && r <= '9'): case r == '+' || r == '-' || ('0' <= r && r <= '9'):
@ -299,10 +304,14 @@ func lexName(l *lexer) stateFn {
switch { switch {
case equals(l.input, s, e, queryToken): case equals(l.input, s, e, queryToken):
l.emitL(itemQuery) l.emitL(itemQuery)
case equals(l.input, s, e, fragmentToken):
l.emitL(itemFragment)
case equals(l.input, s, e, mutationToken): case equals(l.input, s, e, mutationToken):
l.emitL(itemMutation) l.emitL(itemMutation)
case equals(l.input, s, e, subscriptionToken): case equals(l.input, s, e, subscriptionToken):
l.emitL(itemSub) l.emitL(itemSub)
case equals(l.input, s, e, onToken):
l.emitL(itemOn)
case equals(l.input, s, e, trueToken): case equals(l.input, s, e, trueToken):
l.emitL(itemBoolVal) l.emitL(itemBoolVal)
case equals(l.input, s, e, falseToken): case equals(l.input, s, e, falseToken):
@ -396,31 +405,11 @@ func isAlphaNumeric(r rune) bool {
} }
func equals(b []byte, s Pos, e Pos, val []byte) bool { func equals(b []byte, s Pos, e Pos, val []byte) bool {
n := 0 return bytes.EqualFold(b[s:e], val)
for i := s; i < e; i++ {
if n >= len(val) {
return true
}
switch {
case b[i] >= 'A' && b[i] <= 'Z' && ('a'+(b[i]-'A')) != val[n]:
return false
case b[i] != val[n]:
return false
}
n++
}
return true
} }
func contains(b []byte, s Pos, e Pos, val []byte) bool { func contains(b []byte, s Pos, e Pos, chars string) bool {
for i := s; i < e; i++ { return bytes.ContainsAny(b[s:e], chars)
for n := 0; n < len(val); n++ {
if b[i] == val[n] {
return true
}
}
}
return false
} }
func lowercase(b []byte, s Pos, e Pos) { func lowercase(b []byte, s Pos, e Pos) {

View File

@ -1,12 +1,12 @@
package qcode package qcode
import ( import (
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"hash/maphash"
"sync" "sync"
"unsafe" "unsafe"
"github.com/dosco/super-graph/core/internal/util"
) )
var ( var (
@ -35,8 +35,7 @@ const (
NodeVar NodeVar
) )
type Operation struct { type SelectionSet struct {
Type parserType
Name string Name string
Args []Arg Args []Arg
argsA [10]Arg argsA [10]Arg
@ -44,12 +43,29 @@ type Operation struct {
fieldsA [10]Field fieldsA [10]Field
} }
type Operation struct {
Type parserType
SelectionSet
}
var zeroOperation = Operation{} var zeroOperation = Operation{}
func (o *Operation) Reset() { func (o *Operation) Reset() {
*o = zeroOperation *o = zeroOperation
} }
type Fragment struct {
Name string
On string
SelectionSet
}
var zeroFragment = Fragment{}
func (f *Fragment) Reset() {
*f = zeroFragment
}
type Field struct { type Field struct {
ID int32 ID int32
ParentID int32 ParentID int32
@ -82,6 +98,8 @@ func (n *Node) Reset() {
} }
type Parser struct { type Parser struct {
frags map[uint64]*Fragment
h maphash.Hash
input []byte // the string being scanned input []byte // the string being scanned
pos int pos int
items []item items []item
@ -96,12 +114,194 @@ var opPool = sync.Pool{
New: func() interface{} { return new(Operation) }, New: func() interface{} { return new(Operation) },
} }
var fragPool = sync.Pool{
New: func() interface{} { return new(Fragment) },
}
var lexPool = sync.Pool{ var lexPool = sync.Pool{
New: func() interface{} { return new(lexer) }, New: func() interface{} { return new(lexer) },
} }
func Parse(gql []byte) (*Operation, error) { func Parse(gql []byte) (*Operation, error) {
return parseSelectionSet(gql) var err error
if len(gql) == 0 {
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer)
l.Reset()
defer lexPool.Put(l)
if err = lex(l, gql); err != nil {
return nil, err
}
p := &Parser{
input: l.input,
pos: -1,
items: l.items,
}
op := opPool.Get().(*Operation)
op.Reset()
op.Fields = op.fieldsA[:0]
s := -1
qf := false
for {
if p.peek(itemEOF) {
p.ignore()
break
}
if p.peek(itemFragment) {
p.ignore()
if err = p.parseFragment(op); err != nil {
return nil, err
}
} else {
if !qf && p.peek(itemQuery, itemMutation, itemSub, itemObjOpen) {
s = p.pos
qf = true
}
p.ignore()
}
}
p.reset(s)
if err := p.parseOp(op); err != nil {
return nil, err
}
return op, nil
}
func (p *Parser) parseFragment(op *Operation) error {
frag := fragPool.Get().(*Fragment)
frag.Reset()
frag.Fields = frag.fieldsA[:0]
frag.Args = frag.argsA[:0]
if p.peek(itemName) {
frag.Name = p.val(p.next())
}
if p.peek(itemOn) {
p.ignore()
} else {
return errors.New("fragment: missing 'on' keyword")
}
if p.peek(itemName) {
frag.On = p.vall(p.next())
} else {
return errors.New("fragment: missing table name after 'on' keyword")
}
if p.peek(itemObjOpen) {
p.ignore()
} else {
return fmt.Errorf("fragment: expecting a '{', got: %s", p.next())
}
if err := p.parseSelectionSet(&frag.SelectionSet); err != nil {
return fmt.Errorf("fragment: %v", err)
}
if p.frags == nil {
p.frags = make(map[uint64]*Fragment)
}
_, _ = p.h.WriteString(frag.Name)
k := p.h.Sum64()
p.h.Reset()
p.frags[k] = frag
return nil
}
func (p *Parser) parseOp(op *Operation) error {
var err error
var typeSet bool
if p.peek(itemQuery, itemMutation, itemSub) {
err = p.parseOpTypeAndArgs(op)
if err != nil {
return fmt.Errorf("%s: %v", op.Type, err)
}
typeSet = true
}
if p.peek(itemObjOpen) {
p.ignore()
if !typeSet {
op.Type = opQuery
}
for {
if p.peek(itemEOF, itemFragment) {
p.ignore()
break
}
err = p.parseSelectionSet(&op.SelectionSet)
if err != nil {
return fmt.Errorf("%s: %v", op.Type, err)
}
}
} else {
return fmt.Errorf("expecting a query, mutation or subscription, got: %s", p.next())
}
return nil
}
func (p *Parser) parseOpTypeAndArgs(op *Operation) error {
item := p.next()
switch item._type {
case itemQuery:
op.Type = opQuery
case itemMutation:
op.Type = opMutate
case itemSub:
op.Type = opSub
}
op.Args = op.argsA[:0]
var err error
if p.peek(itemName) {
op.Name = p.val(p.next())
}
if p.peek(itemArgsOpen) {
p.ignore()
op.Args, err = p.parseOpParams(op.Args)
if err != nil {
return err
}
}
return nil
}
func (p *Parser) parseSelectionSet(selset *SelectionSet) error {
var err error
selset.Fields, err = p.parseFields(selset.Fields)
if err != nil {
return err
}
return nil
} }
func ParseArgValue(argVal string) (*Node, error) { func ParseArgValue(argVal string) (*Node, error) {
@ -123,216 +323,111 @@ func ParseArgValue(argVal string) (*Node, error) {
return op, err return op, err
} }
func parseSelectionSet(gql []byte) (*Operation, error) {
var err error
if len(gql) == 0 {
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer)
l.Reset()
if err = lex(l, gql); err != nil {
return nil, err
}
p := &Parser{
input: l.input,
pos: -1,
items: l.items,
}
var op *Operation
if p.peek(itemObjOpen) {
p.ignore()
op, err = p.parseQueryOp()
} else {
op, err = p.parseOp()
}
if err != nil {
return nil, err
}
if p.peek(itemObjClose) {
p.ignore()
} else {
return nil, fmt.Errorf("operation missing closing '}'")
}
if !p.peek(itemEOF) {
p.ignore()
return nil, fmt.Errorf("invalid '%s' found after closing '}'", p.current())
}
lexPool.Put(l)
return op, err
}
func (p *Parser) next() item {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return item{_type: itemEOF}
}
p.pos = n
return p.items[p.pos]
}
func (p *Parser) ignore() {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return
}
p.pos = n
}
func (p *Parser) current() string {
item := p.items[p.pos]
return b2s(p.input[item.pos:item.end])
}
func (p *Parser) peek(types ...itemType) bool {
n := p.pos + 1
// if p.items[n]._type == itemEOF {
// return false
// }
if n >= len(p.items) {
return false
}
for i := 0; i < len(types); i++ {
if p.items[n]._type == types[i] {
return true
}
}
return false
}
func (p *Parser) parseOp() (*Operation, error) {
if !p.peek(itemQuery, itemMutation, itemSub) {
err := errors.New("expecting a query, mutation or subscription")
return nil, err
}
item := p.next()
op := opPool.Get().(*Operation)
op.Reset()
switch item._type {
case itemQuery:
op.Type = opQuery
case itemMutation:
op.Type = opMutate
case itemSub:
op.Type = opSub
}
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
var err error
if p.peek(itemName) {
op.Name = p.val(p.next())
}
if p.peek(itemArgsOpen) {
p.ignore()
op.Args, err = p.parseOpParams(op.Args)
if err != nil {
return nil, err
}
}
if p.peek(itemObjOpen) {
p.ignore()
for n := 0; n < 10; n++ {
if !p.peek(itemName) {
break
}
op.Fields, err = p.parseFields(op.Fields)
if err != nil {
return nil, err
}
}
}
return op, nil
}
func (p *Parser) parseQueryOp() (*Operation, error) {
op := opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
var err error
for n := 0; n < 10; n++ {
if !p.peek(itemName) {
break
}
op.Fields, err = p.parseFields(op.Fields)
if err != nil {
return nil, err
}
}
return op, nil
}
func (p *Parser) parseFields(fields []Field) ([]Field, error) { func (p *Parser) parseFields(fields []Field) ([]Field, error) {
st := util.NewStack() st := NewStack()
if !p.peek(itemName, itemSpread) {
return nil, fmt.Errorf("unexpected token: %s", p.peekNext())
}
for { for {
if p.peek(itemEOF) {
p.ignore()
return nil, errors.New("invalid query")
}
if p.peek(itemObjClose) {
p.ignore()
if st.Len() != 0 {
st.Pop()
continue
} else {
break
}
}
if len(fields) >= maxFields { if len(fields) >= maxFields {
return nil, fmt.Errorf("too many fields (max %d)", maxFields) return nil, fmt.Errorf("too many fields (max %d)", maxFields)
} }
if p.peek(itemEOF, itemObjClose) { isFrag := false
p.ignore()
st.Pop()
if st.Len() == 0 { if p.peek(itemSpread) {
break p.ignore()
} else { isFrag = true
continue
}
} }
if !p.peek(itemName) { if !p.peek(itemName) {
return nil, errors.New("expecting an alias or field name") if isFrag {
return nil, fmt.Errorf("expecting a fragment name, got: %s", p.next())
} else {
return nil, fmt.Errorf("expecting an alias or field name, got: %s", p.next())
}
} }
fields = append(fields, Field{ID: int32(len(fields))}) var f *Field
f := &fields[(len(fields) - 1)] if isFrag {
f.Args = f.argsA[:0] name := p.val(p.next())
f.Children = f.childrenA[:0] p.h.WriteString(name)
k := p.h.Sum64()
p.h.Reset()
// Parse the inside of the the fields () parentheses fr, ok := p.frags[k]
// in short parse the args like id, where, etc if !ok {
if err := p.parseField(f); err != nil { return nil, fmt.Errorf("no fragment named '%s' defined", name)
return nil, err }
}
n := int32(len(fields))
fields = append(fields, fr.Fields...)
for i := 0; i < len(fr.Fields); i++ {
k := (n + int32(i))
f := &fields[k]
f.ID = int32(k)
// If this is the top-level point the parent to the parent of the
// previous field.
if f.ParentID == -1 {
pid := st.Peek()
f.ParentID = pid
if f.ParentID != -1 {
fields[pid].Children = append(fields[f.ParentID].Children, f.ID)
}
// Update all the other parents id's by our new place in this new array
} else {
f.ParentID += n
}
f.Children = make([]int32, len(f.Children))
copy(f.Children, fr.Fields[i].Children)
// Update all the children which is needed.
for j := range f.Children {
f.Children[j] += n
}
}
intf := st.Peek()
if pid, ok := intf.(int32); ok {
f.ParentID = pid
fields[pid].Children = append(fields[pid].Children, f.ID)
} else { } else {
f.ParentID = -1 fields = append(fields, Field{ID: int32(len(fields))})
f = &fields[(len(fields) - 1)]
f.Args = f.argsA[:0]
f.Children = f.childrenA[:0]
// Parse the field
if err := p.parseField(f); err != nil {
return nil, err
}
if st.Len() == 0 {
f.ParentID = -1
} else {
pid := st.Peek()
f.ParentID = pid
fields[pid].Children = append(fields[pid].Children, f.ID)
}
} }
// The first opening curley brackets after this // The first opening curley brackets after this
@ -340,13 +435,6 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
if p.peek(itemObjOpen) { if p.peek(itemObjOpen) {
p.ignore() p.ignore()
st.Push(f.ID) st.Push(f.ID)
} else if p.peek(itemObjClose) {
if st.Len() == 0 {
break
} else {
continue
}
} }
} }
@ -546,6 +634,72 @@ func (p *Parser) vall(v item) string {
return b2s(p.input[v.pos:v.end]) return b2s(p.input[v.pos:v.end])
} }
func (p *Parser) peek(types ...itemType) bool {
n := p.pos + 1
l := len(types)
// if p.items[n]._type == itemEOF {
// return false
// }
if n >= len(p.items) {
return types[0] == itemEOF
}
if l == 1 {
return p.items[n]._type == types[0]
}
for i := 0; i < l; i++ {
if p.items[n]._type == types[i] {
return true
}
}
return false
}
func (p *Parser) next() item {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return item{_type: itemEOF}
}
p.pos = n
return p.items[p.pos]
}
func (p *Parser) ignore() {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return
}
p.pos = n
}
func (p *Parser) peekCurrent() string {
item := p.items[p.pos]
return b2s(p.input[item.pos:item.end])
}
func (p *Parser) peekNext() string {
item := p.items[p.pos+1]
return b2s(p.input[item.pos:item.end])
}
func (p *Parser) reset(to int) {
p.pos = to
}
func (p *Parser) fHash(name string, parentID int32) uint64 {
var b []byte
binary.LittleEndian.PutUint32(b, uint32(parentID))
p.h.WriteString(name)
p.h.Write(b)
v := p.h.Sum64()
p.h.Reset()
return v
}
func b2s(b []byte) string { func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b)) return *(*string)(unsafe.Pointer(&b))
} }
@ -579,7 +733,7 @@ func (t parserType) String() string {
case NodeList: case NodeList:
v = "node-list" v = "node-list"
} }
return fmt.Sprintf("<%s>", v) return v
} }
// type Frees struct { // type Frees struct {

View File

@ -2,8 +2,9 @@ package qcode
import ( import (
"errors" "errors"
"github.com/chirino/graphql/schema"
"testing" "testing"
"github.com/chirino/graphql/schema"
) )
func TestCompile1(t *testing.T) { func TestCompile1(t *testing.T) {
@ -120,7 +121,7 @@ updateThread {
} }
} }
} }
}` }}`
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "anon") _, err := qcompile.Compile([]byte(gql), "anon")
@ -130,6 +131,93 @@ updateThread {
} }
func TestFragmentsCompile1(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields2 on user {
first_name
last_name
}
`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "user")
if err != nil {
t.Fatal(err)
}
}
func TestFragmentsCompile2(t *testing.T) {
gql := `
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "user")
if err != nil {
t.Fatal(err)
}
}
func TestFragmentsCompile3(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}
query {
users {
...userFields2
created_at
...userFields1
}
}
`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "user")
if err != nil {
t.Fatal(err)
}
}
var gql = []byte(` var gql = []byte(`
{products( {products(
# returns only 30 items # returns only 30 items
@ -151,6 +239,29 @@ var gql = []byte(`
price price
}}`) }}`)
var gqlWithFragments = []byte(`
fragment userFields1 on user {
id
email
__typename
}
query {
users {
...userFields2
created_at
...userFields1
__typename
}
}
fragment userFields2 on user {
first_name
last_name
__typename
}`)
func BenchmarkQCompile(b *testing.B) { func BenchmarkQCompile(b *testing.B) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
@ -183,8 +294,22 @@ func BenchmarkQCompileP(b *testing.B) {
}) })
} }
func BenchmarkParse(b *testing.B) { func BenchmarkQCompileFragment(b *testing.B) {
qcompile, _ := NewCompiler(Config{})
b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, err := qcompile.Compile(gqlWithFragments, "user")
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkParse(b *testing.B) {
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@ -211,6 +336,18 @@ func BenchmarkParseP(b *testing.B) {
}) })
} }
func BenchmarkParseFragment(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, err := Parse(gqlWithFragments)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSchemaParse(b *testing.B) { func BenchmarkSchemaParse(b *testing.B) {
b.ResetTimer() b.ResetTimer()

View File

@ -419,6 +419,7 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
com.AddFilters(qc, s, role) com.AddFilters(qc, s, role)
s.Cols = make([]Column, 0, len(field.Children)) s.Cols = make([]Column, 0, len(field.Children))
cm := make(map[string]struct{})
action = QTQuery action = QTQuery
for _, cid := range field.Children { for _, cid := range field.Children {
@ -428,19 +429,27 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
continue continue
} }
var fname string
if f.Alias != "" {
fname = f.Alias
} else {
fname = f.Name
}
if _, ok := cm[fname]; ok {
continue
} else {
cm[fname] = struct{}{}
}
if len(f.Children) != 0 { if len(f.Children) != 0 {
val := f.ID | (s.ID << 16) val := f.ID | (s.ID << 16)
st.Push(val) st.Push(val)
continue continue
} }
col := Column{Name: f.Name} col := Column{Name: f.Name, FieldName: fname}
if len(f.Alias) != 0 {
col.FieldName = f.Alias
} else {
col.FieldName = f.Name
}
s.Cols = append(s.Cols, col) s.Cols = append(s.Cols, col)
} }

View File

@ -2,120 +2,95 @@ package core
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256"
"database/sql" "database/sql"
"encoding/hex"
"fmt" "fmt"
"hash/maphash"
"io" "io"
"strings" "strings"
"sync"
"github.com/dosco/super-graph/core/internal/allow" "github.com/dosco/super-graph/core/internal/allow"
"github.com/dosco/super-graph/core/internal/qcode" "github.com/dosco/super-graph/core/internal/qcode"
) )
type preparedItem struct { type query struct {
sync.Once
sd *sql.Stmt sd *sql.Stmt
ai allow.Item
qt qcode.QType
err error
st stmt st stmt
roleArg bool roleArg bool
} }
func (sg *SuperGraph) initPrepared() error { func (sg *SuperGraph) prepare(q *query, role string) {
ct := context.Background() var stmts []stmt
var err error
qb := []byte(q.ai.Query)
vars := []byte(q.ai.Vars)
switch q.qt {
case qcode.QTQuery:
if sg.abacEnabled {
stmts, err = sg.buildMultiStmt(qb, vars)
} else {
stmts, err = sg.buildRoleStmt(qb, vars, role)
}
case qcode.QTMutation:
stmts, err = sg.buildRoleStmt(qb, vars, role)
}
if err != nil {
sg.log.Printf("WRN %s %s: %v", q.qt, q.ai.Name, err)
}
q.st = stmts[0]
q.roleArg = len(stmts) > 1
q.sd, err = sg.db.Prepare(q.st.sql)
if err != nil {
q.err = fmt.Errorf("prepare failed: %v: %s", err, q.st.sql)
}
}
func (sg *SuperGraph) initPrepared() error {
if sg.allowList.IsPersist() { if sg.allowList.IsPersist() {
return nil return nil
} }
sg.prepared = make(map[string]*preparedItem)
tx, err := sg.db.BeginTx(ct, nil) if err := sg.prepareRoleStmt(); err != nil {
if err != nil { return fmt.Errorf("role query: %w", err)
return err
}
defer tx.Rollback() //nolint: errcheck
if err = sg.prepareRoleStmt(tx); err != nil {
return fmt.Errorf("prepareRoleStmt: %w", err)
} }
if err := tx.Commit(); err != nil { sg.queries = make(map[uint64]query)
return err
}
success := 0
list, err := sg.allowList.Load() list, err := sg.allowList.Load()
if err != nil { if err != nil {
return err return err
} }
h := maphash.Hash{}
h.SetSeed(sg.hashSeed)
for _, v := range list { for _, v := range list {
if len(v.Query) == 0 { if len(v.Query) == 0 {
continue continue
} }
qt := qcode.GetQType(v.Query)
err := sg.prepareStmt(v) switch qt {
if err != nil { case qcode.QTQuery:
return err sg.queries[queryID(&h, v.Name, "user")] = query{ai: v, qt: qt}
} else {
success++
}
}
sg.log.Printf("INF allow list: prepared %d / %d queries", success, len(list)) if sg.anonExists {
sg.queries[queryID(&h, v.Name, "anon")] = query{ai: v, qt: qt}
return nil
}
func (sg *SuperGraph) prepareStmt(item allow.Item) error {
query := item.Query
qb := []byte(query)
vars := item.Vars
qt := qcode.GetQType(query)
ct := context.Background()
switch qt {
case qcode.QTQuery:
var stmts1 []stmt
var err error
if sg.abacEnabled {
stmts1, err = sg.buildMultiStmt(qb, vars)
} else {
stmts1, err = sg.buildRoleStmt(qb, vars, "user")
}
if err == nil {
if err = sg.prepare(ct, stmts1, stmtHash(item.Name, "user")); err != nil {
return err
} }
} else {
sg.log.Printf("WRN query %s: %v", item.Name, err)
}
if sg.anonExists { case qcode.QTMutation:
stmts2, err := sg.buildRoleStmt(qb, vars, "anon") for _, role := range sg.conf.Roles {
sg.queries[queryID(&h, v.Name, role.Name)] = query{ai: v, qt: qt}
if err == nil {
if err = sg.prepare(ct, stmts2, stmtHash(item.Name, "anon")); err != nil {
return err
}
} else {
sg.log.Printf("WRN query %s: %v", item.Name, err)
}
}
case qcode.QTMutation:
for _, role := range sg.conf.Roles {
stmts, err := sg.buildRoleStmt(qb, vars, role.Name)
if err == nil {
if err = sg.prepare(ct, stmts, stmtHash(item.Name, role.Name)); err != nil {
return err
}
} else {
sg.log.Printf("WRN mutation %s: %v", item.Name, err)
} }
} }
} }
@ -123,22 +98,8 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error {
return nil return nil
} }
func (sg *SuperGraph) prepare(ct context.Context, st []stmt, key string) error {
sd, err := sg.db.PrepareContext(ct, st[0].sql)
if err != nil {
return fmt.Errorf("prepare failed: %v: %s", err, st[0].sql)
}
sg.prepared[key] = &preparedItem{
sd: sd,
st: st[0],
roleArg: len(st) > 1,
}
return nil
}
// nolint: errcheck // nolint: errcheck
func (sg *SuperGraph) prepareRoleStmt(tx *sql.Tx) error { func (sg *SuperGraph) prepareRoleStmt() error {
var err error var err error
if !sg.abacEnabled { if !sg.abacEnabled {
@ -165,11 +126,11 @@ func (sg *SuperGraph) prepareRoleStmt(tx *sql.Tx) error {
} }
io.WriteString(w, ` ELSE $2 END) FROM (`) io.WriteString(w, ` ELSE $2 END) FROM (`)
io.WriteString(w, sg.conf.RolesQuery) io.WriteString(w, rq)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `) io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `) io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
sg.getRole, err = tx.Prepare(w.String()) sg.getRole, err = sg.db.Prepare(w.String())
if err != nil { if err != nil {
return err return err
} }
@ -200,9 +161,11 @@ func (sg *SuperGraph) initAllowList() error {
} }
// nolint: errcheck // nolint: errcheck
func stmtHash(name string, role string) string { func queryID(h *maphash.Hash, name string, role string) uint64 {
h := sha256.New() h.WriteString(name)
io.WriteString(h, strings.ToLower(name)) h.WriteString(role)
io.WriteString(h, role) v := h.Sum64()
return hex.EncodeToString(h.Sum(nil)) h.Reset()
return v
} }

View File

@ -4,10 +4,10 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"hash/maphash"
"net/http" "net/http"
"sync" "sync"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/core/internal/qcode" "github.com/dosco/super-graph/core/internal/qcode"
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
) )
@ -16,12 +16,13 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
var err error var err error
sel := st.qc.Selects sel := st.qc.Selects
h := xxhash.New() h := maphash.Hash{}
h.SetSeed(sg.hashSeed)
// fetch the field name used within the db response json // fetch the field name used within the db response json
// that are used to mark insertion points and the mapping between // that are used to mark insertion points and the mapping between
// those field names and their select objects // those field names and their select objects
fids, sfmap := sg.parentFieldIds(h, sel, st.md.Skipped) fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped())
// fetch the field values of the marked insertion points // fetch the field values of the marked insertion points
// these values contain the id to be used with fetching remote data // these values contain the id to be used with fetching remote data
@ -30,10 +31,10 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
switch { switch {
case len(from) == 1: case len(from) == 1:
to, err = sg.resolveRemote(hdr, h, from[0], sel, sfmap) to, err = sg.resolveRemote(hdr, &h, from[0], sel, sfmap)
case len(from) > 1: case len(from) > 1:
to, err = sg.resolveRemotes(hdr, h, from, sel, sfmap) to, err = sg.resolveRemotes(hdr, &h, from, sel, sfmap)
default: default:
return nil, errors.New("something wrong no remote ids found in db response") return nil, errors.New("something wrong no remote ids found in db response")
@ -55,7 +56,7 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
func (sg *SuperGraph) resolveRemote( func (sg *SuperGraph) resolveRemote(
hdr http.Header, hdr http.Header,
h *xxhash.Digest, h *maphash.Hash,
field jsn.Field, field jsn.Field,
sel []qcode.Select, sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
@ -66,7 +67,8 @@ func (sg *SuperGraph) resolveRemote(
to := toA[:1] to := toA[:1]
// use the json key to find the related Select object // use the json key to find the related Select object
k1 := xxhash.Sum64(field.Key) _, _ = h.Write(field.Key)
k1 := h.Sum64()
s, ok := sfmap[k1] s, ok := sfmap[k1]
if !ok { if !ok {
@ -117,7 +119,7 @@ func (sg *SuperGraph) resolveRemote(
func (sg *SuperGraph) resolveRemotes( func (sg *SuperGraph) resolveRemotes(
hdr http.Header, hdr http.Header,
h *xxhash.Digest, h *maphash.Hash,
from []jsn.Field, from []jsn.Field,
sel []qcode.Select, sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
@ -134,7 +136,8 @@ func (sg *SuperGraph) resolveRemotes(
for i, id := range from { for i, id := range from {
// use the json key to find the related Select object // use the json key to find the related Select object
k1 := xxhash.Sum64(id.Key) _, _ = h.Write(id.Key)
k1 := h.Sum64()
s, ok := sfmap[k1] s, ok := sfmap[k1]
if !ok { if !ok {
@ -192,7 +195,7 @@ func (sg *SuperGraph) resolveRemotes(
return to, cerr return to, cerr
} }
func (sg *SuperGraph) parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) ( func (sg *SuperGraph) parentFieldIds(h *maphash.Hash, sel []qcode.Select, skipped uint32) (
[][]byte, [][]byte,
map[uint64]*qcode.Select) { map[uint64]*qcode.Select) {
@ -227,8 +230,8 @@ func (sg *SuperGraph) parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipp
fm[n] = r.IDField fm[n] = r.IDField
n++ n++
k := xxhash.Sum64(r.IDField) _, _ = h.Write(r.IDField)
sm[k] = s sm[h.Sum64()] = s
} }
} }

View File

@ -2,11 +2,11 @@ package core
import ( import (
"fmt" "fmt"
"hash/maphash"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/core/internal/psql" "github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
) )
@ -19,7 +19,7 @@ type resolvFn struct {
func (sg *SuperGraph) initResolvers() error { func (sg *SuperGraph) initResolvers() error {
var err error var err error
sg.rmap = make(map[uint64]*resolvFn) sg.rmap = make(map[uint64]resolvFn)
for _, t := range sg.conf.Tables { for _, t := range sg.conf.Tables {
err = sg.initRemotes(t) err = sg.initRemotes(t)
@ -36,7 +36,8 @@ func (sg *SuperGraph) initResolvers() error {
} }
func (sg *SuperGraph) initRemotes(t Table) error { func (sg *SuperGraph) initRemotes(t Table) error {
h := xxhash.New() h := maphash.Hash{}
h.SetSeed(sg.hashSeed)
for _, r := range t.Remotes { for _, r := range t.Remotes {
// defines the table column to be used as an id in the // defines the table column to be used as an id in the
@ -75,17 +76,18 @@ func (sg *SuperGraph) initRemotes(t Table) error {
path = append(path, []byte(p)) path = append(path, []byte(p))
} }
rf := &resolvFn{ rf := resolvFn{
IDField: []byte(idk), IDField: []byte(idk),
Path: path, Path: path,
Fn: fn, Fn: fn,
} }
// index resolver obj by parent and child names // index resolver obj by parent and child names
sg.rmap[mkkey(h, r.Name, t.Name)] = rf sg.rmap[mkkey(&h, r.Name, t.Name)] = rf
// index resolver obj by IDField // index resolver obj by IDField
sg.rmap[xxhash.Sum64(rf.IDField)] = rf _, _ = h.Write(rf.IDField)
sg.rmap[h.Sum64()] = rf
} }
return nil return nil

View File

@ -1,11 +1,9 @@
package core package core
import ( import "hash/maphash"
"github.com/cespare/xxhash/v2"
)
// nolint: errcheck // nolint: errcheck
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { func mkkey(h *maphash.Hash, k1 string, k2 string) uint64 {
h.WriteString(k1) h.WriteString(k1)
h.WriteString(k2) h.WriteString(k2)
v := h.Sum64() v := h.Sum64()

View File

@ -55,6 +55,30 @@ query {
} }
``` ```
### Fragments
Fragments make it easy to build large complex queries with small composible and re-usable fragment blocks.
```graphql
query {
users {
...userFields2
...userFields1
picture_url
}
}
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}
```
### Sorting ### Sorting
To sort or ordering results just use the `order_by` argument. This can be combined with `where`, `search`, etc to build complex queries to fit you needs. To sort or ordering results just use the `order_by` argument. This can be combined with `where`, `search`, etc to build complex queries to fit you needs.

View File

@ -4,6 +4,8 @@ title: Introduction
sidebar_label: Introduction sidebar_label: Introduction
--- ---
import useBaseUrl from '@docusaurus/useBaseUrl'; // Add to the top of the file below the front matter.
Super Graph is a service that instantly and without code gives you a high performance and secure GraphQL API. Your GraphQL queries are auto translated into a single fast SQL query. No more spending weeks or months writing backend API code. Just make the query you need and Super Graph will do the rest. Super Graph is a service that instantly and without code gives you a high performance and secure GraphQL API. Your GraphQL queries are auto translated into a single fast SQL query. No more spending weeks or months writing backend API code. Just make the query you need and Super Graph will do the rest.
Super Graph has a rich feature set like integrating with your existing Ruby on Rails apps, joining your DB with data from remote APIs, Role and Attribute based access control, Support for JWT tokens, DB migrations, seeding and a lot more. Super Graph has a rich feature set like integrating with your existing Ruby on Rails apps, joining your DB with data from remote APIs, Role and Attribute based access control, Support for JWT tokens, DB migrations, seeding and a lot more.
@ -134,3 +136,9 @@ mutation {
} }
} }
``` ```
### Built-in GraphQL Editor
Quickly craft and test your queries with a full-featured GraphQL editor. Auto-complete and schema documentation is automatically available.
<img alt="Zipkin Traces" src={useBaseUrl("img/webui.jpg")} />

View File

@ -95,7 +95,7 @@ auth:
type: jwt type: jwt
jwt: jwt:
# the two providers are 'auth0' and 'none' # valid providers are auth0, firebase and none
provider: auth0 provider: auth0
secret: abc335bfcfdb04e50db5bb0a4d67ab9 secret: abc335bfcfdb04e50db5bb0a4d67ab9
public_key_file: /secrets/public_key.pem public_key_file: /secrets/public_key.pem
@ -108,6 +108,19 @@ We can get the JWT token either from the `authorization` header where we expect
For validation a `secret` or a public key (ecdsa or rsa) is required. When using public keys they have to be in a PEM format file. For validation a `secret` or a public key (ecdsa or rsa) is required. When using public keys they have to be in a PEM format file.
### Firebase Auth
```yaml
auth:
type: jwt
jwt:
provider: firebase
audience: <firebase-project-id>
```
Firebase auth also uses JWT the keys are auto-fetched from Google and used according to their documentation mechanism. The `audience` config value needs to be set to your project id and everything else is taken care for you.
### HTTP Headers ### HTTP Headers
```yaml ```yaml

View File

@ -0,0 +1,13 @@
---
id: webui
title: Web UI / GraphQL Editor
sidebar_label: Web UI
---
import useBaseUrl from '@docusaurus/useBaseUrl'; // Add to the top of the file below the front matter.
<img alt="Zipkin Traces" src={useBaseUrl("img/webui.jpg")} />
Super Graph comes with a build-in GraphQL editor that only runs in development. Use it to craft your queries and copy-paste them into you're app once you're ready. The editor supports auto-completation and schema documentation. This makes it super easy to craft and test your query all in one go without knowing anything about the underlying database structure.
You can even set query variables or http headers as required. To simulate an authenticated user set the http header `"X-USER-ID": 5` to the user id of the user you want to test with.

View File

@ -36,8 +36,8 @@ module.exports = {
position: "left", position: "left",
}, },
{ {
label: "Art Compute", label: "AbtCode",
href: "https://artcompute.com/s/super-graph", href: "https://abtcode.com/s/super-graph",
position: "left", position: "left",
}, },
], ],

View File

@ -3,6 +3,7 @@ module.exports = {
Docusaurus: [ Docusaurus: [
"home", "home",
"intro", "intro",
"webui",
"start", "start",
"why", "why",
"graphql", "graphql",

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

3
go.mod
View File

@ -12,13 +12,11 @@ require (
github.com/adjust/gorails v0.0.0-20171013043634-2786ed0c03d3 github.com/adjust/gorails v0.0.0-20171013043634-2786ed0c03d3
github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b
github.com/brianvoe/gofakeit/v5 v5.2.0 github.com/brianvoe/gofakeit/v5 v5.2.0
github.com/cespare/xxhash/v2 v2.1.1
github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a
github.com/daaku/go.zipexe v1.0.1 // indirect github.com/daaku/go.zipexe v1.0.1 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dlclark/regexp2 v1.2.0 // indirect github.com/dlclark/regexp2 v1.2.0 // indirect
github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0 github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0
github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813 // indirect
github.com/fsnotify/fsnotify v1.4.9 github.com/fsnotify/fsnotify v1.4.9
github.com/garyburd/redigo v1.6.0 github.com/garyburd/redigo v1.6.0
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
@ -30,7 +28,6 @@ require (
github.com/openzipkin/zipkin-go v0.2.2 github.com/openzipkin/zipkin-go v0.2.2
github.com/pelletier/go-toml v1.7.0 // indirect github.com/pelletier/go-toml v1.7.0 // indirect
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/common v0.4.0
github.com/rs/cors v1.7.0 github.com/rs/cors v1.7.0
github.com/spf13/afero v1.2.2 // indirect github.com/spf13/afero v1.2.2 // indirect
github.com/spf13/cast v1.3.1 // indirect github.com/spf13/cast v1.3.1 // indirect

4
go.sum
View File

@ -55,8 +55,6 @@ github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a h1:WVu7r2vwlrBVmunbSSU+9/3M3AgsQyhE49CKDjHiFq4= github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a h1:WVu7r2vwlrBVmunbSSU+9/3M3AgsQyhE49CKDjHiFq4=
github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a/go.mod h1:wQjjxFMFyMlsWh4Z3nMuHQtevD4Ul9UVQSnz1JOLuP8= github.com/chirino/graphql v0.0.0-20200430165312-293648399b1a/go.mod h1:wQjjxFMFyMlsWh4Z3nMuHQtevD4Ul9UVQSnz1JOLuP8=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
@ -87,8 +85,6 @@ github.com/dlclark/regexp2 v1.2.0 h1:8sAhBGEM0dRWogWqWyQeIJnxjWO6oIjl8FKqREDsGfk
github.com/dlclark/regexp2 v1.2.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= github.com/dlclark/regexp2 v1.2.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0 h1:EfFAcaAwGai/wlDCWwIObHBm3T2C2CCPX/SaS0fpOJ4= github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0 h1:EfFAcaAwGai/wlDCWwIObHBm3T2C2CCPX/SaS0fpOJ4=
github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0/go.mod h1:Mw6PkjjMXWbTj+nnj4s3QPXq1jaT0s5pC0iFD4+BOAA= github.com/dop251/goja v0.0.0-20200424152103-d0b8fda54cd0/go.mod h1:Mw6PkjjMXWbTj+nnj4s3QPXq1jaT0s5pC0iFD4+BOAA=
github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813 h1:NgO45/5mBLRVfiXerEFzH6ikcZ7DNRPS639xFg3ENzU=
github.com/dvyukov/go-fuzz v0.0.0-20200318091601-be3528f3a813/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=

View File

@ -82,8 +82,6 @@ func graphQLFunc(sg *core.SuperGraph, query string, data interface{}, opt map[st
if v, ok := opt["user_id"]; ok && len(v) != 0 { if v, ok := opt["user_id"]; ok && len(v) != 0 {
ct = context.WithValue(ct, core.UserIDKey, v) ct = context.WithValue(ct, core.UserIDKey, v)
} else {
ct = context.WithValue(ct, core.UserIDKey, "-1")
} }
// var role string // var role string

View File

@ -66,7 +66,7 @@ func newViper(configPath, configFile string) *viper.Viper {
vi.SetDefault("host_port", "0.0.0.0:8080") vi.SetDefault("host_port", "0.0.0.0:8080")
vi.SetDefault("web_ui", false) vi.SetDefault("web_ui", false)
vi.SetDefault("enable_tracing", false) vi.SetDefault("enable_tracing", false)
vi.SetDefault("auth_fail_block", "always") vi.SetDefault("auth_fail_block", false)
vi.SetDefault("seed_file", "seed.js") vi.SetDefault("seed_file", "seed.js")
vi.SetDefault("default_block", true) vi.SetDefault("default_block", true)

View File

@ -32,6 +32,7 @@ type Auth struct {
Secret string Secret string
PubKeyFile string `mapstructure:"public_key_file"` PubKeyFile string `mapstructure:"public_key_file"`
PubKeyType string `mapstructure:"public_key_type"` PubKeyType string `mapstructure:"public_key_type"`
Audience string `mapstructure:"audience"`
} }
Header struct { Header struct {

View File

@ -2,19 +2,32 @@ package auth
import ( import (
"context" "context"
"encoding/json"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time"
jwt "github.com/dgrijalva/jwt-go" jwt "github.com/dgrijalva/jwt-go"
"github.com/dosco/super-graph/core" "github.com/dosco/super-graph/core"
) )
const ( const (
authHeader = "Authorization" authHeader = "Authorization"
jwtAuth0 int = iota + 1 jwtAuth0 int = iota + 1
jwtFirebase int = iota + 2
firebasePKEndpoint = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"
firebaseIssuerPrefix = "https://securetoken.google.com/"
) )
type firebasePKCache struct {
PublicKeys map[string]string
Expiration time.Time
}
var firebasePublicKeys firebasePKCache
func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
var key interface{} var key interface{}
var jwtProvider int var jwtProvider int
@ -23,6 +36,8 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
if ac.JWT.Provider == "auth0" { if ac.JWT.Provider == "auth0" {
jwtProvider = jwtAuth0 jwtProvider = jwtAuth0
} else if ac.JWT.Provider == "firebase" {
jwtProvider = jwtFirebase
} }
secret := ac.JWT.Secret secret := ac.JWT.Secret
@ -56,6 +71,7 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var tok string var tok string
if len(cookie) != 0 { if len(cookie) != 0 {
@ -74,9 +90,16 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
tok = ah[7:] tok = ah[7:]
} }
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) { var keyFunc jwt.Keyfunc
return key, nil if jwtProvider == jwtFirebase {
}) keyFunc = firebaseKeyFunction
} else {
keyFunc = func(token *jwt.Token) (interface{}, error) {
return key, nil
}
}
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, keyFunc)
if err != nil { if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -86,12 +109,20 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
if claims, ok := token.Claims.(*jwt.StandardClaims); ok { if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
ctx := r.Context() ctx := r.Context()
if ac.JWT.Audience != "" && claims.Audience != ac.JWT.Audience {
next.ServeHTTP(w, r)
return
}
if jwtProvider == jwtAuth0 { if jwtProvider == jwtAuth0 {
sub := strings.Split(claims.Subject, "|") sub := strings.Split(claims.Subject, "|")
if len(sub) != 2 { if len(sub) != 2 {
ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0]) ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0])
ctx = context.WithValue(ctx, core.UserIDKey, sub[1]) ctx = context.WithValue(ctx, core.UserIDKey, sub[1])
} }
} else if jwtProvider == jwtFirebase &&
claims.Issuer == firebaseIssuerPrefix+ac.JWT.Audience {
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
} else { } else {
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject) ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
} }
@ -103,3 +134,92 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}, nil }, nil
} }
type firebaseKeyError struct {
Err error
Message string
}
func (e *firebaseKeyError) Error() string {
return e.Message + " " + e.Err.Error()
}
func firebaseKeyFunction(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"]
if !ok {
return nil, &firebaseKeyError{
Message: "Error 'kid' header not found in token",
}
}
if firebasePublicKeys.Expiration.Before(time.Now()) {
resp, err := http.Get(firebasePKEndpoint)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error connecting to firebase certificate server",
Err: err,
}
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error reading firebase certificate server response",
Err: err,
}
}
cachePolicy := resp.Header.Get("cache-control")
ageIndex := strings.Index(cachePolicy, "max-age=")
if ageIndex < 0 {
return nil, &firebaseKeyError{
Message: "Error parsing cache-control header: 'max-age=' not found",
}
}
ageToEnd := cachePolicy[ageIndex+8:]
endIndex := strings.Index(ageToEnd, ",")
if endIndex < 0 {
endIndex = len(ageToEnd) - 1
}
ageString := ageToEnd[:endIndex]
age, err := strconv.ParseInt(ageString, 10, 64)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error parsing max-age cache policy",
Err: err,
}
}
expiration := time.Now().Add(time.Duration(time.Duration(age) * time.Second))
err = json.Unmarshal(data, &firebasePublicKeys.PublicKeys)
if err != nil {
firebasePublicKeys = firebasePKCache{}
return nil, &firebaseKeyError{
Message: "Error unmarshalling firebase public key json",
Err: err,
}
}
firebasePublicKeys.Expiration = expiration
}
if key, found := firebasePublicKeys.PublicKeys[kid.(string)]; found {
k, err := jwt.ParseRSAPublicKeyFromPEM([]byte(key))
return k, err
}
return nil, &firebaseKeyError{
Message: "Error no matching public key for kid supplied in jwt",
}
}

View File

@ -6,9 +6,11 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"sort"
"strconv" "strconv"
"strings" "strings"
"text/template" "text/template"
@ -105,39 +107,40 @@ func (defaultMigratorFS) Glob(pattern string) ([]string, error) {
func FindMigrationsEx(path string, fs MigratorFS) ([]string, error) { func FindMigrationsEx(path string, fs MigratorFS) ([]string, error) {
path = strings.TrimRight(path, string(filepath.Separator)) path = strings.TrimRight(path, string(filepath.Separator))
fileInfos, err := fs.ReadDir(path) files, err := ioutil.ReadDir(path)
if err != nil { if err != nil {
return nil, err log.Fatal(err)
} }
paths := make([]string, 0, len(fileInfos)) fm := make(map[int]string, len(files))
for _, fi := range fileInfos { keys := make([]int, 0, len(files))
for _, fi := range files {
if fi.IsDir() { if fi.IsDir() {
continue continue
} }
matches := migrationPattern.FindStringSubmatch(fi.Name()) matches := migrationPattern.FindStringSubmatch(fi.Name())
if len(matches) != 2 { if len(matches) != 2 {
continue continue
} }
n, err := strconv.ParseInt(matches[1], 10, 32) n, err := strconv.Atoi(matches[1])
if err != nil { if err != nil {
// The regexp already validated that the prefix is all digits so this *should* never fail // The regexp already validated that the prefix is all digits so this *should* never fail
return nil, err return nil, err
} }
mcount := len(paths) fm[n] = filepath.Join(path, fi.Name())
keys = append(keys, n)
}
if n < int64(mcount) { sort.Ints(keys)
return nil, fmt.Errorf("Duplicate migration %d", n)
}
if int64(mcount) < n { paths := make([]string, 0, len(keys))
return nil, fmt.Errorf("Missing migration %d", mcount) for _, k := range keys {
} paths = append(paths, fm[k])
paths = append(paths, filepath.Join(path, fi.Name()))
} }
return paths, nil return paths, nil

13
jsn/bench.1 Normal file
View File

@ -0,0 +1,13 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/jsn
BenchmarkGet
BenchmarkGet-16 13898 85293 ns/op 3328 B/op 2 allocs/op
BenchmarkFilter
BenchmarkFilter-16 189328 6341 ns/op 448 B/op 1 allocs/op
BenchmarkStrip
BenchmarkStrip-16 219765 5543 ns/op 224 B/op 1 allocs/op
BenchmarkReplace
BenchmarkReplace-16 100899 12022 ns/op 416 B/op 1 allocs/op
PASS
ok github.com/dosco/super-graph/jsn 6.029s

View File

@ -2,17 +2,19 @@ package jsn
import ( import (
"bytes" "bytes"
"hash/maphash"
"github.com/cespare/xxhash/v2"
) )
// Filter function filters the JSON keeping only the provided keys and removing all others // Filter function filters the JSON keeping only the provided keys and removing all others
func Filter(w *bytes.Buffer, b []byte, keys []string) error { func Filter(w *bytes.Buffer, b []byte, keys []string) error {
var err error var err error
kmap := make(map[uint64]struct{}, len(keys)) kmap := make(map[uint64]struct{}, len(keys))
h := maphash.Hash{}
for i := range keys { for i := range keys {
kmap[xxhash.Sum64String(keys[i])] = struct{}{} _, _ = h.WriteString(keys[i])
kmap[h.Sum64()] = struct{}{}
h.Reset()
} }
// is an list // is an list
@ -132,7 +134,11 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error {
cb := b[s:(e + 1)] cb := b[s:(e + 1)]
e = 0 e = 0
if _, ok := kmap[xxhash.Sum64(k)]; !ok { _, _ = h.Write(k)
_, ok := kmap[h.Sum64()]
h.Reset()
if !ok {
continue continue
} }

View File

@ -1,7 +1,7 @@
package jsn package jsn
import ( import (
"github.com/cespare/xxhash/v2" "hash/maphash"
) )
const ( const (
@ -41,9 +41,12 @@ func Value(b []byte) []byte {
// Keys function fetches values for the provided keys // Keys function fetches values for the provided keys
func Get(b []byte, keys [][]byte) []Field { func Get(b []byte, keys [][]byte) []Field {
kmap := make(map[uint64]struct{}, len(keys)) kmap := make(map[uint64]struct{}, len(keys))
h := maphash.Hash{}
for i := range keys { for i := range keys {
kmap[xxhash.Sum64(keys[i])] = struct{}{} _, _ = h.Write(keys[i])
kmap[h.Sum64()] = struct{}{}
h.Reset()
} }
res := make([]Field, 0, 20) res := make([]Field, 0, 20)
@ -141,7 +144,9 @@ func Get(b []byte, keys [][]byte) []Field {
} }
if e != 0 { if e != 0 {
_, ok := kmap[xxhash.Sum64(k)] _, _ = h.Write(k)
_, ok := kmap[h.Sum64()]
h.Reset()
if ok { if ok {
res = append(res, Field{k, b[s:(e + 1)]}) res = append(res, Field{k, b[s:(e + 1)]})

View File

@ -3,8 +3,7 @@ package jsn
import ( import (
"bytes" "bytes"
"errors" "errors"
"hash/maphash"
"github.com/cespare/xxhash/v2"
) )
// Replace function replaces key-value pairs provided in the `from` argument with those in the `to` argument // Replace function replaces key-value pairs provided in the `from` argument with those in the `to` argument
@ -18,7 +17,7 @@ func Replace(w *bytes.Buffer, b []byte, from, to []Field) error {
return err return err
} }
h := xxhash.New() h := maphash.Hash{}
tmap := make(map[uint64]int, len(from)) tmap := make(map[uint64]int, len(from))
for i, f := range from { for i, f := range from {