feat: add support for single argument Postgres functions

This commit is contained in:
Vikram Rangnekar 2020-04-22 20:51:14 -04:00
parent 6293d37e73
commit ae7cde0433
9 changed files with 167 additions and 60 deletions

View File

@ -93,6 +93,7 @@ type SuperGraph struct {
anonExists bool anonExists bool
qc *qcode.Compiler qc *qcode.Compiler
pc *psql.Compiler pc *psql.Compiler
ge *graphql.Engine
} }
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its // NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
@ -124,6 +125,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
return nil, err return nil, err
} }
if err := sg.initGraphQLEgine(); err != nil {
return nil, err
}
if len(conf.SecretKey) != 0 { if len(conf.SecretKey) != 0 {
sk := sha256.Sum256([]byte(conf.SecretKey)) sk := sha256.Sum256([]byte(conf.SecretKey))
conf.SecretKey = "" conf.SecretKey = ""
@ -163,14 +168,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
// use the chirino/graphql library for introspection queries // use the chirino/graphql library for introspection queries
// disabled when allow list is enforced // disabled when allow list is enforced
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" { if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
engine, err := sg.createGraphQLEgine() r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query})
if err != nil {
res.Error = err.Error()
return &res, err
}
r := engine.ExecuteOne(&graphql.EngineRequest{Query: query})
res.Data = r.Data res.Data = r.Data
if r.Error() != nil { if r.Error() != nil {
res.Error = r.Error().Error() res.Error = r.Error().Error()
} }
@ -200,9 +200,5 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
} }
func (sg *SuperGraph) GraphQLSchema() (string, error) { func (sg *SuperGraph) GraphQLSchema() (string, error) {
engine, err := sg.createGraphQLEgine() return sg.ge.Schema.String(), nil
if err != nil {
return "", err
}
return engine.Schema.String(), nil
} }

View File

