Get RBAC working for queries and mutations

This commit is contained in:
Vikram Rangnekar 2019-10-24 02:07:42 -04:00
parent c797deb4d0
commit 6bc66d28bc
19 changed files with 902 additions and 568 deletions

View File

@ -22,7 +22,7 @@ enable_tracing: true
# Watch the config folder and reload Super Graph # Watch the config folder and reload Super Graph
# with the new configs when a change is detected # with the new configs when a change is detected
reload_on_config_change: false reload_on_config_change: true
# File that points to the database seeding script # File that points to the database seeding script
# seed_file: seed.js # seed_file: seed.js
@ -53,7 +53,7 @@ auth:
# Comment this out if you want to disable setting # Comment this out if you want to disable setting
# the user_id via a header. Good for testing # the user_id via a header. Good for testing
header: X-User-ID creds_in_header: true
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
@ -143,6 +143,8 @@ tables:
name: me name: me
table: users table: users
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
roles: roles:
- name: anon - name: anon
tables: tables:
@ -164,6 +166,10 @@ roles:
- name: user - name: user
tables: tables:
- name: users
query:
filter: ["{ id: { _eq: $user_id } }"]
- name: products - name: products
query: query:
@ -189,9 +195,10 @@ roles:
delete: delete:
deny: true deny: true
- name: manager - name: admin
match: id = 1
tables: tables:
- name: users - name: users
select: # select:
filter: ["{ account_id: { _eq: $account_id } }"] # filter: ["{ account_id: { _eq: $account_id } }"]

View File

@ -47,10 +47,6 @@ auth:
type: rails type: rails
cookie: _app_session cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
# various cookies formats. # various cookies formats.

View File

@ -1,7 +1,6 @@
package psql package psql
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -12,7 +11,7 @@ import (
var noLimit = qcode.Paging{NoLimit: true} var noLimit = qcode.Paging{NoLimit: true}
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) {
if len(qc.Selects) == 0 { if len(qc.Selects) == 0 {
return 0, errors.New("empty query") return 0, errors.New("empty query")
} }
@ -25,9 +24,9 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
return 0, err return 0, err
} }
c.w.WriteString(`WITH `) io.WriteString(c.w, `WITH `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
c.w.WriteString(` AS `) io.WriteString(c.w, ` AS `)
switch qc.Type { switch qc.Type {
case qcode.QTInsert: case qcode.QTInsert:
@ -67,7 +66,7 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
return c.compileQuery(qc, w) return c.compileQuery(qc, w)
} }
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
insert, ok := vars[qc.ActionVar] insert, ok := vars[qc.ActionVar]
@ -80,32 +79,32 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
return 0, err return 0, err
} }
c.w.WriteString(`(WITH "input" AS (SELECT {{`) io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
c.w.WriteString(qc.ActionVar) io.WriteString(c.w, qc.ActionVar)
c.w.WriteString(`}}::json AS j) INSERT INTO `) io.WriteString(c.w, `}}::json AS j) INSERT INTO `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` (`) io.WriteString(c.w, ` (`)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
c.w.WriteString(` SELECT `) io.WriteString(c.w, ` SELECT `)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(` FROM input i, `) io.WriteString(c.w, ` FROM input i, `)
if array { if array {
c.w.WriteString(`json_populate_recordset`) io.WriteString(c.w, `json_populate_recordset`)
} else { } else {
c.w.WriteString(`json_populate_record`) io.WriteString(c.w, `json_populate_record`)
} }
c.w.WriteString(`(NULL::`) io.WriteString(c.w, `(NULL::`)
c.w.WriteString(ti.Name) io.WriteString(c.w, ti.Name)
c.w.WriteString(`, i.j) t`) io.WriteString(c.w, `, i.j) t`)
return 0, nil return 0, nil
} }
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer,
jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { jt map[string]interface{}, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0] root := &qc.Selects[0]
@ -122,14 +121,14 @@ func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Bu
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
c.w.WriteString(cn) io.WriteString(c.w, cn)
i++ i++
} }
return 0, nil return 0, nil
} }
func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0] root := &qc.Selects[0]
@ -143,26 +142,26 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
return 0, err return 0, err
} }
c.w.WriteString(`(WITH "input" AS (SELECT {{`) io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
c.w.WriteString(qc.ActionVar) io.WriteString(c.w, qc.ActionVar)
c.w.WriteString(`}}::json AS j) UPDATE `) io.WriteString(c.w, `}}::json AS j) UPDATE `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` SET (`) io.WriteString(c.w, ` SET (`)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(`) = (SELECT `) io.WriteString(c.w, `) = (SELECT `)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(` FROM input i, `) io.WriteString(c.w, ` FROM input i, `)
if array { if array {
c.w.WriteString(`json_populate_recordset`) io.WriteString(c.w, `json_populate_recordset`)
} else { } else {
c.w.WriteString(`json_populate_record`) io.WriteString(c.w, `json_populate_record`)
} }
c.w.WriteString(`(NULL::`) io.WriteString(c.w, `(NULL::`)
c.w.WriteString(ti.Name) io.WriteString(c.w, ti.Name)
c.w.WriteString(`, i.j) t)`) io.WriteString(c.w, `, i.j) t)`)
io.WriteString(c.w, ` WHERE `) io.WriteString(c.w, ` WHERE `)
@ -173,11 +172,11 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
return 0, nil return 0, nil
} }
func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0] root := &qc.Selects[0]
c.w.WriteString(`(DELETE FROM `) io.WriteString(c.w, `(DELETE FROM `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` WHERE `) io.WriteString(c.w, ` WHERE `)
@ -188,7 +187,7 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
return 0, nil return 0, nil
} }
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
upsert, ok := vars[qc.ActionVar] upsert, ok := vars[qc.ActionVar]
@ -205,7 +204,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
return 0, err return 0, err
} }
c.w.WriteString(` ON CONFLICT DO (`) io.WriteString(c.w, ` ON CONFLICT DO (`)
i := 0 i := 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.ColumnNames {
@ -220,15 +219,15 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
c.w.WriteString(cn) io.WriteString(c.w, cn)
i++ i++
} }
if i == 0 { if i == 0 {
c.w.WriteString(ti.PrimaryCol) io.WriteString(c.w, ti.PrimaryCol)
} }
c.w.WriteString(`) DO `) io.WriteString(c.w, `) DO `)
c.w.WriteString(`UPDATE `) io.WriteString(c.w, `UPDATE `)
io.WriteString(c.w, ` SET `) io.WriteString(c.w, ` SET `)
i = 0 i = 0
@ -239,17 +238,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
c.w.WriteString(cn) io.WriteString(c.w, cn)
io.WriteString(c.w, ` = EXCLUDED.`) io.WriteString(c.w, ` = EXCLUDED.`)
c.w.WriteString(cn) io.WriteString(c.w, cn)
i++ i++
} }
return 0, nil return 0, nil
} }
func quoted(w *bytes.Buffer, identifier string) { func quoted(w io.Writer, identifier string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(identifier) io.WriteString(w, identifier)
w.WriteString(`"`) io.WriteString(w, `"`)
} }

View File

@ -49,7 +49,7 @@ func (c *Compiler) IDColumn(table string) (string, error) {
} }
type compilerContext struct { type compilerContext struct {
w *bytes.Buffer w io.Writer
s []qcode.Select s []qcode.Select
*Compiler *Compiler
} }
@ -60,7 +60,7 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte,
return skipped, w.Bytes(), err return skipped, w.Bytes(), err
} }
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { func (co *Compiler) Compile(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) {
switch qc.Type { switch qc.Type {
case qcode.QTQuery: case qcode.QTQuery:
return co.compileQuery(qc, w) return co.compileQuery(qc, w)
@ -71,7 +71,7 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u
return 0, fmt.Errorf("Unknown operation type %d", qc.Type) return 0, fmt.Errorf("Unknown operation type %d", qc.Type)
} }
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
if len(qc.Selects) == 0 { if len(qc.Selects) == 0 {
return 0, errors.New("empty query") return 0, errors.New("empty query")
} }
@ -90,17 +90,17 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro
//fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`, //fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`,
//root.FieldName, root.Table) //root.FieldName, root.Table)
c.w.WriteString(`SELECT json_object_agg('`) io.WriteString(c.w, `SELECT json_object_agg('`)
c.w.WriteString(root.FieldName) io.WriteString(c.w, root.FieldName)
c.w.WriteString(`', `) io.WriteString(c.w, `', `)
if ti.Singular == false { if ti.Singular == false {
c.w.WriteString(root.Table) io.WriteString(c.w, root.Table)
} else { } else {
c.w.WriteString("sel_json_") io.WriteString(c.w, "sel_json_")
int2string(c.w, root.ID) int2string(c.w, root.ID)
} }
c.w.WriteString(`) FROM (`) io.WriteString(c.w, `) FROM (`)
var ignored uint32 var ignored uint32
@ -161,9 +161,8 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro
} }
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
alias(c.w, `done_1337`) alias(c.w, `done_1337`)
c.w.WriteString(`;`)
return ignored, nil return ignored, nil
} }
@ -219,10 +218,10 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
// SELECT // SELECT
if ti.Singular == false { if ti.Singular == false {
//fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table) //fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table)
c.w.WriteString(`SELECT coalesce(json_agg("`) io.WriteString(c.w, `SELECT coalesce(json_agg("`)
c.w.WriteString("sel_json_") io.WriteString(c.w, "sel_json_")
int2string(c.w, sel.ID) int2string(c.w, sel.ID)
c.w.WriteString(`"`) io.WriteString(c.w, `"`)
if hasOrder { if hasOrder {
err := c.renderOrderBy(sel, ti) err := c.renderOrderBy(sel, ti)
@ -232,24 +231,24 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
} }
//fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table) //fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table)
c.w.WriteString(`), '[]')`) io.WriteString(c.w, `), '[]')`)
alias(c.w, sel.Table) alias(c.w, sel.Table)
c.w.WriteString(` FROM (`) io.WriteString(c.w, ` FROM (`)
} }
// ROW-TO-JSON // ROW-TO-JSON
c.w.WriteString(`SELECT `) io.WriteString(c.w, `SELECT `)
if len(sel.DistinctOn) != 0 { if len(sel.DistinctOn) != 0 {
c.renderDistinctOn(sel, ti) c.renderDistinctOn(sel, ti)
} }
c.w.WriteString(`row_to_json((`) io.WriteString(c.w, `row_to_json((`)
//fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID) //fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID)
c.w.WriteString(`SELECT "sel_`) io.WriteString(c.w, `SELECT "sel_`)
int2string(c.w, sel.ID) int2string(c.w, sel.ID)
c.w.WriteString(`" FROM (SELECT `) io.WriteString(c.w, `" FROM (SELECT `)
// Combined column names // Combined column names
c.renderColumns(sel, ti) c.renderColumns(sel, ti)
@ -262,11 +261,11 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
} }
//fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID) //fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithID(c.w, "sel", sel.ID) aliasWithID(c.w, "sel", sel.ID)
//fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table) //fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
aliasWithID(c.w, "sel_json", sel.ID) aliasWithID(c.w, "sel_json", sel.ID)
// END-ROW-TO-JSON // END-ROW-TO-JSON
@ -301,27 +300,27 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
case len(sel.Paging.Limit) != 0: case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`) io.WriteString(c.w, ` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit) io.WriteString(c.w, sel.Paging.Limit)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
case ti.Singular: case ti.Singular:
c.w.WriteString(` LIMIT ('1') :: integer`) io.WriteString(c.w, ` LIMIT ('1') :: integer`)
default: default:
c.w.WriteString(` LIMIT ('20') :: integer`) io.WriteString(c.w, ` LIMIT ('20') :: integer`)
} }
if len(sel.Paging.Offset) != 0 { if len(sel.Paging.Offset) != 0 {
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
c.w.WriteString(`OFFSET ('`) io.WriteString(c.w, `OFFSET ('`)
c.w.WriteString(sel.Paging.Offset) io.WriteString(c.w, sel.Paging.Offset)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
} }
if ti.Singular == false { if ti.Singular == false {
//fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID) //fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithID(c.w, "sel_json_agg", sel.ID) aliasWithID(c.w, "sel_json_agg", sel.ID)
} }
@ -329,15 +328,15 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
} }
func (c *compilerContext) renderJoin(sel *qcode.Select) error { func (c *compilerContext) renderJoin(sel *qcode.Select) error {
c.w.WriteString(` LEFT OUTER JOIN LATERAL (`) io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
return nil return nil
} }
func (c *compilerContext) renderJoinClose(sel *qcode.Select) error { func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID) //fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join") aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
c.w.WriteString(` ON ('true')`) io.WriteString(c.w, ` ON ('true')`)
return nil return nil
} }
@ -360,13 +359,13 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
//fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`,
//rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1) //rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1)
c.w.WriteString(` LEFT OUTER JOIN "`) io.WriteString(c.w, ` LEFT OUTER JOIN "`)
c.w.WriteString(rel.Through) io.WriteString(c.w, rel.Through)
c.w.WriteString(`" ON ((`) io.WriteString(c.w, `" ON ((`)
colWithTable(c.w, rel.Through, rel.ColT) colWithTable(c.w, rel.Through, rel.ColT)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, pt.Name, parent.ID, rel.Col1) colWithTableID(c.w, pt.Name, parent.ID, rel.Col1)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
return nil return nil
} }
@ -443,11 +442,11 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
//s.Table, s.ID, s.Table, s.FieldName) //s.Table, s.ID, s.Table, s.FieldName)
if cti.Singular { if cti.Singular {
c.w.WriteString(`"sel_json_`) io.WriteString(c.w, `"sel_json_`)
int2string(c.w, childSel.ID) int2string(c.w, childSel.ID)
c.w.WriteString(`" AS "`) io.WriteString(c.w, `" AS "`)
c.w.WriteString(childSel.FieldName) io.WriteString(c.w, childSel.FieldName)
c.w.WriteString(`"`) io.WriteString(c.w, `"`)
} else { } else {
colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID,
@ -467,7 +466,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
isSearch := sel.Args["search"] != nil isSearch := sel.Args["search"] != nil
isAgg := false isAgg := false
c.w.WriteString(` FROM (SELECT `) io.WriteString(c.w, ` FROM (SELECT `)
i := 0 i := 0
for n, col := range sel.Cols { for n, col := range sel.Cols {
@ -483,15 +482,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name) //c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_rank(`) io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`, to_tsquery('`) io.WriteString(c.w, `, to_tsquery('`)
c.w.WriteString(arg.Val) io.WriteString(c.w, arg.Val)
c.w.WriteString(`')`) io.WriteString(c.w, `')`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++ i++
@ -500,15 +499,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name) //c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_headlinek(`) io.WriteString(c.w, `ts_headlinek(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`, to_tsquery('`) io.WriteString(c.w, `, to_tsquery('`)
c.w.WriteString(arg.Val) io.WriteString(c.w, arg.Val)
c.w.WriteString(`')`) io.WriteString(c.w, `')`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++ i++
@ -517,12 +516,12 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
pl := funcPrefixLen(cn) pl := funcPrefixLen(cn)
if pl == 0 { if pl == 0 {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
c.w.WriteString(cn) io.WriteString(c.w, cn)
c.w.WriteString(` not defined'`) io.WriteString(c.w, ` not defined'`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++ i++
@ -532,16 +531,16 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
continue continue
} }
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
fn := cn[0 : pl-1] fn := cn[0 : pl-1]
isAgg = true isAgg = true
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name) //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
c.w.WriteString(fn) io.WriteString(c.w, fn)
c.w.WriteString(`(`) io.WriteString(c.w, `(`)
colWithTable(c.w, ti.Name, cn1) colWithTable(c.w, ti.Name, cn1)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++ i++
@ -551,7 +550,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
groupBy = append(groupBy, n) groupBy = append(groupBy, n)
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
i++ i++
@ -561,7 +560,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
for _, col := range childCols { for _, col := range childCols {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name)
@ -569,29 +568,29 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
i++ i++
} }
c.w.WriteString(` FROM `) io.WriteString(c.w, ` FROM `)
//fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table)
c.w.WriteString(`"`) io.WriteString(c.w, `"`)
c.w.WriteString(ti.Name) io.WriteString(c.w, ti.Name)
c.w.WriteString(`"`) io.WriteString(c.w, `"`)
// if tn, ok := c.tmap[sel.Table]; ok { // if tn, ok := c.tmap[sel.Table]; ok {
// //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table) // //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table)
// tableWithAlias(c.w, ti.Name, sel.Table) // tableWithAlias(c.w, ti.Name, sel.Table)
// } else { // } else {
// //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) // //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table)
// c.w.WriteString(`"`) // io.WriteString(c.w, `"`)
// c.w.WriteString(sel.Table) // io.WriteString(c.w, sel.Table)
// c.w.WriteString(`"`) // io.WriteString(c.w, `"`)
// } // }
if isRoot && isFil { if isRoot && isFil {
c.w.WriteString(` WHERE (`) io.WriteString(c.w, ` WHERE (`)
if err := c.renderWhere(sel, ti); err != nil { if err := c.renderWhere(sel, ti); err != nil {
return err return err
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
if !isRoot { if !isRoot {
@ -599,28 +598,28 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
return err return err
} }
c.w.WriteString(` WHERE (`) io.WriteString(c.w, ` WHERE (`)
if err := c.renderRelationship(sel, ti); err != nil { if err := c.renderRelationship(sel, ti); err != nil {
return err return err
} }
if isFil { if isFil {
c.w.WriteString(` AND `) io.WriteString(c.w, ` AND `)
if err := c.renderWhere(sel, ti); err != nil { if err := c.renderWhere(sel, ti); err != nil {
return err return err
} }
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
if isAgg { if isAgg {
if len(groupBy) != 0 { if len(groupBy) != 0 {
c.w.WriteString(` GROUP BY `) io.WriteString(c.w, ` GROUP BY `)
for i, id := range groupBy { for i, id := range groupBy {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name)
colWithTable(c.w, ti.Name, sel.Cols[id].Name) colWithTable(c.w, ti.Name, sel.Cols[id].Name)
@ -634,26 +633,26 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
case len(sel.Paging.Limit) != 0: case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`) io.WriteString(c.w, ` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit) io.WriteString(c.w, sel.Paging.Limit)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
case ti.Singular: case ti.Singular:
c.w.WriteString(` LIMIT ('1') :: integer`) io.WriteString(c.w, ` LIMIT ('1') :: integer`)
default: default:
c.w.WriteString(` LIMIT ('20') :: integer`) io.WriteString(c.w, ` LIMIT ('20') :: integer`)
} }
if len(sel.Paging.Offset) != 0 { if len(sel.Paging.Offset) != 0 {
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
c.w.WriteString(` OFFSET ('`) io.WriteString(c.w, ` OFFSET ('`)
c.w.WriteString(sel.Paging.Offset) io.WriteString(c.w, sel.Paging.Offset)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
} }
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID) //fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithID(c.w, ti.Name, sel.ID) aliasWithID(c.w, ti.Name, sel.ID)
return nil return nil
} }
@ -664,7 +663,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
for i := range sel.OrderBy { for i := range sel.OrderBy {
if colsRendered { if colsRendered {
//io.WriteString(w, ", ") //io.WriteString(w, ", ")
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
col := sel.OrderBy[i].Col col := sel.OrderBy[i].Col
@ -672,7 +671,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
//c.sel.Table, c.sel.ID, c, //c.sel.Table, c.sel.ID, c,
//c.sel.Table, c.sel.ID, c) //c.sel.Table, c.sel.ID, c)
colWithTableID(c.w, ti.Name, sel.ID, col) colWithTableID(c.w, ti.Name, sel.ID, col)
c.w.WriteString(` AS `) io.WriteString(c.w, ` AS `)
tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob") tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob")
} }
} }
@ -689,29 +688,29 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo)
case RelBelongTo: case RelBelongTo:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
case RelOneToMany: case RelOneToMany:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
case RelOneToManyThrough: case RelOneToManyThrough:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
//c.sel.Table, rel.Col1, rel.Through, rel.Col2) //c.sel.Table, rel.Col1, rel.Through, rel.Col2)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTable(c.w, rel.Through, rel.Col2) colWithTable(c.w, rel.Through, rel.Col2)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
} }
return nil return nil
@ -735,11 +734,11 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
case qcode.ExpOp: case qcode.ExpOp:
switch val { switch val {
case qcode.OpAnd: case qcode.OpAnd:
c.w.WriteString(` AND `) io.WriteString(c.w, ` AND `)
case qcode.OpOr: case qcode.OpOr:
c.w.WriteString(` OR `) io.WriteString(c.w, ` OR `)
case qcode.OpNot: case qcode.OpNot:
c.w.WriteString(`NOT `) io.WriteString(c.w, `NOT `)
default: default:
return fmt.Errorf("11: unexpected value %v (%t)", intf, intf) return fmt.Errorf("11: unexpected value %v (%t)", intf, intf)
} }
@ -763,62 +762,62 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
default: default:
if val.NestedCol { if val.NestedCol {
//fmt.Fprintf(w, `(("%s") `, val.Col) //fmt.Fprintf(w, `(("%s") `, val.Col)
c.w.WriteString(`(("`) io.WriteString(c.w, `(("`)
c.w.WriteString(val.Col) io.WriteString(c.w, val.Col)
c.w.WriteString(`") `) io.WriteString(c.w, `") `)
} else if len(val.Col) != 0 { } else if len(val.Col) != 0 {
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col) //fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, val.Col) colWithTable(c.w, ti.Name, val.Col)
c.w.WriteString(`) `) io.WriteString(c.w, `) `)
} }
valExists := true valExists := true
switch val.Op { switch val.Op {
case qcode.OpEquals: case qcode.OpEquals:
c.w.WriteString(`=`) io.WriteString(c.w, `=`)
case qcode.OpNotEquals: case qcode.OpNotEquals:
c.w.WriteString(`!=`) io.WriteString(c.w, `!=`)
case qcode.OpGreaterOrEquals: case qcode.OpGreaterOrEquals:
c.w.WriteString(`>=`) io.WriteString(c.w, `>=`)
case qcode.OpLesserOrEquals: case qcode.OpLesserOrEquals:
c.w.WriteString(`<=`) io.WriteString(c.w, `<=`)
case qcode.OpGreaterThan: case qcode.OpGreaterThan:
c.w.WriteString(`>`) io.WriteString(c.w, `>`)
case qcode.OpLesserThan: case qcode.OpLesserThan:
c.w.WriteString(`<`) io.WriteString(c.w, `<`)
case qcode.OpIn: case qcode.OpIn:
c.w.WriteString(`IN`) io.WriteString(c.w, `IN`)
case qcode.OpNotIn: case qcode.OpNotIn:
c.w.WriteString(`NOT IN`) io.WriteString(c.w, `NOT IN`)
case qcode.OpLike: case qcode.OpLike:
c.w.WriteString(`LIKE`) io.WriteString(c.w, `LIKE`)
case qcode.OpNotLike: case qcode.OpNotLike:
c.w.WriteString(`NOT LIKE`) io.WriteString(c.w, `NOT LIKE`)
case qcode.OpILike: case qcode.OpILike:
c.w.WriteString(`ILIKE`) io.WriteString(c.w, `ILIKE`)
case qcode.OpNotILike: case qcode.OpNotILike:
c.w.WriteString(`NOT ILIKE`) io.WriteString(c.w, `NOT ILIKE`)
case qcode.OpSimilar: case qcode.OpSimilar:
c.w.WriteString(`SIMILAR TO`) io.WriteString(c.w, `SIMILAR TO`)
case qcode.OpNotSimilar: case qcode.OpNotSimilar:
c.w.WriteString(`NOT SIMILAR TO`) io.WriteString(c.w, `NOT SIMILAR TO`)
case qcode.OpContains: case qcode.OpContains:
c.w.WriteString(`@>`) io.WriteString(c.w, `@>`)
case qcode.OpContainedIn: case qcode.OpContainedIn:
c.w.WriteString(`<@`) io.WriteString(c.w, `<@`)
case qcode.OpHasKey: case qcode.OpHasKey:
c.w.WriteString(`?`) io.WriteString(c.w, `?`)
case qcode.OpHasKeyAny: case qcode.OpHasKeyAny:
c.w.WriteString(`?|`) io.WriteString(c.w, `?|`)
case qcode.OpHasKeyAll: case qcode.OpHasKeyAll:
c.w.WriteString(`?&`) io.WriteString(c.w, `?&`)
case qcode.OpIsNull: case qcode.OpIsNull:
if strings.EqualFold(val.Val, "true") { if strings.EqualFold(val.Val, "true") {
c.w.WriteString(`IS NULL)`) io.WriteString(c.w, `IS NULL)`)
} else { } else {
c.w.WriteString(`IS NOT NULL)`) io.WriteString(c.w, `IS NOT NULL)`)
} }
valExists = false valExists = false
case qcode.OpEqID: case qcode.OpEqID:
@ -826,20 +825,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
return fmt.Errorf("no primary key column defined for %s", ti.Name) return fmt.Errorf("no primary key column defined for %s", ti.Name)
} }
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol) //fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol) colWithTable(c.w, ti.Name, ti.PrimaryCol)
//c.w.WriteString(ti.PrimaryCol) //io.WriteString(c.w, ti.PrimaryCol)
c.w.WriteString(`) =`) io.WriteString(c.w, `) =`)
case qcode.OpTsQuery: case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 { if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name) return fmt.Errorf("no tsv column defined for %s", ti.Name)
} }
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) //fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
c.w.WriteString(`(("`) io.WriteString(c.w, `(("`)
c.w.WriteString(ti.TSVCol) io.WriteString(c.w, ti.TSVCol)
c.w.WriteString(`") @@ to_tsquery('`) io.WriteString(c.w, `") @@ to_tsquery('`)
c.w.WriteString(val.Val) io.WriteString(c.w, val.Val)
c.w.WriteString(`'))`) io.WriteString(c.w, `'))`)
valExists = false valExists = false
default: default:
@ -852,7 +851,7 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
} else { } else {
c.renderVal(val, c.vars) c.renderVal(val, c.vars)
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
qcode.FreeExp(val) qcode.FreeExp(val)
@ -868,10 +867,10 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
} }
func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error { func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error {
c.w.WriteString(` ORDER BY `) io.WriteString(c.w, ` ORDER BY `)
for i := range sel.OrderBy { for i := range sel.OrderBy {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
ob := sel.OrderBy[i] ob := sel.OrderBy[i]
@ -879,27 +878,27 @@ func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) erro
case qcode.OrderAsc: case qcode.OrderAsc:
//fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC`) io.WriteString(c.w, ` ASC`)
case qcode.OrderDesc: case qcode.OrderDesc:
//fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC`) io.WriteString(c.w, ` DESC`)
case qcode.OrderAscNullsFirst: case qcode.OrderAscNullsFirst:
//fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC NULLS FIRST`) io.WriteString(c.w, ` ASC NULLS FIRST`)
case qcode.OrderDescNullsFirst: case qcode.OrderDescNullsFirst:
//fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC NULLLS FIRST`) io.WriteString(c.w, ` DESC NULLLS FIRST`)
case qcode.OrderAscNullsLast: case qcode.OrderAscNullsLast:
//fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC NULLS LAST`) io.WriteString(c.w, ` ASC NULLS LAST`)
case qcode.OrderDescNullsLast: case qcode.OrderDescNullsLast:
//fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC NULLS LAST`) io.WriteString(c.w, ` DESC NULLS LAST`)
default: default:
return fmt.Errorf("13: unexpected value %v", ob.Order) return fmt.Errorf("13: unexpected value %v", ob.Order)
} }
@ -911,30 +910,30 @@ func (c *compilerContext) renderDistinctOn(sel *qcode.Select, ti *DBTableInfo) {
io.WriteString(c.w, `DISTINCT ON (`) io.WriteString(c.w, `DISTINCT ON (`)
for i := range sel.DistinctOn { for i := range sel.DistinctOn {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i]) //fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i])
tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob")
} }
c.w.WriteString(`) `) io.WriteString(c.w, `) `)
} }
func (c *compilerContext) renderList(ex *qcode.Exp) { func (c *compilerContext) renderList(ex *qcode.Exp) {
io.WriteString(c.w, ` (`) io.WriteString(c.w, ` (`)
for i := range ex.ListVal { for i := range ex.ListVal {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
switch ex.ListType { switch ex.ListType {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat: case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
c.w.WriteString(ex.ListVal[i]) io.WriteString(c.w, ex.ListVal[i])
case qcode.ValStr: case qcode.ValStr:
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
c.w.WriteString(ex.ListVal[i]) io.WriteString(c.w, ex.ListVal[i])
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
} }
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
@ -943,27 +942,27 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
switch ex.Type { switch ex.Type {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat: case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
if len(ex.Val) != 0 { if len(ex.Val) != 0 {
c.w.WriteString(ex.Val) io.WriteString(c.w, ex.Val)
} else { } else {
c.w.WriteString(`''`) io.WriteString(c.w, `''`)
} }
case qcode.ValStr: case qcode.ValStr:
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
c.w.WriteString(ex.Val) io.WriteString(c.w, ex.Val)
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
case qcode.ValVar: case qcode.ValVar:
if val, ok := vars[ex.Val]; ok { if val, ok := vars[ex.Val]; ok {
c.w.WriteString(val) io.WriteString(c.w, val)
} else { } else {
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val) //fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
c.w.WriteString(`{{`) io.WriteString(c.w, `{{`)
c.w.WriteString(ex.Val) io.WriteString(c.w, ex.Val)
c.w.WriteString(`}}`) io.WriteString(c.w, `}}`)
} }
} }
//c.w.WriteString(`)`) //io.WriteString(c.w, `)`)
} }
func funcPrefixLen(fn string) int { func funcPrefixLen(fn string) int {
@ -999,105 +998,105 @@ func hasBit(n uint32, pos uint32) bool {
return (val > 0) return (val > 0)
} }
func alias(w *bytes.Buffer, alias string) { func alias(w io.Writer, alias string) {
w.WriteString(` AS "`) io.WriteString(w, ` AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func aliasWithID(w *bytes.Buffer, alias string, id int32) { func aliasWithID(w io.Writer, alias string, id int32) {
w.WriteString(` AS "`) io.WriteString(w, ` AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func aliasWithIDSuffix(w *bytes.Buffer, alias string, id int32, suffix string) { func aliasWithIDSuffix(w io.Writer, alias string, id int32, suffix string) {
w.WriteString(` AS "`) io.WriteString(w, ` AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(suffix) io.WriteString(w, suffix)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithAlias(w *bytes.Buffer, col, alias string) { func colWithAlias(w io.Writer, col, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func tableWithAlias(w *bytes.Buffer, table, alias string) { func tableWithAlias(w io.Writer, table, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTable(w *bytes.Buffer, table, col string) { func colWithTable(w io.Writer, table, col string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTableID(w *bytes.Buffer, table string, id int32, col string) { func colWithTableID(w io.Writer, table string, id int32, col string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTableIDAlias(w *bytes.Buffer, table string, id int32, col, alias string) { func colWithTableIDAlias(w io.Writer, table string, id int32, col, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTableIDSuffixAlias(w *bytes.Buffer, table string, id int32, func colWithTableIDSuffixAlias(w io.Writer, table string, id int32,
suffix, col, alias string) { suffix, col, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(suffix) io.WriteString(w, suffix)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func tableIDColSuffix(w *bytes.Buffer, table string, id int32, col, suffix string) { func tableIDColSuffix(w io.Writer, table string, id int32, col, suffix string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`_`) io.WriteString(w, `_`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(suffix) io.WriteString(w, suffix)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
const charset = "0123456789" const charset = "0123456789"
func int2string(w *bytes.Buffer, val int32) { func int2string(w io.Writer, val int32) {
if val < 10 { if val < 10 {
w.WriteByte(charset[val]) w.Write([]byte{charset[val]})
return return
} }
@ -1113,7 +1112,7 @@ func int2string(w *bytes.Buffer, val int32) {
for val3 > 0 { for val3 > 0 {
d := val3 % 10 d := val3 % 10
val3 /= 10 val3 /= 10
w.WriteByte(charset[d]) w.Write([]byte{charset[d]})
} }
} }

View File

@ -182,7 +182,7 @@ func (al *allowList) load() {
item.vars = varBytes item.vars = varBytes
} }
al.list[gqlHash(q, varBytes)] = item al.list[gqlHash(q, varBytes, "")] = item
varBytes = nil varBytes = nil
} else if ty == AL_VARS { } else if ty == AL_VARS {
@ -203,7 +203,11 @@ func (al *allowList) save(item *allowItem) {
if al.active == false { if al.active == false {
return return
} }
al.list[gqlHash(item.gql, item.vars)] = item h := gqlHash(item.gql, item.vars, "")
if _, ok := al.list[h]; ok {
return
}
al.list[gqlHash(item.gql, item.vars, "")] = item
f, err := os.Create(al.filepath) f, err := os.Create(al.filepath)
if err != nil { if err != nil {

View File

@ -9,26 +9,40 @@ import (
var ( var (
userIDProviderKey = struct{}{} userIDProviderKey = struct{}{}
userIDKey = struct{}{} userIDKey = struct{}{}
userRoleKey = struct{}{}
) )
func headerAuth(r *http.Request, c *config) *http.Request { func headerAuth(next http.HandlerFunc) http.HandlerFunc {
if len(c.Auth.Header) == 0 { return func(w http.ResponseWriter, r *http.Request) {
return nil ctx := r.Context()
}
userID := r.Header.Get(c.Auth.Header) userIDProvider := r.Header.Get("X-User-ID-Provider")
if len(userID) != 0 { if len(userIDProvider) != 0 {
ctx := context.WithValue(r.Context(), userIDKey, userID) ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider)
return r.WithContext(ctx) }
}
return nil userID := r.Header.Get("X-User-ID")
if len(userID) != 0 {
ctx = context.WithValue(ctx, userIDKey, userID)
}
userRole := r.Header.Get("X-User-Role")
if len(userRole) != 0 {
ctx = context.WithValue(ctx, userRoleKey, userRole)
}
next.ServeHTTP(w, r.WithContext(ctx))
}
} }
func withAuth(next http.HandlerFunc) http.HandlerFunc { func withAuth(next http.HandlerFunc) http.HandlerFunc {
at := conf.Auth.Type at := conf.Auth.Type
ru := conf.Auth.Rails.URL ru := conf.Auth.Rails.URL
if conf.Auth.CredsInHeader {
next = headerAuth(next)
}
switch at { switch at {
case "rails": case "rails":
if strings.HasPrefix(ru, "memcache:") { if strings.HasPrefix(ru, "memcache:") {

View File

@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var tok string var tok string
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
if len(cookie) != 0 { if len(cookie) != 0 {
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
} }

View File

@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
rURL, err := url.Parse(conf.Auth.Rails.URL) rURL, err := url.Parse(conf.Auth.Rails.URL)
if err != nil { if err != nil {
logger.Fatal().Err(err) logger.Fatal().Err(err).Send()
} }
mc := memcache.New(rURL.Host) mc := memcache.New(rURL.Host)
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
ra, err := railsAuth(conf) ra, err := railsAuth(conf)
if err != nil { if err != nil {
logger.Fatal().Err(err) logger.Fatal().Err(err).Send()
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
logger.Error().Err(err) logger.Warn().Err(err).Send()
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
userID, err := ra.ParseCookie(ck.Value) userID, err := ra.ParseCookie(ck.Value)
if err != nil { if err != nil {
logger.Error().Err(err) logger.Warn().Err(err).Send()
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }

View File

@ -183,7 +183,32 @@ func initConf() (*config, error) {
} }
zerolog.SetGlobalLevel(logLevel) zerolog.SetGlobalLevel(logLevel)
//fmt.Printf("%#v", c) for k, v := range c.DB.Vars {
c.DB.Vars[k] = sanitize(v)
}
c.RolesQuery = sanitize(c.RolesQuery)
rolesMap := make(map[string]struct{})
for i := range c.Roles {
role := &c.Roles[i]
if _, ok := rolesMap[role.Name]; ok {
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
}
role.Name = sanitize(role.Name)
role.Match = sanitize(role.Match)
rolesMap[role.Name] = struct{}{}
}
if _, ok := rolesMap["user"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "user"})
}
if _, ok := rolesMap["anon"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "anon"})
}
return c, nil return c, nil
} }

View File

@ -66,8 +66,9 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} {
c := &coreContext{Context: context.Background()} c := &coreContext{Context: context.Background()}
c.req.Query = query c.req.Query = query
c.req.Vars = b c.req.Vars = b
c.req.role = "user"
res, err := c.execQuery("user") res, err := c.execQuery()
if err != nil { if err != nil {
logger.Fatal().Err(err).Msg("graphql query failed") logger.Fatal().Err(err).Msg("graphql query failed")
} }

View File

@ -1,7 +1,9 @@
package serv package serv
import ( import (
"regexp"
"strings" "strings"
"unicode"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -24,9 +26,9 @@ type config struct {
Inflections map[string]string Inflections map[string]string
Auth struct { Auth struct {
Type string Type string
Cookie string Cookie string
Header string CredsInHeader bool `mapstructure:"creds_in_header"`
Rails struct { Rails struct {
Version string Version string
@ -60,7 +62,7 @@ type config struct {
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"` LogLevel string `mapstructure:"log_level"`
vars map[string][]byte `mapstructure:"variables"` Vars map[string]string `mapstructure:"variables"`
Defaults struct { Defaults struct {
Filter []string Filter []string
@ -71,7 +73,9 @@ type config struct {
} `mapstructure:"database"` } `mapstructure:"database"`
Tables []configTable Tables []configTable
Roles []configRoles
RolesQuery string `mapstructure:"roles_query"`
Roles []configRole
} }
type configTable struct { type configTable struct {
@ -94,8 +98,9 @@ type configRemote struct {
} `mapstructure:"set_headers"` } `mapstructure:"set_headers"`
} }
type configRoles struct { type configRole struct {
Name string Name string
Match string
Tables []struct { Tables []struct {
Name string Name string
@ -163,26 +168,6 @@ func newConfig() *viper.Viper {
return vi return vi
} }
func (c *config) getVariables() map[string]string {
vars := make(map[string]string, len(c.DB.vars))
for k, v := range c.DB.vars {
isVar := false
for i := range v {
if v[i] == '$' {
isVar = true
} else if v[i] == ' ' {
isVar = false
} else if isVar && v[i] >= 'a' && v[i] <= 'z' {
v[i] = 'A' + (v[i] - 'a')
}
}
vars[k] = string(v)
}
return vars
}
func (c *config) getAliasMap() map[string][]string { func (c *config) getAliasMap() map[string][]string {
m := make(map[string][]string, len(c.Tables)) m := make(map[string][]string, len(c.Tables))
@ -198,3 +183,21 @@ func (c *config) getAliasMap() map[string][]string {
} }
return m return m
} }
var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`)
func sanitize(s string) string {
s0 := varRe1.ReplaceAllString(s, `{{$1}}`)
s1 := strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return ' '
}
return r
}, s0)
return varRe2.ReplaceAllStringFunc(s1, func(m string) string {
return strings.ToLower(m)
})
}

View File

@ -13,8 +13,8 @@ import (
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
) )
@ -32,15 +32,13 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
c.req.ref = req.Referer() c.req.ref = req.Referer()
c.req.hdr = req.Header c.req.hdr = req.Header
var role string
if authCheck(c) { if authCheck(c) {
role = "user" c.req.role = "user"
} else { } else {
role = "anon" c.req.role = "anon"
} }
b, err := c.execQuery(role) b, err := c.execQuery()
if err != nil { if err != nil {
return err return err
} }
@ -48,18 +46,18 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
return c.render(w, b) return c.render(w, b)
} }
func (c *coreContext) execQuery(role string) ([]byte, error) { func (c *coreContext) execQuery() ([]byte, error) {
var err error var err error
var skipped uint32 var skipped uint32
var qc *qcode.QCode var qc *qcode.QCode
var data []byte var data []byte
logger.Debug().Str("role", role).Msg(c.req.Query) logger.Debug().Str("role", c.req.role).Msg(c.req.Query)
if conf.UseAllowList { if conf.UseAllowList {
var ps *preparedItem var ps *preparedItem
data, ps, err = c.resolvePreparedSQL(c.req.Query) data, ps, err = c.resolvePreparedSQL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,12 +67,7 @@ func (c *coreContext) execQuery(role string) ([]byte, error) {
} else { } else {
qc, err = qcompile.Compile([]byte(c.req.Query), role) data, skipped, err = c.resolveSQL()
if err != nil {
return nil, err
}
data, skipped, err = c.resolveSQL(qc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,6 +115,152 @@ func (c *coreContext) execQuery(role string) ([]byte, error) {
return ob.Bytes(), nil return ob.Bytes(), nil
} }
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
var role string
useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query)
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
role = v.(string)
} else {
role = c.req.role
}
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
if !ok {
return nil, nil, errUnauthorized
}
var root []byte
vars := varList(c, ps.args)
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
return root, ps, nil
}
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
return nil, 0, err
}
} else if v := c.Value(userRoleKey); v != nil {
c.req.role = v.(string)
}
stmts, err := c.buildStmt()
if err != nil {
return nil, 0, err
}
var st *stmt
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := buf.String()
var stime time.Time
if conf.EnableTracing {
stime = time.Now()
}
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
var root []byte
if mutation {
err = tx.QueryRow(c, finalSQL).Scan(&root)
} else {
err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root)
}
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
if conf.EnableTracing && len(st.qc.Selects) != 0 {
c.addTrace(
st.qc.Selects,
st.qc.Selects[0].ID,
stime)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, st.skipped, nil
}
func (c *coreContext) resolveRemote( func (c *coreContext) resolveRemote(
hdr http.Header, hdr http.Header,
h *xxhash.Digest, h *xxhash.Digest,
@ -269,125 +408,15 @@ func (c *coreContext) resolveRemotes(
return to, cerr return to, cerr
} }
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) { func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)] var role string
if !ok { row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
return nil, nil, errUnauthorized
if err := row.Scan(&role); err != nil {
return "", err
} }
var root []byte return role, nil
vars := varList(c, ps.args)
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
fmt.Printf("PRE: %v\n", ps.stmt)
return root, ps, nil
}
func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) {
var vars map[string]json.RawMessage
stmt := &bytes.Buffer{}
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
if err != nil {
return nil, 0, err
}
t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset()
_, err = t.ExecuteFunc(stmt, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := stmt.String()
// if conf.LogLevel == "debug" {
// os.Stdout.WriteString(finalSQL)
// os.Stdout.WriteString("\n\n")
// }
var st time.Time
if conf.EnableTracing {
st = time.Now()
}
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
//fmt.Printf("\nRAW: %#v\n", finalSQL)
var root []byte
err = tx.QueryRow(c, finalSQL).Scan(&root)
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace(
qc.Selects,
qc.Selects[0].ID,
st)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, skipped, nil
} }
func (c *coreContext) render(w io.Writer, data []byte) error { func (c *coreContext) render(w io.Writer, data []byte) error {

144
serv/core_build.go Normal file
View File

@ -0,0 +1,144 @@
package serv
import (
"bytes"
"encoding/json"
"errors"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
)
type stmt struct {
role *configRole
qc *qcode.QCode
skipped uint32
sql string
}
func (c *coreContext) buildStmt() ([]stmt, error) {
var vars map[string]json.RawMessage
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, err
}
}
gql := []byte(c.req.Query)
if len(conf.Roles) == 0 {
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
}
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
if err != nil {
return nil, err
}
stmts := make([]stmt, 0, len(conf.Roles))
mutation := (qc.Type != qcode.QTQuery)
w := &bytes.Buffer{}
for i := range conf.Roles {
role := &conf.Roles[i]
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
continue
}
if i > 0 {
qc, err = qcompile.Compile(gql, role.Name)
if err != nil {
return nil, err
}
}
stmts = append(stmts, stmt{role: role, qc: qc})
if mutation {
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
s := &stmts[len(stmts)-1]
s.skipped = skipped
s.sql = w.String()
w.Reset()
}
}
if mutation {
return stmts, nil
}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
for _, s := range stmts {
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
io.WriteString(w, `) `)
}
io.WriteString(w, `END) FROM (`)
if len(conf.RolesQuery) == 0 {
v := c.Value(userRoleKey)
io.WriteString(w, `VALUES ("`)
if v != nil {
io.WriteString(w, v.(string))
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
} else {
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, s := range stmts {
if len(s.role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
}
if len(c.req.role) == 0 {
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
} else {
io.WriteString(w, ` ELSE '`)
io.WriteString(w, c.req.role)
io.WriteString(w, `' END) FROM (`)
}
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
if len(c.req.role) == 0 {
io.WriteString(w, `anon`)
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
}
stmts[0].sql = w.String()
stmts[0].role = nil
return stmts, nil
}

View File

@ -30,6 +30,7 @@ type gqlReq struct {
Query string `json:"query"` Query string `json:"query"`
Vars json.RawMessage `json:"variables"` Vars json.RawMessage `json:"variables"`
ref string ref string
role string
hdr http.Header hdr http.Header
} }
@ -101,13 +102,11 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
err = ctx.handleReq(w, r) err = ctx.handleReq(w, r)
if err == errUnauthorized { if err == errUnauthorized {
err := "Not authorized" http.Error(w, "Not authorized", 401)
logger.Debug().Msg(err)
http.Error(w, err, 401)
} }
if err != nil { if err != nil {
logger.Err(err).Msg("Failed to handle request") logger.Err(err).Msg("failed to handle request")
errorResp(w, err) errorResp(w, err)
} }
} }

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
@ -27,55 +26,100 @@ var (
func initPreparedList() { func initPreparedList() {
_preparedList = make(map[string]*preparedItem) _preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list { if err := prepareRoleStmt(); err != nil {
err := prepareStmt(k, v.gql, v.vars) logger.Fatal().Err(err).Msg("failed to prepare get role statement")
}
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
if err != nil { if err != nil {
logger.Warn().Str("gql", v.gql).Err(err).Send() logger.Warn().Str("gql", v.gql).Err(err).Send()
} }
} }
} }
func prepareStmt(key, gql string, varBytes json.RawMessage) error { func prepareStmt(gql string, varBytes json.RawMessage) error {
if len(gql) == 0 || len(key) == 0 { if len(gql) == 0 {
return nil return nil
} }
qc, err := qcompile.Compile([]byte(gql), "user") c := &coreContext{Context: context.Background()}
c.req.Query = gql
c.req.Vars = varBytes
stmts, err := c.buildStmt()
if err != nil { if err != nil {
return err return err
} }
var vars map[string]json.RawMessage for _, s := range stmts {
if len(s.sql) == 0 {
continue
}
if len(varBytes) != 0 { finalSQL, am := processTemplate(s.sql)
vars = make(map[string]json.RawMessage)
if err := json.Unmarshal(varBytes, &vars); err != nil { ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
if err != nil {
return err
}
var key string
if s.role == nil {
key = gqlHash(gql, varBytes, "")
} else {
key = gqlHash(gql, varBytes, s.role.Name)
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: s.skipped,
qc: s.qc,
}
if err := tx.Commit(ctx); err != nil {
return err return err
} }
} }
buf := &bytes.Buffer{} return nil
}
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars)) func prepareRoleStmt() error {
if err != nil { if len(conf.RolesQuery) == 0 {
return err return nil
} }
t := fasttemplate.New(buf.String(), `{{`, `}}`) w := &bytes.Buffer{}
am := make([][]byte, 0, 5)
i := 0
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { io.WriteString(w, `SELECT (CASE`)
am = append(am, []byte(tag)) for _, role := range conf.Roles {
i++ if len(role.Match) == 0 {
return w.Write([]byte(fmt.Sprintf("$%d", i))) continue
}) }
io.WriteString(w, ` WHEN `)
if err != nil { io.WriteString(w, role.Match)
return err io.WriteString(w, ` THEN '`)
io.WriteString(w, role.Name)
io.WriteString(w, `'`)
} }
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
roleSQL, _ := processTemplate(w.String())
ctx := context.Background() ctx := context.Background()
tx, err := db.Begin(ctx) tx, err := db.Begin(ctx)
@ -84,21 +128,28 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
} }
defer tx.Rollback(ctx) defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL) _, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
if err != nil { if err != nil {
return err return err
} }
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: skipped,
qc: qc,
}
if err := tx.Commit(ctx); err != nil {
return err
}
return nil return nil
} }
func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
vmap := make(map[string]int)
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
am = append(am, []byte(tag))
i++
vmap[tag] = i
return w.Write([]byte(fmt.Sprintf("$%d", i)))
}), am
}

