Compare commits

...

4 Commits

27 changed files with 1192 additions and 524 deletions

View File

@ -12,7 +12,8 @@ import (
// to a prepared statement.
func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
vars := make([]interface{}, len(md.Params))
params := md.Params()
vars := make([]interface{}, len(params))
var fields map[string]json.RawMessage
var err error
@ -25,7 +26,7 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
}
}
for i, p := range md.Params {
for i, p := range params {
switch p.Name {
case "user_id":
if v := c.Value(UserIDKey); v != nil {

View File

@ -88,6 +88,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
stmts := make([]stmt, 0, len(sg.conf.Roles))
w := &bytes.Buffer{}
md := psql.Metadata{}
for i := 0; i < len(sg.conf.Roles); i++ {
role := &sg.conf.Roles[i]
@ -105,16 +106,18 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
stmts = append(stmts, stmt{role: role, qc: qc})
s := &stmts[len(stmts)-1]
s.md, err = sg.pc.Compile(w, qc, psql.Variables(vm))
md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
if err != nil {
return nil, err
}
s.sql = w.String()
s.md = md
w.Reset()
}
sql, err := sg.renderUserQuery(stmts)
sql, err := sg.renderUserQuery(md, stmts)
if err != nil {
return nil, err
}
@ -124,7 +127,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
}
//nolint: errcheck
func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
func (sg *SuperGraph) renderUserQuery(md psql.Metadata, stmts []stmt) (string, error) {
w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
@ -142,7 +145,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
}
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, sg.conf.RolesQuery)
md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
@ -158,7 +161,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
}
io.WriteString(w, ` ELSE 'user' END) FROM (`)
io.WriteString(w, sg.conf.RolesQuery)
md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)

View File

