feat: add support for single argument Postgres functions

This commit is contained in:
Vikram Rangnekar 2020-04-22 20:51:14 -04:00
parent 6293d37e73
commit ae7cde0433
9 changed files with 167 additions and 60 deletions

View File

@ -93,6 +93,7 @@ type SuperGraph struct {
anonExists bool
qc *qcode.Compiler
pc *psql.Compiler
ge *graphql.Engine
}
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
@ -124,6 +125,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
return nil, err
}
if err := sg.initGraphQLEgine(); err != nil {
return nil, err
}
if len(conf.SecretKey) != 0 {
sk := sha256.Sum256([]byte(conf.SecretKey))
conf.SecretKey = ""
@ -163,14 +168,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
// use the chirino/graphql library for introspection queries
// disabled when allow list is enforced
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
engine, err := sg.createGraphQLEgine()
if err != nil {
res.Error = err.Error()
return &res, err
}
r := engine.ExecuteOne(&graphql.EngineRequest{Query: query})
r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query})
res.Data = r.Data
if r.Error() != nil {
res.Error = r.Error().Error()
}
@ -200,9 +200,5 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
}
func (sg *SuperGraph) GraphQLSchema() (string, error) {
engine, err := sg.createGraphQLEgine()
if err != nil {
return "", err
}
return engine.Schema.String(), nil
return sg.ge.Schema.String(), nil
}

View File

