fix: postgres schema name config value is not used
This commit is contained in:
parent
94fa51ffb2
commit
ab8566df03
12
core/api.go
12
core/api.go
|
@ -210,3 +210,15 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
|
||||||
func (sg *SuperGraph) GraphQLSchema() (string, error) {
|
func (sg *SuperGraph) GraphQLSchema() (string, error) {
|
||||||
return sg.ge.Schema.String(), nil
|
return sg.ge.Schema.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Operation function return the operation type from the query. It uses a very fast algorithm to
|
||||||
|
// extract the operation without having to parse the query.
|
||||||
|
func Operation(query string) OpType {
|
||||||
|
return OpType(qcode.GetQType(query))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name function return the operation name from the query. It uses a very fast algorithm to
|
||||||
|
// extract the operation name without having to parse the query.
|
||||||
|
func Name(query string) string {
|
||||||
|
return allow.QueryName(query)
|
||||||
|
}
|
||||||
|
|
|
@ -61,6 +61,9 @@ type Config struct {
|
||||||
// Inflections is to add additionally singular to plural mappings
|
// Inflections is to add additionally singular to plural mappings
|
||||||
// to the engine (eg. sheep: sheep)
|
// to the engine (eg. sheep: sheep)
|
||||||
Inflections map[string]string `mapstructure:"inflections"`
|
Inflections map[string]string `mapstructure:"inflections"`
|
||||||
|
|
||||||
|
// Database schema name. Defaults to 'public'
|
||||||
|
DBSchema string `mapstructure:"db_schema"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Table struct defines a database table
|
// Table struct defines a database table
|
||||||
|
|
19
core/core.go
19
core/core.go
|
@ -14,8 +14,10 @@ import (
|
||||||
"github.com/valyala/fasttemplate"
|
"github.com/valyala/fasttemplate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OpType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OpQuery int = iota
|
OpQuery OpType = iota
|
||||||
OpMutation
|
OpMutation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,16 +58,27 @@ type scontext struct {
|
||||||
|
|
||||||
func (sg *SuperGraph) initCompilers() error {
|
func (sg *SuperGraph) initCompilers() error {
|
||||||
var err error
|
var err error
|
||||||
|
var schema string
|
||||||
|
|
||||||
|
if sg.conf.DBSchema == "" {
|
||||||
|
schema = "public"
|
||||||
|
} else {
|
||||||
|
schema = sg.conf.DBSchema
|
||||||
|
}
|
||||||
|
|
||||||
// If sg.di is not null then it's probably set
|
// If sg.di is not null then it's probably set
|
||||||
// for tests
|
// for tests
|
||||||
if sg.dbinfo == nil {
|
if sg.dbinfo == nil {
|
||||||
sg.dbinfo, err = psql.GetDBInfo(sg.db)
|
sg.dbinfo, err = psql.GetDBInfo(sg.db, schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(sg.dbinfo.Tables) == 0 {
|
||||||
|
return fmt.Errorf("no tables found in database (schema: %s)", schema)
|
||||||
|
}
|
||||||
|
|
||||||
if err = addTables(sg.conf, sg.dbinfo); err != nil {
|
if err = addTables(sg.conf, sg.dbinfo); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -334,7 +347,7 @@ func (c *scontext) executeRoleQuery(tx *sql.Tx) (string, error) {
|
||||||
return role, nil
|
return role, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Result) Operation() int {
|
func (r *Result) Operation() OpType {
|
||||||
switch r.op {
|
switch r.op {
|
||||||
case qcode.QTQuery:
|
case qcode.QTQuery:
|
||||||
return OpQuery
|
return OpQuery
|
||||||
|
|
|
@ -17,7 +17,7 @@ type DBInfo struct {
|
||||||
colMap map[string]map[string]*DBColumn
|
colMap map[string]map[string]*DBColumn
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
func GetDBInfo(db *sql.DB, schema string) (*DBInfo, error) {
|
||||||
di := &DBInfo{}
|
di := &DBInfo{}
|
||||||
var version string
|
var version string
|
||||||
|
|
||||||
|
@ -31,13 +31,13 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
di.Tables, err = GetTables(db)
|
di.Tables, err = GetTables(db, schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range di.Tables {
|
for _, t := range di.Tables {
|
||||||
cols, err := GetColumns(db, "public", t.Name)
|
cols, err := GetColumns(db, schema, t.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||||
|
|
||||||
di.colMap = newColMap(di.Tables, di.Columns)
|
di.colMap = newColMap(di.Tables, di.Columns)
|
||||||
|
|
||||||
di.Functions, err = GetFunctions(db)
|
di.Functions, err = GetFunctions(db, schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ type DBTable struct {
|
||||||
Type string
|
Type string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTables(db *sql.DB) ([]DBTable, error) {
|
func GetTables(db *sql.DB, schema string) ([]DBTable, error) {
|
||||||
sqlStmt := `
|
sqlStmt := `
|
||||||
SELECT
|
SELECT
|
||||||
c.relname as "name",
|
c.relname as "name",
|
||||||
|
@ -108,14 +108,12 @@ SELECT
|
||||||
FROM pg_catalog.pg_class c
|
FROM pg_catalog.pg_class c
|
||||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||||
WHERE c.relkind IN ('r','v','m','f','')
|
WHERE c.relkind IN ('r','v','m','f','')
|
||||||
AND n.nspname <> ('pg_catalog')
|
AND n.nspname = $1
|
||||||
AND n.nspname <> ('information_schema')
|
AND pg_catalog.pg_table_is_visible(c.oid);`
|
||||||
AND n.nspname !~ ('^pg_toast')
|
|
||||||
AND pg_catalog.pg_table_is_visible(c.oid);`
|
|
||||||
|
|
||||||
var tables []DBTable
|
var tables []DBTable
|
||||||
|
|
||||||
rows, err := db.Query(sqlStmt)
|
rows, err := db.Query(sqlStmt, schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error fetching tables: %s", err)
|
return nil, fmt.Errorf("Error fetching tables: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -264,7 +262,7 @@ type DBFuncParam struct {
|
||||||
Type string
|
Type string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
|
func GetFunctions(db *sql.DB, schema string) ([]DBFunction, error) {
|
||||||
sqlStmt := `
|
sqlStmt := `
|
||||||
SELECT
|
SELECT
|
||||||
routines.routine_name,
|
routines.routine_name,
|
||||||
|
@ -278,11 +276,11 @@ RIGHT JOIN
|
||||||
information_schema.parameters
|
information_schema.parameters
|
||||||
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
|
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
|
||||||
WHERE
|
WHERE
|
||||||
routines.specific_schema = 'public'
|
routines.specific_schema = $1
|
||||||
ORDER BY
|
ORDER BY
|
||||||
routines.routine_name, parameters.ordinal_position;`
|
routines.routine_name, parameters.ordinal_position;`
|
||||||
|
|
||||||
rows, err := db.Query(sqlStmt)
|
rows, err := db.Query(sqlStmt, schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error fetching functions: %s", err)
|
return nil, fmt.Errorf("Error fetching functions: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -946,6 +946,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||||
ex.Op = OpDistinct
|
ex.Op = OpDistinct
|
||||||
ex.Val = node.Val
|
ex.Val = node.Val
|
||||||
default:
|
default:
|
||||||
|
if len(node.Children) == 0 {
|
||||||
|
return nil, fmt.Errorf("[Where] invalid operation: %s", name)
|
||||||
|
}
|
||||||
pushChildren(st, node.exp, node)
|
pushChildren(st, node.exp, node)
|
||||||
return nil, nil // skip node
|
return nil, nil // skip node
|
||||||
}
|
}
|
||||||
|
@ -965,8 +968,9 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||||
case NodeVar:
|
case NodeVar:
|
||||||
ex.Type = ValVar
|
ex.Type = ValVar
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type)
|
return nil, fmt.Errorf("[Where] invalid values for: %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
setWhereColName(ex, node)
|
setWhereColName(ex, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1015,6 +1019,7 @@ func setWhereColName(ex *Exp, node *Node) {
|
||||||
ex.Col = list[listlen-1]
|
ex.Col = list[listlen-1]
|
||||||
ex.NestedCols = list[:listlen]
|
ex.NestedCols = list[:listlen]
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func setOrderByColName(ob *OrderBy, node *Node) {
|
func setOrderByColName(ob *OrderBy, node *Node) {
|
||||||
|
|
|
@ -45,6 +45,16 @@ func initConf() (*Config, error) {
|
||||||
logLevel = LogLevelNone
|
logLevel = LogLevelNone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copy over db_schema from the core config
|
||||||
|
if c.DB.Schema == "" {
|
||||||
|
c.DB.Schema = c.DBSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default database schema
|
||||||
|
if c.DB.Schema == "" {
|
||||||
|
c.DB.Schema = "public"
|
||||||
|
}
|
||||||
|
|
||||||
// Auths: validate and sanitize
|
// Auths: validate and sanitize
|
||||||
am := make(map[string]struct{})
|
am := make(map[string]struct{})
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue