Get RBAC working for queries and mutations

This commit is contained in:
Vikram Rangnekar
2019-10-24 02:07:42 -04:00
parent c797deb4d0
commit 6bc66d28bc
19 changed files with 902 additions and 568 deletions

View File

@ -182,7 +182,7 @@ func (al *allowList) load() {
item.vars = varBytes
}
al.list[gqlHash(q, varBytes)] = item
al.list[gqlHash(q, varBytes, "")] = item
varBytes = nil
} else if ty == AL_VARS {
@ -203,7 +203,11 @@ func (al *allowList) save(item *allowItem) {
if al.active == false {
return
}
al.list[gqlHash(item.gql, item.vars)] = item
h := gqlHash(item.gql, item.vars, "")
if _, ok := al.list[h]; ok {
return
}
al.list[gqlHash(item.gql, item.vars, "")] = item
f, err := os.Create(al.filepath)
if err != nil {

View File

@ -9,26 +9,40 @@ import (
var (
userIDProviderKey = struct{}{}
userIDKey = struct{}{}
userRoleKey = struct{}{}
)
func headerAuth(r *http.Request, c *config) *http.Request {
if len(c.Auth.Header) == 0 {
return nil
}
func headerAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userID := r.Header.Get(c.Auth.Header)
if len(userID) != 0 {
ctx := context.WithValue(r.Context(), userIDKey, userID)
return r.WithContext(ctx)
}
userIDProvider := r.Header.Get("X-User-ID-Provider")
if len(userIDProvider) != 0 {
ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider)
}
return nil
userID := r.Header.Get("X-User-ID")
if len(userID) != 0 {
ctx = context.WithValue(ctx, userIDKey, userID)
}
userRole := r.Header.Get("X-User-Role")
if len(userRole) != 0 {
ctx = context.WithValue(ctx, userRoleKey, userRole)
}
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func withAuth(next http.HandlerFunc) http.HandlerFunc {
at := conf.Auth.Type
ru := conf.Auth.Rails.URL
if conf.Auth.CredsInHeader {
next = headerAuth(next)
}
switch at {
case "rails":
if strings.HasPrefix(ru, "memcache:") {

View File

@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var tok string
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
if len(cookie) != 0 {
ck, err := r.Cookie(cookie)
if err != nil {
@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
}
next.ServeHTTP(w, r.WithContext(ctx))
}
next.ServeHTTP(w, r)
}
}

View File

@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
}
return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
rURL, err := url.Parse(conf.Auth.Rails.URL)
if err != nil {
logger.Fatal().Err(err)
logger.Fatal().Err(err).Send()
}
mc := memcache.New(rURL.Host)
return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
ra, err := railsAuth(conf)
if err != nil {
logger.Fatal().Err(err)
logger.Fatal().Err(err).Send()
}
return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie)
if err != nil {
logger.Error().Err(err)
logger.Warn().Err(err).Send()
next.ServeHTTP(w, r)
return
}
userID, err := ra.ParseCookie(ck.Value)
if err != nil {
logger.Error().Err(err)
logger.Warn().Err(err).Send()
next.ServeHTTP(w, r)
return
}

View File

@ -183,7 +183,32 @@ func initConf() (*config, error) {
}
zerolog.SetGlobalLevel(logLevel)
//fmt.Printf("%#v", c)
for k, v := range c.DB.Vars {
c.DB.Vars[k] = sanitize(v)
}
c.RolesQuery = sanitize(c.RolesQuery)
rolesMap := make(map[string]struct{})
for i := range c.Roles {
role := &c.Roles[i]
if _, ok := rolesMap[role.Name]; ok {
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
}
role.Name = sanitize(role.Name)
role.Match = sanitize(role.Match)
rolesMap[role.Name] = struct{}{}
}
if _, ok := rolesMap["user"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "user"})
}
if _, ok := rolesMap["anon"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "anon"})
}
return c, nil
}

View File

