From ab8566df0368b82cf98a26f4e279e8e4dfd0f104 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Wed, 20 May 2020 00:03:05 -0400 Subject: [PATCH] fix: postgres schema name config value is not used --- core/api.go | 12 ++++++++++++ core/config.go | 3 +++ core/core.go | 19 ++++++++++++++++--- core/internal/psql/tables.go | 24 +++++++++++------------- core/internal/qcode/qcode.go | 7 ++++++- internal/serv/init.go | 10 ++++++++++ 6 files changed, 58 insertions(+), 17 deletions(-) diff --git a/core/api.go b/core/api.go index 68d9667..453ba60 100644 --- a/core/api.go +++ b/core/api.go @@ -210,3 +210,15 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess func (sg *SuperGraph) GraphQLSchema() (string, error) { return sg.ge.Schema.String(), nil } + +// Operation function return the operation type from the query. It uses a very fast algorithm to +// extract the operation without having to parse the query. +func Operation(query string) OpType { + return OpType(qcode.GetQType(query)) +} + +// Name function return the operation name from the query. It uses a very fast algorithm to +// extract the operation name without having to parse the query. +func Name(query string) string { + return allow.QueryName(query) +} diff --git a/core/config.go b/core/config.go index 0a9c0e4..7420936 100644 --- a/core/config.go +++ b/core/config.go @@ -61,6 +61,9 @@ type Config struct { // Inflections is to add additionally singular to plural mappings // to the engine (eg. sheep: sheep) Inflections map[string]string `mapstructure:"inflections"` + + // Database schema name. Defaults to 'public' + DBSchema string `mapstructure:"db_schema"` } // Table struct defines a database table diff --git a/core/core.go b/core/core.go index 0cd72ed..989249b 100644 --- a/core/core.go +++ b/core/core.go @@ -14,8 +14,10 @@ import ( "github.com/valyala/fasttemplate" ) +type OpType int + const ( - OpQuery int = iota + OpQuery OpType = iota OpMutation ) @@ -56,16 +58,27 @@ type scontext struct { func (sg *SuperGraph) initCompilers() error { var err error + var schema string + + if sg.conf.DBSchema == "" { + schema = "public" + } else { + schema = sg.conf.DBSchema + } // If sg.di is not null then it's probably set // for tests if sg.dbinfo == nil { - sg.dbinfo, err = psql.GetDBInfo(sg.db) + sg.dbinfo, err = psql.GetDBInfo(sg.db, schema) if err != nil { return err } } + if len(sg.dbinfo.Tables) == 0 { + return fmt.Errorf("no tables found in database (schema: %s)", schema) + } + if err = addTables(sg.conf, sg.dbinfo); err != nil { return err } @@ -334,7 +347,7 @@ func (c *scontext) executeRoleQuery(tx *sql.Tx) (string, error) { return role, nil } -func (r *Result) Operation() int { +func (r *Result) Operation() OpType { switch r.op { case qcode.QTQuery: return OpQuery diff --git a/core/internal/psql/tables.go b/core/internal/psql/tables.go index ec04d0a..219587b 100644 --- a/core/internal/psql/tables.go +++ b/core/internal/psql/tables.go @@ -17,7 +17,7 @@ type DBInfo struct { colMap map[string]map[string]*DBColumn } -func GetDBInfo(db *sql.DB) (*DBInfo, error) { +func GetDBInfo(db *sql.DB, schema string) (*DBInfo, error) { di := &DBInfo{} var version string @@ -31,13 +31,13 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) { return nil, err } - di.Tables, err = GetTables(db) + di.Tables, err = GetTables(db, schema) if err != nil { return nil, err } for _, t := range di.Tables { - cols, err := GetColumns(db, "public", t.Name) + cols, err := GetColumns(db, schema, t.Name) if err != nil { return nil, err } @@ -47,7 +47,7 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) { di.colMap = newColMap(di.Tables, di.Columns) - di.Functions, err = GetFunctions(db) + di.Functions, err = GetFunctions(db, schema) if err != nil { return nil, err } @@ -96,7 +96,7 @@ type DBTable struct { Type string } -func GetTables(db *sql.DB) ([]DBTable, error) { +func GetTables(db *sql.DB, schema string) ([]DBTable, error) { sqlStmt := ` SELECT c.relname as "name", @@ -108,14 +108,12 @@ SELECT FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind IN ('r','v','m','f','') - AND n.nspname <> ('pg_catalog') - AND n.nspname <> ('information_schema') - AND n.nspname !~ ('^pg_toast') -AND pg_catalog.pg_table_is_visible(c.oid);` + AND n.nspname = $1 + AND pg_catalog.pg_table_is_visible(c.oid);` var tables []DBTable - rows, err := db.Query(sqlStmt) + rows, err := db.Query(sqlStmt, schema) if err != nil { return nil, fmt.Errorf("Error fetching tables: %s", err) } @@ -264,7 +262,7 @@ type DBFuncParam struct { Type string } -func GetFunctions(db *sql.DB) ([]DBFunction, error) { +func GetFunctions(db *sql.DB, schema string) ([]DBFunction, error) { sqlStmt := ` SELECT routines.routine_name, @@ -278,11 +276,11 @@ RIGHT JOIN information_schema.parameters ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL) WHERE - routines.specific_schema = 'public' + routines.specific_schema = $1 ORDER BY routines.routine_name, parameters.ordinal_position;` - rows, err := db.Query(sqlStmt) + rows, err := db.Query(sqlStmt, schema) if err != nil { return nil, fmt.Errorf("Error fetching functions: %s", err) } diff --git a/core/internal/qcode/qcode.go b/core/internal/qcode/qcode.go index c3d45ca..f5147aa 100644 --- a/core/internal/qcode/qcode.go +++ b/core/internal/qcode/qcode.go @@ -946,6 +946,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { ex.Op = OpDistinct ex.Val = node.Val default: + if len(node.Children) == 0 { + return nil, fmt.Errorf("[Where] invalid operation: %s", name) + } pushChildren(st, node.exp, node) return nil, nil // skip node } @@ -965,8 +968,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { case NodeVar: ex.Type = ValVar default: - return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type) + return nil, fmt.Errorf("[Where] invalid values for: %s", name) } + setWhereColName(ex, node) } @@ -1015,6 +1019,7 @@ func setWhereColName(ex *Exp, node *Node) { ex.Col = list[listlen-1] ex.NestedCols = list[:listlen] } + } func setOrderByColName(ob *OrderBy, node *Node) { diff --git a/internal/serv/init.go b/internal/serv/init.go index c3d9c5b..2073483 100644 --- a/internal/serv/init.go +++ b/internal/serv/init.go @@ -45,6 +45,16 @@ func initConf() (*Config, error) { logLevel = LogLevelNone } + // copy over db_schema from the core config + if c.DB.Schema == "" { + c.DB.Schema = c.DBSchema + } + + // set default database schema + if c.DB.Schema == "" { + c.DB.Schema = "public" + } + // Auths: validate and sanitize am := make(map[string]struct{})