Add support for prepared statements

This commit is contained in:
Vikram Rangnekar 2019-07-29 01:13:33 -04:00
parent 4c07ad1102
commit 2d8fc2b7e2
19 changed files with 493 additions and 63 deletions

View File

@ -9,7 +9,7 @@ RUN yarn build
FROM golang:1.12-alpine as go-build
RUN apk update && \
apk add --no-cache git && \
apk add --no-cache upx=3.95-r1
apk add --no-cache upx=3.95-r2
RUN go get -u github.com/dosco/esc && \
go get -u github.com/pilu/fresh

38
config/allow.list Normal file
View File

@ -0,0 +1,38 @@
# http://localhost:808
query {
me {
id
full_name
}
}
query {
customers {
id
email
payments {
customer_id
amount
billing_details
}
}
}
# http://localhost:8080/
query {
products(id: $PRODUCT_ID) {
name
}
}
query {
products(id: $PRODUCT_ID) {
name
image
}
}

View File

@ -5,13 +5,21 @@ web_ui: true
# debug, info, warn, error, fatal, panic
log_level: "debug"
# enabling tracing also disables query caching
enable_tracing: true
# Disable this in development to get a list of
# queries used. When enabled super graph
# will only allow queries from this list
# List saved to ./config/allow.list
use_allow_list: false
# Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never
auth_fail_block: never
# Latency tracing for database queries and remote joins
# the resulting latency information is returned with the
# response
enable_tracing: true
# Postgres related environment Variables
# SG_DATABASE_HOST
# SG_DATABASE_PORT

View File

@ -5,13 +5,21 @@ web_ui: false
# debug, info, warn, error, fatal, panic, disable
log_level: "info"
# disabled tracing enables query caching
enable_tracing: false
# Disable this in development to get a list of
# queries used. When enabled super graph
# will only allow queries from this list
# List saved to ./config/allow.list
use_allow_list: true
# Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never
auth_fail_block: always
# Latency tracing for database queries and remote joins
# the resulting latency information is returned with the
# response
enable_tracing: true
# Postgres related environment Variables
# SG_DATABASE_HOST
# SG_DATABASE_PORT

View File

@ -5,7 +5,7 @@ module.exports = {
themeConfig: {
logo: '/logo.svg',
nav: [
{ text: 'Guide', link: '/guide' },
{ text: 'Docs', link: '/guide' },
{ text: 'Install', link: '/install' },
{ text: 'Github', link: 'https://github.com/dosco/super-graph' },
{ text: 'Docker', link: 'https://hub.docker.com/r/dosco/super-graph/builds' },

View File

@ -9,7 +9,7 @@ features:
- title: Simple
details: Easy config file, quick to deploy, No code needed. It just works.
- title: High Performance
details: Converts your GraphQL query into a fast SQL one.
details: Compiles your GraphQL into a fast SQL query in realtime.
- title: Written in GO
details: Go is a language created at Google to build secure and fast web services.
footer: MIT Licensed | Copyright © 2018-present Vikram Rangnekar

View File

@ -7,14 +7,15 @@ sidebar: auto
Without writing a line of code get an instant high-performance GraphQL API for your Ruby-on-Rails app. Super Graph will automatically understand your apps database and expose a secure, fast and complete GraphQL API for it. Built in support for Rails authentication and JWT tokens.
## Features
- Automatically learns Postgres schemas and relationships
- Supports Belongs-To, One-To-Many and Many-To-Many table relationships
- Works with Rails database schemas
- Automatically learns schemas and relationships
- Belongs-To, One-To-Many and Many-To-Many table relationships
- Full text search and Aggregations
- Full text search and aggregations
- Rails Auth supported (Redis, Memcache, Cookie)
- JWT tokens supported (Auth0, etc)
- Join with remote REST APIs
- Highly optimized and fast Postgres SQL queries
- Join database queries with remote data sources (APIs like Stripe, Twitter, etc)
- Generates highly optimized and fast Postgres SQL queries
- Uses prepared statements for very fast Postgres queries
- Configure with a simple config file
- High performance GO codebase
- Tiny docker image and low memory requirements
@ -451,8 +452,6 @@ auth:
```
#### Memcache session store
```yaml
@ -514,12 +513,23 @@ host_port: 0.0.0.0:8080
web_ui: true
debug_level: 1
# enabling tracing also disables query caching
enable_tracing: true
# debug, info, warn, error, fatal, panic, disable
log_level: "info"
# Disable this in development to get a list of
# queries used. When enabled super graph
# will only allow queries from this list
# List saved to ./config/allow.list
use_allow_list: true
# Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never
auth_fail_block: never
auth_fail_block: always
# Latency tracing for database queries and remote joins
# the resulting latency information is returned with the
# response
enable_tracing: true
# Postgres related environment Variables
# SG_DATABASE_HOST
@ -674,7 +684,7 @@ brew install yarn
go generate ./...
# do this the only the time to setup the database
docker-compose run rails_app rake db:create db:migrate
docker-compose run rails_app rake db:create db:migrate db:seed
# start super graph in development mode with a change watcher
docker-compose up

View File

@ -2,7 +2,7 @@ root: .
tmp_path: ./tmp
build_name: runner-build
build_log: runner-build-errors.log
valid_ext: .go, .tpl, .tmpl, .html, .yml
valid_ext: .go, .tpl, .tmpl, .html, .yml, *.list
no_rebuild_ext: .tpl, .tmpl, .html
ignored: web, tmp, vendor, rails-app, docs
build_delay: 600

View File

@ -434,7 +434,7 @@ func queryWithVariables(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{PRODUCT_PRICE}}')) AND (("id") = ('{{PRODUCT_ID}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {

View File

@ -248,6 +248,8 @@ func lexRoot(l *lexer) stateFn {
case r == '$':
l.ignore()
if l.acceptAlphaNum() {
s, e := l.current()
lowercase(l.input, s, e)
l.emit(itemVariable)
}
case contains(l.input, l.start, l.pos, punctuatorToken):

132
serv/allow.go Normal file
View File

@ -0,0 +1,132 @@
package serv
import (
"fmt"
"io/ioutil"
"log"
"os"
"sort"
"strings"
)
type allowItem struct {
uri string
gql string
}
var _allowList allowList
type allowList struct {
list map[string]*allowItem
saveChan chan *allowItem
}
func initAllowList() {
_allowList = allowList{
list: make(map[string]*allowItem),
saveChan: make(chan *allowItem),
}
_allowList.load()
go func() {
for v := range _allowList.saveChan {
_allowList.save(v)
}
}()
}
func (al *allowList) add(req *gqlReq) {
if len(req.ref) == 0 || len(req.Query) == 0 {
return
}
al.saveChan <- &allowItem{
uri: req.ref,
gql: req.Query,
}
}
func (al *allowList) load() {
filename := "./config/allow.list"
if _, err := os.Stat(filename); os.IsNotExist(err) {
return
}
b, err := ioutil.ReadFile(filename)
if err != nil {
log.Fatal(err)
}
if len(b) == 0 {
return
}
var uri string
s, e, c := 0, 0, 0
for {
if c == 0 && b[e] == '#' {
s = e
for ; b[e] != '\n' && e < len(b); e++ {
if (e - s) > 2 {
uri = strings.TrimSpace(string(b[s+1 : e]))
}
}
}
if b[e] == '{' {
if c == 0 {
s = e
}
c++
} else if b[e] == '}' {
c--
if c == 0 {
q := b[s:(e + 1)]
al.list[relaxHash(q)] = &allowItem{
uri: uri,
gql: string(q),
}
}
}
e++
if e >= len(b) {
break
}
}
}
func (al *allowList) save(item *allowItem) {
al.list[relaxHash([]byte(item.gql))] = item
f, err := os.Create("./config/allow.list")
if err != nil {
panic(err)
}
defer f.Close()
keys := []string{}
urlMap := make(map[string][]string)
for _, v := range al.list {
urlMap[v.uri] = append(urlMap[v.uri], v.gql)
}
for k := range urlMap {
keys = append(keys, k)
}
sort.Strings(keys)
for i := range keys {
k := keys[i]
v := urlMap[k]
f.WriteString(fmt.Sprintf("# %s\n\n", k))
for i := range v {
f.WriteString(fmt.Sprintf("query %s\n\n", v[i]))
}
}
}

View File

@ -13,6 +13,7 @@ type config struct {
WebUI bool `mapstructure:"web_ui"`
LogLevel string `mapstructure:"log_level"`
EnableTracing bool `mapstructure:"enable_tracing"`
UseAllowList bool `mapstructure:"use_allow_list"`
AuthFailBlock string `mapstructure:"auth_fail_block"`
Inflections map[string]string
@ -53,7 +54,7 @@ type config struct {
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
Variables map[string]string
vars map[string][]byte `mapstructure:"variables"`
Defaults struct {
Filter []string
@ -86,6 +87,26 @@ type configRemote struct {
} `mapstructure:"set_headers"`
}
func (c *config) getVariables() map[string]string {
vars := make(map[string]string, len(c.DB.vars))
for k, v := range c.DB.vars {
isVar := false
for i := range v {
if v[i] == '$' {
isVar = true
} else if v[i] == ' ' {
isVar = false
} else if isVar && v[i] >= 'a' && v[i] <= 'z' {
v[i] = 'A' + (v[i] - 'a')
}
}
vars[k] = string(v)
}
return vars
}
func (c *config) getAliasMap() map[string][]string {
m := make(map[string][]string, len(c.DB.Tables))

View File

@ -23,10 +23,6 @@ const (
empty = ""
)
// var (
// cache, _ = bigcache.NewBigCache(bigcache.DefaultConfig(24 * time.Hour))
// )
type coreContext struct {
req gqlReq
res gqlResp
@ -35,17 +31,36 @@ type coreContext struct {
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
var err error
var skipped uint32
var qc *qcode.QCode
var data []byte
qc, err := qcompile.CompileQuery([]byte(c.req.Query))
if err != nil {
return err
}
c.req.ref = req.Referer()
vars := varMap(c)
//conf.UseAllowList = true
data, skipped, err := c.resolveSQL(qc, vars)
if err != nil {
return err
if conf.UseAllowList {
var ps *preparedItem
data, ps, err = c.resolvePreparedSQL([]byte(c.req.Query))
if err != nil {
return err
}
skipped = ps.skipped
qc = ps.qc
} else {
qc, err = qcompile.CompileQuery([]byte(c.req.Query))
if err != nil {
return err
}
data, skipped, err = c.resolveSQL(qc)
if err != nil {
return err
}
}
if len(data) == 0 || skipped == 0 {
@ -237,28 +252,28 @@ func (c *coreContext) resolveRemotes(
return to, cerr
}
func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[relaxHash(gql)]
if !ok {
return nil, nil, errUnauthorized
}
var root json.RawMessage
vars := varList(c, ps.args)
_, err := ps.stmt.QueryOne(pg.Scan(&root), vars...)
if err != nil {
return nil, nil, err
}
fmt.Printf("PRE: %#v %#v\n", ps.stmt, vars)
return []byte(root), ps, nil
}
func (c *coreContext) resolveSQL(qc *qcode.QCode) (
[]byte, uint32, error) {
// var entry []byte
// var key string
// cacheEnabled := (conf.EnableTracing == false)
// if cacheEnabled {
// k := sha1.Sum([]byte(req.Query))
// key = string(k[:])
// entry, err = cache.Get(key)
// if err != nil && err != bigcache.ErrEntryNotFound {
// return emtpy, err
// }
// if len(entry) != 0 && err == nil {
// return entry, nil
// }
// }
stmt := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, stmt)
@ -269,7 +284,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset()
_, err = t.Execute(stmt, vars)
_, err = t.Execute(stmt, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
@ -287,20 +302,16 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
os.Stdout.WriteString(finalSQL)
}
// if cacheEnabled {
// if err = cache.Set(key, finalSQL); err != nil {
// return err
// }
// }
var st time.Time
if conf.EnableTracing {
st = time.Now()
}
fmt.Printf("RAW: %#v\n", finalSQL)
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
_, err = db.QueryOne(pg.Scan(&root), finalSQL)
if err != nil {
return nil, 0, err
@ -313,6 +324,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
st)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return []byte(root), skipped, nil
}

View File

@ -29,6 +29,7 @@ type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Vars variables `json:"variables"`
ref string
}
type variables map[string]interface{}

79
serv/prepare.go Normal file
View File

@ -0,0 +1,79 @@
package serv
import (
"bytes"
"fmt"
"io"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
)
type preparedItem struct {
stmt *pg.Stmt
args []string
skipped uint32
qc *qcode.QCode
}
var (
_preparedList map[string]*preparedItem
)
func initPreparedList() {
_preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list {
err := prepareStmt(k, v.gql)
if err != nil {
panic(err)
}
}
}
func prepareStmt(key, gql string) error {
if len(gql) == 0 || len(key) == 0 {
return nil
}
qc, err := qcompile.CompileQuery([]byte(gql))
if err != nil {
return err
}
buf := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, buf)
if err != nil {
return err
}
t := fasttemplate.New(buf.String(), `('{{`, `}}')`)
am := make([]string, 0, 5)
i := 0
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
am = append(am, tag)
i++
return w.Write([]byte(fmt.Sprintf("$%d", i)))
})
if err != nil {
return err
}
pstmt, err := db.Prepare(finalSQL)
if err != nil {
return err
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: skipped,
qc: qc,
}
return nil
}

View File

@ -170,7 +170,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
pc := psql.NewCompiler(psql.Config{
Schema: schema,
Vars: c.DB.Variables,
Vars: c.getVariables(),
})
return qc, pc, nil
@ -206,6 +206,9 @@ func Init() {
logger.Fatal().Err(err).Msg("failed to initialized resolvers")
}
initAllowList()
initPreparedList()
startHTTP()
}

View File

@ -1,6 +1,12 @@
package serv
import "github.com/cespare/xxhash/v2"
import (
"bytes"
"crypto/sha1"
"encoding/hex"
"github.com/cespare/xxhash/v2"
)
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
h.WriteString(k1)
@ -10,3 +16,33 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v
}
func relaxHash(b []byte) string {
h := sha1.New()
s, e := 0, 0
for {
if e == (len(b) - 1) {
if s != 0 {
e++
h.Write(bytes.ToLower(b[s:e]))
}
break
} else if ws(b[e]) == false && ws(b[(e+1)]) {
e++
h.Write(bytes.ToLower(b[s:e]))
s = 0
} else if ws(b[e]) && ws(b[(e+1)]) == false {
e++
s = e
} else {
e++
}
}
return hex.EncodeToString(h.Sum(nil))
}
func ws(b byte) bool {
return b == ' ' || b == '\n' || b == '\t'
}

34
serv/utils_test.go Normal file
View File

@ -0,0 +1,34 @@
package serv
import (
"strings"
"testing"
)
func TestRelaxHash(t *testing.T) {
var v1 = []byte(`
products(
limit: 30,
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
}`)
var v2 = []byte(`
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `)
h1 := relaxHash(v1)
h2 := relaxHash(v2)
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}

View File

@ -3,6 +3,7 @@ package serv
import (
"io"
"strconv"
"strings"
"github.com/valyala/fasttemplate"
)
@ -34,9 +35,12 @@ func varMap(ctx *coreContext) variables {
for k, v := range ctx.req.Vars {
var buf []byte
k = strings.ToLower(k)
if _, ok := vm[k]; ok {
continue
}
switch val := v.(type) {
case string:
vm[k] = val
@ -50,3 +54,42 @@ func varMap(ctx *coreContext) variables {
}
return vm
}
func varList(ctx *coreContext, args []string) []interface{} {
vars := make([]interface{}, 0, len(args))
for k, v := range ctx.req.Vars {
ctx.req.Vars[strings.ToLower(k)] = v
}
for i := range args {
arg := strings.ToLower(args[i])
if arg == "user_id" {
if v := ctx.Value(userIDKey); v != nil {
vars = append(vars, v.(string))
}
}
if arg == "user_id_provider" {
if v := ctx.Value(userIDProviderKey); v != nil {
vars = append(vars, v.(string))
}
}
if v, ok := ctx.req.Vars[arg]; ok {
switch val := v.(type) {
case string:
vars = append(vars, val)
case int:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case int64:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case float64:
vars = append(vars, strconv.FormatFloat(val, 'f', -1, 64))
}
}
}
return vars
}