fix: postgres schema name config value is not used

This commit is contained in:
Vikram Rangnekar 2020-05-20 00:03:05 -04:00
parent 94fa51ffb2
commit ab8566df03
6 changed files with 58 additions and 17 deletions

View File

@ -210,3 +210,15 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
func (sg *SuperGraph) GraphQLSchema() (string, error) { func (sg *SuperGraph) GraphQLSchema() (string, error) {
return sg.ge.Schema.String(), nil return sg.ge.Schema.String(), nil
} }
// Operation function return the operation type from the query. It uses a very fast algorithm to
// extract the operation without having to parse the query.
func Operation(query string) OpType {
return OpType(qcode.GetQType(query))
}
// Name function return the operation name from the query. It uses a very fast algorithm to
// extract the operation name without having to parse the query.
func Name(query string) string {
return allow.QueryName(query)
}

View File

@ -61,6 +61,9 @@ type Config struct {
// Inflections is to add additionally singular to plural mappings // Inflections is to add additionally singular to plural mappings
// to the engine (eg. sheep: sheep) // to the engine (eg. sheep: sheep)
Inflections map[string]string `mapstructure:"inflections"` Inflections map[string]string `mapstructure:"inflections"`
// Database schema name. Defaults to 'public'
DBSchema string `mapstructure:"db_schema"`
} }
// Table struct defines a database table // Table struct defines a database table

View File

@ -14,8 +14,10 @@ import (
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
) )
type OpType int
const ( const (
OpQuery int = iota OpQuery OpType = iota
OpMutation OpMutation
) )
@ -56,16 +58,27 @@ type scontext struct {
func (sg *SuperGraph) initCompilers() error { func (sg *SuperGraph) initCompilers() error {
var err error var err error
var schema string
if sg.conf.DBSchema == "" {
schema = "public"
} else {
schema = sg.conf.DBSchema
}
// If sg.di is not null then it's probably set // If sg.di is not null then it's probably set
// for tests // for tests
if sg.dbinfo == nil { if sg.dbinfo == nil {
sg.dbinfo, err = psql.GetDBInfo(sg.db) sg.dbinfo, err = psql.GetDBInfo(sg.db, schema)
if err != nil { if err != nil {
return err return err
} }
} }
if len(sg.dbinfo.Tables) == 0 {
return fmt.Errorf("no tables found in database (schema: %s)", schema)
}
if err = addTables(sg.conf, sg.dbinfo); err != nil { if err = addTables(sg.conf, sg.dbinfo); err != nil {
return err return err
} }
@ -334,7 +347,7 @@ func (c *scontext) executeRoleQuery(tx *sql.Tx) (string, error) {
return role, nil return role, nil
} }
func (r *Result) Operation() int { func (r *Result) Operation() OpType {
switch r.op { switch r.op {
case qcode.QTQuery: case qcode.QTQuery:
return OpQuery return OpQuery

View File

@ -17,7 +17,7 @@ type DBInfo struct {
colMap map[string]map[string]*DBColumn colMap map[string]map[string]*DBColumn
} }
func GetDBInfo(db *sql.DB) (*DBInfo, error) { func GetDBInfo(db *sql.DB, schema string) (*DBInfo, error) {
di := &DBInfo{} di := &DBInfo{}
var version string var version string
@ -31,13 +31,13 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
return nil, err return nil, err
} }
di.Tables, err = GetTables(db) di.Tables, err = GetTables(db, schema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, t := range di.Tables { for _, t := range di.Tables {
cols, err := GetColumns(db, "public", t.Name) cols, err := GetColumns(db, schema, t.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -47,7 +47,7 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
di.colMap = newColMap(di.Tables, di.Columns) di.colMap = newColMap(di.Tables, di.Columns)
di.Functions, err = GetFunctions(db) di.Functions, err = GetFunctions(db, schema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -96,7 +96,7 @@ type DBTable struct {
Type string Type string
} }
func GetTables(db *sql.DB) ([]DBTable, error) { func GetTables(db *sql.DB, schema string) ([]DBTable, error) {
sqlStmt := ` sqlStmt := `
SELECT SELECT
c.relname as "name", c.relname as "name",
@ -108,14 +108,12 @@ SELECT
FROM pg_catalog.pg_class c FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r','v','m','f','') WHERE c.relkind IN ('r','v','m','f','')
AND n.nspname <> ('pg_catalog') AND n.nspname = $1
AND n.nspname <> ('information_schema')
AND n.nspname !~ ('^pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid);` AND pg_catalog.pg_table_is_visible(c.oid);`
var tables []DBTable var tables []DBTable
rows, err := db.Query(sqlStmt) rows, err := db.Query(sqlStmt, schema)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error fetching tables: %s", err) return nil, fmt.Errorf("Error fetching tables: %s", err)
} }
@ -264,7 +262,7 @@ type DBFuncParam struct {
Type string Type string
} }
func GetFunctions(db *sql.DB) ([]DBFunction, error) { func GetFunctions(db *sql.DB, schema string) ([]DBFunction, error) {
sqlStmt := ` sqlStmt := `
SELECT SELECT
routines.routine_name, routines.routine_name,
@ -278,11 +276,11 @@ RIGHT JOIN
information_schema.parameters information_schema.parameters
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL) ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
WHERE WHERE
routines.specific_schema = 'public' routines.specific_schema = $1
ORDER BY ORDER BY
routines.routine_name, parameters.ordinal_position;` routines.routine_name, parameters.ordinal_position;`
rows, err := db.Query(sqlStmt) rows, err := db.Query(sqlStmt, schema)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error fetching functions: %s", err) return nil, fmt.Errorf("Error fetching functions: %s", err)
} }

View File

@ -946,6 +946,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
ex.Op = OpDistinct ex.Op = OpDistinct
ex.Val = node.Val ex.Val = node.Val
default: default:
if len(node.Children) == 0 {
return nil, fmt.Errorf("[Where] invalid operation: %s", name)
}
pushChildren(st, node.exp, node) pushChildren(st, node.exp, node)
return nil, nil // skip node return nil, nil // skip node
} }
@ -965,8 +968,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
case NodeVar: case NodeVar:
ex.Type = ValVar ex.Type = ValVar
default: default:
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type) return nil, fmt.Errorf("[Where] invalid values for: %s", name)
} }
setWhereColName(ex, node) setWhereColName(ex, node)
} }
@ -1015,6 +1019,7 @@ func setWhereColName(ex *Exp, node *Node) {
ex.Col = list[listlen-1] ex.Col = list[listlen-1]
ex.NestedCols = list[:listlen] ex.NestedCols = list[:listlen]
} }
} }
func setOrderByColName(ob *OrderBy, node *Node) { func setOrderByColName(ob *OrderBy, node *Node) {

View File

@ -45,6 +45,16 @@ func initConf() (*Config, error) {
logLevel = LogLevelNone logLevel = LogLevelNone
} }
// copy over db_schema from the core config
if c.DB.Schema == "" {
c.DB.Schema = c.DBSchema
}
// set default database schema
if c.DB.Schema == "" {
c.DB.Schema = "public"
}
// Auths: validate and sanitize // Auths: validate and sanitize
am := make(map[string]struct{}) am := make(map[string]struct{})