View File

@ -67,7 +67,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
pc := psql.NewCompiler(psql.Config{ pc := psql.NewCompiler(psql.Config{
Schema: schema, Schema: schema,
Vars: c.getVariables(), Vars: c.DB.Vars,
}) })
return qc, pc, nil return qc, pc, nil

View File

@ -21,7 +21,7 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v return v
} }
func gqlHash(b string, vars []byte) string { func gqlHash(b string, vars []byte, role string) string {
b = strings.TrimSpace(b) b = strings.TrimSpace(b)
h := sha1.New() h := sha1.New()
@ -56,6 +56,10 @@ func gqlHash(b string, vars []byte) string {
} }
} }
if len(role) != 0 {
io.WriteString(h, role)
}
if vars == nil || len(vars) == 0 { if vars == nil || len(vars) == 0 {
return hex.EncodeToString(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }
@ -80,3 +84,26 @@ func ws(b byte) bool {
func al(b byte) bool { func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
} }
func isMutation(sql string) bool {
for i := range sql {
b := sql[i]
if b == '{' {
return false
}
if al(b) {
return (b == 'm' || b == 'M')
}
}
return false
}
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {
continue
}
return &stmts[i]
}
return nil
}

View File

@ -11,17 +11,27 @@ import (
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
case "user_id_provider": case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil { if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string)) return stringVar(w, v.(string))
} }
return 0, errNoUserID io.WriteString(w, "null")
return 0, nil
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
} }
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})

