feat: add support for single argument Postgres functions

This commit is contained in:
Vikram Rangnekar
2020-04-22 20:51:14 -04:00
parent 6293d37e73
commit ae7cde0433
9 changed files with 167 additions and 60 deletions

View File

@ -93,6 +93,7 @@ type SuperGraph struct {
anonExists bool
qc *qcode.Compiler
pc *psql.Compiler
ge *graphql.Engine
}
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
@ -124,6 +125,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
return nil, err
}
if err := sg.initGraphQLEgine(); err != nil {
return nil, err
}
if len(conf.SecretKey) != 0 {
sk := sha256.Sum256([]byte(conf.SecretKey))
conf.SecretKey = ""
@ -163,14 +168,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
// use the chirino/graphql library for introspection queries
// disabled when allow list is enforced
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
engine, err := sg.createGraphQLEgine()
if err != nil {
res.Error = err.Error()
return &res, err
}
r := engine.ExecuteOne(&graphql.EngineRequest{Query: query})
r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query})
res.Data = r.Data
if r.Error() != nil {
res.Error = r.Error().Error()
}
@ -200,9 +200,5 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
}
func (sg *SuperGraph) GraphQLSchema() (string, error) {
engine, err := sg.createGraphQLEgine()
if err != nil {
return "", err
}
return engine.Schema.String(), nil
return sg.ge.Schema.String(), nil
}

View File

@ -1,8 +1,6 @@
package core
import (
"errors"
"regexp"
"strings"
"github.com/chirino/graphql"
@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{
"boolean": "Boolean",
}
func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) {
func (sg *SuperGraph) initGraphQLEgine() error {
engine := graphql.New()
engineSchema := engine.Schema
dbSchema := sg.schema
@ -63,15 +61,16 @@ enum OrderDirection {
engineSchema.EntryPoints[schema.Query] = query
engineSchema.EntryPoints[schema.Mutation] = mutation
validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
//validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
scalarExpressionTypesNeeded := map[string]bool{}
tableNames := dbSchema.GetTableNames()
for _, table := range tableNames {
funcs := dbSchema.GetFunctions()
for _, table := range tableNames {
ti, err := dbSchema.GetTable(table)
if err != nil {
return nil, err
return err
}
if !ti.IsSingular {
@ -79,13 +78,13 @@ enum OrderDirection {
}
singularName := ti.Singular
if !validGraphQLIdentifierRegex.MatchString(singularName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName)
}
// if !validGraphQLIdentifierRegex.MatchString(singularName) {
// return errors.New("table name is not a valid GraphQL identifier: " + singularName)
// }
pluralName := ti.Plural
if !validGraphQLIdentifierRegex.MatchString(pluralName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName)
}
// if !validGraphQLIdentifierRegex.MatchString(pluralName) {
// return errors.New("table name is not a valid GraphQL identifier: " + pluralName)
// }
outputType := &schema.Object{
Name: singularName + "Output",
@ -127,9 +126,9 @@ enum OrderDirection {
for _, col := range ti.Columns {
colName := col.Name
if !validGraphQLIdentifierRegex.MatchString(colName) {
return nil, errors.New("column name is not a valid GraphQL identifier: " + colName)
}
// if !validGraphQLIdentifierRegex.MatchString(colName) {
// return errors.New("column name is not a valid GraphQL identifier: " + colName)
// }
colType := gqltype(col)
nullableColType := ""
@ -144,6 +143,16 @@ enum OrderDirection {
Type: colType,
})
for _, f := range funcs {
if col.Type != f.Params[0].Type {
continue
}
outputType.Fields = append(outputType.Fields, &schema.Field{
Name: f.Name + "_" + colName,
Type: colType,
})
}
// If it's a numeric type...
if nullableColType == "Float" || nullableColType == "Int" {
outputType.Fields = append(outputType.Fields, &schema.Field{
@ -464,7 +473,7 @@ enum OrderDirection {
err := engineSchema.ResolveTypes()
if err != nil {
return nil, err
return err
}
engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution {
@ -479,5 +488,7 @@ enum OrderDirection {
return nil
})
return engine, nil
sg.ge = engine
return nil
}

View File

@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
}
func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error {
pl := funcPrefixLen(col.Name)
pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`)

View File

@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip
var cn string
for _, col := range sel.Cols {
if n := funcPrefixLen(col.Name); n != 0 {
if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 {
if !sel.Functions {
continue
}
@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type)
}
func funcPrefixLen(fn string) int {
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch {
case strings.HasPrefix(fn, "avg_"):
return 4
@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int {
case strings.HasPrefix(fn, "var_samp_"):
return 9
}
fnLen := len(fn)
for k := range fm {
kLen := len(k)
if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' {
return kLen + 1
}
}
return 0
}

View File

@ -11,6 +11,7 @@ type DBSchema struct {
ver int
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
fm map[string]*DBFunction
}
type DBTableInfo struct {
@ -58,6 +59,7 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
schema := &DBSchema{
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
fm: make(map[string]*DBFunction, len(info.Functions)),
}
for i, t := range info.Tables {
@ -81,6 +83,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
}
}
for k, f := range info.Functions {
if len(f.Params) == 1 {
schema.fm[strings.ToLower(f.Name)] = &info.Functions[k]
}
}
return schema, nil
}
@ -439,3 +447,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
}
return rel, nil
}
func (s *DBSchema) GetFunctions() []*DBFunction {
var funcs []*DBFunction
for _, f := range s.fm {
funcs = append(funcs, f)
}
return funcs
}

View File

@ -10,10 +10,11 @@ import (
)
type DBInfo struct {
Version int
Tables []DBTable
Columns [][]DBColumn
colmap map[string]map[string]*DBColumn
Version int
Tables []DBTable
Columns [][]DBColumn
Functions []DBFunction
colmap map[string]map[string]*DBColumn
}
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
@ -51,6 +52,11 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
}
}
di.Functions, err = GetFunctions(db)
if err != nil {
return nil, err
}
return di, nil
}
@ -237,6 +243,64 @@ ORDER BY id;`
return cols, nil
}
type DBFunction struct {
Name string
Params []DBFuncParam
}
type DBFuncParam struct {
ID int
Name string
Type string
}
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
sqlStmt := `
SELECT
routines.routine_name,
parameters.specific_name,
parameters.data_type,
parameters.parameter_name,
parameters.ordinal_position
FROM
information_schema.routines
RIGHT JOIN
information_schema.parameters
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
WHERE
routines.specific_schema = 'public'
ORDER BY
routines.routine_name, parameters.ordinal_position;`
rows, err := db.Query(sqlStmt)
if err != nil {
return nil, fmt.Errorf("Error fetching functions: %s", err)
}
defer rows.Close()
var funcs []DBFunction
fm := make(map[string]int)
for rows.Next() {
var fn, fid string
fp := DBFuncParam{}
err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
if err != nil {
return nil, err
}
if i, ok := fm[fid]; ok {
funcs[i].Params = append(funcs[i].Params, fp)
} else {
funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
fm[fid] = len(funcs) - 1
}
}
return funcs, nil
}
// func GetValType(type string) qcode.ValType {
// switch {
// case "bigint", "integer", "smallint", "numeric", "bigserial":