diff --git a/core/api.go b/core/api.go index 17125ab..e33d1d7 100644 --- a/core/api.go +++ b/core/api.go @@ -93,6 +93,7 @@ type SuperGraph struct { anonExists bool qc *qcode.Compiler pc *psql.Compiler + ge *graphql.Engine } // NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its @@ -124,6 +125,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) { return nil, err } + if err := sg.initGraphQLEgine(); err != nil { + return nil, err + } + if len(conf.SecretKey) != 0 { sk := sha256.Sum256([]byte(conf.SecretKey)) conf.SecretKey = "" @@ -163,14 +168,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess // use the chirino/graphql library for introspection queries // disabled when allow list is enforced if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" { - engine, err := sg.createGraphQLEgine() - if err != nil { - res.Error = err.Error() - return &res, err - } - - r := engine.ExecuteOne(&graphql.EngineRequest{Query: query}) + r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query}) res.Data = r.Data + if r.Error() != nil { res.Error = r.Error().Error() } @@ -200,9 +200,5 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess } func (sg *SuperGraph) GraphQLSchema() (string, error) { - engine, err := sg.createGraphQLEgine() - if err != nil { - return "", err - } - return engine.Schema.String(), nil + return sg.ge.Schema.String(), nil } diff --git a/core/graph-schema.go b/core/graph-schema.go index 3d477b9..cd4e02e 100644 --- a/core/graph-schema.go +++ b/core/graph-schema.go @@ -1,8 +1,6 @@ package core import ( - "errors" - "regexp" "strings" "github.com/chirino/graphql" @@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{ "boolean": "Boolean", } -func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) { +func (sg *SuperGraph) initGraphQLEgine() error { engine := graphql.New() engineSchema := engine.Schema dbSchema := sg.schema @@ -63,15 +61,16 @@ enum OrderDirection { engineSchema.EntryPoints[schema.Query] = query engineSchema.EntryPoints[schema.Mutation] = mutation - validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`) + //validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`) scalarExpressionTypesNeeded := map[string]bool{} tableNames := dbSchema.GetTableNames() - for _, table := range tableNames { + funcs := dbSchema.GetFunctions() + for _, table := range tableNames { ti, err := dbSchema.GetTable(table) if err != nil { - return nil, err + return err } if !ti.IsSingular { @@ -79,13 +78,13 @@ enum OrderDirection { } singularName := ti.Singular - if !validGraphQLIdentifierRegex.MatchString(singularName) { - return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName) - } + // if !validGraphQLIdentifierRegex.MatchString(singularName) { + // return errors.New("table name is not a valid GraphQL identifier: " + singularName) + // } pluralName := ti.Plural - if !validGraphQLIdentifierRegex.MatchString(pluralName) { - return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName) - } + // if !validGraphQLIdentifierRegex.MatchString(pluralName) { + // return errors.New("table name is not a valid GraphQL identifier: " + pluralName) + // } outputType := &schema.Object{ Name: singularName + "Output", @@ -127,9 +126,9 @@ enum OrderDirection { for _, col := range ti.Columns { colName := col.Name - if !validGraphQLIdentifierRegex.MatchString(colName) { - return nil, errors.New("column name is not a valid GraphQL identifier: " + colName) - } + // if !validGraphQLIdentifierRegex.MatchString(colName) { + // return errors.New("column name is not a valid GraphQL identifier: " + colName) + // } colType := gqltype(col) nullableColType := "" @@ -144,6 +143,16 @@ enum OrderDirection { Type: colType, }) + for _, f := range funcs { + if col.Type != f.Params[0].Type { + continue + } + outputType.Fields = append(outputType.Fields, &schema.Field{ + Name: f.Name + "_" + colName, + Type: colType, + }) + } + // If it's a numeric type... if nullableColType == "Float" || nullableColType == "Int" { outputType.Fields = append(outputType.Fields, &schema.Field{ @@ -464,7 +473,7 @@ enum OrderDirection { err := engineSchema.ResolveTypes() if err != nil { - return nil, err + return err } engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution { @@ -479,5 +488,7 @@ enum OrderDirection { return nil }) - return engine, nil + + sg.ge = engine + return nil } diff --git a/core/internal/psql/columns.go b/core/internal/psql/columns.go index f05d577..a9f6b83 100644 --- a/core/internal/psql/columns.go +++ b/core/internal/psql/columns.go @@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf } func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error { - pl := funcPrefixLen(col.Name) + pl := funcPrefixLen(c.schema.fm, col.Name) // if pl == 0 { // //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) // io.WriteString(c.w, `'`) diff --git a/core/internal/psql/query.go b/core/internal/psql/query.go index 215d0bd..584d263 100644 --- a/core/internal/psql/query.go +++ b/core/internal/psql/query.go @@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip var cn string for _, col := range sel.Cols { - if n := funcPrefixLen(col.Name); n != 0 { + if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 { if !sel.Functions { continue } @@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col * io.WriteString(c.w, col.Type) } -func funcPrefixLen(fn string) int { +func funcPrefixLen(fm map[string]*DBFunction, fn string) int { switch { case strings.HasPrefix(fn, "avg_"): return 4 @@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int { case strings.HasPrefix(fn, "var_samp_"): return 9 } + fnLen := len(fn) + + for k := range fm { + kLen := len(k) + if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' { + return kLen + 1 + } + } return 0 } diff --git a/core/internal/psql/schema.go b/core/internal/psql/schema.go index 8727a89..d5a88e6 100644 --- a/core/internal/psql/schema.go +++ b/core/internal/psql/schema.go @@ -11,6 +11,7 @@ type DBSchema struct { ver int t map[string]*DBTableInfo rm map[string]map[string]*DBRel + fm map[string]*DBFunction } type DBTableInfo struct { @@ -58,6 +59,7 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) { schema := &DBSchema{ t: make(map[string]*DBTableInfo), rm: make(map[string]map[string]*DBRel), + fm: make(map[string]*DBFunction, len(info.Functions)), } for i, t := range info.Tables { @@ -81,6 +83,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) { } } + for k, f := range info.Functions { + if len(f.Params) == 1 { + schema.fm[strings.ToLower(f.Name)] = &info.Functions[k] + } + } + return schema, nil } @@ -439,3 +447,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) { } return rel, nil } + +func (s *DBSchema) GetFunctions() []*DBFunction { + var funcs []*DBFunction + for _, f := range s.fm { + funcs = append(funcs, f) + } + return funcs +} diff --git a/core/internal/psql/tables.go b/core/internal/psql/tables.go index 6dfd62a..f7aa573 100644 --- a/core/internal/psql/tables.go +++ b/core/internal/psql/tables.go @@ -10,10 +10,11 @@ import ( ) type DBInfo struct { - Version int - Tables []DBTable - Columns [][]DBColumn - colmap map[string]map[string]*DBColumn + Version int + Tables []DBTable + Columns [][]DBColumn + Functions []DBFunction + colmap map[string]map[string]*DBColumn } func GetDBInfo(db *sql.DB) (*DBInfo, error) { @@ -51,6 +52,11 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) { } } + di.Functions, err = GetFunctions(db) + if err != nil { + return nil, err + } + return di, nil } @@ -237,6 +243,64 @@ ORDER BY id;` return cols, nil } +type DBFunction struct { + Name string + Params []DBFuncParam +} + +type DBFuncParam struct { + ID int + Name string + Type string +} + +func GetFunctions(db *sql.DB) ([]DBFunction, error) { + sqlStmt := ` +SELECT + routines.routine_name, + parameters.specific_name, + parameters.data_type, + parameters.parameter_name, + parameters.ordinal_position +FROM + information_schema.routines +RIGHT JOIN + information_schema.parameters + ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL) +WHERE + routines.specific_schema = 'public' +ORDER BY + routines.routine_name, parameters.ordinal_position;` + + rows, err := db.Query(sqlStmt) + if err != nil { + return nil, fmt.Errorf("Error fetching functions: %s", err) + } + defer rows.Close() + + var funcs []DBFunction + fm := make(map[string]int) + + for rows.Next() { + var fn, fid string + fp := DBFuncParam{} + + err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID) + if err != nil { + return nil, err + } + + if i, ok := fm[fid]; ok { + funcs[i].Params = append(funcs[i].Params, fp) + } else { + funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}}) + fm[fid] = len(funcs) - 1 + } + } + + return funcs, nil +} + // func GetValType(type string) qcode.ValType { // switch { // case "bigint", "integer", "smallint", "numeric", "bigserial": diff --git a/docs/guide/guide.md b/docs/guide/guide.md index 775aa66..c771d75 100644 --- a/docs/guide/guide.md +++ b/docs/guide/guide.md @@ -730,6 +730,32 @@ query { } ``` +### Custom Functions + +Any function defined in the database like the below `add_five` that adds 5 to any number given to it can be used +within your query. The one limitation is that it should be a function that only accepts a single argument. The function is used within you're GraphQL in similar way to how aggregrations are used above. Example below + +```grahql +query { + thread(id: 5) { + id + total_votes + add_five_total_votes + } +} +``` + +Postgres user-defined function `add_five` +``` +CREATE OR REPLACE FUNCTION add_five(a integer) RETURNS integer AS $$ +BEGIN + + RETURN a + 5; +END; +$$ LANGUAGE plpgsql; +``` + + In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates. When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments. diff --git a/internal/serv/actions.go b/internal/serv/actions.go index e04c845..ceb820f 100644 --- a/internal/serv/actions.go +++ b/internal/serv/actions.go @@ -23,7 +23,7 @@ func newAction(a *Action) (http.Handler, error) { httpFn := func(w http.ResponseWriter, r *http.Request) { if err := fn(w, r); err != nil { - renderErr(w, err, nil) + renderErr(w, err) } } diff --git a/internal/serv/http.go b/internal/serv/http.go index f73430d..8a512b8 100644 --- a/internal/serv/http.go +++ b/internal/serv/http.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "net/http" - "github.com/dosco/super-graph/core" "github.com/dosco/super-graph/internal/serv/internal/auth" "github.com/rs/cors" "go.uber.org/zap" @@ -29,7 +28,7 @@ type gqlReq struct { } type errorResp struct { - Error error `json:"error"` + Error string `json:"error"` } func apiV1Handler() http.Handler { @@ -55,13 +54,13 @@ func apiV1(w http.ResponseWriter, r *http.Request) { //nolint: errcheck if conf.AuthFailBlock && !auth.IsAuth(ct) { - renderErr(w, errUnauthorized, nil) + renderErr(w, errUnauthorized) return } b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes)) if err != nil { - renderErr(w, err, nil) + renderErr(w, err) return } defer r.Body.Close() @@ -70,7 +69,7 @@ func apiV1(w http.ResponseWriter, r *http.Request) { err = json.Unmarshal(b, &req) if err != nil { - renderErr(w, err, nil) + renderErr(w, err) return } @@ -86,12 +85,11 @@ func apiV1(w http.ResponseWriter, r *http.Request) { } if err != nil { - renderErr(w, err, res) - return + renderErr(w, err) + } else { + json.NewEncoder(w).Encode(res) } - json.NewEncoder(w).Encode(res) - if doLog && logLevel >= LogLevelInfo { zlog.Info("success", zap.String("op", res.Operation()), @@ -102,22 +100,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) { } //nolint: errcheck -func renderErr(w http.ResponseWriter, err error, res *core.Result) { +func renderErr(w http.ResponseWriter, err error) { if err == errUnauthorized { w.WriteHeader(http.StatusUnauthorized) } - json.NewEncoder(w).Encode(&errorResp{err}) - - if logLevel >= LogLevelError { - if res != nil { - zlog.Error(err.Error(), - zap.String("op", res.Operation()), - zap.String("name", res.QueryName()), - zap.String("role", res.Role()), - ) - } else { - zlog.Error(err.Error()) - } - } + json.NewEncoder(w).Encode(errorResp{err.Error()}) }