@ -1,8 +1,6 @@
package core
import (
"errors"
"regexp"
"strings"
"github.com/chirino/graphql"
@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{
"boolean": "Boolean",
}
func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) {
func (sg *SuperGraph) initGraphQLEgine() error {
engine := graphql.New()
engineSchema := engine.Schema
dbSchema := sg.schema
@ -63,15 +61,16 @@ enum OrderDirection {
engineSchema.EntryPoints[schema.Query] = query
engineSchema.EntryPoints[schema.Mutation] = mutation
validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
//validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
scalarExpressionTypesNeeded := map[string]bool{}
tableNames := dbSchema.GetTableNames()
for _, table := range tableNames {
funcs := dbSchema.GetFunctions()
for _, table := range tableNames {
ti, err := dbSchema.GetTable(table)
if err != nil {
return nil, err
return err
}
if !ti.IsSingular {
@ -79,13 +78,13 @@ enum OrderDirection {
}
singularName := ti.Singular
if !validGraphQLIdentifierRegex.MatchString(singularName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName)
}
// if !validGraphQLIdentifierRegex.MatchString(singularName) {
// return errors.New("table name is not a valid GraphQL identifier: " + singularName)
// }
pluralName := ti.Plural
if !validGraphQLIdentifierRegex.MatchString(pluralName) {
return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName)
}
// if !validGraphQLIdentifierRegex.MatchString(pluralName) {
// return errors.New("table name is not a valid GraphQL identifier: " + pluralName)
// }
outputType := &schema.Object{
Name: singularName + "Output",
@ -127,9 +126,9 @@ enum OrderDirection {
for _, col := range ti.Columns {
colName := col.Name
if !validGraphQLIdentifierRegex.MatchString(colName) {
return nil, errors.New("column name is not a valid GraphQL identifier: " + colName)
}
// if !validGraphQLIdentifierRegex.MatchString(colName) {
// return errors.New("column name is not a valid GraphQL identifier: " + colName)
// }
colType := gqltype(col)
nullableColType := ""
@ -144,6 +143,16 @@ enum OrderDirection {
Type: colType,
})
for _, f := range funcs {
if col.Type != f.Params[0].Type {
continue
}
outputType.Fields = append(outputType.Fields, &schema.Field{
Name: f.Name + "_" + colName,
Type: colType,
})
}
// If it's a numeric type...
if nullableColType == "Float" || nullableColType == "Int" {
outputType.Fields = append(outputType.Fields, &schema.Field{
@ -464,7 +473,7 @@ enum OrderDirection {
err := engineSchema.ResolveTypes()
if err != nil {
return nil, err
return err
}
engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution {
@ -479,5 +488,7 @@ enum OrderDirection {
return nil
})
return engine, nil
sg.ge = engine
return nil
}

View File

@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
}
func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error {
pl := funcPrefixLen(col.Name)
pl := funcPrefixLen(c.schema.fm, col.Name)
// if pl == 0 {
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
// io.WriteString(c.w, `'`)

View File

@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip
var cn string
for _, col := range sel.Cols {
if n := funcPrefixLen(col.Name); n != 0 {
if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 {
if !sel.Functions {
continue
}
@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
io.WriteString(c.w, col.Type)
}
func funcPrefixLen(fn string) int {
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
switch {
case strings.HasPrefix(fn, "avg_"):
return 4
@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int {
case strings.HasPrefix(fn, "var_samp_"):
return 9
}
fnLen := len(fn)
for k := range fm {
kLen := len(k)
if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' {
return kLen + 1
}
}
return 0
}

View File

@ -11,6 +11,7 @@ type DBSchema struct {
ver int
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
fm map[string]*DBFunction
}
type DBTableInfo struct {
@ -58,6 +59,7 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
schema := &DBSchema{
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
fm: make(map[string]*DBFunction, len(info.Functions)),
}
for i, t := range info.Tables {
@ -81,6 +83,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
}
}
for k, f := range info.Functions {
if len(f.Params) == 1 {
schema.fm[strings.ToLower(f.Name)] = &info.Functions[k]
}
}
return schema, nil
}
@ -439,3 +447,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
}
return rel, nil
}
func (s *DBSchema) GetFunctions() []*DBFunction {
var funcs []*DBFunction
for _, f := range s.fm {
funcs = append(funcs, f)
}
return funcs
}

View File

@ -10,10 +10,11 @@ import (
)
type DBInfo struct {
Version int
Tables []DBTable
Columns [][]DBColumn
colmap map[string]map[string]*DBColumn
Version int
Tables []DBTable
Columns [][]DBColumn
Functions []DBFunction
colmap map[string]map[string]*DBColumn
}
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
@ -51,6 +52,11 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
}
}
di.Functions, err = GetFunctions(db)
if err != nil {
return nil, err
}
return di, nil
}
@ -237,6 +243,64 @@ ORDER BY id;`
return cols, nil
}
type DBFunction struct {
Name string
Params []DBFuncParam
}
type DBFuncParam struct {
ID int
Name string
Type string
}
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
sqlStmt := `
SELECT
routines.routine_name,
parameters.specific_name,
parameters.data_type,
parameters.parameter_name,
parameters.ordinal_position
FROM
information_schema.routines
RIGHT JOIN
information_schema.parameters
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
WHERE
routines.specific_schema = 'public'
ORDER BY
routines.routine_name, parameters.ordinal_position;`
rows, err := db.Query(sqlStmt)
if err != nil {
return nil, fmt.Errorf("Error fetching functions: %s", err)
}
defer rows.Close()
var funcs []DBFunction
fm := make(map[string]int)
for rows.Next() {
var fn, fid string
fp := DBFuncParam{}
err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
if err != nil {
return nil, err
}
if i, ok := fm[fid]; ok {
funcs[i].Params = append(funcs[i].Params, fp)
} else {
funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
fm[fid] = len(funcs) - 1
}
}
return funcs, nil
}
// func GetValType(type string) qcode.ValType {
// switch {
// case "bigint", "integer", "smallint", "numeric", "bigserial":

View File

@ -730,6 +730,32 @@ query {
}
```
### Custom Functions
Any function defined in the database like the below `add_five` that adds 5 to any number given to it can be used
within your query. The one limitation is that it should be a function that only accepts a single argument. The function is used within you're GraphQL in similar way to how aggregrations are used above. Example below
```grahql
query {
thread(id: 5) {
id
total_votes
add_five_total_votes
}
}
```
Postgres user-defined function `add_five`
```
CREATE OR REPLACE FUNCTION add_five(a integer) RETURNS integer AS $$
BEGIN
RETURN a + 5;
END;
$$ LANGUAGE plpgsql;
```
In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates.
When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments.

View File

@ -23,7 +23,7 @@ func newAction(a *Action) (http.Handler, error) {
httpFn := func(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil {
renderErr(w, err, nil)
renderErr(w, err)
}
}

View File

@ -7,7 +7,6 @@ import (
"io/ioutil"
"net/http"
"github.com/dosco/super-graph/core"
"github.com/dosco/super-graph/internal/serv/internal/auth"
"github.com/rs/cors"
"go.uber.org/zap"
@ -29,7 +28,7 @@ type gqlReq struct {
}
type errorResp struct {
Error error `json:"error"`
Error string `json:"error"`
}
func apiV1Handler() http.Handler {
@ -55,13 +54,13 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck
if conf.AuthFailBlock && !auth.IsAuth(ct) {
renderErr(w, errUnauthorized, nil)
renderErr(w, errUnauthorized)
return
}
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
if err != nil {
renderErr(w, err, nil)
renderErr(w, err)
return
}
defer r.Body.Close()
@ -70,7 +69,7 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
err = json.Unmarshal(b, &req)
if err != nil {
renderErr(w, err, nil)
renderErr(w, err)
return
}
@ -86,12 +85,11 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
}
if err != nil {
renderErr(w, err, res)
return
renderErr(w, err)
} else {
json.NewEncoder(w).Encode(res)
}
json.NewEncoder(w).Encode(res)
if doLog && logLevel >= LogLevelInfo {
zlog.Info("success",
zap.String("op", res.Operation()),
@ -102,22 +100,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
}
//nolint: errcheck
func renderErr(w http.ResponseWriter, err error, res *core.Result) {
func renderErr(w http.ResponseWriter, err error) {
if err == errUnauthorized {
w.WriteHeader(http.StatusUnauthorized)
}
json.NewEncoder(w).Encode(&errorResp{err})
if logLevel >= LogLevelError {
if res != nil {
zlog.Error(err.Error(),
zap.String("op", res.Operation()),
zap.String("name", res.QueryName()),
zap.String("role", res.Role()),
)
} else {
zlog.Error(err.Error())
}
}
json.NewEncoder(w).Encode(errorResp{err.Error()})
}