diff --git a/allow/allow.go b/allow/allow.go new file mode 100644 index 0000000..ef66a99 --- /dev/null +++ b/allow/allow.go @@ -0,0 +1,337 @@ +package allow + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path" + "sort" + "strings" +) + +const ( + AL_QUERY int = iota + 1 + AL_VARS +) + +type Item struct { + Name string + key string + URI string + Query string + Vars json.RawMessage +} + +type List struct { + filepath string + saveChan chan Item +} + +type Config struct { + CreateIfNotExists bool + Persist bool +} + +func New(cpath string, conf Config) (*List, error) { + al := List{} + + if len(cpath) != 0 { + fp := path.Join(cpath, "allow.list") + + if _, err := os.Stat(fp); err == nil { + al.filepath = fp + } else if !os.IsNotExist(err) { + return nil, err + } + } + + if len(al.filepath) == 0 { + fp := "./allow.list" + + if _, err := os.Stat(fp); err == nil { + al.filepath = fp + } else if !os.IsNotExist(err) { + return nil, err + } + } + + if len(al.filepath) == 0 { + fp := "./config/allow.list" + + if _, err := os.Stat(fp); err == nil { + al.filepath = fp + } else if !os.IsNotExist(err) { + return nil, err + } + } + + if len(al.filepath) == 0 { + if !conf.CreateIfNotExists { + return nil, errors.New("allow.list not found") + } + + if len(cpath) == 0 { + al.filepath = "./config/allow.list" + } else { + al.filepath = path.Join(cpath, "allow.list") + } + } + + var err error + + if conf.Persist { + al.saveChan = make(chan Item) + + go func() { + for v := range al.saveChan { + if err = al.save(v); err != nil { + break + } + } + }() + } + + if err != nil { + return nil, err + } + + return &al, nil +} + +func (al *List) IsPersist() bool { + return al.saveChan != nil +} + +func (al *List) Add(vars []byte, query, uri string) error { + if al.saveChan == nil { + return errors.New("allow.list is read-only") + } + + if len(query) == 0 { + return errors.New("empty query") + } + + var q string + + for i := 0; i < len(query); i++ { + c := query[i] + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' { + q = query + break + + } else if c == '{' { + q = "query " + query + break + } + } + + al.saveChan <- Item{ + URI: uri, + Query: q, + Vars: vars, + } + + return nil +} + +func (al *List) Load() ([]Item, error) { + var list []Item + + b, err := ioutil.ReadFile(al.filepath) + if err != nil { + return list, err + } + + if len(b) == 0 { + return list, nil + } + + var uri string + var varBytes []byte + + itemMap := make(map[string]struct{}) + + s, e, c := 0, 0, 0 + ty := 0 + + for { + fq := false + + if c == 0 && b[e] == '#' { + s = e + for e < len(b) && b[e] != '\n' { + e++ + } + if (e - s) > 2 { + uri = strings.TrimSpace(string(b[(s + 1):e])) + } + } + + if e >= len(b) { + break + } + + if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") { + if c == 0 { + s = e + } + ty = AL_QUERY + } else if matchPrefix(b, e, "variables") { + if c == 0 { + s = e + len("variables") + 1 + } + ty = AL_VARS + } else if b[e] == '{' { + c++ + + } else if b[e] == '}' { + c-- + + if c == 0 { + if ty == AL_QUERY { + fq = true + } else if ty == AL_VARS { + varBytes = b[s:(e + 1)] + } + ty = 0 + } + } + + if fq { + query := string(b[s:(e + 1)]) + name := QueryName(query) + key := strings.ToLower(name) + + if _, ok := itemMap[key]; !ok { + v := Item{ + Name: name, + key: key, + URI: uri, + Query: query, + Vars: varBytes, + } + list = append(list, v) + } + + varBytes = nil + + } + + e++ + if e >= len(b) { + break + } + } + + return list, nil +} + +func (al *List) save(item Item) error { + item.Name = QueryName(item.Query) + item.key = strings.ToLower(item.Name) + + if len(item.Name) == 0 { + return nil + } + + list, err := al.Load() + if err != nil { + return err + } + + index := -1 + + for i, v := range list { + if strings.EqualFold(v.Name, item.Name) { + index = i + break + } + } + + if index != -1 { + list[index] = item + } else { + list = append(list, item) + } + + f, err := os.Create(al.filepath) + if err != nil { + return err + } + + defer f.Close() + + sort.Slice(list, func(i, j int) bool { + return strings.Compare(list[i].key, list[j].key) == -1 + }) + + for _, v := range list { + _, err := f.WriteString(fmt.Sprintf("# %s\n\n", v.URI)) + if err != nil { + return err + } + + if len(v.Vars) != 0 && !bytes.Equal(v.Vars, []byte("{}")) { + vj, err := json.MarshalIndent(v.Vars, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal vars: %v", err) + } + + _, err = f.WriteString(fmt.Sprintf("variables %s\n\n", vj)) + if err != nil { + return err + } + } + + if v.Query[0] == '{' { + _, err = f.WriteString(fmt.Sprintf("query %s\n\n", v.Query)) + } else { + _, err = f.WriteString(fmt.Sprintf("%s\n\n", v.Query)) + } + + if err != nil { + return err + } + } + + return nil +} + +func matchPrefix(b []byte, i int, s string) bool { + if (len(b) - i) < len(s) { + return false + } + for n := 0; n < len(s); n++ { + if b[(i+n)] != s[n] { + return false + } + } + return true +} + +func QueryName(b string) string { + state, s := 0, 0 + + for i := 0; i < len(b); i++ { + switch { + case state == 2 && b[i] == '{': + return b[s:i] + case state == 2 && b[i] == ' ': + return b[s:i] + case state == 1 && b[i] == '{': + return "" + case state == 1 && b[i] != ' ': + s = i + state = 2 + case state == 1 && b[i] == ' ': + continue + case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'): + state = 1 + } + } + + return "" +} diff --git a/allow/allow_test.go b/allow/allow_test.go new file mode 100644 index 0000000..f92dc76 --- /dev/null +++ b/allow/allow_test.go @@ -0,0 +1,82 @@ +package allow + +import ( + "testing" +) + +func TestGQLName1(t *testing.T) { + var q = ` + query { + products( + distinct: [price] + where: { id: { and: { greater_or_equals: 20, lt: 28 } } } + ) { id name } }` + + name := QueryName(q) + + if len(name) != 0 { + t.Fatal("Name should be empty, not ", name) + } +} + +func TestGQLName2(t *testing.T) { + var q = ` + query hakuna_matata { + products( + distinct: [price] + where: { id: { and: { greater_or_equals: 20, lt: 28 } } } + ) { + id + name + } + }` + + name := QueryName(q) + + if name != "hakuna_matata" { + t.Fatal("Name should be 'hakuna_matata', not ", name) + } +} + +func TestGQLName3(t *testing.T) { + var q = ` + mutation means{ users { id } }` + + // var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` + + name := QueryName(q) + + if name != "means" { + t.Fatal("Name should be 'means', not ", name) + } +} + +func TestGQLName4(t *testing.T) { + var q = ` + query no_worries + users { + id + } + }` + + name := QueryName(q) + + if name != "no_worries" { + t.Fatal("Name should be 'no_worries', not ", name) + } +} + +func TestGQLName5(t *testing.T) { + var q = ` + { + users { + id + } + }` + + name := QueryName(q) + + if len(name) != 0 { + t.Fatal("Name should be empty, not ", name) + } +} diff --git a/allow/fuzz_test.go b/allow/fuzz_test.go new file mode 100644 index 0000000..ba3b143 --- /dev/null +++ b/allow/fuzz_test.go @@ -0,0 +1,15 @@ +package allow + +import "testing" + +func TestFuzzCrashers(t *testing.T) { + var crashers = []string{ + "query", + "q", + "que", + } + + for _, f := range crashers { + _ = QueryName(f) + } +} diff --git a/docs/guide.md b/docs/guide.md index f32ff9a..92e8788 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -651,8 +651,6 @@ query { } ``` -## Mutations - In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates. When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments. @@ -836,8 +834,6 @@ mutation { } ``` -## Nested Mutations - Often you will need to create or update multiple related items at the same time. This can be done using nested mutations. For example you might need to create a product and assign it to a user, or create a user and his products at the same time. You just have to use simple json to define you mutation and Super Graph takes care of the rest. ### Nested Insert @@ -988,6 +984,40 @@ fetch('http://localhost:8080/api/v1/graphql', { .then(res => console.log(res.data)); ``` +## GraphQL with React + +This is a quick simple example using `graphql.js` [https://github.com/f/graphql.js/](https://github.com/f/graphql.js/) + +```js +import React, { useState, useEffect } from 'react' +import graphql from 'graphql.js' + +// Create a GraphQL client pointing to Super Graph +var graph = graphql("http://localhost:3000/api/v1/graphql", { asJSON: true }) + +const App = () => { + const [user, setUser] = useState(null) + + useEffect(() => { + async function action() { + // Use the GraphQL client to execute a graphQL query + // The second argument to the client are the variables you need to pass + const result = await graph(`{ user { id first_name last_name picture_url } }`)() + setUser(result) + } + action() + }, []); + + return ( +
+

{ JSON.stringify(user) }

+
+ ); +} +``` + +export default App; + ## Advanced Columns The ablity to have `JSON/JSONB` and `Array` columns is often considered in the top most useful features of Postgres. There are many cases where using an array or a json column saves space and reduces complexity in your app. The only issue with these columns is the really that your SQL queries can get harder to write and maintain. @@ -1137,45 +1167,43 @@ class AddSearchColumn < ActiveRecord::Migration[5.1] end ``` -## GraphQL with React +## API Security -This is a quick simple example using `graphql.js` [https://github.com/f/graphql.js/](https://github.com/f/graphql.js/) +One of the the most common questions I get asked if what happens if a user out on the internet issues queries +that we don't want issued. For example how do we stop him from fetching all users or the emails of users. Our answer to this is that it is not an issue as this cannot happen, let me explain. -```js -import React, { useState, useEffect } from 'react' -import graphql from 'graphql.js' +Super Graph runs in one of two modes `development` or `production`, this is controlled via the config value `production: false` when it's false it's running in development mode and when true, production. In development mode all the **named** quries (including mutations) you run are saved into the allow list (`./config/allow.list`). I production mode when Super Graph starts only the queries from this allow list file are registered with the database as (prepared statements)[https://stackoverflow.com/questions/8263371/how-can-prepared-statements-protect-from-sql-injection-attacks]. Prepared statements are designed by databases to be fast and secure. They protect against all kinds of sql injection attacks and since they are pre-processed and pre-planned they are much faster to run then raw sql queries. Also there's no GraphQL to SQL compiling happening in production mode which makes your queries lighting fast as they directly goto the database with almost no overhead. -// Create a GraphQL client pointing to Super Graph -var graph = graphql("http://localhost:3000/api/v1/graphql", { asJSON: true }) +In short in production only queries listed in the allow list file (`./config/allow.list`) can be used all other queries will be blocked. -const App = () => { - const [user, setUser] = useState(null) - - useEffect(() => { - async function action() { - // Use the GraphQL client to execute a graphQL query - // The second argument to the client are the variables you need to pass - const result = await graph(`{ user { id first_name last_name picture_url } }`)() - setUser(result) +::: tip How to think about the allow list? +The allow list file is essentially a list of all your exposed API calls and the data thats passes within them in plain text. It's very easy to build tooling to do things like parsing this file within your tests to ensure fields like `credit_card_no` are not accidently leaked. It's a great way to build compliance tooling and ensure your user data is always safe. +::: + +This is an example of a named query `getUserWithProducts` is the name you've given to this query it can be anything you like but should be unique across all you're queries. Only named queries are saved in the allow list in development mode the allow list is not modified in production mode. + + +```graphql +query getUserWithProducts { + users { + id + name + products { + id + name + price } - action() - }, []); - - return ( -
-

{ JSON.stringify(user) }

-
- ); + } } ``` -export default App; + ## Authentication You can only have one type of auth enabled. You can either pick Rails or JWT. -### Rails Auth (Devise / Warden) +### Ruby on Rails Almost all Rails apps use Devise or Warden for authentication. Once the user is authenticated a session is created with the users ID. The session can either be @@ -1261,7 +1289,6 @@ The `user` role can be divided up into further roles based on attributes in the Super Graph allows you to create roles dynamically using a `roles_query` and ` match` config values. - ### Configure RBAC ```yaml diff --git a/main.go b/main.go index 5adcc4c..5a3cff2 100644 --- a/main.go +++ b/main.go @@ -5,5 +5,5 @@ import ( ) func main() { - serv.Init() + serv.Cmd() } diff --git a/serv/allow.go b/serv/allow.go deleted file mode 100644 index 82e2917..0000000 --- a/serv/allow.go +++ /dev/null @@ -1,320 +0,0 @@ -package serv - -import ( - "bytes" - "encoding/json" - "fmt" - "io/ioutil" - "log" - "os" - "path" - "sort" - "strings" -) - -const ( - AL_QUERY int = iota + 1 - AL_VARS -) - -type allowItem struct { - name string - hash string - uri string - gql string - vars json.RawMessage -} - -var _allowList allowList - -type allowList struct { - list []*allowItem - index map[string]int - filepath string - saveChan chan *allowItem - active bool -} - -func initAllowList(cpath string) { - _allowList = allowList{ - index: make(map[string]int), - saveChan: make(chan *allowItem), - active: true, - } - - if len(cpath) != 0 { - fp := path.Join(cpath, "allow.list") - - if _, err := os.Stat(fp); err == nil { - _allowList.filepath = fp - } else if !os.IsNotExist(err) { - errlog.Fatal().Err(err).Send() - } - } - - if len(_allowList.filepath) == 0 { - fp := "./allow.list" - - if _, err := os.Stat(fp); err == nil { - _allowList.filepath = fp - } else if !os.IsNotExist(err) { - errlog.Fatal().Err(err).Send() - } - } - - if len(_allowList.filepath) == 0 { - fp := "./config/allow.list" - - if _, err := os.Stat(fp); err == nil { - _allowList.filepath = fp - } else if !os.IsNotExist(err) { - errlog.Fatal().Err(err).Send() - } - } - - if len(_allowList.filepath) == 0 { - if conf.Production { - errlog.Fatal().Msg("allow.list not found") - } - - if len(cpath) == 0 { - _allowList.filepath = "./config/allow.list" - } else { - _allowList.filepath = path.Join(cpath, "allow.list") - } - - logger.Warn().Msg("allow.list not found") - } else { - _allowList.load() - } - - go func() { - for v := range _allowList.saveChan { - _allowList.save(v) - } - }() -} - -func (al *allowList) add(req *gqlReq) { - if al.saveChan == nil || len(req.ref) == 0 || len(req.Query) == 0 { - return - } - - var query string - - for i := 0; i < len(req.Query); i++ { - c := req.Query[i] - if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' { - query = req.Query - break - - } else if c == '{' { - query = "query " + req.Query - break - } - } - - al.saveChan <- &allowItem{ - uri: req.ref, - gql: query, - vars: req.Vars, - } -} - -func (al *allowList) upsert(query, vars []byte, uri string) { - q := string(query) - hash := gqlHash(q, vars, "") - name := gqlName(q) - - var key string - - if len(name) != 0 { - key = name - } else { - key = hash - } - - if i, ok := al.index[key]; !ok { - al.list = append(al.list, &allowItem{ - name: name, - hash: hash, - uri: uri, - gql: q, - vars: vars, - }) - al.index[key] = len(al.list) - 1 - } else { - item := al.list[i] - item.name = name - item.hash = hash - item.gql = q - item.vars = vars - - } -} - -func (al *allowList) load() { - b, err := ioutil.ReadFile(al.filepath) - if err != nil { - log.Fatal(err) - } - - if len(b) == 0 { - return - } - - var uri string - var varBytes []byte - - s, e, c := 0, 0, 0 - ty := 0 - - for { - if c == 0 && b[e] == '#' { - s = e - for e < len(b) && b[e] != '\n' { - e++ - } - if (e - s) > 2 { - uri = strings.TrimSpace(string(b[(s + 1):e])) - } - } - - if e >= len(b) { - break - } - - if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") { - if c == 0 { - s = e - } - ty = AL_QUERY - } else if matchPrefix(b, e, "variables") { - if c == 0 { - s = e + len("variables") + 1 - } - ty = AL_VARS - } else if b[e] == '{' { - c++ - - } else if b[e] == '}' { - c-- - - if c == 0 { - if ty == AL_QUERY { - al.upsert(b[s:(e+1)], varBytes, uri) - varBytes = nil - - } else if ty == AL_VARS { - varBytes = b[s:(e + 1)] - } - ty = 0 - } - } - - e++ - if e >= len(b) { - break - } - } -} - -func (al *allowList) save(item *allowItem) { - var err error - - item.hash = gqlHash(item.gql, item.vars, "") - item.name = gqlName(item.gql) - - if len(item.name) == 0 { - key := item.hash - - if _, ok := al.index[key]; ok { - return - } - - al.list = append(al.list, item) - al.index[key] = len(al.list) - 1 - - } else { - key := item.name - - if i, ok := al.index[key]; ok { - if al.list[i].hash == item.hash { - return - } - al.list[i] = item - } else { - al.list = append(al.list, item) - al.index[key] = len(al.list) - 1 - } - } - - f, err := os.Create(al.filepath) - if err != nil { - logger.Warn().Err(err).Msgf("Failed to write allow list: %s", al.filepath) - return - } - - defer f.Close() - - keys := []string{} - urlMap := make(map[string][]*allowItem) - - for _, v := range al.list { - urlMap[v.uri] = append(urlMap[v.uri], v) - } - - for k := range urlMap { - keys = append(keys, k) - } - sort.Strings(keys) - - for i := range keys { - k := keys[i] - v := urlMap[k] - - if _, err := f.WriteString(fmt.Sprintf("# %s\n\n", k)); err != nil { - logger.Error().Err(err).Send() - return - } - - for i := range v { - if len(v[i].vars) != 0 && !bytes.Equal(v[i].vars, []byte("{}")) { - vj, err := json.MarshalIndent(v[i].vars, "", " ") - if err != nil { - logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file") - continue - } - - _, err = f.WriteString(fmt.Sprintf("variables %s\n\n", vj)) - if err != nil { - logger.Error().Err(err).Send() - return - } - } - - if v[i].gql[0] == '{' { - _, err = f.WriteString(fmt.Sprintf("query %s\n\n", v[i].gql)) - } else { - _, err = f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql)) - } - - if err != nil { - logger.Error().Err(err).Send() - return - } - } - } -} - -func matchPrefix(b []byte, i int, s string) bool { - if (len(b) - i) < len(s) { - return false - } - for n := 0; n < len(s); n++ { - if b[(i+n)] != s[n] { - return false - } - } - return true -} diff --git a/serv/cmd.go b/serv/cmd.go index 7486fa4..c8a0098 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -1,15 +1,13 @@ package serv import ( - "context" "fmt" - "os" "runtime" "strings" + "github.com/dosco/super-graph/allow" "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" - "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/rs/zerolog" "github.com/spf13/cobra" @@ -31,17 +29,18 @@ var ( ) var ( - logger zerolog.Logger // logger for everything but errors - errlog zerolog.Logger // logger for errors includes line numbers - conf *config // parsed config - confPath string // path to the config file - db *pgxpool.Pool // database connection pool - schema *psql.DBSchema // database tables, columns and relationships - qcompile *qcode.Compiler // qcode compiler - pcompile *psql.Compiler // postgres sql compiler + logger zerolog.Logger // logger for everything but errors + errlog zerolog.Logger // logger for errors includes line numbers + conf *config // parsed config + confPath string // path to the config file + db *pgxpool.Pool // database connection pool + schema *psql.DBSchema // database tables, columns and relationships + allowList *allow.List // allow.list is contains queries allowed in production + qcompile *qcode.Compiler // qcode compiler + pcompile *psql.Compiler // postgres sql compiler ) -func Init() { +func Cmd() { initLog() rootCmd := &cobra.Command{ @@ -156,159 +155,6 @@ e.g. db:migrate -+1 } } -func initLog() { - out := zerolog.ConsoleWriter{Out: os.Stderr} - logger = zerolog.New(out).With().Timestamp().Logger() - errlog = logger.With().Caller().Logger() -} - -func initConf() (*config, error) { - vi := newConfig(getConfigName()) - - if err := vi.ReadInConfig(); err != nil { - return nil, err - } - - inherits := vi.GetString("inherits") - if len(inherits) != 0 { - vi = newConfig(inherits) - - if err := vi.ReadInConfig(); err != nil { - return nil, err - } - - if vi.IsSet("inherits") { - errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)", - inherits, - vi.GetString("inherits")) - } - - vi.SetConfigName(getConfigName()) - - if err := vi.MergeInConfig(); err != nil { - return nil, err - } - } - - c := &config{} - - if err := c.Init(vi); err != nil { - return nil, fmt.Errorf("unable to decode config, %v", err) - } - - logLevel, err := zerolog.ParseLevel(c.LogLevel) - if err != nil { - errlog.Error().Err(err).Msg("error setting log_level") - } - zerolog.SetGlobalLevel(logLevel) - - return c, nil -} - -func initDB(c *config, useDB bool) (*pgx.Conn, error) { - config, _ := pgx.ParseConfig("") - config.Host = c.DB.Host - config.Port = c.DB.Port - config.User = c.DB.User - config.Password = c.DB.Password - config.RuntimeParams = map[string]string{ - "application_name": c.AppName, - "search_path": c.DB.Schema, - } - - if useDB { - config.Database = c.DB.DBName - } - - switch c.LogLevel { - case "debug": - config.LogLevel = pgx.LogLevelDebug - case "info": - config.LogLevel = pgx.LogLevelInfo - case "warn": - config.LogLevel = pgx.LogLevelWarn - case "error": - config.LogLevel = pgx.LogLevelError - default: - config.LogLevel = pgx.LogLevelNone - } - - config.Logger = NewSQLLogger(logger) - - db, err := pgx.ConnectConfig(context.Background(), config) - if err != nil { - return nil, err - } - - return db, nil -} - -func initDBPool(c *config) (*pgxpool.Pool, error) { - config, _ := pgxpool.ParseConfig("") - config.ConnConfig.Host = c.DB.Host - config.ConnConfig.Port = c.DB.Port - config.ConnConfig.Database = c.DB.DBName - config.ConnConfig.User = c.DB.User - config.ConnConfig.Password = c.DB.Password - config.ConnConfig.RuntimeParams = map[string]string{ - "application_name": c.AppName, - "search_path": c.DB.Schema, - } - - switch c.LogLevel { - case "debug": - config.ConnConfig.LogLevel = pgx.LogLevelDebug - case "info": - config.ConnConfig.LogLevel = pgx.LogLevelInfo - case "warn": - config.ConnConfig.LogLevel = pgx.LogLevelWarn - case "error": - config.ConnConfig.LogLevel = pgx.LogLevelError - default: - config.ConnConfig.LogLevel = pgx.LogLevelNone - } - - config.ConnConfig.Logger = NewSQLLogger(logger) - - // if c.DB.MaxRetries != 0 { - // opt.MaxRetries = c.DB.MaxRetries - // } - - if c.DB.PoolSize != 0 { - config.MaxConns = conf.DB.PoolSize - } - - db, err := pgxpool.ConnectConfig(context.Background(), config) - if err != nil { - return nil, err - } - - return db, nil -} - -func initCompiler() { - var err error - - qcompile, pcompile, err = initCompilers(conf) - if err != nil { - errlog.Fatal().Err(err).Msg("failed to initialize compilers") - } - - if err := initResolvers(); err != nil { - errlog.Fatal().Err(err).Msg("failed to initialized resolvers") - } -} - -func initConfOnce() { - var err error - - if conf == nil { - if conf, err = initConf(); err != nil { - errlog.Fatal().Err(err).Msg("failed to read config") - } - } -} - func cmdVersion(cmd *cobra.Command, args []string) { fmt.Printf("%s\n", BuildDetails()) } diff --git a/serv/cmd_serv.go b/serv/cmd_serv.go index 2a8062a..37ac757 100644 --- a/serv/cmd_serv.go +++ b/serv/cmd_serv.go @@ -17,7 +17,7 @@ func cmdServ(cmd *cobra.Command, args []string) { if err == nil { initCompiler() initAllowList(confPath) - initPreparedList() + initPreparedList(confPath) } else { fatalInProd(err, "failed to connect to database") } diff --git a/serv/core.go b/serv/core.go index fc708a9..5302a91 100644 --- a/serv/core.go +++ b/serv/core.go @@ -11,6 +11,7 @@ import ( "time" "github.com/cespare/xxhash/v2" + "github.com/dosco/super-graph/allow" "github.com/dosco/super-graph/qcode" "github.com/jackc/pgx/v4" "github.com/valyala/fasttemplate" @@ -107,7 +108,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) { } - ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)] + ps, ok := _preparedList[stmtHash(allow.QueryName(c.req.Query), role)] if !ok { return nil, nil, errUnauthorized } @@ -240,8 +241,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { } } - if !conf.Production { - _allowList.add(&c.req) + if allowList.IsPersist() { + if err := allowList.Add(c.req.Vars, c.req.Query, c.req.ref); err != nil { + return nil, nil, err + } } if len(stmts) > 1 { diff --git a/serv/fuzz.go b/serv/fuzz.go index 34ef656..17ff153 100644 --- a/serv/fuzz.go +++ b/serv/fuzz.go @@ -4,7 +4,7 @@ package serv func Fuzz(data []byte) int { gql := string(data) - gqlName(gql) + QueryName(gql) gqlHash(gql, nil, "") return 1 diff --git a/serv/fuzz_test.go b/serv/fuzz_test.go index 8c4b715..68fe2c6 100644 --- a/serv/fuzz_test.go +++ b/serv/fuzz_test.go @@ -10,7 +10,6 @@ func TestFuzzCrashers(t *testing.T) { } for _, f := range crashers { - _ = gqlName(f) gqlHash(f, nil, "") } } diff --git a/serv/init.go b/serv/init.go new file mode 100644 index 0000000..c155c31 --- /dev/null +++ b/serv/init.go @@ -0,0 +1,179 @@ +package serv + +import ( + "context" + "fmt" + "os" + + "github.com/dosco/super-graph/allow" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/rs/zerolog" +) + +func initLog() { + out := zerolog.ConsoleWriter{Out: os.Stderr} + logger = zerolog.New(out).With().Timestamp().Logger() + errlog = logger.With().Caller().Logger() +} + +func initConf() (*config, error) { + vi := newConfig(getConfigName()) + + if err := vi.ReadInConfig(); err != nil { + return nil, err + } + + inherits := vi.GetString("inherits") + if len(inherits) != 0 { + vi = newConfig(inherits) + + if err := vi.ReadInConfig(); err != nil { + return nil, err + } + + if vi.IsSet("inherits") { + errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)", + inherits, + vi.GetString("inherits")) + } + + vi.SetConfigName(getConfigName()) + + if err := vi.MergeInConfig(); err != nil { + return nil, err + } + } + + c := &config{} + + if err := c.Init(vi); err != nil { + return nil, fmt.Errorf("unable to decode config, %v", err) + } + + logLevel, err := zerolog.ParseLevel(c.LogLevel) + if err != nil { + errlog.Error().Err(err).Msg("error setting log_level") + } + zerolog.SetGlobalLevel(logLevel) + + return c, nil +} + +func initDB(c *config, useDB bool) (*pgx.Conn, error) { + config, _ := pgx.ParseConfig("") + config.Host = c.DB.Host + config.Port = c.DB.Port + config.User = c.DB.User + config.Password = c.DB.Password + config.RuntimeParams = map[string]string{ + "application_name": c.AppName, + "search_path": c.DB.Schema, + } + + if useDB { + config.Database = c.DB.DBName + } + + switch c.LogLevel { + case "debug": + config.LogLevel = pgx.LogLevelDebug + case "info": + config.LogLevel = pgx.LogLevelInfo + case "warn": + config.LogLevel = pgx.LogLevelWarn + case "error": + config.LogLevel = pgx.LogLevelError + default: + config.LogLevel = pgx.LogLevelNone + } + + config.Logger = NewSQLLogger(logger) + + db, err := pgx.ConnectConfig(context.Background(), config) + if err != nil { + return nil, err + } + + return db, nil +} + +func initDBPool(c *config) (*pgxpool.Pool, error) { + config, _ := pgxpool.ParseConfig("") + config.ConnConfig.Host = c.DB.Host + config.ConnConfig.Port = c.DB.Port + config.ConnConfig.Database = c.DB.DBName + config.ConnConfig.User = c.DB.User + config.ConnConfig.Password = c.DB.Password + config.ConnConfig.RuntimeParams = map[string]string{ + "application_name": c.AppName, + "search_path": c.DB.Schema, + } + + switch c.LogLevel { + case "debug": + config.ConnConfig.LogLevel = pgx.LogLevelDebug + case "info": + config.ConnConfig.LogLevel = pgx.LogLevelInfo + case "warn": + config.ConnConfig.LogLevel = pgx.LogLevelWarn + case "error": + config.ConnConfig.LogLevel = pgx.LogLevelError + default: + config.ConnConfig.LogLevel = pgx.LogLevelNone + } + + config.ConnConfig.Logger = NewSQLLogger(logger) + + // if c.DB.MaxRetries != 0 { + // opt.MaxRetries = c.DB.MaxRetries + // } + + if c.DB.PoolSize != 0 { + config.MaxConns = conf.DB.PoolSize + } + + db, err := pgxpool.ConnectConfig(context.Background(), config) + if err != nil { + return nil, err + } + + return db, nil +} + +func initCompiler() { + var err error + + qcompile, pcompile, err = initCompilers(conf) + if err != nil { + errlog.Fatal().Err(err).Msg("failed to initialize compilers") + } + + if err := initResolvers(); err != nil { + errlog.Fatal().Err(err).Msg("failed to initialized resolvers") + } +} + +func initConfOnce() { + var err error + + if conf == nil { + if conf, err = initConf(); err != nil { + errlog.Fatal().Err(err).Msg("failed to read config") + } + } +} + +func initAllowList(cpath string) { + var ac allow.Config + var err error + + if !conf.Production { + ac = allow.Config{CreateIfNotExists: true, Persist: true} + } + + allowList, err = allow.New(cpath, ac) + if err != nil { + errlog.Fatal().Err(err).Msg("failed to initialize allow list") + } +} diff --git a/serv/prepare.go b/serv/prepare.go index b756144..3f04e1d 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + "github.com/dosco/super-graph/allow" "github.com/dosco/super-graph/qcode" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" @@ -23,7 +24,10 @@ var ( _preparedList map[string]*preparedItem ) -func initPreparedList() { +func initPreparedList(cpath string) { + if allowList.IsPersist() { + return + } _preparedList = make(map[string]*preparedItem) tx, err := db.Begin(context.Background()) @@ -43,30 +47,38 @@ func initPreparedList() { success := 0 - for _, v := range _allowList.list { - if len(v.gql) == 0 { + list, err := allowList.Load() + if err != nil { + errlog.Fatal().Err(err).Send() + } + + for _, v := range list { + if len(v.Query) == 0 { continue } - err := prepareStmt(v.gql, v.vars) + err := prepareStmt(v) if err == nil { success++ continue } - if len(v.vars) == 0 { - logger.Warn().Err(err).Msg(v.gql) + if len(v.Vars) == 0 { + logger.Warn().Err(err).Msg(v.Query) } else { - logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql) + logger.Warn().Err(err).Msgf("%s %s", v.Vars, v.Query) } } logger.Info(). Msgf("Registered %d of %d queries from allow.list as prepared statements", - success, len(_allowList.list)) + success, len(list)) } -func prepareStmt(gql string, vars []byte) error { +func prepareStmt(item allow.Item) error { + gql := item.Query + vars := item.Vars + qt := qcode.GetQType(gql) q := []byte(gql) @@ -99,7 +111,7 @@ func prepareStmt(gql string, vars []byte) error { logger.Debug().Msg("Prepared statement role: user") - err = prepare(tx, stmts1, gqlHash(gql, vars, "user")) + err = prepare(tx, stmts1, stmtHash(item.Name, "user")) if err != nil { return err } @@ -112,7 +124,7 @@ func prepareStmt(gql string, vars []byte) error { return err } - err = prepare(tx, stmts2, gqlHash(gql, vars, "anon")) + err = prepare(tx, stmts2, stmtHash(item.Name, "anon")) if err != nil { return err } @@ -127,7 +139,7 @@ func prepareStmt(gql string, vars []byte) error { return err } - err = prepare(tx, stmts, gqlHash(gql, vars, role.Name)) + err = prepare(tx, stmts, stmtHash(item.Name, role.Name)) if err != nil { return err } diff --git a/serv/utils.go b/serv/utils.go index d6e814e..c96c76b 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -22,6 +22,14 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { return v } +// nolint: errcheck +func stmtHash(name string, role string) string { + h := sha1.New() + io.WriteString(h, strings.ToLower(name)) + io.WriteString(h, role) + return hex.EncodeToString(h.Sum(nil)) +} + // nolint: errcheck func gqlHash(b string, vars []byte, role string) string { b = strings.TrimSpace(b) @@ -108,30 +116,6 @@ func al(b byte) bool { return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } -func gqlName(b string) string { - state, s := 0, 0 - - for i := 0; i < len(b); i++ { - switch { - case state == 2 && b[i] == '{': - return b[s:i] - case state == 2 && b[i] == ' ': - return b[s:i] - case state == 1 && b[i] == '{': - return "" - case state == 1 && b[i] != ' ': - s = i - state = 2 - case state == 1 && b[i] == ' ': - continue - case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'): - state = 1 - } - } - - return "" -} - func findStmt(role string, stmts []stmt) *stmt { for i := range stmts { if stmts[i].role.Name != role { diff --git a/serv/utils_test.go b/serv/utils_test.go index 952eb61..b8babeb 100644 --- a/serv/utils_test.go +++ b/serv/utils_test.go @@ -229,80 +229,3 @@ func TestGQLHashWithVars2(t *testing.T) { t.Fatal("Hashes don't match they should") } } - -func TestGQLName1(t *testing.T) { - var q = ` - query { - products( - distinct: [price] - where: { id: { and: { greater_or_equals: 20, lt: 28 } } } - ) { id name } }` - - name := gqlName(q) - - if len(name) != 0 { - t.Fatal("Name should be empty, not ", name) - } -} - -func TestGQLName2(t *testing.T) { - var q = ` - query hakuna_matata { - products( - distinct: [price] - where: { id: { and: { greater_or_equals: 20, lt: 28 } } } - ) { - id - name - } - }` - - name := gqlName(q) - - if name != "hakuna_matata" { - t.Fatal("Name should be 'hakuna_matata', not ", name) - } -} - -func TestGQLName3(t *testing.T) { - var q = ` - mutation means{ users { id } }` - - // var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` - - name := gqlName(q) - - if name != "means" { - t.Fatal("Name should be 'means', not ", name) - } -} - -func TestGQLName4(t *testing.T) { - var q = ` - query no_worries - users { - id - } - }` - - name := gqlName(q) - - if name != "no_worries" { - t.Fatal("Name should be 'no_worries', not ", name) - } -} - -func TestGQLName5(t *testing.T) { - var q = ` - { - users { - id - } - }` - - name := gqlName(q) - - if len(name) != 0 { - t.Fatal("Name should be empty, not ", name) - } -}