Fix bug with compiling anon queries
This commit is contained in:
parent
7583326d21
commit
f518d5fc69
|
@ -73,43 +73,6 @@ mutation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
variables {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"name": "Gumbo1",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Gumbo2",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
|
||||||
products {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
variables {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"name": "Gumbo1",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Gumbo2",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
query {
|
||||||
products {
|
products {
|
||||||
id
|
id
|
||||||
|
@ -133,21 +96,6 @@ mutation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
variables {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"name": "Gumbo1",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Gumbo2",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
query {
|
||||||
products {
|
products {
|
||||||
id
|
id
|
||||||
|
@ -174,39 +122,6 @@ mutation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
variables {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"name": "Gumbo1",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Gumbo2",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
|
||||||
products {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
users {
|
|
||||||
email
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
|
||||||
me {
|
|
||||||
id
|
|
||||||
email
|
|
||||||
full_name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
variables {
|
variables {
|
||||||
"update": {
|
"update": {
|
||||||
"name": "Helloo",
|
"name": "Helloo",
|
||||||
|
@ -223,70 +138,6 @@ mutation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
variables {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"name": "Gumbo1",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Gumbo2",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
|
||||||
product {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
variables {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"name": "Gumbo1",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Gumbo2",
|
|
||||||
"created_at": "now",
|
|
||||||
"updated_at": "now"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
|
||||||
products {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
description
|
|
||||||
users {
|
|
||||||
email
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
query {
|
|
||||||
users {
|
|
||||||
id
|
|
||||||
email
|
|
||||||
picture: avatar
|
|
||||||
password
|
|
||||||
full_name
|
|
||||||
products(limit: 2, where: {price: {gt: 10}}) {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
description
|
|
||||||
price
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
variables {
|
variables {
|
||||||
"data": {
|
"data": {
|
||||||
"name": "WOOO",
|
"name": "WOOO",
|
||||||
|
@ -301,4 +152,11 @@ mutation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
query {
|
||||||
|
products {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -101,18 +101,14 @@ database:
|
||||||
variables:
|
variables:
|
||||||
admin_account_id: "5"
|
admin_account_id: "5"
|
||||||
|
|
||||||
# Define defaults to for the field key and values below
|
# Field and table names that you wish to block
|
||||||
defaults:
|
blocklist:
|
||||||
# filters: ["{ user_id: { eq: $user_id } }"]
|
- ar_internal_metadata
|
||||||
|
- schema_migrations
|
||||||
# Field and table names that you wish to block
|
- secret
|
||||||
blocklist:
|
- password
|
||||||
- ar_internal_metadata
|
- encrypted
|
||||||
- schema_migrations
|
- token
|
||||||
- secret
|
|
||||||
- password
|
|
||||||
- encrypted
|
|
||||||
- token
|
|
||||||
|
|
||||||
tables:
|
tables:
|
||||||
- name: customers
|
- name: customers
|
||||||
|
@ -140,6 +136,7 @@ roles_query: "SELECT * FROM users WHERE id = $user_id"
|
||||||
roles:
|
roles:
|
||||||
- name: anon
|
- name: anon
|
||||||
tables:
|
tables:
|
||||||
|
- name: users
|
||||||
- name: products
|
- name: products
|
||||||
limit: 10
|
limit: 10
|
||||||
|
|
||||||
|
|
|
@ -1275,18 +1275,14 @@ database:
|
||||||
variables:
|
variables:
|
||||||
admin_account_id: "5"
|
admin_account_id: "5"
|
||||||
|
|
||||||
# Define defaults to for the field key and values below
|
# Field and table names that you wish to block
|
||||||
defaults:
|
blocklist:
|
||||||
# filters: ["{ user_id: { eq: $user_id } }"]
|
- ar_internal_metadata
|
||||||
|
- schema_migrations
|
||||||
# Field and table names that you wish to block
|
- secret
|
||||||
blocklist:
|
- password
|
||||||
- ar_internal_metadata
|
- encrypted
|
||||||
- schema_migrations
|
- token
|
||||||
- secret
|
|
||||||
- password
|
|
||||||
- encrypted
|
|
||||||
- token
|
|
||||||
|
|
||||||
tables:
|
tables:
|
||||||
- name: customers
|
- name: customers
|
||||||
|
|
|
@ -500,7 +500,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||||
var groupBy []int
|
var groupBy []int
|
||||||
|
|
||||||
isRoot := sel.ParentID == -1
|
isRoot := sel.ParentID == -1
|
||||||
isFil := sel.Where != nil
|
isFil := (sel.Where != nil && sel.Where.Op != qcode.OpNop)
|
||||||
isSearch := sel.Args["search"] != nil
|
isSearch := sel.Args["search"] != nil
|
||||||
isAgg := false
|
isAgg := false
|
||||||
|
|
||||||
|
@ -880,6 +880,10 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
|
||||||
var col *DBColumn
|
var col *DBColumn
|
||||||
var ok bool
|
var ok bool
|
||||||
|
|
||||||
|
if ex.Op == qcode.OpNop {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(ex.Col) != 0 {
|
if len(ex.Col) != 0 {
|
||||||
if col, ok = ti.Columns[ex.Col]; !ok {
|
if col, ok = ti.Columns[ex.Col]; !ok {
|
||||||
return fmt.Errorf("no column '%s' found ", ex.Col)
|
return fmt.Errorf("no column '%s' found ", ex.Col)
|
||||||
|
|
|
@ -4,6 +4,8 @@ package qcode
|
||||||
|
|
||||||
// FuzzerEntrypoint for Fuzzbuzz
|
// FuzzerEntrypoint for Fuzzbuzz
|
||||||
func Fuzz(data []byte) int {
|
func Fuzz(data []byte) int {
|
||||||
|
GetQType(string(data))
|
||||||
|
|
||||||
qcompile, _ := NewCompiler(Config{})
|
qcompile, _ := NewCompiler(Config{})
|
||||||
_, err := qcompile.Compile(data, "user")
|
_, err := qcompile.Compile(data, "user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -20,6 +20,7 @@ const (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
QTQuery QType = iota + 1
|
QTQuery QType = iota + 1
|
||||||
|
QTMutation
|
||||||
QTInsert
|
QTInsert
|
||||||
QTUpdate
|
QTUpdate
|
||||||
QTDelete
|
QTDelete
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
package qcode
|
||||||
|
|
||||||
|
func GetQType(gql string) QType {
|
||||||
|
for i := range gql {
|
||||||
|
b := gql[i]
|
||||||
|
if b == '{' {
|
||||||
|
return QTQuery
|
||||||
|
}
|
||||||
|
if al(b) {
|
||||||
|
switch b {
|
||||||
|
case 'm', 'M':
|
||||||
|
return QTMutation
|
||||||
|
case 'q', 'Q':
|
||||||
|
return QTQuery
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func al(b byte) bool {
|
||||||
|
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
||||||
|
}
|
|
@ -46,7 +46,7 @@ func initAllowList(cpath string) {
|
||||||
if _, err := os.Stat(fp); err == nil {
|
if _, err := os.Stat(fp); err == nil {
|
||||||
_allowList.filepath = fp
|
_allowList.filepath = fp
|
||||||
} else if !os.IsNotExist(err) {
|
} else if !os.IsNotExist(err) {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ func initAllowList(cpath string) {
|
||||||
if _, err := os.Stat(fp); err == nil {
|
if _, err := os.Stat(fp); err == nil {
|
||||||
_allowList.filepath = fp
|
_allowList.filepath = fp
|
||||||
} else if !os.IsNotExist(err) {
|
} else if !os.IsNotExist(err) {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,13 +66,13 @@ func initAllowList(cpath string) {
|
||||||
if _, err := os.Stat(fp); err == nil {
|
if _, err := os.Stat(fp); err == nil {
|
||||||
_allowList.filepath = fp
|
_allowList.filepath = fp
|
||||||
} else if !os.IsNotExist(err) {
|
} else if !os.IsNotExist(err) {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(_allowList.filepath) == 0 {
|
if len(_allowList.filepath) == 0 {
|
||||||
if conf.Production {
|
if conf.Production {
|
||||||
logger.Fatal().Msg("allow.list not found")
|
errlog.Fatal().Msg("allow.list not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cpath) == 0 {
|
if len(cpath) == 0 {
|
||||||
|
@ -187,7 +187,6 @@ func (al *allowList) load() {
|
||||||
item.gql = q
|
item.gql = q
|
||||||
item.vars = varBytes
|
item.vars = varBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
varBytes = nil
|
varBytes = nil
|
||||||
|
|
||||||
} else if ty == AL_VARS {
|
} else if ty == AL_VARS {
|
||||||
|
|
29
serv/args.go
29
serv/args.go
|
@ -2,44 +2,46 @@ package serv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/dosco/super-graph/jsn"
|
"github.com/dosco/super-graph/jsn"
|
||||||
)
|
)
|
||||||
|
|
||||||
func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int, error) {
|
||||||
return func(w io.Writer, tag string) (int, error) {
|
return func(w io.Writer, tag string) (int, error) {
|
||||||
switch tag {
|
switch tag {
|
||||||
case "user_id_provider":
|
case "user_id_provider":
|
||||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||||
return io.WriteString(w, v.(string))
|
return io.WriteString(w, v.(string))
|
||||||
}
|
}
|
||||||
return io.WriteString(w, "null")
|
return 0, errors.New("query requires variable $user_id_provider")
|
||||||
|
|
||||||
case "user_id":
|
case "user_id":
|
||||||
if v := ctx.Value(userIDKey); v != nil {
|
if v := ctx.Value(userIDKey); v != nil {
|
||||||
return io.WriteString(w, v.(string))
|
return io.WriteString(w, v.(string))
|
||||||
}
|
}
|
||||||
return io.WriteString(w, "null")
|
return 0, errors.New("query requires variable $user_id")
|
||||||
|
|
||||||
case "user_role":
|
case "user_role":
|
||||||
if v := ctx.Value(userRoleKey); v != nil {
|
if v := ctx.Value(userRoleKey); v != nil {
|
||||||
return io.WriteString(w, v.(string))
|
return io.WriteString(w, v.(string))
|
||||||
}
|
}
|
||||||
return io.WriteString(w, "null")
|
return 0, errors.New("query requires variable $user_role")
|
||||||
}
|
}
|
||||||
|
|
||||||
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
|
fields := jsn.Get(vars, [][]byte{[]byte(tag)})
|
||||||
if len(fields) == 0 {
|
if len(fields) == 0 {
|
||||||
return 0, fmt.Errorf("variable '%s' not found", tag)
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.Write(fields[0].Value)
|
return w.Write(fields[0].Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func argList(ctx *coreContext, args [][]byte) []interface{} {
|
func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
|
||||||
vars := make([]interface{}, len(args))
|
vars := make([]interface{}, len(args))
|
||||||
|
|
||||||
var fields map[string]interface{}
|
var fields map[string]interface{}
|
||||||
|
@ -49,7 +51,7 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
|
||||||
fields, _, err = jsn.Tree(ctx.req.Vars)
|
fields, _, err = jsn.Tree(ctx.req.Vars)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn().Err(err).Msg("Failed to parse variables")
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,24 +62,33 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
|
||||||
case bytes.Equal(av, []byte("user_id")):
|
case bytes.Equal(av, []byte("user_id")):
|
||||||
if v := ctx.Value(userIDKey); v != nil {
|
if v := ctx.Value(userIDKey); v != nil {
|
||||||
vars[i] = v.(string)
|
vars[i] = v.(string)
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("query requires variable $user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
case bytes.Equal(av, []byte("user_id_provider")):
|
case bytes.Equal(av, []byte("user_id_provider")):
|
||||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||||
vars[i] = v.(string)
|
vars[i] = v.(string)
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("query requires variable $user_id_provider")
|
||||||
}
|
}
|
||||||
|
|
||||||
case bytes.Equal(av, []byte("user_role")):
|
case bytes.Equal(av, []byte("user_role")):
|
||||||
if v := ctx.Value(userRoleKey); v != nil {
|
if v := ctx.Value(userRoleKey); v != nil {
|
||||||
vars[i] = v.(string)
|
vars[i] = v.(string)
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("query requires variable $user_role")
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if v, ok := fields[string(av)]; ok {
|
if v, ok := fields[string(av)]; ok {
|
||||||
vars[i] = v
|
vars[i] = v
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("query requires variable $%s", string(av))
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return vars
|
return vars, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
case len(publicKeyFile) != 0:
|
case len(publicKeyFile) != 0:
|
||||||
kd, err := ioutil.ReadFile(publicKeyFile)
|
kd, err := ioutil.ReadFile(publicKeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
switch conf.Auth.JWT.PubKeyType {
|
switch conf.Auth.JWT.PubKeyType {
|
||||||
|
@ -51,7 +51,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,11 @@ import (
|
||||||
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
cookie := conf.Auth.Cookie
|
cookie := conf.Auth.Cookie
|
||||||
if len(cookie) == 0 {
|
if len(cookie) == 0 {
|
||||||
logger.Fatal().Msg("no auth.cookie defined")
|
errlog.Fatal().Msg("no auth.cookie defined")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(conf.Auth.Rails.URL) == 0 {
|
if len(conf.Auth.Rails.URL) == 0 {
|
||||||
logger.Fatal().Msg("no auth.rails.url defined")
|
errlog.Fatal().Msg("no auth.rails.url defined")
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := &redis.Pool{
|
rp := &redis.Pool{
|
||||||
|
@ -28,13 +28,13 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
Dial: func() (redis.Conn, error) {
|
Dial: func() (redis.Conn, error) {
|
||||||
c, err := redis.DialURL(conf.Auth.Rails.URL)
|
c, err := redis.DialURL(conf.Auth.Rails.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
pwd := conf.Auth.Rails.Password
|
pwd := conf.Auth.Rails.Password
|
||||||
if len(pwd) != 0 {
|
if len(pwd) != 0 {
|
||||||
if _, err := c.Do("AUTH", pwd); err != nil {
|
if _, err := c.Do("AUTH", pwd); err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return c, err
|
return c, err
|
||||||
|
@ -69,16 +69,16 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
cookie := conf.Auth.Cookie
|
cookie := conf.Auth.Cookie
|
||||||
if len(cookie) == 0 {
|
if len(cookie) == 0 {
|
||||||
logger.Fatal().Msg("no auth.cookie defined")
|
errlog.Fatal().Msg("no auth.cookie defined")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(conf.Auth.Rails.URL) == 0 {
|
if len(conf.Auth.Rails.URL) == 0 {
|
||||||
logger.Fatal().Msg("no auth.rails.url defined")
|
errlog.Fatal().Msg("no auth.rails.url defined")
|
||||||
}
|
}
|
||||||
|
|
||||||
rURL, err := url.Parse(conf.Auth.Rails.URL)
|
rURL, err := url.Parse(conf.Auth.Rails.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
mc := memcache.New(rURL.Host)
|
mc := memcache.New(rURL.Host)
|
||||||
|
@ -111,12 +111,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||||
cookie := conf.Auth.Cookie
|
cookie := conf.Auth.Cookie
|
||||||
if len(cookie) == 0 {
|
if len(cookie) == 0 {
|
||||||
logger.Fatal().Msg("no auth.cookie defined")
|
errlog.Fatal().Msg("no auth.cookie defined")
|
||||||
}
|
}
|
||||||
|
|
||||||
ra, err := railsAuth(conf)
|
ra, err := railsAuth(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
33
serv/cmd.go
33
serv/cmd.go
|
@ -22,16 +22,18 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logger *zerolog.Logger
|
logger zerolog.Logger
|
||||||
|
errlog zerolog.Logger
|
||||||
conf *config
|
conf *config
|
||||||
confPath string
|
confPath string
|
||||||
db *pgxpool.Pool
|
db *pgxpool.Pool
|
||||||
|
schema *psql.DBSchema
|
||||||
qcompile *qcode.Compiler
|
qcompile *qcode.Compiler
|
||||||
pcompile *psql.Compiler
|
pcompile *psql.Compiler
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
logger = initLog()
|
initLog()
|
||||||
|
|
||||||
rootCmd := &cobra.Command{
|
rootCmd := &cobra.Command{
|
||||||
Use: "super-graph",
|
Use: "super-graph",
|
||||||
|
@ -135,19 +137,14 @@ e.g. db:migrate -+1
|
||||||
"path", "./config", "path to config files")
|
"path", "./config", "path to config files")
|
||||||
|
|
||||||
if err := rootCmd.Execute(); err != nil {
|
if err := rootCmd.Execute(); err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func initLog() *zerolog.Logger {
|
func initLog() {
|
||||||
out := zerolog.ConsoleWriter{Out: os.Stderr}
|
out := zerolog.ConsoleWriter{Out: os.Stderr}
|
||||||
logger := zerolog.New(out).
|
logger = zerolog.New(out).With().Timestamp().Logger()
|
||||||
With().
|
errlog = logger.With().Caller().Logger()
|
||||||
Timestamp().
|
|
||||||
Caller().
|
|
||||||
Logger()
|
|
||||||
|
|
||||||
return &logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func initConf() (*config, error) {
|
func initConf() (*config, error) {
|
||||||
|
@ -166,7 +163,7 @@ func initConf() (*config, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if vi.IsSet("inherits") {
|
if vi.IsSet("inherits") {
|
||||||
logger.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
|
errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
|
||||||
inherits,
|
inherits,
|
||||||
vi.GetString("inherits"))
|
vi.GetString("inherits"))
|
||||||
}
|
}
|
||||||
|
@ -183,7 +180,7 @@ func initConf() (*config, error) {
|
||||||
|
|
||||||
logLevel, err := zerolog.ParseLevel(c.LogLevel)
|
logLevel, err := zerolog.ParseLevel(c.LogLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error().Err(err).Msg("error setting log_level")
|
errlog.Error().Err(err).Msg("error setting log_level")
|
||||||
}
|
}
|
||||||
zerolog.SetGlobalLevel(logLevel)
|
zerolog.SetGlobalLevel(logLevel)
|
||||||
|
|
||||||
|
@ -218,7 +215,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) {
|
||||||
config.LogLevel = pgx.LogLevelNone
|
config.LogLevel = pgx.LogLevelNone
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Logger = NewSQLLogger(*logger)
|
config.Logger = NewSQLLogger(logger)
|
||||||
|
|
||||||
db, err := pgx.ConnectConfig(context.Background(), config)
|
db, err := pgx.ConnectConfig(context.Background(), config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -253,7 +250,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) {
|
||||||
config.ConnConfig.LogLevel = pgx.LogLevelNone
|
config.ConnConfig.LogLevel = pgx.LogLevelNone
|
||||||
}
|
}
|
||||||
|
|
||||||
config.ConnConfig.Logger = NewSQLLogger(*logger)
|
config.ConnConfig.Logger = NewSQLLogger(logger)
|
||||||
|
|
||||||
// if c.DB.MaxRetries != 0 {
|
// if c.DB.MaxRetries != 0 {
|
||||||
// opt.MaxRetries = c.DB.MaxRetries
|
// opt.MaxRetries = c.DB.MaxRetries
|
||||||
|
@ -276,11 +273,11 @@ func initCompiler() {
|
||||||
|
|
||||||
qcompile, pcompile, err = initCompilers(conf)
|
qcompile, pcompile, err = initCompilers(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to initialize compilers")
|
errlog.Fatal().Err(err).Msg("failed to initialize compilers")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := initResolvers(); err != nil {
|
if err := initResolvers(); err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to initialized resolvers")
|
errlog.Fatal().Err(err).Msg("failed to initialized resolvers")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,7 +286,7 @@ func initConfOnce() {
|
||||||
|
|
||||||
if conf == nil {
|
if conf == nil {
|
||||||
if conf, err = initConf(); err != nil {
|
if conf, err = initConf(); err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to read config")
|
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,11 @@ func cmdConfDump(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
conf, err := initConf()
|
conf, err := initConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to read config")
|
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conf.Viper.WriteConfigAs(fname); err != nil {
|
if err := conf.Viper.WriteConfigAs(fname); err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info().Msgf("config dumped to ./%s", fname)
|
logger.Info().Msgf("config dumped to ./%s", fname)
|
||||||
|
|
|
@ -49,7 +49,7 @@ func cmdDBSetup(cmd *cobra.Command, args []string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if os.IsNotExist(err) == false {
|
if os.IsNotExist(err) == false {
|
||||||
logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile)
|
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Warn().Msgf("failed to read seed file '%s'", sfile)
|
logger.Warn().Msgf("failed to read seed file '%s'", sfile)
|
||||||
|
@ -59,7 +59,7 @@ func cmdDBReset(cmd *cobra.Command, args []string) {
|
||||||
initConfOnce()
|
initConfOnce()
|
||||||
|
|
||||||
if conf.Production {
|
if conf.Production {
|
||||||
logger.Fatal().Msg("db:reset does not work in production")
|
errlog.Fatal().Msg("db:reset does not work in production")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cmdDBDrop(cmd, []string{})
|
cmdDBDrop(cmd, []string{})
|
||||||
|
@ -72,7 +72,7 @@ func cmdDBCreate(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
conn, err := initDB(conf, false)
|
conn, err := initDB(conf, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||||
}
|
}
|
||||||
defer conn.Close(ctx)
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ func cmdDBCreate(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
_, err = conn.Exec(ctx, sql)
|
_, err = conn.Exec(ctx, sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to create database")
|
errlog.Fatal().Err(err).Msg("failed to create database")
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info().Msgf("created database '%s'", conf.DB.DBName)
|
logger.Info().Msgf("created database '%s'", conf.DB.DBName)
|
||||||
|
@ -92,7 +92,7 @@ func cmdDBDrop(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
conn, err := initDB(conf, false)
|
conn, err := initDB(conf, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||||
}
|
}
|
||||||
defer conn.Close(ctx)
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ func cmdDBDrop(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
_, err = conn.Exec(ctx, sql)
|
_, err = conn.Exec(ctx, sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to create database")
|
errlog.Fatal().Err(err).Msg("failed to create database")
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info().Msgf("dropped database '%s'", conf.DB.DBName)
|
logger.Info().Msgf("dropped database '%s'", conf.DB.DBName)
|
||||||
|
@ -151,24 +151,24 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
conn, err := initDB(conf, true)
|
conn, err := initDB(conf, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||||
}
|
}
|
||||||
defer conn.Close(context.Background())
|
defer conn.Close(context.Background())
|
||||||
|
|
||||||
m, err := migrate.NewMigrator(conn, "schema_version")
|
m, err := migrate.NewMigrator(conn, "schema_version")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to initializing migrator")
|
errlog.Fatal().Err(err).Msg("failed to initializing migrator")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Data = getMigrationVars()
|
m.Data = getMigrationVars()
|
||||||
|
|
||||||
err = m.LoadMigrations(conf.MigrationsPath)
|
err = m.LoadMigrations(conf.MigrationsPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to load migrations")
|
errlog.Fatal().Err(err).Msg("failed to load migrations")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.Migrations) == 0 {
|
if len(m.Migrations) == 0 {
|
||||||
logger.Fatal().Msg("No migrations found")
|
errlog.Fatal().Msg("No migrations found")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.OnStart = func(sequence int32, name, direction, sql string) {
|
m.OnStart = func(sequence int32, name, direction, sql string) {
|
||||||
|
@ -187,7 +187,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
|
||||||
var n int64
|
var n int64
|
||||||
n, err = strconv.ParseInt(d, 10, 32)
|
n, err = strconv.ParseInt(d, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("invalid destination")
|
errlog.Fatal().Err(err).Msg("invalid destination")
|
||||||
}
|
}
|
||||||
return int32(n)
|
return int32(n)
|
||||||
}
|
}
|
||||||
|
@ -218,17 +218,15 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info().Err(err).Send()
|
logger.Info().Err(err).Send()
|
||||||
|
|
||||||
// logger.Info().Err(err).Send()
|
|
||||||
|
|
||||||
// if err, ok := err.(m.MigrationPgError); ok {
|
// if err, ok := err.(m.MigrationPgError); ok {
|
||||||
// if err.Detail != "" {
|
// if err.Detail != "" {
|
||||||
// logger.Info().Err(err).Msg(err.Detail)
|
// info.Err(err).Msg(err.Detail)
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// if err.Position != 0 {
|
// if err.Position != 0 {
|
||||||
// ele, err := ExtractErrorLine(err.Sql, int(err.Position))
|
// ele, err := ExtractErrorLine(err.Sql, int(err.Position))
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// logger.Fatal().Err(err).Send()
|
// errlog.Fatal().Err(err).Send()
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// prefix := fmt.Sprintf()
|
// prefix := fmt.Sprintf()
|
||||||
|
@ -247,29 +245,29 @@ func cmdDBStatus(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
conn, err := initDB(conf, true)
|
conn, err := initDB(conf, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||||
}
|
}
|
||||||
defer conn.Close(context.Background())
|
defer conn.Close(context.Background())
|
||||||
|
|
||||||
m, err := migrate.NewMigrator(conn, "schema_version")
|
m, err := migrate.NewMigrator(conn, "schema_version")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to initialize migrator")
|
errlog.Fatal().Err(err).Msg("failed to initialize migrator")
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Data = getMigrationVars()
|
m.Data = getMigrationVars()
|
||||||
|
|
||||||
err = m.LoadMigrations(conf.MigrationsPath)
|
err = m.LoadMigrations(conf.MigrationsPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to load migrations")
|
errlog.Fatal().Err(err).Msg("failed to load migrations")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.Migrations) == 0 {
|
if len(m.Migrations) == 0 {
|
||||||
logger.Fatal().Msg("no migrations found")
|
errlog.Fatal().Msg("no migrations found")
|
||||||
}
|
}
|
||||||
|
|
||||||
mver, err := m.GetCurrentVersion()
|
mver, err := m.GetCurrentVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to retrieve migration")
|
errlog.Fatal().Err(err).Msg("failed to retrieve migration")
|
||||||
}
|
}
|
||||||
|
|
||||||
var status string
|
var status string
|
||||||
|
|
|
@ -134,12 +134,12 @@ func ifNotExists(filePath string, doFn func(string) error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if os.IsNotExist(err) == false {
|
if os.IsNotExist(err) == false {
|
||||||
logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath)
|
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = doFn(filePath)
|
err = doFn(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msgf("unable to create '%s'", filePath)
|
errlog.Fatal().Err(err).Msgf("unable to create '%s'", filePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info().Msgf("created '%s'", filePath)
|
logger.Info().Msgf("created '%s'", filePath)
|
||||||
|
|
|
@ -20,14 +20,14 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if conf, err = initConf(); err != nil {
|
if conf, err = initConf(); err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to read config")
|
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||||
}
|
}
|
||||||
|
|
||||||
conf.Production = false
|
conf.Production = false
|
||||||
|
|
||||||
db, err = initDBPool(conf)
|
db, err = initDBPool(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||||
}
|
}
|
||||||
|
|
||||||
initCompiler()
|
initCompiler()
|
||||||
|
@ -36,7 +36,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
b, err := ioutil.ReadFile(path.Join(confPath, conf.SeedFile))
|
b, err := ioutil.ReadFile(path.Join(confPath, conf.SeedFile))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile)
|
errlog.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile)
|
||||||
}
|
}
|
||||||
|
|
||||||
vm := goja.New()
|
vm := goja.New()
|
||||||
|
@ -52,7 +52,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
_, err = vm.RunScript("seed.js", string(b))
|
_, err = vm.RunScript("seed.js", string(b))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to execute script")
|
errlog.Fatal().Err(err).Msg("failed to execute script")
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info().Msg("seed script done")
|
logger.Info().Msg("seed script done")
|
||||||
|
@ -60,15 +60,15 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
//func runFunc(call goja.FunctionCall) {
|
//func runFunc(call goja.FunctionCall) {
|
||||||
func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} {
|
func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} {
|
||||||
b, err := json.Marshal(data)
|
vars, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
c := context.Background()
|
||||||
|
|
||||||
if v, ok := opt["user_id"]; ok && len(v) != 0 {
|
if v, ok := opt["user_id"]; ok && len(v) != 0 {
|
||||||
ctx = context.WithValue(ctx, userIDKey, v)
|
c = context.WithValue(c, userIDKey, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
var role string
|
var role string
|
||||||
|
@ -79,62 +79,50 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri
|
||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &coreContext{Context: ctx}
|
stmts, err := buildRoleStmt([]byte(query), vars, role)
|
||||||
c.req.Query = query
|
|
||||||
c.req.Vars = b
|
|
||||||
|
|
||||||
st, err := c.buildStmtByRole(role)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("graphql query failed")
|
errlog.Fatal().Err(err).Msg("graphql query failed")
|
||||||
}
|
}
|
||||||
|
st := stmts[0]
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
t := fasttemplate.New(st.sql, openVar, closeVar)
|
t := fasttemplate.New(st.sql, openVar, closeVar)
|
||||||
_, err = t.ExecuteFunc(buf, argMap(c))
|
_, err = t.ExecuteFunc(buf, argMap(c, vars))
|
||||||
|
|
||||||
if err == errNoUserID {
|
|
||||||
logger.Fatal().Err(err).Msg("query requires a user_id")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
finalSQL := buf.String()
|
finalSQL := buf.String()
|
||||||
|
|
||||||
tx, err := db.Begin(c)
|
tx, err := db.Begin(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c)
|
defer tx.Rollback(c)
|
||||||
|
|
||||||
if conf.DB.SetUserID {
|
if conf.DB.SetUserID {
|
||||||
if err := c.setLocalUserID(tx); err != nil {
|
if err := setLocalUserID(c, tx); err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var root []byte
|
var root []byte
|
||||||
|
|
||||||
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
|
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
|
||||||
logger.Fatal().Err(err).Msg("sql query failed")
|
errlog.Fatal().Err(err).Msg("sql query failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(c); err != nil {
|
if err := tx.Commit(c); err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
|
||||||
|
|
||||||
res, err := c.execRemoteJoin(st.qc, st.skipped, root)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal().Err(err).Send()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val := make(map[string]interface{})
|
val := make(map[string]interface{})
|
||||||
|
|
||||||
err = json.Unmarshal(res, &val)
|
err = json.Unmarshal(root, &val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
return val
|
return val
|
||||||
|
|
|
@ -8,12 +8,12 @@ func cmdServ(cmd *cobra.Command, args []string) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if conf, err = initConf(); err != nil {
|
if conf, err = initConf(); err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to read config")
|
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err = initDBPool(conf)
|
db, err = initDBPool(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||||
}
|
}
|
||||||
|
|
||||||
initCompiler()
|
initCompiler()
|
||||||
|
|
|
@ -68,12 +68,8 @@ type config struct {
|
||||||
MaxRetries int `mapstructure:"max_retries"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
SetUserID bool `mapstructure:"set_user_id"`
|
SetUserID bool `mapstructure:"set_user_id"`
|
||||||
|
|
||||||
Vars map[string]string `mapstructure:"variables"`
|
Vars map[string]string `mapstructure:"variables"`
|
||||||
|
Blocklist []string
|
||||||
Defaults struct {
|
|
||||||
Filters []string
|
|
||||||
Blocklist []string
|
|
||||||
}
|
|
||||||
|
|
||||||
Tables []configTable
|
Tables []configTable
|
||||||
} `mapstructure:"database"`
|
} `mapstructure:"database"`
|
||||||
|
@ -82,6 +78,7 @@ type config struct {
|
||||||
|
|
||||||
RolesQuery string `mapstructure:"roles_query"`
|
RolesQuery string `mapstructure:"roles_query"`
|
||||||
Roles []configRole
|
Roles []configRole
|
||||||
|
roles map[string]*configRole
|
||||||
}
|
}
|
||||||
|
|
||||||
type configTable struct {
|
type configTable struct {
|
||||||
|
@ -220,16 +217,15 @@ func (c *config) Init(vi *viper.Viper) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.RolesQuery = sanitize(c.RolesQuery)
|
c.RolesQuery = sanitize(c.RolesQuery)
|
||||||
|
c.roles = make(map[string]*configRole)
|
||||||
rolesMap := make(map[string]struct{})
|
|
||||||
|
|
||||||
for i := range c.Roles {
|
for i := range c.Roles {
|
||||||
role := &c.Roles[i]
|
role := &c.Roles[i]
|
||||||
|
|
||||||
if _, ok := rolesMap[role.Name]; ok {
|
if _, ok := c.roles[role.Name]; ok {
|
||||||
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
|
errlog.Fatal().Msgf("duplicate role '%s' found", role.Name)
|
||||||
}
|
}
|
||||||
role.Name = sanitize(role.Name)
|
role.Name = strings.ToLower(role.Name)
|
||||||
role.Match = sanitize(role.Match)
|
role.Match = sanitize(role.Match)
|
||||||
role.tablesMap = make(map[string]*configRoleTable)
|
role.tablesMap = make(map[string]*configRoleTable)
|
||||||
|
|
||||||
|
@ -237,14 +233,16 @@ func (c *config) Init(vi *viper.Viper) error {
|
||||||
role.tablesMap[table.Name] = &role.Tables[n]
|
role.tablesMap[table.Name] = &role.Tables[n]
|
||||||
}
|
}
|
||||||
|
|
||||||
rolesMap[role.Name] = struct{}{}
|
c.roles[role.Name] = role
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := rolesMap["user"]; !ok {
|
if _, ok := c.roles["user"]; !ok {
|
||||||
c.Roles = append(c.Roles, configRole{Name: "user"})
|
u := configRole{Name: "user"}
|
||||||
|
c.Roles = append(c.Roles, u)
|
||||||
|
c.roles["user"] = &u
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := rolesMap["anon"]; !ok {
|
if _, ok := c.roles["anon"]; !ok {
|
||||||
logger.Warn().Msg("unauthenticated requests will be blocked. no role 'anon' defined")
|
logger.Warn().Msg("unauthenticated requests will be blocked. no role 'anon' defined")
|
||||||
c.AuthFailBlock = true
|
c.AuthFailBlock = true
|
||||||
}
|
}
|
||||||
|
@ -261,7 +259,7 @@ func (c *config) validate() {
|
||||||
name := c.Roles[i].Name
|
name := c.Roles[i].Name
|
||||||
|
|
||||||
if _, ok := rm[name]; ok {
|
if _, ok := rm[name]; ok {
|
||||||
logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
|
errlog.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
|
||||||
}
|
}
|
||||||
rm[name] = struct{}{}
|
rm[name] = struct{}{}
|
||||||
}
|
}
|
||||||
|
@ -272,7 +270,7 @@ func (c *config) validate() {
|
||||||
name := c.Tables[i].Name
|
name := c.Tables[i].Name
|
||||||
|
|
||||||
if _, ok := tm[name]; ok {
|
if _, ok := tm[name]; ok {
|
||||||
logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
|
errlog.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
|
||||||
}
|
}
|
||||||
tm[name] = struct{}{}
|
tm[name] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
352
serv/core.go
352
serv/core.go
|
@ -8,11 +8,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
"github.com/dosco/super-graph/jsn"
|
|
||||||
"github.com/dosco/super-graph/qcode"
|
"github.com/dosco/super-graph/qcode"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v4"
|
||||||
"github.com/valyala/fasttemplate"
|
"github.com/valyala/fasttemplate"
|
||||||
|
@ -32,6 +30,10 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||||
c.req.ref = req.Referer()
|
c.req.ref = req.Referer()
|
||||||
c.req.hdr = req.Header
|
c.req.hdr = req.Header
|
||||||
|
|
||||||
|
if len(c.req.Vars) == 2 {
|
||||||
|
c.req.Vars = nil
|
||||||
|
}
|
||||||
|
|
||||||
if authCheck(c) {
|
if authCheck(c) {
|
||||||
c.req.role = "user"
|
c.req.role = "user"
|
||||||
} else {
|
} else {
|
||||||
|
@ -47,83 +49,38 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) execQuery() ([]byte, error) {
|
func (c *coreContext) execQuery() ([]byte, error) {
|
||||||
var err error
|
|
||||||
var skipped uint32
|
|
||||||
var qc *qcode.QCode
|
|
||||||
var data []byte
|
var data []byte
|
||||||
|
var st *stmt
|
||||||
|
var err error
|
||||||
|
|
||||||
if conf.Production {
|
if conf.Production {
|
||||||
var ps *preparedItem
|
data, st, err = c.resolvePreparedSQL()
|
||||||
|
|
||||||
data, ps, err = c.resolvePreparedSQL()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
logger.Error().
|
||||||
}
|
Err(err).
|
||||||
|
Str("default_role", c.req.role).
|
||||||
|
Msg(c.req.Query)
|
||||||
|
|
||||||
skipped = ps.skipped
|
return nil, errors.New("query failed. check logs for error")
|
||||||
qc = ps.qc
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
if data, st, err = c.resolveSQL(); err != nil {
|
||||||
data, skipped, err = c.resolveSQL()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.execRemoteJoin(qc, skipped, data)
|
return execRemoteJoin(st, data, c.req.hdr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) execRemoteJoin(qc *qcode.QCode, skipped uint32, data []byte) ([]byte, error) {
|
func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||||
var err error
|
|
||||||
|
|
||||||
if len(data) == 0 || skipped == 0 {
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sel := qc.Selects
|
|
||||||
h := xxhash.New()
|
|
||||||
|
|
||||||
// fetch the field name used within the db response json
|
|
||||||
// that are used to mark insertion points and the mapping between
|
|
||||||
// those field names and their select objects
|
|
||||||
fids, sfmap := parentFieldIds(h, sel, skipped)
|
|
||||||
|
|
||||||
// fetch the field values of the marked insertion points
|
|
||||||
// these values contain the id to be used with fetching remote data
|
|
||||||
from := jsn.Get(data, fids)
|
|
||||||
|
|
||||||
var to []jsn.Field
|
|
||||||
switch {
|
|
||||||
case len(from) == 1:
|
|
||||||
to, err = c.resolveRemote(c.req.hdr, h, from[0], sel, sfmap)
|
|
||||||
|
|
||||||
case len(from) > 1:
|
|
||||||
to, err = c.resolveRemotes(c.req.hdr, h, from, sel, sfmap)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, errors.New("something wrong no remote ids found in db response")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var ob bytes.Buffer
|
|
||||||
|
|
||||||
err = jsn.Replace(&ob, data, from, to)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ob.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
|
||||||
var tx pgx.Tx
|
var tx pgx.Tx
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
mutation := isMutation(c.req.Query)
|
qt := qcode.GetQType(c.req.Query)
|
||||||
|
mutation := (qt == qcode.QTMutation)
|
||||||
|
anonQuery := (qt == qcode.QTQuery && c.req.role == "anon")
|
||||||
|
|
||||||
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
||||||
useTx := useRoleQuery || conf.DB.SetUserID
|
useTx := useRoleQuery || conf.DB.SetUserID
|
||||||
|
|
||||||
|
@ -135,7 +92,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.DB.SetUserID {
|
if conf.DB.SetUserID {
|
||||||
if err := c.setLocalUserID(tx); err != nil {
|
if err := setLocalUserID(c, tx); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,7 +107,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||||
} else if v := c.Value(userRoleKey); v != nil {
|
} else if v := c.Value(userRoleKey); v != nil {
|
||||||
role = v.(string)
|
role = v.(string)
|
||||||
|
|
||||||
} else if mutation {
|
} else {
|
||||||
role = c.req.role
|
role = c.req.role
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -162,21 +119,29 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||||
|
|
||||||
var root []byte
|
var root []byte
|
||||||
var row pgx.Row
|
var row pgx.Row
|
||||||
vars := argList(c, ps.args)
|
|
||||||
|
|
||||||
if useTx {
|
vars, err := argList(c, ps.args)
|
||||||
row = tx.QueryRow(c, ps.stmt.SQL, vars...)
|
if err != nil {
|
||||||
} else {
|
return nil, nil, err
|
||||||
row = db.QueryRow(c, ps.stmt.SQL, vars...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if mutation {
|
if useTx {
|
||||||
|
row = tx.QueryRow(c, ps.sd.SQL, vars...)
|
||||||
|
} else {
|
||||||
|
row = db.QueryRow(c, ps.sd.SQL, vars...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mutation || anonQuery {
|
||||||
err = row.Scan(&root)
|
err = row.Scan(&root)
|
||||||
} else {
|
} else {
|
||||||
err = row.Scan(&role, &root)
|
err = row.Scan(&role, &root)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
|
if len(role) == 0 {
|
||||||
|
logger.Debug().Str("default_role", c.req.role).Msg(c.req.Query)
|
||||||
|
} else {
|
||||||
|
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -190,65 +155,55 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return root, ps, nil
|
return root, ps.st, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||||
var tx pgx.Tx
|
var tx pgx.Tx
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
mutation := isMutation(c.req.Query)
|
qt := qcode.GetQType(c.req.Query)
|
||||||
|
mutation := (qt == qcode.QTMutation)
|
||||||
|
//anonQuery := (qt == qcode.QTQuery && c.req.role == "anon")
|
||||||
|
|
||||||
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
||||||
useTx := useRoleQuery || conf.DB.SetUserID
|
useTx := useRoleQuery || conf.DB.SetUserID
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
if tx, err = db.Begin(c); err != nil {
|
if tx, err = db.Begin(c); err != nil {
|
||||||
return nil, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c)
|
defer tx.Rollback(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.DB.SetUserID {
|
if conf.DB.SetUserID {
|
||||||
if err := c.setLocalUserID(tx); err != nil {
|
if err := setLocalUserID(c, tx); err != nil {
|
||||||
return nil, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if useRoleQuery {
|
if useRoleQuery {
|
||||||
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
|
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
|
||||||
return nil, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if v := c.Value(userRoleKey); v != nil {
|
} else if v := c.Value(userRoleKey); v != nil {
|
||||||
c.req.role = v.(string)
|
c.req.role = v.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
stmts, err := c.buildStmt()
|
stmts, err := buildStmt(qt, []byte(c.req.Query), c.req.Vars, c.req.role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, nil, err
|
||||||
}
|
|
||||||
|
|
||||||
var st *stmt
|
|
||||||
|
|
||||||
if mutation {
|
|
||||||
st = findStmt(c.req.role, stmts)
|
|
||||||
} else {
|
|
||||||
st = &stmts[0]
|
|
||||||
}
|
}
|
||||||
|
st := &stmts[0]
|
||||||
|
|
||||||
t := fasttemplate.New(st.sql, openVar, closeVar)
|
t := fasttemplate.New(st.sql, openVar, closeVar)
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
_, err = t.ExecuteFunc(buf, argMap(c))
|
|
||||||
|
|
||||||
if err == errNoUserID {
|
|
||||||
logger.Warn().Msg("no user id found. query requires an authenicated request")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
_, err = t.ExecuteFunc(buf, argMap(c, c.req.Vars))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
finalSQL := buf.String()
|
finalSQL := buf.String()
|
||||||
|
|
||||||
var stime time.Time
|
var stime time.Time
|
||||||
|
@ -258,195 +213,56 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var root []byte
|
var root []byte
|
||||||
var role, defaultRole string
|
var role string
|
||||||
var row pgx.Row
|
var row pgx.Row
|
||||||
|
|
||||||
|
defaultRole := c.req.role
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
row = tx.QueryRow(c, finalSQL)
|
row = tx.QueryRow(c, finalSQL)
|
||||||
} else {
|
} else {
|
||||||
row = db.QueryRow(c, finalSQL)
|
row = db.QueryRow(c, finalSQL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if mutation {
|
if len(stmts) == 1 {
|
||||||
err = row.Scan(&root)
|
err = row.Scan(&root)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
err = row.Scan(&role, &root)
|
err = row.Scan(&role, &root)
|
||||||
defaultRole = c.req.role
|
|
||||||
c.req.role = role
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query)
|
if len(role) == 0 {
|
||||||
|
logger.Debug().Str("default_role", defaultRole).Msg(c.req.Query)
|
||||||
|
} else {
|
||||||
|
logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
if err := tx.Commit(c); err != nil {
|
if err := tx.Commit(c); err != nil {
|
||||||
return nil, 0, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.EnableTracing && len(st.qc.Selects) != 0 {
|
// if conf.Production == false {
|
||||||
|
// _allowList.add(&c.req)
|
||||||
|
// }
|
||||||
|
|
||||||
|
if len(stmts) > 1 {
|
||||||
|
if st = findStmt(role, stmts); st == nil {
|
||||||
|
return nil, nil, fmt.Errorf("invalid role '%s' returned", role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.EnableTracing {
|
||||||
for _, id := range st.qc.Roots {
|
for _, id := range st.qc.Roots {
|
||||||
c.addTrace(st.qc.Selects, id, stime)
|
c.addTrace(st.qc.Selects, id, stime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.Production == false {
|
return root, st, nil
|
||||||
_allowList.add(&c.req)
|
|
||||||
}
|
|
||||||
|
|
||||||
return root, st.skipped, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *coreContext) resolveRemote(
|
|
||||||
hdr http.Header,
|
|
||||||
h *xxhash.Digest,
|
|
||||||
field jsn.Field,
|
|
||||||
sel []qcode.Select,
|
|
||||||
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
|
||||||
|
|
||||||
// replacement data for the marked insertion points
|
|
||||||
// key and value will be replaced by whats below
|
|
||||||
toA := [1]jsn.Field{}
|
|
||||||
to := toA[:1]
|
|
||||||
|
|
||||||
// use the json key to find the related Select object
|
|
||||||
k1 := xxhash.Sum64(field.Key)
|
|
||||||
|
|
||||||
s, ok := sfmap[k1]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
p := sel[s.ParentID]
|
|
||||||
|
|
||||||
// then use the Table nme in the Select and it's parent
|
|
||||||
// to find the resolver to use for this relationship
|
|
||||||
k2 := mkkey(h, s.Table, p.Table)
|
|
||||||
|
|
||||||
r, ok := rmap[k2]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
id := jsn.Value(field.Value)
|
|
||||||
if len(id) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
st := time.Now()
|
|
||||||
|
|
||||||
b, err := r.Fn(hdr, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.EnableTracing {
|
|
||||||
c.addTrace(sel, s.ID, st)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(r.Path) != 0 {
|
|
||||||
b = jsn.Strip(b, r.Path)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ob bytes.Buffer
|
|
||||||
|
|
||||||
if len(s.Cols) != 0 {
|
|
||||||
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
ob.WriteString("null")
|
|
||||||
}
|
|
||||||
|
|
||||||
to[0] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
|
|
||||||
return to, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *coreContext) resolveRemotes(
|
|
||||||
hdr http.Header,
|
|
||||||
h *xxhash.Digest,
|
|
||||||
from []jsn.Field,
|
|
||||||
sel []qcode.Select,
|
|
||||||
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
|
||||||
|
|
||||||
// replacement data for the marked insertion points
|
|
||||||
// key and value will be replaced by whats below
|
|
||||||
to := make([]jsn.Field, len(from))
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(from))
|
|
||||||
|
|
||||||
var cerr error
|
|
||||||
|
|
||||||
for i, id := range from {
|
|
||||||
|
|
||||||
// use the json key to find the related Select object
|
|
||||||
k1 := xxhash.Sum64(id.Key)
|
|
||||||
|
|
||||||
s, ok := sfmap[k1]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
p := sel[s.ParentID]
|
|
||||||
|
|
||||||
// then use the Table nme in the Select and it's parent
|
|
||||||
// to find the resolver to use for this relationship
|
|
||||||
k2 := mkkey(h, s.Table, p.Table)
|
|
||||||
|
|
||||||
r, ok := rmap[k2]
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
id := jsn.Value(id.Value)
|
|
||||||
if len(id) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(n int, id []byte, s *qcode.Select) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
st := time.Now()
|
|
||||||
|
|
||||||
b, err := r.Fn(hdr, id)
|
|
||||||
if err != nil {
|
|
||||||
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if conf.EnableTracing {
|
|
||||||
c.addTrace(sel, s.ID, st)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(r.Path) != 0 {
|
|
||||||
b = jsn.Strip(b, r.Path)
|
|
||||||
}
|
|
||||||
|
|
||||||
var ob bytes.Buffer
|
|
||||||
|
|
||||||
if len(s.Cols) != 0 {
|
|
||||||
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
|
||||||
if err != nil {
|
|
||||||
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
ob.WriteString("null")
|
|
||||||
}
|
|
||||||
|
|
||||||
to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
|
|
||||||
}(i, id, s)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
return to, cerr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||||
|
@ -460,15 +276,6 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||||
return role, nil
|
return role, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
|
|
||||||
var err error
|
|
||||||
if v := c.Value(userIDKey); v != nil {
|
|
||||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *coreContext) render(w io.Writer, data []byte) error {
|
func (c *coreContext) render(w io.Writer, data []byte) error {
|
||||||
c.res.Data = json.RawMessage(data)
|
c.res.Data = json.RawMessage(data)
|
||||||
return json.NewEncoder(w).Encode(c.res)
|
return json.NewEncoder(w).Encode(c.res)
|
||||||
|
@ -560,6 +367,15 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
||||||
return fm, sm
|
return fm, sm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setLocalUserID(c context.Context, tx pgx.Tx) error {
|
||||||
|
var err error
|
||||||
|
if v := c.Value(userIDKey); v != nil {
|
||||||
|
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func isSkipped(n uint32, pos uint32) bool {
|
func isSkipped(n uint32, pos uint32) bool {
|
||||||
return ((n & (1 << pos)) != 0)
|
return ((n & (1 << pos)) != 0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/dosco/super-graph/psql"
|
"github.com/dosco/super-graph/psql"
|
||||||
|
@ -17,172 +18,171 @@ type stmt struct {
|
||||||
sql string
|
sql string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) buildStmt() ([]stmt, error) {
|
func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) {
|
||||||
var vars map[string]json.RawMessage
|
switch qt {
|
||||||
|
case qcode.QTMutation:
|
||||||
|
return buildRoleStmt(gql, vars, role)
|
||||||
|
|
||||||
if len(c.req.Vars) != 0 {
|
case qcode.QTQuery:
|
||||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
switch {
|
||||||
|
case role == "anon":
|
||||||
|
return buildRoleStmt(gql, vars, role)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return buildMultiStmt(gql, vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown query type '%d'", qt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
|
||||||
|
ro, ok := conf.roles[role]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf(`roles '%s' not defined in config`, role)
|
||||||
|
}
|
||||||
|
|
||||||
|
var vm map[string]json.RawMessage
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if len(vars) != 0 {
|
||||||
|
if err := json.Unmarshal(vars, &vm); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gql := []byte(c.req.Query)
|
qc, err := qcompile.Compile(gql, ro.Name)
|
||||||
|
|
||||||
if len(conf.Roles) == 0 {
|
|
||||||
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
|
|
||||||
}
|
|
||||||
|
|
||||||
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
stmts := make([]stmt, 0, len(conf.Roles))
|
// For the 'anon' role in production only compile
|
||||||
mutation := (qc.Type != qcode.QTQuery)
|
// queries for tables defined in the config file.
|
||||||
|
if conf.Production &&
|
||||||
|
ro.Name == "anon" &&
|
||||||
|
hasTablesWithConfig(qc, ro) == false {
|
||||||
|
return nil, errors.New("query contains tables with no 'anon' role config")
|
||||||
|
}
|
||||||
|
|
||||||
|
stmts := []stmt{stmt{role: ro, qc: qc}}
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
|
|
||||||
for i := 1; i < len(conf.Roles); i++ {
|
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stmts[0].skipped = skipped
|
||||||
|
stmts[0].sql = w.String()
|
||||||
|
|
||||||
|
return stmts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
|
||||||
|
var vm map[string]json.RawMessage
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if len(vars) != 0 {
|
||||||
|
if err := json.Unmarshal(vars, &vm); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(conf.RolesQuery) == 0 {
|
||||||
|
return buildRoleStmt(gql, vars, "user")
|
||||||
|
}
|
||||||
|
|
||||||
|
stmts := make([]stmt, 0, len(conf.Roles))
|
||||||
|
w := &bytes.Buffer{}
|
||||||
|
|
||||||
|
for i := 0; i < len(conf.Roles); i++ {
|
||||||
role := &conf.Roles[i]
|
role := &conf.Roles[i]
|
||||||
|
|
||||||
// For mutations only render sql for a single role from the request
|
qc, err := qcompile.Compile(gql, role.Name)
|
||||||
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
qc, err = qcompile.Compile(gql, role.Name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.Production && role.Name == "anon" {
|
|
||||||
for _, id := range qc.Roots {
|
|
||||||
root := qc.Selects[id]
|
|
||||||
if _, ok := role.tablesMap[root.Table]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stmts = append(stmts, stmt{role: role, qc: qc})
|
stmts = append(stmts, stmt{role: role, qc: qc})
|
||||||
|
|
||||||
if mutation {
|
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
|
||||||
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
|
if err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &stmts[len(stmts)-1]
|
|
||||||
s.skipped = skipped
|
|
||||||
s.sql = w.String()
|
|
||||||
w.Reset()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s := &stmts[len(stmts)-1]
|
||||||
|
s.skipped = skipped
|
||||||
|
s.sql = w.String()
|
||||||
|
w.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
if mutation {
|
sql, err := renderUserQuery(stmts, vm)
|
||||||
return stmts, nil
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stmts[0].sql = sql
|
||||||
|
return stmts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderUserQuery(
|
||||||
|
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
w := &bytes.Buffer{}
|
||||||
|
|
||||||
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
||||||
|
|
||||||
for _, s := range stmts {
|
for _, s := range stmts {
|
||||||
|
if len(s.role.Match) == 0 &&
|
||||||
|
s.role.Name != "user" && s.role.Name != "anon" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
io.WriteString(w, `WHEN '`)
|
io.WriteString(w, `WHEN '`)
|
||||||
io.WriteString(w, s.role.Name)
|
io.WriteString(w, s.role.Name)
|
||||||
io.WriteString(w, `' THEN (`)
|
io.WriteString(w, `' THEN (`)
|
||||||
|
|
||||||
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
|
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
io.WriteString(w, `) `)
|
io.WriteString(w, `) `)
|
||||||
}
|
}
|
||||||
io.WriteString(w, `END) FROM (`)
|
|
||||||
|
|
||||||
if len(conf.RolesQuery) == 0 {
|
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
|
||||||
v := c.Value(userRoleKey)
|
io.WriteString(w, conf.RolesQuery)
|
||||||
|
io.WriteString(w, `) THEN `)
|
||||||
|
|
||||||
io.WriteString(w, `VALUES ("`)
|
io.WriteString(w, `(SELECT (CASE`)
|
||||||
if v != nil {
|
for _, s := range stmts {
|
||||||
io.WriteString(w, v.(string))
|
if len(s.role.Match) == 0 {
|
||||||
} else {
|
continue
|
||||||
io.WriteString(w, c.req.role)
|
|
||||||
}
|
}
|
||||||
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
|
io.WriteString(w, ` WHEN `)
|
||||||
|
io.WriteString(w, s.role.Match)
|
||||||
} else {
|
io.WriteString(w, ` THEN '`)
|
||||||
|
io.WriteString(w, s.role.Name)
|
||||||
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
|
io.WriteString(w, `'`)
|
||||||
io.WriteString(w, conf.RolesQuery)
|
|
||||||
io.WriteString(w, `) THEN `)
|
|
||||||
|
|
||||||
io.WriteString(w, `(SELECT (CASE`)
|
|
||||||
for _, s := range stmts {
|
|
||||||
if len(s.role.Match) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
io.WriteString(w, ` WHEN `)
|
|
||||||
io.WriteString(w, s.role.Match)
|
|
||||||
io.WriteString(w, ` THEN '`)
|
|
||||||
io.WriteString(w, s.role.Name)
|
|
||||||
io.WriteString(w, `'`)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(c.req.role) == 0 {
|
|
||||||
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
|
|
||||||
} else {
|
|
||||||
io.WriteString(w, ` ELSE '`)
|
|
||||||
io.WriteString(w, c.req.role)
|
|
||||||
io.WriteString(w, `' END) FROM (`)
|
|
||||||
}
|
|
||||||
|
|
||||||
io.WriteString(w, conf.RolesQuery)
|
|
||||||
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
|
|
||||||
if len(c.req.role) == 0 {
|
|
||||||
io.WriteString(w, `anon`)
|
|
||||||
} else {
|
|
||||||
io.WriteString(w, c.req.role)
|
|
||||||
}
|
|
||||||
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
stmts[0].sql = w.String()
|
io.WriteString(w, ` ELSE 'user' END) FROM (`)
|
||||||
stmts[0].role = nil
|
io.WriteString(w, conf.RolesQuery)
|
||||||
|
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
||||||
|
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
||||||
|
|
||||||
return stmts, nil
|
return w.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) buildStmtByRole(role string) (stmt, error) {
|
func hasTablesWithConfig(qc *qcode.QCode, role *configRole) bool {
|
||||||
var st stmt
|
for _, id := range qc.Roots {
|
||||||
var err error
|
t, err := schema.GetTable(qc.Selects[id].Table)
|
||||||
|
if err != nil {
|
||||||
if len(role) == 0 {
|
return false
|
||||||
return st, errors.New(`no role defined`)
|
}
|
||||||
}
|
if _, ok := role.tablesMap[t.Name]; !ok {
|
||||||
|
return false
|
||||||
var vars map[string]json.RawMessage
|
|
||||||
|
|
||||||
if len(c.req.Vars) != 0 {
|
|
||||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
|
||||||
return st, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
gql := []byte(c.req.Query)
|
|
||||||
|
|
||||||
st.qc, err = qcompile.Compile(gql, role)
|
|
||||||
if err != nil {
|
|
||||||
return st, err
|
|
||||||
}
|
|
||||||
|
|
||||||
w := &bytes.Buffer{}
|
|
||||||
|
|
||||||
st.skipped, err = pcompile.Compile(st.qc, w, psql.Variables(vars))
|
|
||||||
if err != nil {
|
|
||||||
return st, err
|
|
||||||
}
|
|
||||||
|
|
||||||
st.sql = w.String()
|
|
||||||
|
|
||||||
return st, nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
package serv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cespare/xxhash/v2"
|
||||||
|
"github.com/dosco/super-graph/jsn"
|
||||||
|
"github.com/dosco/super-graph/qcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]byte, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if len(data) == 0 || st.skipped == 0 {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sel := st.qc.Selects
|
||||||
|
h := xxhash.New()
|
||||||
|
|
||||||
|
// fetch the field name used within the db response json
|
||||||
|
// that are used to mark insertion points and the mapping between
|
||||||
|
// those field names and their select objects
|
||||||
|
fids, sfmap := parentFieldIds(h, sel, st.skipped)
|
||||||
|
|
||||||
|
// fetch the field values of the marked insertion points
|
||||||
|
// these values contain the id to be used with fetching remote data
|
||||||
|
from := jsn.Get(data, fids)
|
||||||
|
var to []jsn.Field
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(from) == 1:
|
||||||
|
to, err = resolveRemote(hdr, h, from[0], sel, sfmap)
|
||||||
|
|
||||||
|
case len(from) > 1:
|
||||||
|
to, err = resolveRemotes(hdr, h, from, sel, sfmap)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.New("something wrong no remote ids found in db response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ob bytes.Buffer
|
||||||
|
|
||||||
|
err = jsn.Replace(&ob, data, from, to)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ob.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveRemote(
|
||||||
|
hdr http.Header,
|
||||||
|
h *xxhash.Digest,
|
||||||
|
field jsn.Field,
|
||||||
|
sel []qcode.Select,
|
||||||
|
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
||||||
|
|
||||||
|
// replacement data for the marked insertion points
|
||||||
|
// key and value will be replaced by whats below
|
||||||
|
toA := [1]jsn.Field{}
|
||||||
|
to := toA[:1]
|
||||||
|
|
||||||
|
// use the json key to find the related Select object
|
||||||
|
k1 := xxhash.Sum64(field.Key)
|
||||||
|
|
||||||
|
s, ok := sfmap[k1]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
p := sel[s.ParentID]
|
||||||
|
|
||||||
|
// then use the Table nme in the Select and it's parent
|
||||||
|
// to find the resolver to use for this relationship
|
||||||
|
k2 := mkkey(h, s.Table, p.Table)
|
||||||
|
|
||||||
|
r, ok := rmap[k2]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id := jsn.Value(field.Value)
|
||||||
|
if len(id) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//st := time.Now()
|
||||||
|
|
||||||
|
b, err := r.Fn(hdr, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Path) != 0 {
|
||||||
|
b = jsn.Strip(b, r.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ob bytes.Buffer
|
||||||
|
|
||||||
|
if len(s.Cols) != 0 {
|
||||||
|
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
ob.WriteString("null")
|
||||||
|
}
|
||||||
|
|
||||||
|
to[0] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
|
||||||
|
return to, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveRemotes(
|
||||||
|
hdr http.Header,
|
||||||
|
h *xxhash.Digest,
|
||||||
|
from []jsn.Field,
|
||||||
|
sel []qcode.Select,
|
||||||
|
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
||||||
|
|
||||||
|
// replacement data for the marked insertion points
|
||||||
|
// key and value will be replaced by whats below
|
||||||
|
to := make([]jsn.Field, len(from))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(from))
|
||||||
|
|
||||||
|
var cerr error
|
||||||
|
|
||||||
|
for i, id := range from {
|
||||||
|
|
||||||
|
// use the json key to find the related Select object
|
||||||
|
k1 := xxhash.Sum64(id.Key)
|
||||||
|
|
||||||
|
s, ok := sfmap[k1]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
p := sel[s.ParentID]
|
||||||
|
|
||||||
|
// then use the Table nme in the Select and it's parent
|
||||||
|
// to find the resolver to use for this relationship
|
||||||
|
k2 := mkkey(h, s.Table, p.Table)
|
||||||
|
|
||||||
|
r, ok := rmap[k2]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id := jsn.Value(id.Value)
|
||||||
|
if len(id) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(n int, id []byte, s *qcode.Select) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
//st := time.Now()
|
||||||
|
|
||||||
|
b, err := r.Fn(hdr, id)
|
||||||
|
if err != nil {
|
||||||
|
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Path) != 0 {
|
||||||
|
b = jsn.Strip(b, r.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ob bytes.Buffer
|
||||||
|
|
||||||
|
if len(s.Cols) != 0 {
|
||||||
|
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
||||||
|
if err != nil {
|
||||||
|
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
ob.WriteString("null")
|
||||||
|
}
|
||||||
|
|
||||||
|
to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
|
||||||
|
}(i, id, s)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return to, cerr
|
||||||
|
}
|
|
@ -4,7 +4,6 @@ package serv
|
||||||
|
|
||||||
func Fuzz(data []byte) int {
|
func Fuzz(data []byte) int {
|
||||||
gql := string(data)
|
gql := string(data)
|
||||||
isMutation(gql)
|
|
||||||
gqlHash(gql, nil, "")
|
gqlHash(gql, nil, "")
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|
|
@ -10,7 +10,6 @@ func TestFuzzCrashers(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range crashers {
|
for _, f := range crashers {
|
||||||
isMutation(f)
|
|
||||||
gqlHash(f, nil, "")
|
gqlHash(f, nil, "")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,6 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
upgrader = websocket.Upgrader{}
|
upgrader = websocket.Upgrader{}
|
||||||
errNoUserID = errors.New("no user_id available")
|
|
||||||
errUnauthorized = errors.New("not authorized")
|
errUnauthorized = errors.New("not authorized")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -78,7 +77,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
|
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Msg("failed to read request body")
|
errlog.Error().Err(err).Msg("failed to read request body")
|
||||||
errorResp(w, err)
|
errorResp(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -86,7 +85,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
err = json.Unmarshal(b, &ctx.req)
|
err = json.Unmarshal(b, &ctx.req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Msg("failed to decode json request body")
|
errlog.Error().Err(err).Msg("failed to decode json request body")
|
||||||
errorResp(w, err)
|
errorResp(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -105,7 +104,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Msg("failed to handle request")
|
errlog.Error().Err(err).Msg("failed to handle request")
|
||||||
errorResp(w, err)
|
errorResp(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
158
serv/prepare.go
158
serv/prepare.go
|
@ -3,7 +3,6 @@ package serv
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
@ -14,10 +13,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type preparedItem struct {
|
type preparedItem struct {
|
||||||
stmt *pgconn.StatementDescription
|
sd *pgconn.StatementDescription
|
||||||
args [][]byte
|
args [][]byte
|
||||||
skipped uint32
|
st *stmt
|
||||||
qc *qcode.QCode
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -25,85 +23,119 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func initPreparedList() {
|
func initPreparedList() {
|
||||||
ctx := context.Background()
|
c := context.Background()
|
||||||
|
|
||||||
tx, err := db.Begin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal().Err(err).Send()
|
|
||||||
}
|
|
||||||
defer tx.Rollback(ctx)
|
|
||||||
|
|
||||||
_preparedList = make(map[string]*preparedItem)
|
_preparedList = make(map[string]*preparedItem)
|
||||||
|
|
||||||
if err := prepareRoleStmt(ctx, tx); err != nil {
|
tx, err := db.Begin(c)
|
||||||
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
|
if err != nil {
|
||||||
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
defer tx.Rollback(c)
|
||||||
|
|
||||||
|
err = prepareRoleStmt(c, tx)
|
||||||
|
if err != nil {
|
||||||
|
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(c); err != nil {
|
||||||
|
errlog.Fatal().Err(err).Send()
|
||||||
|
}
|
||||||
|
|
||||||
|
success := 0
|
||||||
|
|
||||||
for _, v := range _allowList.list {
|
for _, v := range _allowList.list {
|
||||||
err := prepareStmt(ctx, tx, v.gql, v.vars)
|
if len(v.gql) == 0 {
|
||||||
if err != nil {
|
|
||||||
logger.Warn().Str("gql", v.gql).Err(err).Send()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(ctx); err != nil {
|
|
||||||
logger.Fatal().Err(err).Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info().Msgf("Registered %d queries from allow.list as prepared statements", len(_allowList.list))
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareStmt(ctx context.Context, tx pgx.Tx, gql string, varBytes json.RawMessage) error {
|
|
||||||
if len(gql) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &coreContext{Context: context.Background()}
|
|
||||||
c.req.Query = gql
|
|
||||||
c.req.Vars = varBytes
|
|
||||||
|
|
||||||
stmts, err := c.buildStmt()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
|
|
||||||
c.req.Vars = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, s := range stmts {
|
|
||||||
if len(s.sql) == 0 {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
finalSQL, am := processTemplate(s.sql)
|
err := prepareStmt(c, v.gql, v.vars)
|
||||||
|
if err == nil {
|
||||||
|
success++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
pstmt, err := tx.Prepare(c.Context, "", finalSQL)
|
if len(v.vars) == 0 {
|
||||||
|
logger.Warn().Err(err).Msg(v.gql)
|
||||||
|
} else {
|
||||||
|
logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info().
|
||||||
|
Msgf("Registered %d of %d queries from allow.list as prepared statements",
|
||||||
|
success, len(_allowList.list))
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||||
|
qt := qcode.GetQType(gql)
|
||||||
|
q := []byte(gql)
|
||||||
|
|
||||||
|
tx, err := db.Begin(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback(c)
|
||||||
|
|
||||||
|
switch qt {
|
||||||
|
case qcode.QTQuery:
|
||||||
|
stmts1, err := buildMultiStmt(q, vars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var key string
|
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user"))
|
||||||
|
if err != nil {
|
||||||
if s.role == nil {
|
return err
|
||||||
key = gqlHash(gql, c.req.Vars, "")
|
|
||||||
} else {
|
|
||||||
key = gqlHash(gql, c.req.Vars, s.role.Name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_preparedList[key] = &preparedItem{
|
stmts2, err := buildRoleStmt(q, vars, "anon")
|
||||||
stmt: pstmt,
|
if err != nil {
|
||||||
args: am,
|
return err
|
||||||
skipped: s.skipped,
|
|
||||||
qc: s.qc,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case qcode.QTMutation:
|
||||||
|
for _, role := range conf.Roles {
|
||||||
|
stmts, err := buildRoleStmt(q, vars, role.Name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(c); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareRoleStmt(ctx context.Context, tx pgx.Tx) error {
|
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
|
||||||
|
finalSQL, am := processTemplate(st.sql)
|
||||||
|
|
||||||
|
sd, err := tx.Prepare(c, "", finalSQL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_preparedList[key] = &preparedItem{
|
||||||
|
sd: sd,
|
||||||
|
args: am,
|
||||||
|
st: st,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
|
||||||
if len(conf.RolesQuery) == 0 {
|
if len(conf.RolesQuery) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -128,7 +160,7 @@ func prepareRoleStmt(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
|
||||||
roleSQL, _ := processTemplate(w.String())
|
roleSQL, _ := processTemplate(w.String())
|
||||||
|
|
||||||
_, err := tx.Prepare(ctx, "_sg_get_role", roleSQL)
|
_, err := tx.Prepare(c, "_sg_get_role", roleSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -168,7 +168,7 @@ func Do(log func(string, ...interface{}), additional ...dir) error {
|
||||||
func ReExec() {
|
func ReExec() {
|
||||||
err := syscall.Exec(binSelf, append([]string{binSelf}, os.Args[1:]...), os.Environ())
|
err := syscall.Exec(binSelf, append([]string{binSelf}, os.Args[1:]...), os.Environ())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Msg("cannot restart")
|
errlog.Fatal().Err(err).Msg("cannot restart")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -117,7 +117,7 @@ func buildFn(r configRemote) func(http.Header, []byte) ([]byte, error) {
|
||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error().Err(err).Msgf("Failed to connect to: %s", uri)
|
errlog.Error().Err(err).Msgf("Failed to connect to: %s", uri)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
26
serv/serv.go
26
serv/serv.go
|
@ -15,13 +15,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||||
schema, err := psql.NewDBSchema(db, c.getAliasMap())
|
var err error
|
||||||
|
|
||||||
|
schema, err = psql.NewDBSchema(db, c.getAliasMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
conf := qcode.Config{
|
conf := qcode.Config{
|
||||||
Blocklist: c.DB.Defaults.Blocklist,
|
Blocklist: c.DB.Blocklist,
|
||||||
KeepArgs: false,
|
KeepArgs: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,7 +108,7 @@ func initWatcher(cpath string) {
|
||||||
go func() {
|
go func() {
|
||||||
err := Do(logger.Printf, d)
|
err := Do(logger.Printf, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -139,7 +141,7 @@ func startHTTP() {
|
||||||
<-sigint
|
<-sigint
|
||||||
|
|
||||||
if err := srv.Shutdown(context.Background()); err != nil {
|
if err := srv.Shutdown(context.Background()); err != nil {
|
||||||
logger.Error().Err(err).Msg("shutdown signal received")
|
errlog.Error().Err(err).Msg("shutdown signal received")
|
||||||
}
|
}
|
||||||
close(idleConnsClosed)
|
close(idleConnsClosed)
|
||||||
}()
|
}()
|
||||||
|
@ -148,18 +150,14 @@ func startHTTP() {
|
||||||
db.Close()
|
db.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
var ident string
|
logger.Info().
|
||||||
|
Str("host_post", hostPort).
|
||||||
if len(conf.AppName) == 0 {
|
Str("app_name", conf.AppName).
|
||||||
ident = conf.Env
|
Str("env", conf.Env).
|
||||||
} else {
|
Msgf("%s listening", serverName)
|
||||||
ident = conf.AppName
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("%s listening on %s (%s)\n", serverName, hostPort, ident)
|
|
||||||
|
|
||||||
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
|
||||||
logger.Error().Err(err).Msg("server closed")
|
errlog.Error().Err(err).Msg("server closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
<-idleConnsClosed
|
<-idleConnsClosed
|
||||||
|
|
|
@ -106,19 +106,6 @@ func al(b byte) bool {
|
||||||
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
||||||
}
|
}
|
||||||
|
|
||||||
func isMutation(sql string) bool {
|
|
||||||
for i := range sql {
|
|
||||||
b := sql[i]
|
|
||||||
if b == '{' {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if al(b) {
|
|
||||||
return (b == 'm' || b == 'M')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func findStmt(role string, stmts []stmt) *stmt {
|
func findStmt(role string, stmts []stmt) *stmt {
|
||||||
for i := range stmts {
|
for i := range stmts {
|
||||||
if stmts[i].role.Name != role {
|
if stmts[i].role.Name != role {
|
||||||
|
|
20
tmpl/dev.yml
20
tmpl/dev.yml
|
@ -101,18 +101,14 @@ database:
|
||||||
variables:
|
variables:
|
||||||
admin_account_id: "5"
|
admin_account_id: "5"
|
||||||
|
|
||||||
# Define defaults to for the field key and values below
|
# Field and table names that you wish to block
|
||||||
defaults:
|
blocklist:
|
||||||
# filters: ["{ user_id: { eq: $user_id } }"]
|
- ar_internal_metadata
|
||||||
|
- schema_migrations
|
||||||
# Field and table names that you wish to block
|
- secret
|
||||||
blocklist:
|
- password
|
||||||
- ar_internal_metadata
|
- encrypted
|
||||||
- schema_migrations
|
- token
|
||||||
- secret
|
|
||||||
- password
|
|
||||||
- encrypted
|
|
||||||
- token
|
|
||||||
|
|
||||||
tables:
|
tables:
|
||||||
- name: customers
|
- name: customers
|
||||||
|
|
Loading…
Reference in New Issue