@ -1,8 +1,6 @@
package core package core
import ( import (
"errors"
"regexp"
"strings" "strings"
"github.com/chirino/graphql" "github.com/chirino/graphql"
@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{
"boolean": "Boolean", "boolean": "Boolean",
} }
func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) { func (sg *SuperGraph) initGraphQLEgine() error {
engine := graphql.New() engine := graphql.New()
engineSchema := engine.Schema engineSchema := engine.Schema
dbSchema := sg.schema dbSchema := sg.schema
@ -63,15 +61,16 @@ enum OrderDirection {
engineSchema.EntryPoints[schema.Query] = query engineSchema.EntryPoints[schema.Query] = query
engineSchema.EntryPoints[schema.Mutation] = mutation engineSchema.EntryPoints[schema.Mutation] = mutation
validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`) //validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
scalarExpressionTypesNeeded := map[string]bool{} scalarExpressionTypesNeeded := map[string]bool{}
tableNames := dbSchema.GetTableNames() tableNames := dbSchema.GetTableNames()
for _, table := range tableNames { funcs := dbSchema.GetFunctions()
for _, table := range tableNames {
ti, err := dbSchema.GetTable(table) ti, err := dbSchema.GetTable(table)
if err != nil { if err != nil {
return nil, err return err
} }
if !ti.IsSingular { if !ti.IsSingular {
@ -79,13 +78,13 @@ enum OrderDirection {
} }
singularName := ti.Singular singularName := ti.Singular
if !validGraphQLIdentifierRegex.MatchString(singularName) { // if !validGraphQLIdentifierRegex.MatchString(singularName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName) // return errors.New("table name is not a valid GraphQL identifier: " + singularName)
} // }
pluralName := ti.Plural pluralName := ti.Plural
if !validGraphQLIdentifierRegex.MatchString(pluralName) { // if !validGraphQLIdentifierRegex.MatchString(pluralName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName) // return errors.New("table name is not a valid GraphQL identifier: " + pluralName)
} // }
outputType := &schema.Object{ outputType := &schema.Object{
Name: singularName + "Output", Name: singularName + "Output",
@ -127,9 +126,9 @@ enum OrderDirection {
for _, col := range ti.Columns { for _, col := range ti.Columns {
colName := col.Name colName := col.Name
if !validGraphQLIdentifierRegex.MatchString(colName) { // if !validGraphQLIdentifierRegex.MatchString(colName) {
return nil, errors.New("column name is not a valid GraphQL identifier: " + colName) // return errors.New("column name is not a valid GraphQL identifier: " + colName)
} // }
colType := gqltype(col) colType := gqltype(col)
nullableColType := "" nullableColType := ""
@ -144,6 +143,16 @@ enum OrderDirection {
Type: colType, Type: colType,
}) })
for _, f := range funcs {
if col.Type != f.Params[0].Type {
continue
}
outputType.Fields = append(outputType.Fields, &schema.Field{
Name: f.Name + "_" + colName,
Type: colType,
})
}
// If it's a numeric type... // If it's a numeric type...
if nullableColType == "Float" || nullableColType == "Int" { if nullableColType == "Float" || nullableColType == "Int" {
outputType.Fields = append(outputType.Fields, &schema.Field{ outputType.Fields = append(outputType.Fields, &schema.Field{
@ -464,7 +473,7 @@ enum OrderDirection {
err := engineSchema.ResolveTypes() err := engineSchema.ResolveTypes()
if err != nil { if err != nil {
return nil, err return err
} }
engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution { engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution {
@ -479,5 +488,7 @@ enum OrderDirection {
return nil return nil
}) })
return engine, nil
sg.ge = engine
return nil
} }

View File

@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
} }
func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error { func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error {
pl := funcPrefixLen(col.Name) pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 { // if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) // //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`) // io.WriteString(c.w, `'`)

View File

@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip
var cn string var cn string
for _, col := range sel.Cols { for _, col := range sel.Cols {
if n := funcPrefixLen(col.Name); n != 0 { if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 {
if !sel.Functions { if !sel.Functions {
continue continue
} }
@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type) io.WriteString(c.w, col.Type)
} }
func funcPrefixLen(fn string) int { func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch { switch {
case strings.HasPrefix(fn, "avg_"): case strings.HasPrefix(fn, "avg_"):
return 4 return 4
@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int {
case strings.HasPrefix(fn, "var_samp_"): case strings.HasPrefix(fn, "var_samp_"):
return 9 return 9
} }
fnLen := len(fn)
for k := range fm {
kLen := len(k)
if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' {
return kLen + 1
}
}
return 0 return 0
} }

View File

@ -11,6 +11,7 @@ type DBSchema struct {
ver int ver int
t map[string]*DBTableInfo t map[string]*DBTableInfo
rm map[string]map[string]*DBRel rm map[string]map[string]*DBRel
fm map[string]*DBFunction
} }
type DBTableInfo struct { type DBTableInfo struct {
@ -58,6 +59,7 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
schema := &DBSchema{ schema := &DBSchema{
t: make(map[string]*DBTableInfo), t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel), rm: make(map[string]map[string]*DBRel),
fm: make(map[string]*DBFunction, len(info.Functions)),
} }
for i, t := range info.Tables { for i, t := range info.Tables {
@ -81,6 +83,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
} }
} }
for k, f := range info.Functions {
if len(f.Params) == 1 {
schema.fm[strings.ToLower(f.Name)] = &info.Functions[k]
}
}
return schema, nil return schema, nil
} }
@ -439,3 +447,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
} }
return rel, nil return rel, nil
} }
func (s *DBSchema) GetFunctions() []*DBFunction {
var funcs []*DBFunction
for _, f := range s.fm {
funcs = append(funcs, f)
}
return funcs
}

View File