@ -197,30 +197,26 @@ func (c *Config) AddRoleTable(role string, table string, conf interface{}) error
// ReadInConfig function reads in the config file for the environment specified in the GO_ENV
// environment variable. This is the best way to create a new Super Graph config.
func ReadInConfig(configFile string) (*Config, error) {
cpath := path.Dir(configFile)
cfile := path.Base(configFile)
vi := newViper(cpath, cfile)
cp := path.Dir(configFile)
vi := newViper(cp, path.Base(configFile))
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
inherits := vi.GetString("inherits")
if inherits != "" {
vi = newViper(cpath, inherits)
if pcf := vi.GetString("inherits"); pcf != "" {
cf := vi.ConfigFileUsed()
vi = newViper(cp, pcf)
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
if vi.IsSet("inherits") {
return nil, fmt.Errorf("inherited config (%s) cannot itself inherit (%s)",
inherits,
vi.GetString("inherits"))
if v := vi.GetString("inherits"); v != "" {
return nil, fmt.Errorf("inherited config (%s) cannot itself inherit (%s)", pcf, v)
}
vi.SetConfigName(cfile)
vi.SetConfigFile(cf)
if err := vi.MergeInConfig(); err != nil {
return nil, err
@ -234,7 +230,7 @@ func ReadInConfig(configFile string) (*Config, error) {
}
if c.AllowListFile == "" {
c.AllowListFile = path.Join(cpath, "allow.list")
c.AllowListFile = path.Join(cp, "allow.list")
}
return c, nil
@ -248,7 +244,7 @@ func newViper(configPath, configFile string) *viper.Viper {
vi.AutomaticEnv()
if filepath.Ext(configFile) != "" {
vi.SetConfigFile(configFile)
vi.SetConfigFile(path.Join(configPath, configFile))
} else {
vi.SetConfigName(configFile)
vi.AddConfigPath(configPath)

View File

@ -125,7 +125,7 @@ func (c *scontext) execQuery() ([]byte, error) {
return nil, err
}
if len(data) == 0 || st.md.Skipped == 0 {
if len(data) == 0 || st.md.Skipped() == 0 {
return data, nil
}
@ -196,8 +196,6 @@ func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
return nil, nil, err
}
fmt.Println(">>", varsList)
if useTx {
row = tx.Stmt(q.sd).QueryRow(varsList...)
} else {

View File

@ -75,13 +75,22 @@ func (sg *SuperGraph) initConfig() error {
if c.RolesQuery == "" {
sg.log.Printf("INF roles_query not defined: attribute based access control disabled")
} else {
n := 0
for k, v := range sg.roles {
if k == "user" || k == "anon" {
n++
} else if v.Match != "" {
n++
}
}
sg.abacEnabled = (n > 2)
if !sg.abacEnabled {
sg.log.Printf("WRN attribute based access control disabled: no custom roles found (with 'match' defined)")
}
}
_, userExists := sg.roles["user"]
_, sg.anonExists = sg.roles["anon"]
sg.abacEnabled = userExists && c.RolesQuery != ""
return nil
}

View File

@ -10,21 +10,23 @@ import (
"os"
"sort"
"strings"
"text/scanner"
"github.com/chirino/graphql/schema"
"github.com/dosco/super-graph/jsn"
)
const (
AL_QUERY int = iota + 1
AL_VARS
expComment = iota + 1
expVar
expQuery
)
type Item struct {
Name string
key string
Query string
Vars json.RawMessage
Vars string
Comment string
}
@ -126,121 +128,101 @@ func (al *List) Set(vars []byte, query, comment string) error {
return errors.New("empty query")
}
var q string
for i := 0; i < len(query); i++ {
c := query[i]
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
q = query
break
} else if c == '{' {
q = "query " + query
break
}
}
al.saveChan <- Item{
Comment: comment,
Query: q,
Vars: vars,
Query: query,
Vars: string(vars),
}
return nil
}
func (al *List) Load() ([]Item, error) {
var list []Item
varString := "variables"
b, err := ioutil.ReadFile(al.filepath)
if err != nil {
return list, err
return nil, err
}
if len(b) == 0 {
return list, nil
return parse(string(b), al.filepath)
}
func parse(b string, filename string) ([]Item, error) {
var items []Item
var s scanner.Scanner
s.Init(strings.NewReader(b))
s.Filename = filename
s.Mode ^= scanner.SkipComments
var op, sp scanner.Position
var item Item
newComment := false
st := expComment
for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() {
txt := s.TokenText()
switch {
case strings.HasPrefix(txt, "/*"):
if st == expQuery {
v := b[sp.Offset:s.Pos().Offset]
item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
items = append(items, item)
}
item = Item{Comment: strings.TrimSpace(txt[2 : len(txt)-2])}
sp = s.Pos()
st = expComment
newComment = true
case !newComment && strings.HasPrefix(txt, "#"):
if st == expQuery {
v := b[sp.Offset:s.Pos().Offset]
item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
items = append(items, item)
}
item = Item{}
sp = s.Pos()
st = expComment
case strings.HasPrefix(txt, "variables"):
if st == expComment {
v := b[sp.Offset:s.Pos().Offset]
item.Comment = strings.TrimSpace(v[:strings.IndexByte(v, '\n')])
}
sp = s.Pos()
st = expVar
case isGraphQL(txt):
if st == expVar {
v := b[sp.Offset:s.Pos().Offset]
item.Vars = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
}
sp = op
st = expQuery
}
op = s.Pos()
}
var comment bytes.Buffer
var varBytes []byte
itemMap := make(map[string]struct{})
s, e, c := 0, 0, 0
ty := 0
for {
fq := false
if c == 0 && b[e] == '#' {
s = e
for e < len(b) && b[e] != '\n' {
e++
}
if (e - s) > 2 {
comment.Write(b[(s + 1):(e + 1)])
}
}
if e >= len(b) {
break
}
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
if c == 0 {
s = e
}
ty = AL_QUERY
} else if matchPrefix(b, e, varString) {
if c == 0 {
s = e + len(varString) + 1
}
ty = AL_VARS
} else if b[e] == '{' {
c++
} else if b[e] == '}' {
c--
if c == 0 {
if ty == AL_QUERY {
fq = true
} else if ty == AL_VARS {
varBytes = b[s:(e + 1)]
}
ty = 0
}
}
if fq {
query := string(b[s:(e + 1)])
name := QueryName(query)
key := strings.ToLower(name)
if _, ok := itemMap[key]; !ok {
v := Item{
Name: name,
key: key,
Query: query,
Vars: varBytes,
Comment: comment.String(),
}
list = append(list, v)
comment.Reset()
}
varBytes = nil
}
e++
if e >= len(b) {
break
}
if st == expQuery {
v := b[sp.Offset:s.Pos().Offset]
item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1])
items = append(items, item)
}
return list, nil
for i := range items {
items[i].Name = QueryName(items[i].Query)
items[i].key = strings.ToLower(items[i].Name)
}
return items, nil
}
func isGraphQL(s string) bool {
return strings.HasPrefix(s, "query") ||
strings.HasPrefix(s, "mutation") ||
strings.HasPrefix(s, "subscription")
}
func (al *List) save(item Item) error {
@ -297,57 +279,39 @@ func (al *List) save(item Item) error {
return strings.Compare(list[i].key, list[j].key) == -1
})
for _, v := range list {
cmtLines := strings.Split(v.Comment, "\n")
i := 0
for _, c := range cmtLines {
if c = strings.TrimSpace(c); c == "" {
continue
}
_, err := f.WriteString(fmt.Sprintf("# %s\n", c))
if err != nil {
return err
}
i++
}
if i != 0 {
if _, err := f.WriteString("\n"); err != nil {
return err
}
} else {
if _, err := f.WriteString(fmt.Sprintf("# Query named %s\n\n", v.Name)); err != nil {
return err
}
}
if len(v.Vars) != 0 && !bytes.Equal(v.Vars, []byte("{}")) {
for i, v := range list {
var vars string
if v.Vars != "" {
buf.Reset()
if err := jsn.Clear(&buf, v.Vars); err != nil {
return fmt.Errorf("failed to clean vars: %w", err)
if err := jsn.Clear(&buf, []byte(v.Vars)); err != nil {
continue
}
vj := json.RawMessage(buf.Bytes())
vj, err = json.MarshalIndent(vj, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal vars: %w", err)
if vj, err = json.MarshalIndent(vj, "", " "); err != nil {
continue
}
vars = string(vj)
}
list[i].Vars = vars
list[i].Comment = strings.TrimSpace(v.Comment)
}
_, err = f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
for _, v := range list {
if v.Comment != "" {
f.WriteString(fmt.Sprintf("/* %s */\n\n", v.Comment))
} else {
f.WriteString(fmt.Sprintf("/* %s */\n\n", v.Name))
}
if v.Vars != "" {
_, err = f.WriteString(fmt.Sprintf("variables %s\n\n", v.Vars))
if err != nil {
return err
}
}
if v.Query[0] == '{' {
_, err = f.WriteString(fmt.Sprintf("query %s\n\n", v.Query))
} else {
_, err = f.WriteString(fmt.Sprintf("%s\n\n", v.Query))
}
_, err = f.WriteString(fmt.Sprintf("%s\n\n", v.Query))
if err != nil {
return err
}

View File

@ -82,3 +82,160 @@ func TestGQLName5(t *testing.T) {
t.Fatal("Name should be empty, not ", name)
}
}
func TestParse1(t *testing.T) {
var al = `
# Hello world
variables {
"data": {
"slug": "",
"body": "",
"post": {
"connect": {
"slug": ""
}
}
}
}
mutation createComment {
comment(insert: $data) {
slug
body
createdAt: created_at
totalVotes: cached_votes_total
totalReplies: cached_replies_total
vote: comment_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}
# Query named createPost
query createPost {
post(insert: $data) {
slug
body
published
createdAt: created_at
totalVotes: cached_votes_total
totalComments: cached_comments_total
vote: post_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}`
_, err := parse(al, "allow.list")
if err != nil {
t.Fatal(err)
}
}
func TestParse2(t *testing.T) {
var al = `
/* Hello world */
variables {
"data": {
"slug": "",
"body": "",
"post": {
"connect": {
"slug": ""
}
}
}
}
mutation createComment {
comment(insert: $data) {
slug
body
createdAt: created_at
totalVotes: cached_votes_total
totalReplies: cached_replies_total
vote: comment_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}
/*
Query named createPost
*/
variables {
"data": {
"thread": {
"connect": {
"slug": ""
}
},
"slug": "",
"published": false,
"body": ""
}
}
query createPost {
post(insert: $data) {
slug
body
published
createdAt: created_at
totalVotes: cached_votes_total
totalComments: cached_comments_total
vote: post_vote(where: {user_id: {eq: $user_id}}) {
created_at
__typename
}
author: user {
slug
firstName: first_name
lastName: last_name
pictureURL: picture_url
bio
__typename
}
__typename
}
}`
_, err := parse(al, "allow.list")
if err != nil {
t.Fatal(err)
}
}

View File

@ -1,4 +1,3 @@
//nolint:errcheck
package psql
import (
@ -112,15 +111,15 @@ func (c *compilerContext) renderColumnSearchRank(sel *qcode.Select, ti *DBTableI
c.renderComma(columnsRendered)
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_rank(`)
_, _ = io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery(`)
_, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
} else {
io.WriteString(c.w, `, to_tsquery(`)
_, _ = io.WriteString(c.w, `, to_tsquery(`)
}
c.renderValueExp(Param{Name: arg.Val, Type: "string"})
io.WriteString(c.w, `))`)
c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
_, _ = io.WriteString(c.w, `))`)
alias(c.w, col.Name)
return nil
@ -137,15 +136,15 @@ func (c *compilerContext) renderColumnSearchHeadline(sel *qcode.Select, ti *DBTa
c.renderComma(columnsRendered)
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_headline(`)
_, _ = io.WriteString(c.w, `ts_headline(`)
colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery(`)
_, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
} else {
io.WriteString(c.w, `, to_tsquery(`)
_, _ = io.WriteString(c.w, `, to_tsquery(`)
}
c.renderValueExp(Param{Name: arg.Val, Type: "string"})
io.WriteString(c.w, `))`)
c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
_, _ = io.WriteString(c.w, `))`)
alias(c.w, col.Name)
return nil
@ -157,9 +156,9 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
}
c.renderComma(columnsRendered)
io.WriteString(c.w, `(`)
_, _ = io.WriteString(c.w, `(`)
squoted(c.w, ti.Name)
io.WriteString(c.w, ` :: text)`)
_, _ = io.WriteString(c.w, ` :: text)`)
alias(c.w, col.Name)
return nil
@ -169,9 +168,9 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`)
// io.WriteString(c.w, col.Name)
// io.WriteString(c.w, ` not defined'`)
// _, _ = io.WriteString(c.w, `'`)
// _, _ = io.WriteString(c.w, col.Name)
// _, _ = io.WriteString(c.w, ` not defined'`)
// alias(c.w, col.Name)
// }
@ -190,10 +189,10 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
c.renderComma(columnsRendered)
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Name, cn, col.Name)
io.WriteString(c.w, fn)
io.WriteString(c.w, `(`)
_, _ = io.WriteString(c.w, fn)
_, _ = io.WriteString(c.w, `(`)
colWithTable(c.w, ti.Name, cn)
io.WriteString(c.w, `)`)
_, _ = io.WriteString(c.w, `)`)
alias(c.w, col.Name)
return nil
@ -201,7 +200,7 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
func (c *compilerContext) renderComma(columnsRendered int) {
if columnsRendered != 0 {
io.WriteString(c.w, `, `)
_, _ = io.WriteString(c.w, `, `)
}
}

View File

@ -25,7 +25,7 @@ func (c *compilerContext) renderInsert(
if insert[0] == '[' {
io.WriteString(c.w, `json_array_elements(`)
}
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"})
c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
io.WriteString(c.w, ` :: json`)
if insert[0] == '[' {
io.WriteString(c.w, `)`)

View File

@ -0,0 +1,61 @@
package psql
import (
"io"
)
func (md *Metadata) RenderVar(w io.Writer, vv string) {
f, s := -1, 0
for i := range vv {
v := vv[i]
switch {
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
if (i - s) > 0 {
_, _ = io.WriteString(w, vv[s:i])
}
f = i
case (v < 'a' && v > 'z') &&
(v < 'A' && v > 'Z') &&
(v < '0' && v > '9') &&
v != '_' &&
f != -1 &&
(i-f) > 1:
md.renderValueExp(w, Param{Name: vv[f+1 : i]})
s = i
f = -1
}
}
if f != -1 && (len(vv)-f) > 1 {
md.renderValueExp(w, Param{Name: vv[f+1:]})
} else {
_, _ = io.WriteString(w, vv[s:])
}
}
func (md *Metadata) renderValueExp(w io.Writer, p Param) {
_, _ = io.WriteString(w, `$`)
if v, ok := md.pindex[p.Name]; ok {
int32String(w, int32(v))
} else {
md.params = append(md.params, p)
n := len(md.params)
if md.pindex == nil {
md.pindex = make(map[string]int)
}
md.pindex[p.Name] = n
int32String(w, int32(n))
}
}
func (md Metadata) Skipped() uint32 {
return md.skipped
}
func (md Metadata) Params() []Param {
return md.params
}

View File

@ -432,11 +432,11 @@ func (c *compilerContext) renderInsertUpdateColumns(
val := root.PresetMap[cn]
switch {
case ok && len(val) > 1 && val[0] == '$':
c.renderValueExp(Param{Name: val[1:], Type: col.Type})
c.md.renderValueExp(c.w, Param{Name: val[1:], Type: col.Type})
case ok && strings.HasPrefix(val, "sql:"):
io.WriteString(c.w, `(`)
c.renderVar(val[4:], c.renderValueExp)
c.md.RenderVar(c.w, val[4:])
io.WriteString(c.w, `)`)
case ok:

View File

@ -25,8 +25,8 @@ type Param struct {
}
type Metadata struct {
Skipped uint32
Params []Param
skipped uint32
params []Param
pindex map[string]int
}
@ -80,26 +80,30 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (Metadata, []byte
}
func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
return co.CompileWithMetadata(w, qc, vars, Metadata{})
}
func (co *Compiler) CompileWithMetadata(w io.Writer, qc *qcode.QCode, vars Variables, md Metadata) (Metadata, error) {
md.skipped = 0
if qc == nil {
return Metadata{}, fmt.Errorf("qcode is nil")
return md, fmt.Errorf("qcode is nil")
}
switch qc.Type {
case qcode.QTQuery:
return co.compileQuery(w, qc, vars)
return co.compileQueryWithMetadata(w, qc, vars, md)
case qcode.QTInsert,
qcode.QTUpdate,
qcode.QTDelete,
qcode.QTUpsert:
return co.compileMutation(w, qc, vars)
default:
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
}
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
}
func (co *Compiler) compileQuery(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
return co.compileQueryWithMetadata(w, qc, vars, Metadata{})
}
func (co *Compiler) compileQueryWithMetadata(
@ -176,7 +180,7 @@ func (co *Compiler) compileQueryWithMetadata(
}
for _, cid := range sel.Children {
if hasBit(c.md.Skipped, uint32(cid)) {
if hasBit(c.md.skipped, uint32(cid)) {
continue
}
child := &c.s[cid]
@ -354,7 +358,7 @@ func (c *compilerContext) initSelect(sel *qcode.Select, ti *DBTableInfo, vars Va
if _, ok := colmap[rel.Left.Col]; !ok {
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col})
colmap[rel.Left.Col] = struct{}{}
c.md.Skipped |= (1 << uint(id))
c.md.skipped |= (1 << uint(id))
}
default:
@ -622,7 +626,7 @@ func (c *compilerContext) renderJoinColumns(sel *qcode.Select, ti *DBTableInfo,
i := colsRendered
for _, id := range sel.Children {
if hasBit(c.md.Skipped, uint32(id)) {
if hasBit(c.md.skipped, uint32(id)) {
continue
}
childSel := &c.s[id]
@ -804,7 +808,7 @@ func (c *compilerContext) renderCursorCTE(sel *qcode.Select) error {
quoted(c.w, ob.Col)
}
io.WriteString(c.w, ` FROM string_to_array(`)
c.renderValueExp(Param{Name: "cursor", Type: "json"})
c.md.renderValueExp(c.w, Param{Name: "cursor", Type: "json"})
io.WriteString(c.w, `, ',') as a) `)
return nil
}
@ -1102,7 +1106,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, ti *DBTableInfo) error {
} else {
io.WriteString(c.w, `) @@ to_tsquery(`)
}
c.renderValueExp(Param{Name: ex.Val, Type: "string"})
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: "string"})
io.WriteString(c.w, `))`)
return nil
@ -1191,7 +1195,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
switch {
case ok && strings.HasPrefix(val, "sql:"):
io.WriteString(c.w, `(`)
c.renderVar(val[4:], c.renderValueExp)
c.md.RenderVar(c.w, val[4:])
io.WriteString(c.w, `)`)
case ok:
@ -1199,7 +1203,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
case ex.Op == qcode.OpIn || ex.Op == qcode.OpNotIn:
io.WriteString(c.w, `(ARRAY(SELECT json_array_elements_text(`)
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: true})
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: true})
io.WriteString(c.w, `))`)
io.WriteString(c.w, ` :: `)
@ -1208,7 +1212,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
return
default:
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: false})
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: false})
}
case qcode.ValRef:
@ -1222,54 +1226,6 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type)
}
func (c *compilerContext) renderValueExp(p Param) {
io.WriteString(c.w, `$`)
if v, ok := c.md.pindex[p.Name]; ok {
int32String(c.w, int32(v))
} else {
c.md.Params = append(c.md.Params, p)
n := len(c.md.Params)
if c.md.pindex == nil {
c.md.pindex = make(map[string]int)
}
c.md.pindex[p.Name] = n
int32String(c.w, int32(n))
}
}
func (c *compilerContext) renderVar(vv string, fn func(Param)) {
f, s := -1, 0
for i := range vv {
v := vv[i]
switch {
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
if (i - s) > 0 {
io.WriteString(c.w, vv[s:i])
}
f = i
case (v < 'a' && v > 'z') &&
(v < 'A' && v > 'Z') &&
(v < '0' && v > '9') &&
v != '_' &&
f != -1 &&
(i-f) > 1:
fn(Param{Name: vv[f+1 : i]})
s = i
f = -1
}
}
if f != -1 && (len(vv)-f) > 1 {
fn(Param{Name: vv[f+1:]})
} else {
io.WriteString(c.w, vv[s:])
}
}
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch {
case strings.HasPrefix(fn, "avg_"):

View File

@ -307,6 +307,80 @@ func multiRoot(t *testing.T) {
compileGQLToPSQL(t, gql, nil, "user")
}
func withFragment1(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields2 on user {
first_name
last_name
}`
compileGQLToPSQL(t, gql, nil, "anon")
}
func withFragment2(t *testing.T) {
gql := `
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}`
compileGQLToPSQL(t, gql, nil, "anon")
}
func withFragment3(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}
query {
users {
...userFields2
created_at
...userFields1
}
}
`
compileGQLToPSQL(t, gql, nil, "anon")
}
func withCursor(t *testing.T) {
gql := `query {
Products(
@ -400,6 +474,9 @@ func TestCompileQuery(t *testing.T) {
t.Run("queryWithVariables", queryWithVariables)
t.Run("withWhereOnRelations", withWhereOnRelations)
t.Run("multiRoot", multiRoot)
t.Run("withFragment1", withFragment1)
t.Run("withFragment2", withFragment2)
t.Run("withFragment3", withFragment3)
t.Run("jsonColumnAsTable", jsonColumnAsTable)
t.Run("withCursor", withCursor)
t.Run("nullForAuthRequiredInAnon", nullForAuthRequiredInAnon)

View File

@ -86,6 +86,12 @@ SELECT jsonb_build_object('product', "__sj_0"."json") as "__root" FROM (SELECT t
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")) AND ((("products"."price") > '3' :: numeric(7,2))))) LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/multiRoot
SELECT jsonb_build_object('customer', "__sj_0"."json", 'user', "__sj_1"."json", 'product', "__sj_2"."json") as "__root" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "products_2"."id" AS "id", "products_2"."name" AS "name", "__sj_3"."json" AS "customers", "__sj_4"."json" AS "customer" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE (((("products"."price") > '0' :: numeric(7,2)) AND (("products"."price") < '8' :: numeric(7,2)))) LIMIT ('1') :: integer) AS "products_2" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_4".*) AS "json"FROM (SELECT "customers_4"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('1') :: integer) AS "customers_4") AS "__sr_4") AS "__sj_4" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_3"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "customers_3"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_3") AS "__sr_3") AS "__sj_3") AS "__sj_3" ON ('true')) AS "__sr_2") AS "__sj_2", (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "users_1"."id" AS "id", "users_1"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_1") AS "__sr_1") AS "__sj_1", (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "customers_0"."id" AS "id" FROM (SELECT "customers"."id" FROM "customers" LIMIT ('1') :: integer) AS "customers_0") AS "__sr_0") AS "__sj_0"
=== RUN TestCompileQuery/withFragment1
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/withFragment2
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/withFragment3
SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/jsonColumnAsTable
SELECT jsonb_build_object('products', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "__sj_1"."json" AS "tag_count" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "tag_count_1"."count" AS "count", "__sj_2"."json" AS "tags" FROM (SELECT "tag_count"."count", "tag_count"."tag_id" FROM "products", json_to_recordset("products"."tag_count") AS "tag_count"(tag_id bigint, count int) WHERE ((("products"."id") = ("products_0"."id"))) LIMIT ('1') :: integer) AS "tag_count_1" LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "tags_2"."name" AS "name" FROM (SELECT "tags"."name" FROM "tags" WHERE ((("tags"."id") = ("tag_count_1"."tag_id"))) LIMIT ('20') :: integer) AS "tags_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1" ON ('true')) AS "__sr_0") AS "__sj_0") AS "__sj_0"
=== RUN TestCompileQuery/withCursor
@ -117,6 +123,9 @@ SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coa
--- PASS: TestCompileQuery/queryWithVariables (0.00s)
--- PASS: TestCompileQuery/withWhereOnRelations (0.00s)
--- PASS: TestCompileQuery/multiRoot (0.00s)
--- PASS: TestCompileQuery/withFragment1 (0.00s)
--- PASS: TestCompileQuery/withFragment2 (0.00s)
--- PASS: TestCompileQuery/withFragment3 (0.00s)
--- PASS: TestCompileQuery/jsonColumnAsTable (0.00s)
--- PASS: TestCompileQuery/withCursor (0.00s)
--- PASS: TestCompileQuery/nullForAuthRequiredInAnon (0.00s)
@ -151,4 +160,4 @@ WITH "_sg_input" AS (SELECT $1 :: json AS j), "_x_users" AS (SELECT * FROM (VALU
--- PASS: TestCompileUpdate/nestedUpdateOneToOneWithConnect (0.00s)
--- PASS: TestCompileUpdate/nestedUpdateOneToOneWithDisconnect (0.00s)
PASS
ok github.com/dosco/super-graph/core/internal/psql (cached)
ok github.com/dosco/super-graph/core/internal/psql 0.374s

View File

@ -22,7 +22,7 @@ func (c *compilerContext) renderUpdate(
}
io.WriteString(c.w, `WITH "_sg_input" AS (SELECT `)
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"})
c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
// io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, ` :: json AS j)`)

View File

@ -11,15 +11,18 @@ import (
var (
queryToken = []byte("query")
mutationToken = []byte("mutation")
fragmentToken = []byte("fragment")
subscriptionToken = []byte("subscription")
onToken = []byte("on")
trueToken = []byte("true")
falseToken = []byte("false")
quotesToken = []byte(`'"`)
signsToken = []byte(`+-`)
punctuatorToken = []byte(`!():=[]{|}`)
spreadToken = []byte(`...`)
digitToken = []byte(`0123456789`)
dotToken = []byte(`.`)
punctuatorToken = `!():=[]{|}`
)
// Pos represents a byte position in the original input text from which
@ -43,6 +46,8 @@ const (
itemName
itemQuery
itemMutation
itemFragment
itemOn
itemSub
itemPunctuator
itemArgsOpen
@ -263,11 +268,11 @@ func lexRoot(l *lexer) stateFn {
l.backup()
return lexString
case r == '.':
if len(l.input) >= 3 {
if equals(l.input, 0, 3, spreadToken) {
l.emit(itemSpread)
return lexRoot
}
l.acceptRun(dotToken)
s, e := l.current()
if equals(l.input, s, e, spreadToken) {
l.emit(itemSpread)
return lexRoot
}
fallthrough // '.' can start a number.
case r == '+' || r == '-' || ('0' <= r && r <= '9'):
@ -299,10 +304,14 @@ func lexName(l *lexer) stateFn {
switch {
case equals(l.input, s, e, queryToken):
l.emitL(itemQuery)
case equals(l.input, s, e, fragmentToken):
l.emitL(itemFragment)
case equals(l.input, s, e, mutationToken):
l.emitL(itemMutation)
case equals(l.input, s, e, subscriptionToken):
l.emitL(itemSub)
case equals(l.input, s, e, onToken):
l.emitL(itemOn)
case equals(l.input, s, e, trueToken):
l.emitL(itemBoolVal)
case equals(l.input, s, e, falseToken):
@ -396,31 +405,11 @@ func isAlphaNumeric(r rune) bool {
}
func equals(b []byte, s Pos, e Pos, val []byte) bool {
n := 0
for i := s; i < e; i++ {
if n >= len(val) {
return true
}
switch {
case b[i] >= 'A' && b[i] <= 'Z' && ('a'+(b[i]-'A')) != val[n]:
return false
case b[i] != val[n]:
return false
}
n++
}
return true
return bytes.EqualFold(b[s:e], val)
}
func contains(b []byte, s Pos, e Pos, val []byte) bool {
for i := s; i < e; i++ {
for n := 0; n < len(val); n++ {
if b[i] == val[n] {
return true
}
}
}
return false
func contains(b []byte, s Pos, e Pos, chars string) bool {
return bytes.ContainsAny(b[s:e], chars)
}
func lowercase(b []byte, s Pos, e Pos) {

View File

@ -1,12 +1,12 @@
package qcode
import (
"encoding/binary"
"errors"
"fmt"
"hash/maphash"
"sync"
"unsafe"
"github.com/dosco/super-graph/core/internal/util"
)
var (
@ -35,8 +35,7 @@ const (
NodeVar
)
type Operation struct {
Type parserType
type SelectionSet struct {
Name string
Args []Arg
argsA [10]Arg
@ -44,12 +43,29 @@ type Operation struct {
fieldsA [10]Field
}
type Operation struct {
Type parserType
SelectionSet
}
var zeroOperation = Operation{}
func (o *Operation) Reset() {
*o = zeroOperation
}
type Fragment struct {
Name string
On string
SelectionSet
}
var zeroFragment = Fragment{}
func (f *Fragment) Reset() {
*f = zeroFragment
}
type Field struct {
ID int32
ParentID int32
@ -82,6 +98,8 @@ func (n *Node) Reset() {
}
type Parser struct {
frags map[uint64]*Fragment
h maphash.Hash
input []byte // the string being scanned
pos int
items []item
@ -96,12 +114,194 @@ var opPool = sync.Pool{
New: func() interface{} { return new(Operation) },
}
var fragPool = sync.Pool{
New: func() interface{} { return new(Fragment) },
}
var lexPool = sync.Pool{
New: func() interface{} { return new(lexer) },
}
func Parse(gql []byte) (*Operation, error) {
return parseSelectionSet(gql)
var err error
if len(gql) == 0 {
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer)
l.Reset()
defer lexPool.Put(l)
if err = lex(l, gql); err != nil {
return nil, err
}
p := &Parser{
input: l.input,
pos: -1,
items: l.items,
}
op := opPool.Get().(*Operation)
op.Reset()
op.Fields = op.fieldsA[:0]
s := -1
qf := false
for {
if p.peek(itemEOF) {
p.ignore()
break
}
if p.peek(itemFragment) {
p.ignore()
if err = p.parseFragment(op); err != nil {
return nil, err
}
} else {
if !qf && p.peek(itemQuery, itemMutation, itemSub, itemObjOpen) {
s = p.pos
qf = true
}
p.ignore()
}
}
p.reset(s)
if err := p.parseOp(op); err != nil {
return nil, err
}
return op, nil
}
func (p *Parser) parseFragment(op *Operation) error {
frag := fragPool.Get().(*Fragment)
frag.Reset()
frag.Fields = frag.fieldsA[:0]
frag.Args = frag.argsA[:0]
if p.peek(itemName) {
frag.Name = p.val(p.next())
}
if p.peek(itemOn) {
p.ignore()
} else {
return errors.New("fragment: missing 'on' keyword")
}
if p.peek(itemName) {
frag.On = p.vall(p.next())
} else {
return errors.New("fragment: missing table name after 'on' keyword")
}
if p.peek(itemObjOpen) {
p.ignore()
} else {
return fmt.Errorf("fragment: expecting a '{', got: %s", p.next())
}
if err := p.parseSelectionSet(&frag.SelectionSet); err != nil {
return fmt.Errorf("fragment: %v", err)
}
if p.frags == nil {
p.frags = make(map[uint64]*Fragment)
}
_, _ = p.h.WriteString(frag.Name)
k := p.h.Sum64()
p.h.Reset()
p.frags[k] = frag
return nil
}
func (p *Parser) parseOp(op *Operation) error {
var err error
var typeSet bool
if p.peek(itemQuery, itemMutation, itemSub) {
err = p.parseOpTypeAndArgs(op)
if err != nil {
return fmt.Errorf("%s: %v", op.Type, err)
}
typeSet = true
}
if p.peek(itemObjOpen) {
p.ignore()
if !typeSet {
op.Type = opQuery
}
for {
if p.peek(itemEOF, itemFragment) {
p.ignore()
break
}
err = p.parseSelectionSet(&op.SelectionSet)
if err != nil {
return fmt.Errorf("%s: %v", op.Type, err)
}
}
} else {
return fmt.Errorf("expecting a query, mutation or subscription, got: %s", p.next())
}
return nil
}
func (p *Parser) parseOpTypeAndArgs(op *Operation) error {
item := p.next()
switch item._type {
case itemQuery:
op.Type = opQuery
case itemMutation:
op.Type = opMutate
case itemSub:
op.Type = opSub
}
op.Args = op.argsA[:0]
var err error
if p.peek(itemName) {
op.Name = p.val(p.next())
}
if p.peek(itemArgsOpen) {
p.ignore()
op.Args, err = p.parseOpParams(op.Args)
if err != nil {
return err
}
}
return nil
}
func (p *Parser) parseSelectionSet(selset *SelectionSet) error {
var err error
selset.Fields, err = p.parseFields(selset.Fields)
if err != nil {
return err
}
return nil
}
func ParseArgValue(argVal string) (*Node, error) {
@ -123,216 +323,137 @@ func ParseArgValue(argVal string) (*Node, error) {
return op, err
}
func parseSelectionSet(gql []byte) (*Operation, error) {
var err error
if len(gql) == 0 {
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer)
l.Reset()
if err = lex(l, gql); err != nil {
return nil, err
}
p := &Parser{
input: l.input,
pos: -1,
items: l.items,
}
var op *Operation
if p.peek(itemObjOpen) {
p.ignore()
op, err = p.parseQueryOp()
} else {
op, err = p.parseOp()
}
if err != nil {
return nil, err
}
if p.peek(itemObjClose) {
p.ignore()
} else {
return nil, fmt.Errorf("operation missing closing '}'")
}
if !p.peek(itemEOF) {
p.ignore()
return nil, fmt.Errorf("invalid '%s' found after closing '}'", p.current())
}
lexPool.Put(l)
return op, err
}
func (p *Parser) next() item {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return item{_type: itemEOF}
}
p.pos = n
return p.items[p.pos]
}
func (p *Parser) ignore() {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return
}
p.pos = n
}
func (p *Parser) current() string {
item := p.items[p.pos]
return b2s(p.input[item.pos:item.end])
}
func (p *Parser) peek(types ...itemType) bool {
n := p.pos + 1
// if p.items[n]._type == itemEOF {
// return false
// }
if n >= len(p.items) {
return false
}
for i := 0; i < len(types); i++ {
if p.items[n]._type == types[i] {
return true
}
}
return false
}
func (p *Parser) parseOp() (*Operation, error) {
if !p.peek(itemQuery, itemMutation, itemSub) {
err := errors.New("expecting a query, mutation or subscription")
return nil, err
}
item := p.next()
op := opPool.Get().(*Operation)
op.Reset()
switch item._type {
case itemQuery:
op.Type = opQuery
case itemMutation:
op.Type = opMutate
case itemSub:
op.Type = opSub
}
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
var err error
if p.peek(itemName) {
op.Name = p.val(p.next())
}
if p.peek(itemArgsOpen) {
p.ignore()
op.Args, err = p.parseOpParams(op.Args)
if err != nil {
return nil, err
}
}
if p.peek(itemObjOpen) {
p.ignore()
for n := 0; n < 10; n++ {
if !p.peek(itemName) {
break
}
op.Fields, err = p.parseFields(op.Fields)
if err != nil {
return nil, err
}
}
}
return op, nil
}
func (p *Parser) parseQueryOp() (*Operation, error) {
op := opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
var err error
for n := 0; n < 10; n++ {
if !p.peek(itemName) {
break
}
op.Fields, err = p.parseFields(op.Fields)
if err != nil {
return nil, err
}
}
return op, nil
}
func (p *Parser) parseFields(fields []Field) ([]Field, error) {
st := util.NewStack()
st := NewStack()
if !p.peek(itemName, itemSpread) {
return nil, fmt.Errorf("unexpected token: %s", p.peekNext())
}
// fm := make(map[uint64]struct{})
for {
if p.peek(itemEOF) {
p.ignore()
return nil, errors.New("invalid query")
}
if p.peek(itemObjClose) {
p.ignore()
if st.Len() != 0 {
st.Pop()
continue
} else {
break
}
}
if len(fields) >= maxFields {
return nil, fmt.Errorf("too many fields (max %d)", maxFields)
}
if p.peek(itemEOF, itemObjClose) {
p.ignore()
st.Pop()
isFrag := false
if st.Len() == 0 {
break
} else {
continue
}
if p.peek(itemSpread) {
p.ignore()
isFrag = true
}
if !p.peek(itemName) {
return nil, errors.New("expecting an alias or field name")
if isFrag {
return nil, fmt.Errorf("expecting a fragment name, got: %s", p.next())
} else {
return nil, fmt.Errorf("expecting an alias or field name, got: %s", p.next())
}
}
fields = append(fields, Field{ID: int32(len(fields))})
var f *Field
f := &fields[(len(fields) - 1)]
f.Args = f.argsA[:0]
f.Children = f.childrenA[:0]
if isFrag {
name := p.val(p.next())
p.h.WriteString(name)
k := p.h.Sum64()
p.h.Reset()
// Parse the inside of the the fields () parentheses
// in short parse the args like id, where, etc
if err := p.parseField(f); err != nil {
return nil, err
}
fr, ok := p.frags[k]
if !ok {
return nil, fmt.Errorf("no fragment named '%s' defined", name)
}
n := int32(len(fields))
fields = append(fields, fr.Fields...)
for i := int(n); i < len(fields); i++ {
f := &fields[i]
f.ID = int32(i)
// var name string
// if f.Alias != "" {
// name = f.Alias
// } else {
// name = f.Name
// }
// if _, ok := fm[name]; ok {
// continue
// } else {
// fm[name] = struct{}{}
// }
// If this is the top-level point the parent to the parent of the
// previous field.
if f.ParentID == -1 {
pid := st.Peek()
f.ParentID = pid
if f.ParentID != -1 {
fields[pid].Children = append(fields[f.ParentID].Children, f.ID)
}
// Update all the other parents id's by our new place in this new array
} else {
f.ParentID += n
}
// Update all the children which is needed.
for j := range f.Children {
f.Children[j] += n
}
}
intf := st.Peek()
if pid, ok := intf.(int32); ok {
f.ParentID = pid
fields[pid].Children = append(fields[pid].Children, f.ID)
} else {
f.ParentID = -1
fields = append(fields, Field{ID: int32(len(fields))})
f = &fields[(len(fields) - 1)]
f.Args = f.argsA[:0]
f.Children = f.childrenA[:0]
// Parse the field
if err := p.parseField(f); err != nil {
return nil, err
}
// var name string
// if f.Alias != "" {
// name = f.Alias
// } else {
// name = f.Name
// }
// if _, ok := fm[name]; ok {
// continue
// } else {
// fm[name] = struct{}{}
// }
if st.Len() == 0 {
f.ParentID = -1
} else {
pid := st.Peek()
f.ParentID = pid
fields[pid].Children = append(fields[pid].Children, f.ID)
}
}
// The first opening curley brackets after this
@ -340,13 +461,6 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
if p.peek(itemObjOpen) {
p.ignore()
st.Push(f.ID)
} else if p.peek(itemObjClose) {
if st.Len() == 0 {
break
} else {
continue
}
}
}
@ -546,6 +660,72 @@ func (p *Parser) vall(v item) string {
return b2s(p.input[v.pos:v.end])
}
func (p *Parser) peek(types ...itemType) bool {
n := p.pos + 1
l := len(types)
// if p.items[n]._type == itemEOF {
// return false
// }
if n >= len(p.items) {
return types[0] == itemEOF
}
if l == 1 {
return p.items[n]._type == types[0]
}
for i := 0; i < l; i++ {
if p.items[n]._type == types[i] {
return true
}
}
return false
}
func (p *Parser) next() item {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return item{_type: itemEOF}
}
p.pos = n
return p.items[p.pos]
}
func (p *Parser) ignore() {
n := p.pos + 1
if n >= len(p.items) {
p.err = errEOT
return
}
p.pos = n
}
func (p *Parser) peekCurrent() string {
item := p.items[p.pos]
return b2s(p.input[item.pos:item.end])
}
func (p *Parser) peekNext() string {
item := p.items[p.pos+1]
return b2s(p.input[item.pos:item.end])
}
func (p *Parser) reset(to int) {
p.pos = to
}
func (p *Parser) fHash(name string, parentID int32) uint64 {
var b []byte
binary.LittleEndian.PutUint32(b, uint32(parentID))
p.h.WriteString(name)
p.h.Write(b)
v := p.h.Sum64()
p.h.Reset()
return v
}
func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
@ -579,7 +759,7 @@ func (t parserType) String() string {
case NodeList:
v = "node-list"
}
return fmt.Sprintf("<%s>", v)
return v
}
// type Frees struct {

View File

@ -2,8 +2,9 @@ package qcode
import (
"errors"
"github.com/chirino/graphql/schema"
"testing"
"github.com/chirino/graphql/schema"
)
func TestCompile1(t *testing.T) {
@ -120,7 +121,7 @@ updateThread {
}
}
}
}`
}}`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "anon")
@ -130,6 +131,93 @@ updateThread {
}
func TestFragmentsCompile1(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields2 on user {
first_name
last_name
}
`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "user")
if err != nil {
t.Fatal(err)
}
}
func TestFragmentsCompile2(t *testing.T) {
gql := `
query {
users {
...userFields2
created_at
...userFields1
}
}
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "user")
if err != nil {
t.Fatal(err)
}
}
func TestFragmentsCompile3(t *testing.T) {
gql := `
fragment userFields1 on user {
id
email
}
fragment userFields2 on user {
first_name
last_name
}
query {
users {
...userFields2
created_at
...userFields1
}
}
`
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(gql), "user")
if err != nil {
t.Fatal(err)
}
}
var gql = []byte(`
{products(
# returns only 30 items
@ -151,6 +239,29 @@ var gql = []byte(`
price
}}`)
var gqlWithFragments = []byte(`
fragment userFields1 on user {
id
email
__typename
}
query {
users {
...userFields2
created_at
...userFields1
__typename
}
}
fragment userFields2 on user {
first_name
last_name
__typename
}`)
func BenchmarkQCompile(b *testing.B) {
qcompile, _ := NewCompiler(Config{})
@ -183,8 +294,22 @@ func BenchmarkQCompileP(b *testing.B) {
})
}
func BenchmarkParse(b *testing.B) {
func BenchmarkQCompileFragment(b *testing.B) {
qcompile, _ := NewCompiler(Config{})
b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, err := qcompile.Compile(gqlWithFragments, "user")
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkParse(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
@ -211,6 +336,18 @@ func BenchmarkParseP(b *testing.B) {
})
}
func BenchmarkParseFragment(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, err := Parse(gqlWithFragments)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSchemaParse(b *testing.B) {
b.ResetTimer()

View File

@ -419,6 +419,7 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
com.AddFilters(qc, s, role)
s.Cols = make([]Column, 0, len(field.Children))
cm := make(map[string]struct{})
action = QTQuery
for _, cid := range field.Children {
@ -428,19 +429,28 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
continue
}
var fname string
if f.Alias != "" {
fname = f.Alias
} else {
fname = f.Name
}
if _, ok := cm[fname]; ok {
continue
} else {
cm[fname] = struct{}{}
}
if len(f.Children) != 0 {
val := f.ID | (s.ID << 16)
st.Push(val)
continue
}
col := Column{Name: f.Name}
col := Column{Name: f.Name, FieldName: fname}
if len(f.Alias) != 0 {
col.FieldName = f.Alias
} else {
col.FieldName = f.Name
}
s.Cols = append(s.Cols, col)
}

View File

@ -28,17 +28,18 @@ func (sg *SuperGraph) prepare(q *query, role string) {
var err error
qb := []byte(q.ai.Query)
vars := []byte(q.ai.Vars)
switch q.qt {
case qcode.QTQuery:
if sg.abacEnabled {
stmts, err = sg.buildMultiStmt(qb, q.ai.Vars)
stmts, err = sg.buildMultiStmt(qb, vars)
} else {
stmts, err = sg.buildRoleStmt(qb, q.ai.Vars, role)
stmts, err = sg.buildRoleStmt(qb, vars, role)
}
case qcode.QTMutation:
stmts, err = sg.buildRoleStmt(qb, q.ai.Vars, role)
stmts, err = sg.buildRoleStmt(qb, vars, role)
}
if err != nil {
@ -125,7 +126,7 @@ func (sg *SuperGraph) prepareRoleStmt() error {
}
io.WriteString(w, ` ELSE $2 END) FROM (`)
io.WriteString(w, sg.conf.RolesQuery)
io.WriteString(w, rq)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)

View File

@ -22,7 +22,7 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
// fetch the field name used within the db response json
// that are used to mark insertion points and the mapping between
// those field names and their select objects
fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped)
fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped())
// fetch the field values of the marked insertion points
// these values contain the id to be used with fetching remote data
@ -67,7 +67,7 @@ func (sg *SuperGraph) resolveRemote(
to := toA[:1]
// use the json key to find the related Select object
h.Write(field.Key)
_, _ = h.Write(field.Key)
k1 := h.Sum64()
s, ok := sfmap[k1]
@ -136,7 +136,7 @@ func (sg *SuperGraph) resolveRemotes(
for i, id := range from {
// use the json key to find the related Select object
h.Write(id.Key)
_, _ = h.Write(id.Key)
k1 := h.Sum64()
s, ok := sfmap[k1]
@ -230,7 +230,7 @@ func (sg *SuperGraph) parentFieldIds(h *maphash.Hash, sel []qcode.Select, skippe
fm[n] = r.IDField
n++
h.Write(r.IDField)
_, _ = h.Write(r.IDField)
sm[h.Sum64()] = s
}
}

View File

@ -86,7 +86,7 @@ func (sg *SuperGraph) initRemotes(t Table) error {
sg.rmap[mkkey(&h, r.Name, t.Name)] = rf
// index resolver obj by IDField
h.Write(rf.IDField)
_, _ = h.Write(rf.IDField)
sg.rmap[h.Sum64()] = rf
}

View File

@ -66,7 +66,7 @@ func newViper(configPath, configFile string) *viper.Viper {
vi.SetDefault("host_port", "0.0.0.0:8080")
vi.SetDefault("web_ui", false)
vi.SetDefault("enable_tracing", false)
vi.SetDefault("auth_fail_block", "always")
vi.SetDefault("auth_fail_block", false)
vi.SetDefault("seed_file", "seed.js")
vi.SetDefault("default_block", true)

View File

@ -32,6 +32,7 @@ type Auth struct {
Secret string
PubKeyFile string `mapstructure:"public_key_file"`
PubKeyType string `mapstructure:"public_key_type"`
Audience string `mapstructure:"audience"`
}
Header struct {

View File

@ -2,19 +2,32 @@ package auth
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
"strconv"
"strings"
"time"
jwt "github.com/dgrijalva/jwt-go"
"github.com/dosco/super-graph/core"
)
const (
authHeader = "Authorization"
jwtAuth0 int = iota + 1
authHeader = "Authorization"
jwtAuth0 int = iota + 1
jwtFirebase int = iota + 2
firebasePKEndpoint = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"
firebaseIssuerPrefix = "https://securetoken.google.com/"
)
type firebasePKCache struct {
PublicKeys map[string]string
Expiration time.Time
}
var firebasePublicKeys firebasePKCache
func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
var key interface{}
var jwtProvider int
@ -23,6 +36,8 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
if ac.JWT.Provider == "auth0" {
jwtProvider = jwtAuth0
} else if ac.JWT.Provider == "firebase" {
jwtProvider = jwtFirebase
}
secret := ac.JWT.Secret
@ -56,6 +71,7 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
}
return func(w http.ResponseWriter, r *http.Request) {
var tok string
if len(cookie) != 0 {
@ -74,9 +90,16 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
tok = ah[7:]
}
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
return key, nil
})
var keyFunc jwt.Keyfunc
if jwtProvider == jwtFirebase {
keyFunc = firebaseKeyFunction
} else {
keyFunc = func(token *jwt.Token) (interface{}, error) {
return key, nil
}
}
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, keyFunc)
if err != nil {
next.ServeHTTP(w, r)
@ -86,12 +109,20 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
ctx := r.Context()
if ac.JWT.Audience != "" && claims.Audience != ac.JWT.Audience {
next.ServeHTTP(w, r)
return
}
if jwtProvider == jwtAuth0 {
sub := strings.Split(claims.Subject, "|")
if len(sub) != 2 {
ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0])
ctx = context.WithValue(ctx, core.UserIDKey, sub[1])
}
} else if jwtProvider == jwtFirebase &&
claims.Issuer == firebaseIssuerPrefix+ac.JWT.Audience {
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
} else {
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
}
@ -103,3 +134,92 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
next.ServeHTTP(w, r)
}, nil
}
type firebaseKeyError struct {
Err error
Message string
}
func (e *firebaseKeyError) Error() string {
return e.Message + " " + e.Err.Error()
}
func firebaseKeyFunction(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"]
if !ok {
return nil, &firebaseKeyError{
Message: "Error 'kid' header not found in token",
}
}
if firebasePublicKeys.Expiration.Before(time.Now()) {
resp, err := http.Get(firebasePKEndpoint)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error connecting to firebase certificate server",
Err: err,
}
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error reading firebase certificate server response",
Err: err,
}
}
cachePolicy := resp.Header.Get("cache-control")
ageIndex := strings.Index(cachePolicy, "max-age=")
if ageIndex < 0 {
return nil, &firebaseKeyError{
Message: "Error parsing cache-control header: 'max-age=' not found",
}
}
ageToEnd := cachePolicy[ageIndex+8:]
endIndex := strings.Index(ageToEnd, ",")
if endIndex < 0 {
endIndex = len(ageToEnd) - 1
}
ageString := ageToEnd[:endIndex]
age, err := strconv.ParseInt(ageString, 10, 64)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error parsing max-age cache policy",
Err: err,
}
}
expiration := time.Now().Add(time.Duration(time.Duration(age) * time.Second))
err = json.Unmarshal(data, &firebasePublicKeys.PublicKeys)
if err != nil {
firebasePublicKeys = firebasePKCache{}
return nil, &firebaseKeyError{
Message: "Error unmarshalling firebase public key json",
Err: err,
}
}
firebasePublicKeys.Expiration = expiration
}
if key, found := firebasePublicKeys.PublicKeys[kid.(string)]; found {
k, err := jwt.ParseRSAPublicKeyFromPEM([]byte(key))
return k, err
}
return nil, &firebaseKeyError{
Message: "Error no matching public key for kid supplied in jwt",
}
}

View File

@ -12,7 +12,7 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error {
h := maphash.Hash{}
for i := range keys {
h.WriteString(keys[i])
_, _ = h.WriteString(keys[i])
kmap[h.Sum64()] = struct{}{}
h.Reset()
}
@ -134,7 +134,7 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error {
cb := b[s:(e + 1)]
e = 0
h.Write(k)
_, _ = h.Write(k)
_, ok := kmap[h.Sum64()]
h.Reset()

View File

@ -44,7 +44,7 @@ func Get(b []byte, keys [][]byte) []Field {
h := maphash.Hash{}
for i := range keys {
h.Write(keys[i])
_, _ = h.Write(keys[i])
kmap[h.Sum64()] = struct{}{}
h.Reset()
}
@ -144,7 +144,7 @@ func Get(b []byte, keys [][]byte) []Field {
}
if e != 0 {
h.Write(k)
_, _ = h.Write(k)
_, ok := kmap[h.Sum64()]
h.Reset()