Add session variable for user id

This commit is contained in:
Vikram Rangnekar 2019-09-08 01:54:38 -04:00
parent fe4d1107ac
commit 65921d4d42
9 changed files with 81 additions and 57 deletions

View File

@ -1,5 +1,30 @@
# http://localhost:8080/ # http://localhost:8080/
query {
customers {
id
email
payments {
customer_id
amount
billing_details
}
}
}
query {
products(id: $PRODUCT_ID) {
name
}
}
query {
products(id: $PRODUCT_ID) {
name
image
}
}
variables { variables {
"update": { "update": {
"name": "Hellooooo", "name": "Hellooooo",
@ -20,8 +45,8 @@ mutation {
variables { variables {
"update": { "update": {
"name": "Hellooooo", "name": "Helloo",
"description": "World !!!!!" "description": "World \u003c\u003e"
}, },
"user": 123 "user": 123
} }
@ -34,18 +59,6 @@ mutation {
} }
} }
variables {
"id": 5
}
{
products(id: $ID) {
id
name
description
}
}
variables { variables {
"update": { "update": {
@ -70,28 +83,3 @@ query {
} }
} }
query {
customers {
id
email
payments {
customer_id
amount
billing_details
}
}
}
query {
products(id: $PRODUCT_ID) {
name
}
}
query {
products(id: $PRODUCT_ID) {
name
image
}
}

View File

@ -93,7 +93,7 @@ database:
# filter: ["{ user_id: { eq: $user_id } }"] # filter: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block # Field and table names that you wish to block
blacklist: blocklist:
- ar_internal_metadata - ar_internal_metadata
- schema_migrations - schema_migrations
- secret - secret

View File

@ -91,7 +91,7 @@ database:
filter: ["{ user_id: { eq: $user_id } }"] filter: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block # Field and table names that you wish to block
blacklist: blocklist:
- ar_internal_metadata - ar_internal_metadata
- schema_migrations - schema_migrations
- secret - secret

View File

@ -4,7 +4,7 @@ build_name: runner-build
build_log: runner-build-errors.log build_log: runner-build-errors.log
valid_ext: .go, .tpl, .tmpl, .html, .yml, *.list valid_ext: .go, .tpl, .tmpl, .html, .yml, *.list
no_rebuild_ext: .tpl, .tmpl, .html no_rebuild_ext: .tpl, .tmpl, .html
ignored: web, tmp, vendor, rails-app, docs ignored: web, tmp, vendor, rails-app, docs, slides, bench, corpus
build_delay: 600 build_delay: 600
colors: 1 colors: 1
log_color_main: cyan log_color_main: cyan

View File

@ -38,7 +38,7 @@ func TestMain(m *testing.M) {
"{ id: { eq: $user_id } }", "{ id: { eq: $user_id } }",
}, },
}, },
Blacklist: []string{ Blocklist: []string{
"secret", "secret",
"password", "password",
"token", "token",

View File

@ -147,7 +147,7 @@ const (
type Config struct { type Config struct {
DefaultFilter []string DefaultFilter []string
FilterMap map[string][]string FilterMap map[string][]string
Blacklist []string Blocklist []string
KeepArgs bool KeepArgs bool
} }
@ -168,10 +168,10 @@ var expPool = sync.Pool{
} }
func NewCompiler(c Config) (*Compiler, error) { func NewCompiler(c Config) (*Compiler, error) {
bl := make(map[string]struct{}, len(c.Blacklist)) bl := make(map[string]struct{}, len(c.Blocklist))
for i := range c.Blacklist { for i := range c.Blocklist {
bl[c.Blacklist[i]] = struct{}{} bl[c.Blocklist[i]] = struct{}{}
} }
fl, err := compileFilter(c.DefaultFilter) fl, err := compileFilter(c.DefaultFilter)
@ -669,6 +669,9 @@ func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error {
if arg.Val.Type != nodeBool { if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
} }
if arg.Val.Val == "false" {
sel.Action = 0
}
default: default:
if arg.Val.Type != nodeVar { if arg.Val.Type != nodeVar {

View File

@ -60,10 +60,9 @@ type config struct {
Defaults struct { Defaults struct {
Filter []string Filter []string
Blacklist []string Blocklist []string
} }
Fields []configTable
Tables []configTable Tables []configTable
} `mapstructure:"database"` } `mapstructure:"database"`
} }
@ -72,7 +71,7 @@ type configTable struct {
Name string Name string
Filter []string Filter []string
Table string Table string
Blacklist []string Blocklist []string
Remotes []configRemote Remotes []configRemote
} }

View File

@ -262,10 +262,30 @@ func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, err
var root json.RawMessage var root json.RawMessage
vars := varList(c, ps.args) vars := varList(c, ps.args)
_, err := ps.stmt.QueryOne(pg.Scan(&root), vars...) tx, err := db.Begin()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer tx.Rollback()
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(`SET LOCAL SESSION "user.id" = ?`, v)
if err != nil {
return nil, nil, err
}
}
_, err = tx.Stmt(ps.stmt).QueryOne(pg.Scan(&root), vars...)
if err != nil {
return nil, nil, err
}
if err := tx.Commit(); err != nil {
return nil, nil, err
}
// w.WriteString(`SET LOCAL SESSION "user.id" = '{{user_id}}'; `)
fmt.Printf("PRE: %v\n", ps.stmt) fmt.Printf("PRE: %v\n", ps.stmt)
@ -314,15 +334,33 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
st = time.Now() st = time.Now()
} }
tx, err := db.Begin()
if err != nil {
return nil, 0, err
}
defer tx.Rollback()
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(`SET LOCAL SESSION "user.id" = ?`, v)
if err != nil {
return nil, 0, err
}
}
fmt.Printf("RAW: %#v\n", finalSQL) fmt.Printf("RAW: %#v\n", finalSQL)
var root json.RawMessage var root json.RawMessage
_, err = db.QueryOne(pg.Scan(&root), finalSQL) _, err = tx.QueryOne(pg.Scan(&root), finalSQL)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
if err := tx.Commit(); err != nil {
return nil, 0, err
}
if conf.EnableTracing && len(qc.Selects) != 0 { if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace( c.addTrace(
qc.Selects, qc.Selects,

View File

@ -99,10 +99,6 @@ func initConf(path string) (*config, error) {
flect.AddPlural(k, v) flect.AddPlural(k, v)
} }
if len(c.DB.Tables) == 0 {
c.DB.Tables = c.DB.Fields
}
for i := range c.DB.Tables { for i := range c.DB.Tables {
t := c.DB.Tables[i] t := c.DB.Tables[i]
t.Name = flect.Pluralize(strings.ToLower(t.Name)) t.Name = flect.Pluralize(strings.ToLower(t.Name))
@ -159,7 +155,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
qc, err := qcode.NewCompiler(qcode.Config{ qc, err := qcode.NewCompiler(qcode.Config{
DefaultFilter: c.DB.Defaults.Filter, DefaultFilter: c.DB.Defaults.Filter,
FilterMap: c.getFilterMap(), FilterMap: c.getFilterMap(),
Blacklist: c.DB.Defaults.Blacklist, Blocklist: c.DB.Defaults.Blocklist,
KeepArgs: false, KeepArgs: false,
}) })