diff --git a/core/config.go b/core/config.go index 7e60c58..9284c0b 100644 --- a/core/config.go +++ b/core/config.go @@ -197,30 +197,26 @@ func (c *Config) AddRoleTable(role string, table string, conf interface{}) error // ReadInConfig function reads in the config file for the environment specified in the GO_ENV // environment variable. This is the best way to create a new Super Graph config. func ReadInConfig(configFile string) (*Config, error) { - cpath := path.Dir(configFile) - cfile := path.Base(configFile) - vi := newViper(cpath, cfile) + cp := path.Dir(configFile) + vi := newViper(cp, path.Base(configFile)) if err := vi.ReadInConfig(); err != nil { return nil, err } - inherits := vi.GetString("inherits") - - if inherits != "" { - vi = newViper(cpath, inherits) + if pcf := vi.GetString("inherits"); pcf != "" { + cf := vi.ConfigFileUsed() + vi = newViper(cp, pcf) if err := vi.ReadInConfig(); err != nil { return nil, err } - if vi.IsSet("inherits") { - return nil, fmt.Errorf("inherited config (%s) cannot itself inherit (%s)", - inherits, - vi.GetString("inherits")) + if v := vi.GetString("inherits"); v != "" { + return nil, fmt.Errorf("inherited config (%s) cannot itself inherit (%s)", pcf, v) } - vi.SetConfigName(cfile) + vi.SetConfigFile(cf) if err := vi.MergeInConfig(); err != nil { return nil, err @@ -234,7 +230,7 @@ func ReadInConfig(configFile string) (*Config, error) { } if c.AllowListFile == "" { - c.AllowListFile = path.Join(cpath, "allow.list") + c.AllowListFile = path.Join(cp, "allow.list") } return c, nil @@ -248,7 +244,7 @@ func newViper(configPath, configFile string) *viper.Viper { vi.AutomaticEnv() if filepath.Ext(configFile) != "" { - vi.SetConfigFile(configFile) + vi.SetConfigFile(path.Join(configPath, configFile)) } else { vi.SetConfigName(configFile) vi.AddConfigPath(configPath) diff --git a/core/internal/allow/allow.go b/core/internal/allow/allow.go index 2121b54..0e0c821 100644 --- a/core/internal/allow/allow.go +++ b/core/internal/allow/allow.go @@ -10,21 +10,23 @@ import ( "os" "sort" "strings" + "text/scanner" "github.com/chirino/graphql/schema" "github.com/dosco/super-graph/jsn" ) const ( - AL_QUERY int = iota + 1 - AL_VARS + expComment = iota + 1 + expVar + expQuery ) type Item struct { Name string key string Query string - Vars json.RawMessage + Vars string Comment string } @@ -126,121 +128,101 @@ func (al *List) Set(vars []byte, query, comment string) error { return errors.New("empty query") } - var q string - - for i := 0; i < len(query); i++ { - c := query[i] - if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' { - q = query - break - - } else if c == '{' { - q = "query " + query - break - } - } - al.saveChan <- Item{ Comment: comment, - Query: q, - Vars: vars, + Query: query, + Vars: string(vars), } return nil } func (al *List) Load() ([]Item, error) { - var list []Item - varString := "variables" - b, err := ioutil.ReadFile(al.filepath) if err != nil { - return list, err + return nil, err } - if len(b) == 0 { - return list, nil + return parse(string(b), al.filepath) +} + +func parse(b string, filename string) ([]Item, error) { + var items []Item + + var s scanner.Scanner + s.Init(strings.NewReader(b)) + s.Filename = filename + s.Mode ^= scanner.SkipComments + + var op, sp scanner.Position + var item Item + + newComment := false + st := expComment + + for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() { + txt := s.TokenText() + + switch { + case strings.HasPrefix(txt, "/*"): + if st == expQuery { + v := b[sp.Offset:s.Pos().Offset] + item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1]) + items = append(items, item) + } + item = Item{Comment: strings.TrimSpace(txt[2 : len(txt)-2])} + sp = s.Pos() + st = expComment + newComment = true + + case !newComment && strings.HasPrefix(txt, "#"): + if st == expQuery { + v := b[sp.Offset:s.Pos().Offset] + item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1]) + items = append(items, item) + } + item = Item{} + sp = s.Pos() + st = expComment + + case strings.HasPrefix(txt, "variables"): + if st == expComment { + v := b[sp.Offset:s.Pos().Offset] + item.Comment = strings.TrimSpace(v[:strings.IndexByte(v, '\n')]) + } + sp = s.Pos() + st = expVar + + case isGraphQL(txt): + if st == expVar { + v := b[sp.Offset:s.Pos().Offset] + item.Vars = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1]) + } + sp = op + st = expQuery + + } + op = s.Pos() } - var comment bytes.Buffer - var varBytes []byte - - itemMap := make(map[string]struct{}) - - s, e, c := 0, 0, 0 - ty := 0 - - for { - fq := false - - if c == 0 && b[e] == '#' { - s = e - for e < len(b) && b[e] != '\n' { - e++ - } - if (e - s) > 2 { - comment.Write(b[(s + 1):(e + 1)]) - } - } - - if e >= len(b) { - break - } - - if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") { - if c == 0 { - s = e - } - ty = AL_QUERY - } else if matchPrefix(b, e, varString) { - if c == 0 { - s = e + len(varString) + 1 - } - ty = AL_VARS - } else if b[e] == '{' { - c++ - - } else if b[e] == '}' { - c-- - - if c == 0 { - if ty == AL_QUERY { - fq = true - } else if ty == AL_VARS { - varBytes = b[s:(e + 1)] - } - ty = 0 - } - } - - if fq { - query := string(b[s:(e + 1)]) - name := QueryName(query) - key := strings.ToLower(name) - - if _, ok := itemMap[key]; !ok { - v := Item{ - Name: name, - key: key, - Query: query, - Vars: varBytes, - Comment: comment.String(), - } - list = append(list, v) - comment.Reset() - } - - varBytes = nil - - } - - e++ - if e >= len(b) { - break - } + if st == expQuery { + v := b[sp.Offset:s.Pos().Offset] + item.Query = strings.TrimSpace(v[:strings.LastIndexByte(v, '}')+1]) + items = append(items, item) } - return list, nil + for i := range items { + items[i].Name = QueryName(items[i].Query) + items[i].key = strings.ToLower(items[i].Name) + } + + return items, nil +} + +func isGraphQL(s string) bool { + return strings.HasPrefix(s, "query") || + strings.HasPrefix(s, "mutation") || + strings.HasPrefix(s, "subscription") } func (al *List) save(item Item) error { @@ -297,57 +279,39 @@ func (al *List) save(item Item) error { return strings.Compare(list[i].key, list[j].key) == -1 }) - for _, v := range list { - cmtLines := strings.Split(v.Comment, "\n") - - i := 0 - for _, c := range cmtLines { - if c = strings.TrimSpace(c); c == "" { - continue - } - - _, err := f.WriteString(fmt.Sprintf("# %s\n", c)) - if err != nil { - return err - } - i++ - } - - if i != 0 { - if _, err := f.WriteString("\n"); err != nil { - return err - } - } else { - if _, err := f.WriteString(fmt.Sprintf("# Query named %s\n\n", v.Name)); err != nil { - return err - } - } - - if len(v.Vars) != 0 && !bytes.Equal(v.Vars, []byte("{}")) { + for i, v := range list { + var vars string + if v.Vars != "" { buf.Reset() - - if err := jsn.Clear(&buf, v.Vars); err != nil { - return fmt.Errorf("failed to clean vars: %w", err) + if err := jsn.Clear(&buf, []byte(v.Vars)); err != nil { + continue } vj := json.RawMessage(buf.Bytes()) - vj, err = json.MarshalIndent(vj, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal vars: %w", err) + if vj, err = json.MarshalIndent(vj, "", " "); err != nil { + continue } + vars = string(vj) + } + list[i].Vars = vars + list[i].Comment = strings.TrimSpace(v.Comment) + } - _, err = f.WriteString(fmt.Sprintf("variables %s\n\n", vj)) + for _, v := range list { + if v.Comment != "" { + f.WriteString(fmt.Sprintf("/* %s */\n\n", v.Comment)) + } else { + f.WriteString(fmt.Sprintf("/* %s */\n\n", v.Name)) + } + + if v.Vars != "" { + _, err = f.WriteString(fmt.Sprintf("variables %s\n\n", v.Vars)) if err != nil { return err } } - if v.Query[0] == '{' { - _, err = f.WriteString(fmt.Sprintf("query %s\n\n", v.Query)) - } else { - _, err = f.WriteString(fmt.Sprintf("%s\n\n", v.Query)) - } - + _, err = f.WriteString(fmt.Sprintf("%s\n\n", v.Query)) if err != nil { return err } diff --git a/core/internal/allow/allow_test.go b/core/internal/allow/allow_test.go index 07f81b6..9c96780 100644 --- a/core/internal/allow/allow_test.go +++ b/core/internal/allow/allow_test.go @@ -82,3 +82,160 @@ func TestGQLName5(t *testing.T) { t.Fatal("Name should be empty, not ", name) } } + +func TestParse1(t *testing.T) { + var al = ` + # Hello world + + variables { + "data": { + "slug": "", + "body": "", + "post": { + "connect": { + "slug": "" + } + } + } + } + + mutation createComment { + comment(insert: $data) { + slug + body + createdAt: created_at + totalVotes: cached_votes_total + totalReplies: cached_replies_total + vote: comment_vote(where: {user_id: {eq: $user_id}}) { + created_at + __typename + } + author: user { + slug + firstName: first_name + lastName: last_name + pictureURL: picture_url + bio + __typename + } + __typename + } + } + + # Query named createPost + + query createPost { + post(insert: $data) { + slug + body + published + createdAt: created_at + totalVotes: cached_votes_total + totalComments: cached_comments_total + vote: post_vote(where: {user_id: {eq: $user_id}}) { + created_at + __typename + } + author: user { + slug + firstName: first_name + lastName: last_name + pictureURL: picture_url + bio + __typename + } + __typename + } + }` + + _, err := parse(al, "allow.list") + if err != nil { + t.Fatal(err) + } +} + +func TestParse2(t *testing.T) { + var al = ` + /* Hello world */ + + variables { + "data": { + "slug": "", + "body": "", + "post": { + "connect": { + "slug": "" + } + } + } + } + + mutation createComment { + comment(insert: $data) { + slug + body + createdAt: created_at + totalVotes: cached_votes_total + totalReplies: cached_replies_total + vote: comment_vote(where: {user_id: {eq: $user_id}}) { + created_at + __typename + } + author: user { + slug + firstName: first_name + lastName: last_name + pictureURL: picture_url + bio + __typename + } + __typename + } + } + + /* + Query named createPost + */ + + variables { + "data": { + "thread": { + "connect": { + "slug": "" + } + }, + "slug": "", + "published": false, + "body": "" + } + } + + query createPost { + post(insert: $data) { + slug + body + published + createdAt: created_at + totalVotes: cached_votes_total + totalComments: cached_comments_total + vote: post_vote(where: {user_id: {eq: $user_id}}) { + created_at + __typename + } + author: user { + slug + firstName: first_name + lastName: last_name + pictureURL: picture_url + bio + __typename + } + __typename + } + }` + + _, err := parse(al, "allow.list") + if err != nil { + t.Fatal(err) + } +} diff --git a/core/internal/qcode/parse.go b/core/internal/qcode/parse.go index 2dd1cd2..5d1b87e 100644 --- a/core/internal/qcode/parse.go +++ b/core/internal/qcode/parse.go @@ -1,6 +1,7 @@ package qcode import ( + "encoding/binary" "errors" "fmt" "hash/maphash" @@ -329,6 +330,8 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { return nil, fmt.Errorf("unexpected token: %s", p.peekNext()) } + // fm := make(map[uint64]struct{}) + for { if p.peek(itemEOF) { p.ignore() @@ -385,6 +388,20 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { f := &fields[i] f.ID = int32(i) + // var name string + + // if f.Alias != "" { + // name = f.Alias + // } else { + // name = f.Name + // } + + // if _, ok := fm[name]; ok { + // continue + // } else { + // fm[name] = struct{}{} + // } + // If this is the top-level point the parent to the parent of the // previous field. if f.ParentID == -1 { @@ -416,6 +433,20 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { return nil, err } + // var name string + + // if f.Alias != "" { + // name = f.Alias + // } else { + // name = f.Name + // } + + // if _, ok := fm[name]; ok { + // continue + // } else { + // fm[name] = struct{}{} + // } + if st.Len() == 0 { f.ParentID = -1 } else { @@ -685,6 +716,16 @@ func (p *Parser) reset(to int) { p.pos = to } +func (p *Parser) fHash(name string, parentID int32) uint64 { + var b []byte + binary.LittleEndian.PutUint32(b, uint32(parentID)) + p.h.WriteString(name) + p.h.Write(b) + v := p.h.Sum64() + p.h.Reset() + return v +} + func b2s(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } diff --git a/core/internal/qcode/parse_test.go b/core/internal/qcode/parse_test.go index 29d70e4..396b297 100644 --- a/core/internal/qcode/parse_test.go +++ b/core/internal/qcode/parse_test.go @@ -239,6 +239,29 @@ var gql = []byte(` price }}`) +var gqlWithFragments = []byte(` +fragment userFields1 on user { + id + email + __typename +} + +query { + users { + ...userFields2 + + created_at + ...userFields1 + __typename + } +} + +fragment userFields2 on user { + first_name + last_name + __typename +}`) + func BenchmarkQCompile(b *testing.B) { qcompile, _ := NewCompiler(Config{}) @@ -271,6 +294,21 @@ func BenchmarkQCompileP(b *testing.B) { }) } +func BenchmarkQCompileFragment(b *testing.B) { + qcompile, _ := NewCompiler(Config{}) + + b.ResetTimer() + b.ReportAllocs() + for n := 0; n < b.N; n++ { + _, err := qcompile.Compile(gqlWithFragments, "user") + + if err != nil { + b.Fatal(err) + } + } + +} + func BenchmarkParse(b *testing.B) { b.ResetTimer() b.ReportAllocs() @@ -298,6 +336,18 @@ func BenchmarkParseP(b *testing.B) { }) } +func BenchmarkParseFragment(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for n := 0; n < b.N; n++ { + _, err := Parse(gqlWithFragments) + + if err != nil { + b.Fatal(err) + } + } +} + func BenchmarkSchemaParse(b *testing.B) { b.ResetTimer() diff --git a/core/internal/qcode/qcode.go b/core/internal/qcode/qcode.go index 6f9b9ea..c65d0e6 100644 --- a/core/internal/qcode/qcode.go +++ b/core/internal/qcode/qcode.go @@ -419,6 +419,7 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { com.AddFilters(qc, s, role) s.Cols = make([]Column, 0, len(field.Children)) + cm := make(map[string]struct{}) action = QTQuery for _, cid := range field.Children { @@ -428,19 +429,28 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { continue } + var fname string + + if f.Alias != "" { + fname = f.Alias + } else { + fname = f.Name + } + + if _, ok := cm[fname]; ok { + continue + } else { + cm[fname] = struct{}{} + } + if len(f.Children) != 0 { val := f.ID | (s.ID << 16) st.Push(val) continue } - col := Column{Name: f.Name} + col := Column{Name: f.Name, FieldName: fname} - if len(f.Alias) != 0 { - col.FieldName = f.Alias - } else { - col.FieldName = f.Name - } s.Cols = append(s.Cols, col) } diff --git a/core/prepare.go b/core/prepare.go index a2f57ac..a3588cd 100644 --- a/core/prepare.go +++ b/core/prepare.go @@ -28,17 +28,18 @@ func (sg *SuperGraph) prepare(q *query, role string) { var err error qb := []byte(q.ai.Query) + vars := []byte(q.ai.Vars) switch q.qt { case qcode.QTQuery: if sg.abacEnabled { - stmts, err = sg.buildMultiStmt(qb, q.ai.Vars) + stmts, err = sg.buildMultiStmt(qb, vars) } else { - stmts, err = sg.buildRoleStmt(qb, q.ai.Vars, role) + stmts, err = sg.buildRoleStmt(qb, vars, role) } case qcode.QTMutation: - stmts, err = sg.buildRoleStmt(qb, q.ai.Vars, role) + stmts, err = sg.buildRoleStmt(qb, vars, role) } if err != nil { diff --git a/internal/serv/config.go b/internal/serv/config.go index c1161f0..dc06756 100644 --- a/internal/serv/config.go +++ b/internal/serv/config.go @@ -66,7 +66,7 @@ func newViper(configPath, configFile string) *viper.Viper { vi.SetDefault("host_port", "0.0.0.0:8080") vi.SetDefault("web_ui", false) vi.SetDefault("enable_tracing", false) - vi.SetDefault("auth_fail_block", "always") + vi.SetDefault("auth_fail_block", false) vi.SetDefault("seed_file", "seed.js") vi.SetDefault("default_block", true)