View File

@ -47,10 +47,6 @@ auth:
type: rails type: rails
cookie: _app_session cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
# various cookies formats. # various cookies formats.
@ -106,42 +102,93 @@ database:
- token - token
tables: tables:
- name: users - name: users
# This filter will overwrite defaults.filter # This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"] # filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
- name: products # - name: products
# Multiple filters are AND'd together # # Multiple filters are AND'd together
filter: [ # filter: [
"{ price: { gt: 0 } }", # "{ price: { gt: 0 } }",
"{ price: { lt: 8 } }" # "{ price: { lt: 8 } }"
] # ]
- name: customers - name: customers
# No filter is used for this field not remotes:
# even defaults.filter - name: payments
filter: none id: stripe_id
url: http://rails_app:3000/stripe/$id
path: data
# debug: true
pass_headers:
- cookie
set_headers:
- name: Host
value: 0.0.0.0
# - name: Authorization
# value: Bearer <stripe_api_key>
# remotes: - # You can create new fields that have a
# - name: payments # real db table backing them
# id: stripe_id name: me
# url: http://rails_app:3000/stripe/$id table: users
# path: data
# # pass_headers:
# # - cookie
# # - host
# set_headers:
# - name: Authorization
# value: Bearer <stripe_api_key>
- # You can create new fields that have a roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
# real db table backing them
name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts roles:
# filter: ["{ account_id: { _eq: $account_id } }"] - name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filter: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filter: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filter: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filter: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# select:
# filter: ["{ account_id: { _eq: $account_id } }"]