feat: add support for single argument Postgres functions
This commit is contained in:
parent
6293d37e73
commit
ae7cde0433
20
core/api.go
20
core/api.go
@ -93,6 +93,7 @@ type SuperGraph struct {
|
||||
anonExists bool
|
||||
qc *qcode.Compiler
|
||||
pc *psql.Compiler
|
||||
ge *graphql.Engine
|
||||
}
|
||||
|
||||
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
|
||||
@ -124,6 +125,10 @@ func NewSuperGraph(conf *Config, db *sql.DB) (*SuperGraph, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sg.initGraphQLEgine(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(conf.SecretKey) != 0 {
|
||||
sk := sha256.Sum256([]byte(conf.SecretKey))
|
||||
conf.SecretKey = ""
|
||||
@ -163,14 +168,9 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
|
||||
// use the chirino/graphql library for introspection queries
|
||||
// disabled when allow list is enforced
|
||||
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
|
||||
engine, err := sg.createGraphQLEgine()
|
||||
if err != nil {
|
||||
res.Error = err.Error()
|
||||
return &res, err
|
||||
}
|
||||
|
||||
r := engine.ExecuteOne(&graphql.EngineRequest{Query: query})
|
||||
r := sg.ge.ExecuteOne(&graphql.EngineRequest{Query: query})
|
||||
res.Data = r.Data
|
||||
|
||||
if r.Error() != nil {
|
||||
res.Error = r.Error().Error()
|
||||
}
|
||||
@ -200,9 +200,5 @@ func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMess
|
||||
}
|
||||
|
||||
func (sg *SuperGraph) GraphQLSchema() (string, error) {
|
||||
engine, err := sg.createGraphQLEgine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return engine.Schema.String(), nil
|
||||
return sg.ge.Schema.String(), nil
|
||||
}
|
||||
|
@ -1,8 +1,6 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/chirino/graphql"
|
||||
@ -26,7 +24,7 @@ var typeMap map[string]string = map[string]string{
|
||||
"boolean": "Boolean",
|
||||
}
|
||||
|
||||
func (sg *SuperGraph) createGraphQLEgine() (*graphql.Engine, error) {
|
||||
func (sg *SuperGraph) initGraphQLEgine() error {
|
||||
engine := graphql.New()
|
||||
engineSchema := engine.Schema
|
||||
dbSchema := sg.schema
|
||||
@ -63,15 +61,16 @@ enum OrderDirection {
|
||||
engineSchema.EntryPoints[schema.Query] = query
|
||||
engineSchema.EntryPoints[schema.Mutation] = mutation
|
||||
|
||||
validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
|
||||
//validGraphQLIdentifierRegex := regexp.MustCompile(`^[A-Za-z_][A-Za-z_0-9]*$`)
|
||||
|
||||
scalarExpressionTypesNeeded := map[string]bool{}
|
||||
tableNames := dbSchema.GetTableNames()
|
||||
for _, table := range tableNames {
|
||||
funcs := dbSchema.GetFunctions()
|
||||
|
||||
for _, table := range tableNames {
|
||||
ti, err := dbSchema.GetTable(table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if !ti.IsSingular {
|
||||
@ -79,13 +78,13 @@ enum OrderDirection {
|
||||
}
|
||||
|
||||
singularName := ti.Singular
|
||||
if !validGraphQLIdentifierRegex.MatchString(singularName) {
|
||||
return nil, errors.New("table name is not a valid GraphQL identifier: " + singularName)
|
||||
}
|
||||
// if !validGraphQLIdentifierRegex.MatchString(singularName) {
|
||||
// return errors.New("table name is not a valid GraphQL identifier: " + singularName)
|
||||
// }
|
||||
pluralName := ti.Plural
|
||||
if !validGraphQLIdentifierRegex.MatchString(pluralName) {
|
||||
return nil, errors.New("table name is not a valid GraphQL identifier: " + pluralName)
|
||||
}
|
||||
// if !validGraphQLIdentifierRegex.MatchString(pluralName) {
|
||||
// return errors.New("table name is not a valid GraphQL identifier: " + pluralName)
|
||||
// }
|
||||
|
||||
outputType := &schema.Object{
|
||||
Name: singularName + "Output",
|
||||
@ -127,9 +126,9 @@ enum OrderDirection {
|
||||
|
||||
for _, col := range ti.Columns {
|
||||
colName := col.Name
|
||||
if !validGraphQLIdentifierRegex.MatchString(colName) {
|
||||
return nil, errors.New("column name is not a valid GraphQL identifier: " + colName)
|
||||
}
|
||||
// if !validGraphQLIdentifierRegex.MatchString(colName) {
|
||||
// return errors.New("column name is not a valid GraphQL identifier: " + colName)
|
||||
// }
|
||||
|
||||
colType := gqltype(col)
|
||||
nullableColType := ""
|
||||
@ -144,6 +143,16 @@ enum OrderDirection {
|
||||
Type: colType,
|
||||
})
|
||||
|
||||
for _, f := range funcs {
|
||||
if col.Type != f.Params[0].Type {
|
||||
continue
|
||||
}
|
||||
outputType.Fields = append(outputType.Fields, &schema.Field{
|
||||
Name: f.Name + "_" + colName,
|
||||
Type: colType,
|
||||
})
|
||||
}
|
||||
|
||||
// If it's a numeric type...
|
||||
if nullableColType == "Float" || nullableColType == "Int" {
|
||||
outputType.Fields = append(outputType.Fields, &schema.Field{
|
||||
@ -464,7 +473,7 @@ enum OrderDirection {
|
||||
|
||||
err := engineSchema.ResolveTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
engine.Resolver = resolvers.Func(func(request *resolvers.ResolveRequest, next resolvers.Resolution) resolvers.Resolution {
|
||||
@ -479,5 +488,7 @@ enum OrderDirection {
|
||||
|
||||
return nil
|
||||
})
|
||||
return engine, nil
|
||||
|
||||
sg.ge = engine
|
||||
return nil
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInfo, col qcode.Column, columnsRendered int) error {
|
||||
pl := funcPrefixLen(col.Name)
|
||||
pl := funcPrefixLen(c.schema.fm, col.Name)
|
||||
// if pl == 0 {
|
||||
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
|
||||
// io.WriteString(c.w, `'`)
|
||||
|
@ -543,7 +543,7 @@ func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo, skip
|
||||
var cn string
|
||||
|
||||
for _, col := range sel.Cols {
|
||||
if n := funcPrefixLen(col.Name); n != 0 {
|
||||
if n := funcPrefixLen(c.schema.fm, col.Name); n != 0 {
|
||||
if !sel.Functions {
|
||||
continue
|
||||
}
|
||||
@ -1193,7 +1193,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
|
||||
io.WriteString(c.w, col.Type)
|
||||
}
|
||||
|
||||
func funcPrefixLen(fn string) int {
|
||||
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
|
||||
switch {
|
||||
case strings.HasPrefix(fn, "avg_"):
|
||||
return 4
|
||||
@ -1218,6 +1218,14 @@ func funcPrefixLen(fn string) int {
|
||||
case strings.HasPrefix(fn, "var_samp_"):
|
||||
return 9
|
||||
}
|
||||
fnLen := len(fn)
|
||||
|
||||
for k := range fm {
|
||||
kLen := len(k)
|
||||
if kLen < fnLen && k[0] == fn[0] && strings.HasPrefix(fn, k) && fn[kLen] == '_' {
|
||||
return kLen + 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ type DBSchema struct {
|
||||
ver int
|
||||
t map[string]*DBTableInfo
|
||||
rm map[string]map[string]*DBRel
|
||||
fm map[string]*DBFunction
|
||||
}
|
||||
|
||||
type DBTableInfo struct {
|
||||
@ -58,6 +59,7 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
|
||||
schema := &DBSchema{
|
||||
t: make(map[string]*DBTableInfo),
|
||||
rm: make(map[string]map[string]*DBRel),
|
||||
fm: make(map[string]*DBFunction, len(info.Functions)),
|
||||
}
|
||||
|
||||
for i, t := range info.Tables {
|
||||
@ -81,6 +83,12 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
|
||||
}
|
||||
}
|
||||
|
||||
for k, f := range info.Functions {
|
||||
if len(f.Params) == 1 {
|
||||
schema.fm[strings.ToLower(f.Name)] = &info.Functions[k]
|
||||
}
|
||||
}
|
||||
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
@ -439,3 +447,11 @@ func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
|
||||
}
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
func (s *DBSchema) GetFunctions() []*DBFunction {
|
||||
var funcs []*DBFunction
|
||||
for _, f := range s.fm {
|
||||
funcs = append(funcs, f)
|
||||
}
|
||||
return funcs
|
||||
}
|
||||
|
@ -10,10 +10,11 @@ import (
|
||||
)
|
||||
|
||||
type DBInfo struct {
|
||||
Version int
|
||||
Tables []DBTable
|
||||
Columns [][]DBColumn
|
||||
colmap map[string]map[string]*DBColumn
|
||||
Version int
|
||||
Tables []DBTable
|
||||
Columns [][]DBColumn
|
||||
Functions []DBFunction
|
||||
colmap map[string]map[string]*DBColumn
|
||||
}
|
||||
|
||||
func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||
@ -51,6 +52,11 @@ func GetDBInfo(db *sql.DB) (*DBInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
di.Functions, err = GetFunctions(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return di, nil
|
||||
}
|
||||
|
||||
@ -237,6 +243,64 @@ ORDER BY id;`
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
type DBFunction struct {
|
||||
Name string
|
||||
Params []DBFuncParam
|
||||
}
|
||||
|
||||
type DBFuncParam struct {
|
||||
ID int
|
||||
Name string
|
||||
Type string
|
||||
}
|
||||
|
||||
func GetFunctions(db *sql.DB) ([]DBFunction, error) {
|
||||
sqlStmt := `
|
||||
SELECT
|
||||
routines.routine_name,
|
||||
parameters.specific_name,
|
||||
parameters.data_type,
|
||||
parameters.parameter_name,
|
||||
parameters.ordinal_position
|
||||
FROM
|
||||
information_schema.routines
|
||||
RIGHT JOIN
|
||||
information_schema.parameters
|
||||
ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)
|
||||
WHERE
|
||||
routines.specific_schema = 'public'
|
||||
ORDER BY
|
||||
routines.routine_name, parameters.ordinal_position;`
|
||||
|
||||
rows, err := db.Query(sqlStmt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error fetching functions: %s", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var funcs []DBFunction
|
||||
fm := make(map[string]int)
|
||||
|
||||
for rows.Next() {
|
||||
var fn, fid string
|
||||
fp := DBFuncParam{}
|
||||
|
||||
err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if i, ok := fm[fid]; ok {
|
||||
funcs[i].Params = append(funcs[i].Params, fp)
|
||||
} else {
|
||||
funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
|
||||
fm[fid] = len(funcs) - 1
|
||||
}
|
||||
}
|
||||
|
||||
return funcs, nil
|
||||
}
|
||||
|
||||
// func GetValType(type string) qcode.ValType {
|
||||
// switch {
|
||||
// case "bigint", "integer", "smallint", "numeric", "bigserial":
|
||||
|
@ -730,6 +730,32 @@ query {
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Functions
|
||||
|
||||
Any function defined in the database like the below `add_five` that adds 5 to any number given to it can be used
|
||||
within your query. The one limitation is that it should be a function that only accepts a single argument. The function is used within you're GraphQL in similar way to how aggregrations are used above. Example below
|
||||
|
||||
```grahql
|
||||
query {
|
||||
thread(id: 5) {
|
||||
id
|
||||
total_votes
|
||||
add_five_total_votes
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Postgres user-defined function `add_five`
|
||||
```
|
||||
CREATE OR REPLACE FUNCTION add_five(a integer) RETURNS integer AS $$
|
||||
BEGIN
|
||||
|
||||
RETURN a + 5;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
```
|
||||
|
||||
|
||||
In GraphQL mutations is the operation type for when you need to modify data. Super Graph supports the `insert`, `update`, `upsert` and `delete`. You can also do complex nested inserts and updates.
|
||||
|
||||
When using mutations the data must be passed as variables since Super Graphs compiles the query into an prepared statement in the database for maximum speed. Prepared statements are are functions in your code when called they accept arguments and your variables are passed in as those arguments.
|
||||
|
@ -23,7 +23,7 @@ func newAction(a *Action) (http.Handler, error) {
|
||||
|
||||
httpFn := func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := fn(w, r); err != nil {
|
||||
renderErr(w, err, nil)
|
||||
renderErr(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/dosco/super-graph/core"
|
||||
"github.com/dosco/super-graph/internal/serv/internal/auth"
|
||||
"github.com/rs/cors"
|
||||
"go.uber.org/zap"
|
||||
@ -29,7 +28,7 @@ type gqlReq struct {
|
||||
}
|
||||
|
||||
type errorResp struct {
|
||||
Error error `json:"error"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func apiV1Handler() http.Handler {
|
||||
@ -55,13 +54,13 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
//nolint: errcheck
|
||||
if conf.AuthFailBlock && !auth.IsAuth(ct) {
|
||||
renderErr(w, errUnauthorized, nil)
|
||||
renderErr(w, errUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
|
||||
if err != nil {
|
||||
renderErr(w, err, nil)
|
||||
renderErr(w, err)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
@ -70,7 +69,7 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.Unmarshal(b, &req)
|
||||
if err != nil {
|
||||
renderErr(w, err, nil)
|
||||
renderErr(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -86,12 +85,11 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
renderErr(w, err, res)
|
||||
return
|
||||
renderErr(w, err)
|
||||
} else {
|
||||
json.NewEncoder(w).Encode(res)
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(res)
|
||||
|
||||
if doLog && logLevel >= LogLevelInfo {
|
||||
zlog.Info("success",
|
||||
zap.String("op", res.Operation()),
|
||||
@ -102,22 +100,10 @@ func apiV1(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
//nolint: errcheck
|
||||
func renderErr(w http.ResponseWriter, err error, res *core.Result) {
|
||||
func renderErr(w http.ResponseWriter, err error) {
|
||||
if err == errUnauthorized {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(&errorResp{err})
|
||||
|
||||
if logLevel >= LogLevelError {
|
||||
if res != nil {
|
||||
zlog.Error(err.Error(),
|
||||
zap.String("op", res.Operation()),
|
||||
zap.String("name", res.QueryName()),
|
||||
zap.String("role", res.Role()),
|
||||
)
|
||||
} else {
|
||||
zlog.Error(err.Error())
|
||||
}
|
||||
}
|
||||
json.NewEncoder(w).Encode(errorResp{err.Error()})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user