@ -66,8 +66,9 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} {
c := &coreContext{Context: context.Background()}
c.req.Query = query
c.req.Vars = b
c.req.role = "user"
res, err := c.execQuery("user")
res, err := c.execQuery()
if err != nil {
logger.Fatal().Err(err).Msg("graphql query failed")
}

View File

@ -1,7 +1,9 @@
package serv
import (
"regexp"
"strings"
"unicode"
"github.com/spf13/viper"
)
@ -24,9 +26,9 @@ type config struct {
Inflections map[string]string
Auth struct {
Type string
Cookie string
Header string
Type string
Cookie string
CredsInHeader bool `mapstructure:"creds_in_header"`
Rails struct {
Version string
@ -60,7 +62,7 @@ type config struct {
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
vars map[string][]byte `mapstructure:"variables"`
Vars map[string]string `mapstructure:"variables"`
Defaults struct {
Filter []string
@ -71,7 +73,9 @@ type config struct {
} `mapstructure:"database"`
Tables []configTable
Roles []configRoles
RolesQuery string `mapstructure:"roles_query"`
Roles []configRole
}
type configTable struct {
@ -94,8 +98,9 @@ type configRemote struct {
} `mapstructure:"set_headers"`
}
type configRoles struct {
type configRole struct {
Name string
Match string
Tables []struct {
Name string
@ -163,26 +168,6 @@ func newConfig() *viper.Viper {
return vi
}
func (c *config) getVariables() map[string]string {
vars := make(map[string]string, len(c.DB.vars))
for k, v := range c.DB.vars {
isVar := false
for i := range v {
if v[i] == '$' {
isVar = true
} else if v[i] == ' ' {
isVar = false
} else if isVar && v[i] >= 'a' && v[i] <= 'z' {
v[i] = 'A' + (v[i] - 'a')
}
}
vars[k] = string(v)
}
return vars
}
func (c *config) getAliasMap() map[string][]string {
m := make(map[string][]string, len(c.Tables))
@ -198,3 +183,21 @@ func (c *config) getAliasMap() map[string][]string {
}
return m
}
var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`)
func sanitize(s string) string {
s0 := varRe1.ReplaceAllString(s, `{{$1}}`)
s1 := strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return ' '
}
return r
}, s0)
return varRe2.ReplaceAllStringFunc(s1, func(m string) string {
return strings.ToLower(m)
})
}

View File

@ -13,8 +13,8 @@ import (
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate"
)
@ -32,15 +32,13 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
c.req.ref = req.Referer()
c.req.hdr = req.Header
var role string
if authCheck(c) {
role = "user"
c.req.role = "user"
} else {
role = "anon"
c.req.role = "anon"
}
b, err := c.execQuery(role)
b, err := c.execQuery()
if err != nil {
return err
}
@ -48,18 +46,18 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
return c.render(w, b)
}
func (c *coreContext) execQuery(role string) ([]byte, error) {
func (c *coreContext) execQuery() ([]byte, error) {
var err error
var skipped uint32
var qc *qcode.QCode
var data []byte
logger.Debug().Str("role", role).Msg(c.req.Query)
logger.Debug().Str("role", c.req.role).Msg(c.req.Query)
if conf.UseAllowList {
var ps *preparedItem
data, ps, err = c.resolvePreparedSQL(c.req.Query)
data, ps, err = c.resolvePreparedSQL()
if err != nil {
return nil, err
}
@ -69,12 +67,7 @@ func (c *coreContext) execQuery(role string) ([]byte, error) {
} else {
qc, err = qcompile.Compile([]byte(c.req.Query), role)
if err != nil {
return nil, err
}
data, skipped, err = c.resolveSQL(qc)
data, skipped, err = c.resolveSQL()
if err != nil {
return nil, err
}
@ -122,6 +115,152 @@ func (c *coreContext) execQuery(role string) ([]byte, error) {
return ob.Bytes(), nil
}
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
var role string
useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query)
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
role = v.(string)
} else {
role = c.req.role
}
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
if !ok {
return nil, nil, errUnauthorized
}
var root []byte
vars := varList(c, ps.args)
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
return root, ps, nil
}
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
return nil, 0, err
}
} else if v := c.Value(userRoleKey); v != nil {
c.req.role = v.(string)
}
stmts, err := c.buildStmt()
if err != nil {
return nil, 0, err
}
var st *stmt
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := buf.String()
var stime time.Time
if conf.EnableTracing {
stime = time.Now()
}
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
var root []byte
if mutation {
err = tx.QueryRow(c, finalSQL).Scan(&root)
} else {
err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root)
}
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
if conf.EnableTracing && len(st.qc.Selects) != 0 {
c.addTrace(
st.qc.Selects,
st.qc.Selects[0].ID,
stime)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, st.skipped, nil
}
func (c *coreContext) resolveRemote(
hdr http.Header,
h *xxhash.Digest,
@ -269,125 +408,15 @@ func (c *coreContext) resolveRemotes(
return to, cerr
}
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
if !ok {
return nil, nil, errUnauthorized
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
var role string
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
if err := row.Scan(&role); err != nil {
return "", err
}
var root []byte
vars := varList(c, ps.args)
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
fmt.Printf("PRE: %v\n", ps.stmt)
return root, ps, nil
}
func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) {
var vars map[string]json.RawMessage
stmt := &bytes.Buffer{}
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
if err != nil {
return nil, 0, err
}
t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset()
_, err = t.ExecuteFunc(stmt, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := stmt.String()
// if conf.LogLevel == "debug" {
// os.Stdout.WriteString(finalSQL)
// os.Stdout.WriteString("\n\n")
// }
var st time.Time
if conf.EnableTracing {
st = time.Now()
}
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
//fmt.Printf("\nRAW: %#v\n", finalSQL)
var root []byte
err = tx.QueryRow(c, finalSQL).Scan(&root)
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace(
qc.Selects,
qc.Selects[0].ID,
st)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, skipped, nil
return role, nil
}
func (c *coreContext) render(w io.Writer, data []byte) error {

144
serv/core_build.go Normal file
View File

@ -0,0 +1,144 @@
package serv
import (
"bytes"
"encoding/json"
"errors"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
)
type stmt struct {
role *configRole
qc *qcode.QCode
skipped uint32
sql string
}
func (c *coreContext) buildStmt() ([]stmt, error) {
var vars map[string]json.RawMessage
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, err
}
}
gql := []byte(c.req.Query)
if len(conf.Roles) == 0 {
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
}
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
if err != nil {
return nil, err
}
stmts := make([]stmt, 0, len(conf.Roles))
mutation := (qc.Type != qcode.QTQuery)
w := &bytes.Buffer{}
for i := range conf.Roles {
role := &conf.Roles[i]
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
continue
}
if i > 0 {
qc, err = qcompile.Compile(gql, role.Name)
if err != nil {
return nil, err
}
}
stmts = append(stmts, stmt{role: role, qc: qc})
if mutation {
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
s := &stmts[len(stmts)-1]
s.skipped = skipped
s.sql = w.String()
w.Reset()
}
}
if mutation {
return stmts, nil
}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
for _, s := range stmts {
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
io.WriteString(w, `) `)
}
io.WriteString(w, `END) FROM (`)
if len(conf.RolesQuery) == 0 {
v := c.Value(userRoleKey)
io.WriteString(w, `VALUES ("`)
if v != nil {
io.WriteString(w, v.(string))
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
} else {
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, s := range stmts {
if len(s.role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
}
if len(c.req.role) == 0 {
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
} else {
io.WriteString(w, ` ELSE '`)
io.WriteString(w, c.req.role)
io.WriteString(w, `' END) FROM (`)
}
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
if len(c.req.role) == 0 {
io.WriteString(w, `anon`)
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
}
stmts[0].sql = w.String()
stmts[0].role = nil
return stmts, nil
}

View File

@ -30,6 +30,7 @@ type gqlReq struct {
Query string `json:"query"`
Vars json.RawMessage `json:"variables"`
ref string
role string
hdr http.Header
}
@ -101,13 +102,11 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
err = ctx.handleReq(w, r)
if err == errUnauthorized {
err := "Not authorized"
logger.Debug().Msg(err)
http.Error(w, err, 401)
http.Error(w, "Not authorized", 401)
}
if err != nil {
logger.Err(err).Msg("Failed to handle request")
logger.Err(err).Msg("failed to handle request")
errorResp(w, err)
}
}

View File

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn"
"github.com/valyala/fasttemplate"
@ -27,55 +26,100 @@ var (
func initPreparedList() {
_preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list {
err := prepareStmt(k, v.gql, v.vars)
if err := prepareRoleStmt(); err != nil {
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
}
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
if err != nil {
logger.Warn().Str("gql", v.gql).Err(err).Send()
}
}
}
func prepareStmt(key, gql string, varBytes json.RawMessage) error {
if len(gql) == 0 || len(key) == 0 {
func prepareStmt(gql string, varBytes json.RawMessage) error {
if len(gql) == 0 {
return nil
}
qc, err := qcompile.Compile([]byte(gql), "user")
c := &coreContext{Context: context.Background()}
c.req.Query = gql
c.req.Vars = varBytes
stmts, err := c.buildStmt()
if err != nil {
return err
}
var vars map[string]json.RawMessage
for _, s := range stmts {
if len(s.sql) == 0 {
continue
}
if len(varBytes) != 0 {
vars = make(map[string]json.RawMessage)
finalSQL, am := processTemplate(s.sql)
if err := json.Unmarshal(varBytes, &vars); err != nil {
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
if err != nil {
return err
}
var key string
if s.role == nil {
key = gqlHash(gql, varBytes, "")
} else {
key = gqlHash(gql, varBytes, s.role.Name)
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: s.skipped,
qc: s.qc,
}
if err := tx.Commit(ctx); err != nil {
return err
}
}
buf := &bytes.Buffer{}
return nil
}
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
if err != nil {
return err
func prepareRoleStmt() error {
if len(conf.RolesQuery) == 0 {
return nil
}
t := fasttemplate.New(buf.String(), `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
w := &bytes.Buffer{}
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
am = append(am, []byte(tag))
i++
return w.Write([]byte(fmt.Sprintf("$%d", i)))
})
if err != nil {
return err
io.WriteString(w, `SELECT (CASE`)
for _, role := range conf.Roles {
if len(role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, role.Name)
io.WriteString(w, `'`)
}
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
roleSQL, _ := processTemplate(w.String())
ctx := context.Background()
tx, err := db.Begin(ctx)
@ -84,21 +128,28 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
if err != nil {
return err
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: skipped,
qc: qc,
}
if err := tx.Commit(ctx); err != nil {
return err
}
return nil
}
func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
vmap := make(map[string]int)
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
am = append(am, []byte(tag))
i++
vmap[tag] = i
return w.Write([]byte(fmt.Sprintf("$%d", i)))
}), am
}

View File

@ -67,7 +67,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
pc := psql.NewCompiler(psql.Config{
Schema: schema,
Vars: c.getVariables(),
Vars: c.DB.Vars,
})
return qc, pc, nil

View File

@ -21,7 +21,7 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v
}
func gqlHash(b string, vars []byte) string {
func gqlHash(b string, vars []byte, role string) string {
b = strings.TrimSpace(b)
h := sha1.New()
@ -56,6 +56,10 @@ func gqlHash(b string, vars []byte) string {
}
}
if len(role) != 0 {
io.WriteString(h, role)
}
if vars == nil || len(vars) == 0 {
return hex.EncodeToString(h.Sum(nil))
}
@ -80,3 +84,26 @@ func ws(b byte) bool {
func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}
func isMutation(sql string) bool {
for i := range sql {
b := sql[i]
if b == '{' {
return false
}
if al(b) {
return (b == 'm' || b == 'M')
}
}
return false
}
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {
continue
}
return &stmts[i]
}
return nil
}

View File

@ -11,17 +11,27 @@ import (
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
io.WriteString(w, "null")
return 0, nil
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
}
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})