fix: postgres schema name config value is not used
This commit is contained in:
parent
94fa51ffb2
commit
ab8566df03
12
core/api.go
12
core/api.go
@ -210,3 +210,15 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
|
||||
func (sg *SuperGraph) GraphQLSchema() (string, error) {
|
||||
return sg.ge.Schema.String(), nil
|
||||
}
|
||||
|
||||
// Operation function return the operation type from the query. It uses a very fast algorithm to
|
||||
// extract the operation without having to parse the query.
|
||||
func Operation(query string) OpType {
|
||||
return OpType(qcode.GetQType(query))
|
||||
}
|
||||
|
||||
// Name function return the operation name from the query. It uses a very fast algorithm to
|
||||
// extract the operation name without having to parse the query.
|
||||
func Name(query string) string {
|
||||
return allow.QueryName(query)
|
||||
}
|
||||
|
@ -61,6 +61,9 @@ type Config struct {
|
||||
// Inflections is to add additionally singular to plural mappings
|
||||
// to the engine (eg. sheep: sheep)
|
||||
Inflections map[string]string `mapstructure:"inflections"`
|
||||
|
||||
// Database schema name. Defaults to 'public'
|
||||
DBSchema string `mapstructure:"db_schema"`
|
||||
}
|
||||
|
||||
// Table struct defines a database table
|
||||
|
19
core/core.go
19
core/core.go
@ -14,8 +14,10 @@ import (
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
type OpType int
|
||||
|
||||
const (
|
||||
OpQuery int = iota
|
||||
OpQuery OpType = iota
|
||||
OpMutation
|
||||
)
|
||||
|
||||
@ -56,16 +58,27 @@ type scontext struct {
|
||||
|
||||
func (sg *SuperGraph) initCompilers() error {
|
||||
var err error
|
||||
var schema string
|
||||
|
||||
if sg.conf.DBSchema == "" {
|
||||
schema = "public"
|
||||
} else {
|
||||
schema = sg.conf.DBSchema
|
||||
}
|
||||
|
||||
// If sg.di is not null then it's probably set
|
||||
// for tests
|
||||
if sg.dbinfo == nil {
|
||||
sg.dbinfo, err = psql.GetDBInfo(sg.db)
|
||||
sg.dbinfo, err = psql.GetDBInfo(sg.db, schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(sg.dbinfo.Tables) == 0 {
|
||||
return fmt.Errorf("no tables found in database (schema: %s)", schema)
|
||||
}
|
||||
|
||||
if err = addTables(sg.conf, sg.dbinfo); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -334,7 +347,7 @@ func (c *scontext) executeRoleQuery(tx *sql.Tx) (string, error) {
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (r *Result) Operation() int {
|
||||
func (r *Result) Operation() OpType {
|
||||
switch r.op {
|
||||
case qcode.QTQuery:
|
||||
return OpQuery
|
||||
|
@ -17,7 +17,7 @@ type DBInfo struct {
|
||||
colMap map[string]map[string]*DBColumn
|
||||
}
|
||||
|
||||
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||
func GetDBInfo(db *sql.DB, schema string) (*DBInfo, error) {
|
||||
di := &DBInfo{}
|
||||
var version string
|
||||
|
||||
@ -31,13 +31,13 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
di.Tables, err = GetTables(db)
|
||||
di.Tables, err = GetTables(db, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, t := range di.Tables {
|
||||
cols, err := GetColumns(db, "public", t.Name)
|
||||
cols, err := GetColumns(db, schema, t.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -47,7 +47,7 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||
|
||||
di.colMap = newColMap(di.Tables, di.Columns)
|
||||
|
||||
di.Functions, err = GetFunctions(db)
|
||||
di.Functions, err = GetFunctions(db, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -96,7 +96,7 @@ type DBTable struct {
|
||||
Type string
|
||||
}
|
||||
|
||||
func GetTables(db *sql.DB) ([]DBTable, error) {
|
||||
func GetTables(db *sql.DB, schema string) ([]DBTable, error) {
|
||||
sqlStmt := `
|
||||
SELECT
|
||||
c.relname as "name",
|
||||
@ -108,14 +108,12 @@ SELECT
|
||||
FROM pg_catalog.pg_class c
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind IN ('r','v','m','f','')
|
||||
AND n.nspname <> ('pg_catalog')
|
||||
AND n.nspname <> ('information_schema')
|
||||
AND n.nspname !~ ('^pg_toast')
|
||||
AND pg_catalog.pg_table_is_visible(c.oid);`
|
||||
AND n.nspname = $1
|
||||
AND pg_catalog.pg_table_is_visible(c.oid);`
|
||||
|
||||
var tables []DBTable
|
||||
|
||||
rows, err := db.Query(sqlStmt)
|
||||
rows, err := db.Query(sqlStmt, schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error fetching tables: %s", err)
|
||||
}
|
||||
@ -264,7 +262,7 @@ type DBFuncParam struct {
|
||||
Type string
|
||||
}
|
||||
|
||||
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
|
||||
func GetFunctions(db *sql.DB, schema string) ([]DBFunction, error) {
|
||||
sqlStmt := `
|
||||
SELECT
|
||||
routines.routine_name,
|
||||
@ -278,11 +276,11 @@ RIGHT JOIN
|
||||
information_schema.parameters
|
||||
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
|
||||
WHERE
|
||||
routines.specific_schema = 'public'
|
||||
routines.specific_schema = $1
|
||||
ORDER BY
|
||||
routines.routine_name, parameters.ordinal_position;`
|
||||
|
||||
rows, err := db.Query(sqlStmt)
|
||||
rows, err := db.Query(sqlStmt, schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error fetching functions: %s", err)
|
||||
}
|
||||
|
@ -946,6 +946,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||
ex.Op = OpDistinct
|
||||
ex.Val = node.Val
|
||||
default:
|
||||
if len(node.Children) == 0 {
|
||||
return nil, fmt.Errorf("[Where] invalid operation: %s", name)
|
||||
}
|
||||
pushChildren(st, node.exp, node)
|
||||
return nil, nil // skip node
|
||||
}
|
||||
@ -965,8 +968,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||
case NodeVar:
|
||||
ex.Type = ValVar
|
||||
default:
|
||||
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type)
|
||||
return nil, fmt.Errorf("[Where] invalid values for: %s", name)
|
||||
}
|
||||
|
||||
setWhereColName(ex, node)
|
||||
}
|
||||
|
||||
@ -1015,6 +1019,7 @@ func setWhereColName(ex *Exp, node *Node) {
|
||||
ex.Col = list[listlen-1]
|
||||
ex.NestedCols = list[:listlen]
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func setOrderByColName(ob *OrderBy, node *Node) {
|
||||
|
@ -45,6 +45,16 @@ func initConf() (*Config, error) {
|
||||
logLevel = LogLevelNone
|
||||
}
|
||||
|
||||
// copy over db_schema from the core config
|
||||
if c.DB.Schema == "" {
|
||||
c.DB.Schema = c.DBSchema
|
||||
}
|
||||
|
||||
// set default database schema
|
||||
if c.DB.Schema == "" {
|
||||
c.DB.Schema = "public"
|
||||
}
|
||||
|
||||
// Auths: validate and sanitize
|
||||
am := make(map[string]struct{})
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user