feat: add support for single argument Postgres functions
This commit is contained in:
20
core/api.go
20
core/api.go
@ -93,6 +93,7 @@ type SuperGraph struct {
|
||||
anonExists bool
|
||||
qc *qcode.Compiler
|
||||
pc *psql.Compiler
|
||||
ge *graphql.Engine
|
||||
}
|
||||
|
||||
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
|
||||
@ -124,6 +125,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sg.initGraphQLEgine(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(conf.SecretKey) != 0 {
|
||||
sk := sha256.Sum256([]byte(conf.SecretKey))
|
||||
conf.SecretKey = ""
|
||||
@ -163,14 +168,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
|
||||
// use the chirino/graphql library for introspection queries
|
||||
// disabled when allow list is enforced
|
||||
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
|
||||
engine, err := sg.createGraphQLEgine()
|
||||
if err != nil {
|
||||
res.Error = err.Error()
|
||||
return &res, err
|
||||
}
|
||||
|
||||
r := engine.ExecuteOne(&graphql.EngineRequest{Query: query})
|
||||
r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query})
|
||||
res.Data = r.Data
|
||||
|
||||
if r.Error() != nil {
|
||||
res.Error = r.Error().Error()
|
||||
}
|
||||
@ -200,9 +200,5 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
|
||||
}
|
||||
|
||||
func (sg *SuperGraph) GraphQLSchema() (string, error) {
|
||||
engine, err := sg.createGraphQLEgine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return engine.Schema.String(), nil
|
||||
return sg.ge.Schema.String(), nil
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/chirino/graphql"
|
||||
@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{
|
||||
"boolean": "Boolean",
|
||||
}
|
||||
|
||||
func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) {
|
||||
func (sg *SuperGraph) initGraphQLEgine() error {
|
||||
engine := graphql.New()
|
||||
engineSchema := engine.Schema
|
||||
dbSchema := sg.schema
|
||||
@ -63,15 +61,16 @@ enum OrderDirection {
|
||||
engineSchema.EntryPoints[schema.Query] = query
|
||||
engineSchema.EntryPoints[schema.Mutation] = mutation
|
||||
|
||||
validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
|
||||
//validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
|
||||
|
||||
scalarExpressionTypesNeeded := map[string]bool{}
|
||||
tableNames := dbSchema.GetTableNames()
|
||||
for _, table := range tableNames {
|
||||
funcs := dbSchema.GetFunctions()
|
||||
|
||||
for _, table := range tableNames {
|
||||
ti, err := dbSchema.GetTable(table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if !ti.IsSingular {
|
||||
@ -79,13 +78,13 @@ enum OrderDirection {
|
||||
}
|
||||
|
||||
singularName := ti.Singular
|
||||
if !validGraphQLIdentifierRegex.MatchString(singularName) {
|
||||
return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName)
|
||||
}
|
||||
// if !validGraphQLIdentifierRegex.MatchString(singularName) {
|
||||
// return errors.New("table name is not a valid GraphQL identifier: " + singularName)
|
||||
// }
|
||||
pluralName := ti.Plural
|
||||
if !validGraphQLIdentifierRegex.MatchString(pluralName) {
|
||||
return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName)
|
||||
}
|
||||
// if !validGraphQLIdentifierRegex.MatchString(pluralName) {
|
||||
// return errors.New("table name is not a valid GraphQL identifier: " + pluralName)
|
||||
// }
|
||||
|
||||
outputType := &schema.Object{
|
||||
Name: singularName + "Output",
|
||||
@ -127,9 +126,9 @@ enum OrderDirection {
|
||||
|
||||
for _, col := range ti.Columns {
|
||||
colName := col.Name
|
||||
if !validGraphQLIdentifierRegex.MatchString(colName) {
|
||||
return nil, errors.New("column name is not a valid GraphQL identifier: " + colName)
|
||||
}
|
||||
// if !validGraphQLIdentifierRegex.MatchString(colName) {
|
||||
// return errors.New("column name is not a valid GraphQL identifier: " + colName)
|
||||
// }
|
||||
|
||||
colType := gqltype(col)
|
||||
nullableColType := ""
|
||||
@ -144,6 +143,16 @@ enum OrderDirection {
|
||||
Type: colType,
|
||||
})
|
||||
|
||||
for _, f := range funcs {
|
||||
if col.Type != f.Params[0].Type {
|
||||
continue
|
||||
}
|
||||
outputType.Fields = append(outputType.Fields, &schema.Field{
|
||||
Name: f.Name + "_" + colName,
|
||||
Type: colType,
|
||||
})
|
||||
}
|
||||
|
||||
// If it's a numeric type...
|
||||
if nullableColType == "Float" || nullableColType == "Int" {
|
||||
outputType.Fields = append(outputType.Fields, &schema.Field{
|
||||
@ -464,7 +473,7 @@ enum OrderDirection {
|
||||
|
||||
err := engineSchema.ResolveTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution {
|
||||
@ -479,5 +488,7 @@ enum OrderDirection {
|
||||
|
||||
return nil
|
||||
})
|
||||
return engine, nil
|
||||
|
||||
sg.ge = engine
|
||||
return nil
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error {
|
||||
pl := funcPrefixLen(col.Name)
|
||||
pl := funcPrefixLen(c.schema.fm, col.Name)
|
||||
// if pl == 0 {
|
||||
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
|
||||
// io.WriteString(c.w, `'`)
|
||||
|
@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip
|
||||
var cn string
|
||||
|
||||
for _, col := range sel.Cols {
|
||||
if n := funcPrefixLen(col.Name); n != 0 {
|
||||
if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 {
|
||||
if !sel.Functions {
|
||||
continue
|
||||
}
|
||||
@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
|
||||
io.WriteString(c.w, col.Type)
|
||||
}
|
||||
|
||||
func funcPrefixLen(fn string) int {
|
||||
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
|
||||
switch {
|
||||
case strings.HasPrefix(fn, "avg_"):
|
||||
return 4
|
||||
@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int {
|
||||
case strings.HasPrefix(fn, "var_samp_"):
|
||||
return 9
|
||||
}
|
||||
fnLen := len(fn)
|
||||
|
||||
for k := range fm {
|
||||
kLen := len(k)
|
||||
if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' {
|
||||
return kLen + 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ type DBSchema struct {
|
||||
ver int
|
||||
t map[string]*DBTableInfo
|
||||
rm map[string]map[string]*DBRel
|
||||
fm map[string]*DBFunction
|
||||
}
|
||||
|
||||
type DBTableInfo struct {
|
||||
@ -58,6 +59,7 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
|
||||
schema := &DBSchema{
|
||||
t: make(map[string]*DBTableInfo),
|
||||
rm: make(map[string]map[string]*DBRel),
|
||||
fm: make(map[string]*DBFunction, len(info.Functions)),
|
||||
}
|
||||
|
||||
for i, t := range info.Tables {
|
||||
@ -81,6 +83,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
|
||||
}
|
||||
}
|
||||
|
||||
for k, f := range info.Functions {
|
||||
if len(f.Params) == 1 {
|
||||
schema.fm[strings.ToLower(f.Name)] = &info.Functions[k]
|
||||
}
|
||||
}
|
||||
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
@ -439,3 +447,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
|
||||
}
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
func (s *DBSchema) GetFunctions() []*DBFunction {
|
||||
var funcs []*DBFunction
|
||||
for _, f := range s.fm {
|
||||
funcs = append(funcs, f)
|
||||
}
|
||||
return funcs
|
||||
}
|
||||
|
@ -10,10 +10,11 @@ import (
|
||||
)
|
||||
|
||||
type DBInfo struct {
|
||||
Version int
|
||||
Tables []DBTable
|
||||
Columns [][]DBColumn
|
||||
colmap map[string]map[string]*DBColumn
|
||||
Version int
|
||||
Tables []DBTable
|
||||
Columns [][]DBColumn
|
||||
Functions []DBFunction
|
||||
colmap map[string]map[string]*DBColumn
|
||||
}
|
||||
|
||||
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||
@ -51,6 +52,11 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
di.Functions, err = GetFunctions(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return di, nil
|
||||
}
|
||||
|
||||
@ -237,6 +243,64 @@ ORDER BY id;`
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
type DBFunction struct {
|
||||
Name string
|
||||
Params []DBFuncParam
|
||||
}
|
||||
|
||||
type DBFuncParam struct {
|
||||
ID int
|
||||
Name string
|
||||
Type string
|
||||
}
|
||||
|
||||
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
|
||||
sqlStmt := `
|
||||
SELECT
|
||||
routines.routine_name,
|
||||
parameters.specific_name,
|
||||
parameters.data_type,
|
||||
parameters.parameter_name,
|
||||
parameters.ordinal_position
|
||||
FROM
|
||||
information_schema.routines
|
||||
RIGHT JOIN
|
||||
information_schema.parameters
|
||||
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
|
||||
WHERE
|
||||
routines.specific_schema = 'public'
|
||||
ORDER BY
|
||||
routines.routine_name, parameters.ordinal_position;`
|
||||
|
||||
rows, err := db.Query(sqlStmt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error fetching functions: %s", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var funcs []DBFunction
|
||||
fm := make(map[string]int)
|
||||
|
||||
for rows.Next() {
|
||||
var fn, fid string
|
||||
fp := DBFuncParam{}
|
||||
|
||||
err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if i, ok := fm[fid]; ok {
|
||||
funcs[i].Params = append(funcs[i].Params, fp)
|
||||
} else {
|
||||
funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
|
||||
fm[fid] = len(funcs) - 1
|
||||
}
|
||||
}
|
||||
|
||||
return funcs, nil
|
||||
}
|
||||
|
||||
// func GetValType(type string) qcode.ValType {
|
||||
// switch {
|
||||
// case "bigint", "integer", "smallint", "numeric", "bigserial":
|
||||
|
Reference in New Issue
Block a user