@ -13,6 +13,7 @@ type DBInfo struct {
Version int Version int
Tables []DBTable Tables []DBTable
Columns [][]DBColumn Columns [][]DBColumn
Functions []DBFunction
colmap map[string]map[string]*DBColumn colmap map[string]map[string]*DBColumn
} }
@ -51,6 +52,11 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
} }
} }
di.Functions, err = GetFunctions(db)
if err != nil {
return nil, err
}
return di, nil return di, nil
} }
@ -237,6 +243,64 @@ ORDER BY id;`
return cols, nil return cols, nil
} }
type DBFunction struct {
Name string
Params []DBFuncParam
}
type DBFuncParam struct {
ID int
Name string
Type string
}
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
sqlStmt := `
SELECT
routines.routine_name,
parameters.specific_name,
parameters.data_type,
parameters.parameter_name,
parameters.ordinal_position
FROM
information_schema.routines
RIGHT JOIN
information_schema.parameters
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
WHERE
routines.specific_schema = 'public'
ORDER BY
routines.routine_name, parameters.ordinal_position;`
rows, err := db.Query(sqlStmt)
if err != nil {
return nil, fmt.Errorf("Error fetching functions: %s", err)
}
defer rows.Close()
var funcs []DBFunction
fm := make(map[string]int)
for rows.Next() {
var fn, fid string
fp := DBFuncParam{}
err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
if err != nil {
return nil, err
}
if i, ok := fm[fid]; ok {
funcs[i].Params = append(funcs[i].Params, fp)
} else {
funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
fm[fid] = len(funcs) - 1
}
}
return funcs, nil
}
// func GetValType(type string) qcode.ValType { // func GetValType(type string) qcode.ValType {
// switch { // switch {
// case "bigint", "integer", "smallint", "numeric", "bigserial": // case "bigint", "integer", "smallint", "numeric", "bigserial":

View File

@ -730,6 +730,32 @@ query {
} }
``` ```
### Custom Functions
Any function defined in the database like the below `add_five` that adds 5 to any number given to it can be used
within your query. The one limitation is that it should be a function that only accepts a single argument. The function is used within you're GraphQL in similar way to how aggregrations are used above. Example below
```grahql
query {
thread(id: 5) {
id
total_votes
add_five_total_votes
}
}
```
Postgres user-defined function `add_five`
```
CREATE OR REPLACE FUNCTION add_five(a integer) RETURNS integer AS $$
BEGIN
RETURN a + 5;
END;
$$ LANGUAGE plpgsql;
```
In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates. In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates.
When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments. When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments.

View File

@ -23,7 +23,7 @@ func newAction(a *Action) (http.Handler, error) {
httpFn := func(w http.ResponseWriter, r *http.Request) { httpFn := func(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil { if err := fn(w, r); err != nil {
renderErr(w, err, nil) renderErr(w, err)
} }
} }

View File

@ -7,7 +7,6 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"github.com/dosco/super-graph/core"
"github.com/dosco/super-graph/internal/serv/internal/auth" "github.com/dosco/super-graph/internal/serv/internal/auth"
"github.com/rs/cors" "github.com/rs/cors"
"go.uber.org/zap" "go.uber.org/zap"
@ -29,7 +28,7 @@ type gqlReq struct {
} }
type errorResp struct { type errorResp struct {
Error error `json:"error"` Error string `json:"error"`
} }
func apiV1Handler() http.Handler { func apiV1Handler() http.Handler {
@ -55,13 +54,13 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck //nolint: errcheck
if conf.AuthFailBlock && !auth.IsAuth(ct) { if conf.AuthFailBlock && !auth.IsAuth(ct) {
renderErr(w, errUnauthorized, nil) renderErr(w, errUnauthorized)
return return
} }
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes)) b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
if err != nil { if err != nil {
renderErr(w, err, nil) renderErr(w, err)
return return
} }
defer r.Body.Close() defer r.Body.Close()
@ -70,7 +69,7 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
err = json.Unmarshal(b, &req) err = json.Unmarshal(b, &req)
if err != nil { if err != nil {
renderErr(w, err, nil) renderErr(w, err)
return return
} }
@ -86,11 +85,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
} }
if err != nil { if err != nil {
renderErr(w, err, res) renderErr(w, err)
return } else {
}
json.NewEncoder(w).Encode(res) json.NewEncoder(w).Encode(res)
}
if doLog && logLevel >= LogLevelInfo { if doLog && logLevel >= LogLevelInfo {
zlog.Info("success", zlog.Info("success",
@ -102,22 +100,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
} }
//nolint: errcheck //nolint: errcheck
func renderErr(w http.ResponseWriter, err error, res *core.Result) { func renderErr(w http.ResponseWriter, err error) {
if err == errUnauthorized { if err == errUnauthorized {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
} }
json.NewEncoder(w).Encode(&errorResp{err}) json.NewEncoder(w).Encode(errorResp{err.Error()})
if logLevel >= LogLevelError {
if res != nil {
zlog.Error(err.Error(),
zap.String("op", res.Operation()),
zap.String("name", res.QueryName()),
zap.String("role", res.Role()),
)
} else {
zlog.Error(err.Error())
}
}
} }