fix: postgres schema name config value is not used

This commit is contained in:
Vikram Rangnekar 2020-05-20 00:03:05 -04:00
parent 94fa51ffb2
commit ab8566df03
6 changed files with 58 additions and 17 deletions

View File

@ -210,3 +210,15 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
func (sg *SuperGraph) GraphQLSchema() (string, error) {
return sg.ge.Schema.String(), nil
}
// Operation function return the operation type from the query. It uses a very fast algorithm to
// extract the operation without having to parse the query.
func Operation(query string) OpType {
return OpType(qcode.GetQType(query))
}
// Name function return the operation name from the query. It uses a very fast algorithm to
// extract the operation name without having to parse the query.
func Name(query string) string {
return allow.QueryName(query)
}

View File

@ -61,6 +61,9 @@ type Config struct {
// Inflections is to add additionally singular to plural mappings
// to the engine (eg. sheep: sheep)
Inflections map[string]string `mapstructure:"inflections"`
// Database schema name. Defaults to 'public'
DBSchema string `mapstructure:"db_schema"`
}
// Table struct defines a database table

View File

@ -14,8 +14,10 @@ import (
"github.com/valyala/fasttemplate"
)
type OpType int
const (
OpQuery int = iota
OpQuery OpType = iota
OpMutation
)
@ -56,16 +58,27 @@ type scontext struct {
func (sg *SuperGraph) initCompilers() error {
var err error
var schema string
if sg.conf.DBSchema == "" {
schema = "public"
} else {
schema = sg.conf.DBSchema
}
// If sg.di is not null then it's probably set
// for tests
if sg.dbinfo == nil {
sg.dbinfo, err = psql.GetDBInfo(sg.db)
sg.dbinfo, err = psql.GetDBInfo(sg.db, schema)
if err != nil {
return err
}
}
if len(sg.dbinfo.Tables) == 0 {
return fmt.Errorf("no tables found in database (schema: %s)", schema)
}
if err = addTables(sg.conf, sg.dbinfo); err != nil {
return err
}
@ -334,7 +347,7 @@ func (c *scontext) executeRoleQuery(tx *sql.Tx) (string, error) {
return role, nil
}
func (r *Result) Operation() int {
func (r *Result) Operation() OpType {
switch r.op {
case qcode.QTQuery:
return OpQuery

View File

@ -17,7 +17,7 @@ type DBInfo struct {
colMap map[string]map[string]*DBColumn
}
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
func GetDBInfo(db *sql.DB, schema string) (*DBInfo, error) {
di := &DBInfo{}
var version string
@ -31,13 +31,13 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
return nil, err
}
di.Tables, err = GetTables(db)
di.Tables, err = GetTables(db, schema)
if err != nil {
return nil, err
}
for _, t := range di.Tables {
cols, err := GetColumns(db, "public", t.Name)
cols, err := GetColumns(db, schema, t.Name)
if err != nil {
return nil, err
}
@ -47,7 +47,7 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
di.colMap = newColMap(di.Tables, di.Columns)
di.Functions, err = GetFunctions(db)
di.Functions, err = GetFunctions(db, schema)
if err != nil {
return nil, err
}
@ -96,7 +96,7 @@ type DBTable struct {
Type string
}
func GetTables(db *sql.DB) ([]DBTable, error) {
func GetTables(db *sql.DB, schema string) ([]DBTable, error) {
sqlStmt := `
SELECT
c.relname as "name",
@ -108,14 +108,12 @@ SELECT
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r','v','m','f','')
AND n.nspname <> ('pg_catalog')
AND n.nspname <> ('information_schema')
AND n.nspname !~ ('^pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid);`
AND n.nspname = $1
AND pg_catalog.pg_table_is_visible(c.oid);`
var tables []DBTable
rows, err := db.Query(sqlStmt)
rows, err := db.Query(sqlStmt, schema)
if err != nil {
return nil, fmt.Errorf("Error fetching tables: %s", err)
}
@ -264,7 +262,7 @@ type DBFuncParam struct {
Type string
}
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
func GetFunctions(db *sql.DB, schema string) ([]DBFunction, error) {
sqlStmt := `
SELECT
routines.routine_name,
@ -278,11 +276,11 @@ RIGHT JOIN
information_schema.parameters
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
WHERE
routines.specific_schema = 'public'
routines.specific_schema = $1
ORDER BY
routines.routine_name, parameters.ordinal_position;`
rows, err := db.Query(sqlStmt)
rows, err := db.Query(sqlStmt, schema)
if err != nil {
return nil, fmt.Errorf("Error fetching functions: %s", err)
}

View File

@ -946,6 +946,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
ex.Op = OpDistinct
ex.Val = node.Val
default:
if len(node.Children) == 0 {
return nil, fmt.Errorf("[Where] invalid operation: %s", name)
}
pushChildren(st, node.exp, node)
return nil, nil // skip node
}
@ -965,8 +968,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
case NodeVar:
ex.Type = ValVar
default:
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type)
return nil, fmt.Errorf("[Where] invalid values for: %s", name)
}
setWhereColName(ex, node)
}
@ -1015,6 +1019,7 @@ func setWhereColName(ex *Exp, node *Node) {
ex.Col = list[listlen-1]
ex.NestedCols = list[:listlen]
}
}
func setOrderByColName(ob *OrderBy, node *Node) {

View File

@ -45,6 +45,16 @@ func initConf() (*Config, error) {
logLevel = LogLevelNone
}
// copy over db_schema from the core config
if c.DB.Schema == "" {
c.DB.Schema = c.DBSchema
}
// set default database schema
if c.DB.Schema == "" {
c.DB.Schema = "public"
}
// Auths: validate and sanitize
am := make(map[string]struct{})