Add REST API stitching

This commit is contained in:
Vikram Rangnekar
2019-05-12 19:27:26 -04:00
parent 6c9accb628
commit f16e95ef22
40 changed files with 1127 additions and 479 deletions

114
serv/config.go Normal file
View File

@ -0,0 +1,114 @@
package serv
type config struct {
AppName string `mapstructure:"app_name"`
Env string
HostPort string `mapstructure:"host_port"`
WebUI bool `mapstructure:"web_ui"`
DebugLevel int `mapstructure:"debug_level"`
EnableTracing bool `mapstructure:"enable_tracing"`
AuthFailBlock string `mapstructure:"auth_fail_block"`
Inflections map[string]string
Auth struct {
Type string
Cookie string
Header string
Rails struct {
Version string
SecretKeyBase string `mapstructure:"secret_key_base"`
URL string
Password string
MaxIdle int `mapstructure:"max_idle"`
MaxActive int `mapstructure:"max_active"`
Salt string
SignSalt string `mapstructure:"sign_salt"`
AuthSalt string `mapstructure:"auth_salt"`
}
JWT struct {
Provider string
Secret string
PubKeyFile string `mapstructure:"public_key_file"`
PubKeyType string `mapstructure:"public_key_type"`
}
}
DB struct {
Type string
Host string
Port string
DBName string
User string
Password string
Schema string
PoolSize int `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
Variables map[string]string
Defaults struct {
Filter []string
Blacklist []string
}
Fields []configTable
Tables []configTable
} `mapstructure:"database"`
}
type configTable struct {
Name string
Filter []string
Table string
Blacklist []string
Remotes []configRemote
}
type configRemote struct {
Name string
ID string
Path string
URL string
PassHeaders []string `mapstructure:"pass_headers"`
SetHeaders []struct {
Name string
Value string
} `mapstructure:"set_headers"`
}
func (c *config) getAliasMap() map[string]string {
m := make(map[string]string, len(c.DB.Tables))
for i := range c.DB.Tables {
t := c.DB.Tables[i]
if len(t.Table) == 0 {
continue
}
m[t.Name] = t.Table
}
return m
}
func (c *config) getFilterMap() map[string][]string {
m := make(map[string][]string, len(c.DB.Tables))
for i := range c.DB.Tables {
t := c.DB.Tables[i]
if len(t.Filter) == 0 {
continue
}
if t.Filter[0] == "none" {
m[t.Name] = []string{}
} else {
m[t.Name] = t.Filter
}
}
return m
}

View File

@ -1,100 +1,282 @@
package serv
import (
"bytes"
"context"
"crypto/sha1"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/allegro/bigcache"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
)
var (
cache, _ = bigcache.NewBigCache(bigcache.DefaultConfig(24 * time.Hour))
const (
empty = ""
)
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
var key, finalSQL string
var qc *qcode.QCode
// var (
// cache, _ = bigcache.NewBigCache(bigcache.DefaultConfig(24 * time.Hour))
// )
var entry []byte
type coreContext struct {
req gqlReq
res gqlResp
context.Context
}
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
var err error
cacheEnabled := (conf.EnableTracing == false)
//cacheEnabled := (conf.EnableTracing == false)
if cacheEnabled {
k := sha1.Sum([]byte(req.Query))
key = string(k[:])
entry, err = cache.Get(key)
}
if len(entry) == 0 || err == bigcache.ErrEntryNotFound {
qc, err = qcompile.CompileQuery(req.Query)
if err != nil {
return err
}
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
return err
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
return errUnauthorized
}
if err != nil {
return err
}
finalSQL = sqlStmt.String()
} else if err != nil {
qc, err := qcompile.CompileQuery(c.req.Query)
if err != nil {
return err
} else {
finalSQL = string(entry)
}
vars := varMap(c)
data, skipped, err := c.resolveSQL(qc, vars)
if err != nil {
return err
}
if len(data) == 0 || skipped == 0 {
return c.render(w, data)
}
sel := qc.Query.Selects
h := xxhash.New()
// fetch the field name used within the db response json
// that are used to mark insertion points and the mapping between
// those field names and their select objects
fids, sfmap := parentFieldIds(h, sel, skipped)
// fetch the field values of the marked insertion points
// these values contain the id to be used with fetching remote data
from := jsn.Get(data, fids)
// replacement data for the marked insertion points
// key and value will be replaced by whats below
to := make([]jsn.Field, 0, len(from))
for _, id := range from {
// use the json key to find the related Select object
k1 := xxhash.Sum64(id.Key)
s, ok := sfmap[k1]
if !ok {
continue
}
p := sel[s.ParentID]
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
r, ok := rmap[k2]
if !ok {
continue
}
id := jsn.Value(id.Value)
if len(id) == 0 {
continue
}
b, err := r.Fn(req, id)
if err != nil {
return err
}
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
}
fils := []string{}
for i := range s.Cols {
fils = append(fils, s.Cols[i].Name)
}
var ob bytes.Buffer
if err = jsn.Filter(&ob, b, fils); err != nil {
return err
}
f := jsn.Field{[]byte(s.FieldName), ob.Bytes()}
to = append(to, f)
}
var ob bytes.Buffer
err = jsn.Replace(&ob, data, from, to)
if err != nil {
return err
}
// if cacheEnabled {
// if err = cache.Set(key, []byte(finalSQL)); err != nil {
// return err
// }
// }
return c.render(w, ob.Bytes())
}
func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
[]byte, uint32, error) {
//var entry []byte
//var key string
//cacheEnabled := (conf.EnableTracing == false)
// if cacheEnabled {
// k := sha1.Sum([]byte(req.Query))
// key = string(k[:])
// entry, err = cache.Get(key)
// if err != nil && err != bigcache.ErrEntryNotFound {
// return emtpy, err
// }
// if len(entry) != 0 && err == nil {
// return entry, nil
// }
// }
skipped, stmts, err := pcompile.Compile(qc)
if err != nil {
return nil, 0, err
}
t := fasttemplate.New(stmts[0], openVar, closeVar)
var sqlStmt strings.Builder
_, err = t.Execute(&sqlStmt, vars)
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := sqlStmt.String()
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
return err
}
et := time.Now()
resp := gqlResp{Data: json.RawMessage(root)}
if cacheEnabled {
if err = cache.Set(key, []byte(finalSQL)); err != nil {
return err
}
return nil, 0, err
}
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
c.res.Extensions = &extensions{newTrace(st, time.Now(), qc)}
}
json.NewEncoder(w).Encode(resp)
return nil
return []byte(root), skipped, nil
}
func (c *coreContext) render(w io.Writer, data []byte) error {
c.res.Data = json.RawMessage(data)
return json.NewEncoder(w).Encode(c.res)
}
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
[][]byte,
map[uint64]*qcode.Select) {
c := 0
for i := range sel {
s := &sel[i]
if isSkipped(skipped, s.ID) {
c++
}
}
// list of keys (and it's related value) to extract from
// the db json response
fm := make([][]byte, c)
// mapping between the above extracted key and a Select
// object
sm := make(map[uint64]*qcode.Select, c)
n := 0
for i := range sel {
s := &sel[i]
if isSkipped(skipped, s.ID) == false {
continue
}
p := sel[s.ParentID]
k := mkkey(h, s.Table, p.Table)
if r, ok := rmap[k]; ok {
fm[n] = r.IDField
n++
k := xxhash.Sum64(r.IDField)
sm[k] = s
}
}
return fm, sm
}
func isSkipped(n uint32, pos uint16) bool {
return ((n & (1 << pos)) != 0)
}
func authCheck(ctx *coreContext) bool {
return (ctx.Value(userIDKey) != nil)
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
if len(qc.Query.Selects) == 0 {
return nil
}
du := et.Sub(et)
sel := qc.Query.Selects[0]
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{sel.Table},
ParentType: "Query",
FieldName: sel.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
}

View File

@ -65,7 +65,7 @@ type resolver struct {
}
func apiv1Http(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx := &coreContext{Context: r.Context()}
if authFailBlock == authFailBlockAlways && authCheck(ctx) == false {
http.Error(w, "Not authorized", 401)
@ -79,13 +79,12 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return
}
req := &gqlReq{}
if err := json.Unmarshal(b, req); err != nil {
if err := json.Unmarshal(b, &ctx.req); err != nil {
errorResp(w, err)
return
}
if strings.EqualFold(req.OpName, introspectionQuery) {
if strings.EqualFold(ctx.req.OpName, introspectionQuery) {
dat, err := ioutil.ReadFile("test.schema")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -95,7 +94,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
return
}
err = handleReq(ctx, w, req)
err = ctx.handleReq(w, r)
if err == errUnauthorized {
http.Error(w, "Not authorized", 401)
@ -105,3 +104,12 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
errorResp(w, err)
}
}
func errorResp(w http.ResponseWriter, err error) {
if conf.DebugLevel > 0 {
logger.Error(err.Error())
}
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(gqlResp{Error: err.Error()})
}

114
serv/reso.go Normal file
View File

@ -0,0 +1,114 @@
package serv
import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/psql"
)
var (
rmap map[uint64]*resolvFn
)
type resolvFn struct {
IDField []byte
Path [][]byte
Fn func(r *http.Request, id []byte) ([]byte, error)
}
func initResolvers() {
rmap = make(map[uint64]*resolvFn)
for _, t := range conf.DB.Tables {
initRemotes(t)
}
}
func initRemotes(t configTable) {
h := xxhash.New()
for _, r := range t.Remotes {
// defines the table column to be used as an id in the
// remote request
idcol := r.ID
// if no table column specified in the config then
// use the primary key of the table as the id
if len(idcol) == 0 {
idcol = pcompile.IDColumn(t.Name)
}
idk := fmt.Sprintf("__%s_%s", t.Name, idcol)
// register a relationship between the remote data
// and the database table
key := psql.TTKey{strings.ToLower(r.Name), t.Name}
val := &psql.DBRel{
Type: psql.RelRemote,
Col1: idcol,
Col2: idk,
}
pcompile.AddRelationship(key, val)
// the function thats called to resolve this remote
// data request
fn := buildFn(r)
path := [][]byte{}
for _, p := range strings.Split(r.Path, ".") {
path = append(path, []byte(p))
}
rf := &resolvFn{
IDField: []byte(idk),
Path: path,
Fn: fn,
}
// index resolver obj by parent and child names
rmap[mkkey(h, r.Name, t.Name)] = rf
// index resolver obj by IDField
rmap[xxhash.Sum64(rf.IDField)] = rf
}
}
func buildFn(r configRemote) func(*http.Request, []byte) ([]byte, error) {
reqURL := strings.Replace(r.URL, "$id", "%s", 1)
client := &http.Client{}
h := make(http.Header, len(r.PassHeaders))
for _, v := range r.SetHeaders {
h.Set(v.Name, v.Value)
}
fn := func(inReq *http.Request, id []byte) ([]byte, error) {
req, err := http.NewRequest("GET", fmt.Sprintf(reqURL, id), nil)
if err != nil {
return nil, err
}
for _, v := range r.PassHeaders {
h.Set(v, inReq.Header.Get(v))
}
req.Header = h
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
b, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
return b, nil
}
return fn
}

View File

@ -1,13 +1,16 @@
package serv
import (
"context"
"errors"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"strings"
"time"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
@ -20,6 +23,8 @@ import (
//go:generate esc -o static.go -ignore \\.DS_Store -prefix ../web/build -private -pkg serv ../web/build
const (
serverName = "Super Graph"
authFailBlockAlways = iota + 1
authFailBlockPerQuery
authFailBlockNever
@ -29,74 +34,11 @@ var (
logger *logrus.Logger
conf *config
db *pg.DB
pcompile *psql.Compiler
qcompile *qcode.Compiler
pcompile *psql.Compiler
authFailBlock int
)
type config struct {
AppName string `mapstructure:"app_name"`
Env string
HostPort string `mapstructure:"host_port"`
WebUI bool `mapstructure:"web_ui"`
DebugLevel int `mapstructure:"debug_level"`
EnableTracing bool `mapstructure:"enable_tracing"`
AuthFailBlock string `mapstructure:"auth_fail_block"`
Inflections map[string]string
Auth struct {
Type string
Cookie string
Header string
Rails struct {
Version string
SecretKeyBase string `mapstructure:"secret_key_base"`
URL string
Password string
MaxIdle int `mapstructure:"max_idle"`
MaxActive int `mapstructure:"max_active"`
Salt string
SignSalt string `mapstructure:"sign_salt"`
AuthSalt string `mapstructure:"auth_salt"`
}
JWT struct {
Provider string
Secret string
PubKeyFile string `mapstructure:"public_key_file"`
PubKeyType string `mapstructure:"public_key_type"`
}
}
DB struct {
Type string
Host string
Port string
DBName string
User string
Password string
Schema string
PoolSize int `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
Variables map[string]string
Defaults struct {
Filter []string
Blacklist []string
}
Fields []struct {
Name string
Filter []string
Table string
Blacklist []string
}
} `mapstructure:"database"`
}
func initLog() *logrus.Logger {
log := logrus.New()
log.Formatter = new(logrus.TextFormatter)
@ -153,6 +95,15 @@ func initConf() (*config, error) {
flect.AddPlural(k, v)
}
if len(c.DB.Tables) == 0 {
c.DB.Tables = c.DB.Fields
}
for i := range c.DB.Tables {
t := c.DB.Tables[i]
t.Name = flect.Pluralize(strings.ToLower(t.Name))
}
authFailBlock = getAuthFailBlock(c)
//fmt.Printf("%#v", c)
@ -196,50 +147,31 @@ func initDB(c *config) (*pg.DB, error) {
}
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
cdb := c.DB
fm := make(map[string][]string, len(cdb.Fields))
tmap := make(map[string]string, len(cdb.Fields))
for i := range cdb.Fields {
f := cdb.Fields[i]
name := flect.Pluralize(strings.ToLower(f.Name))
if len(f.Filter) != 0 {
if f.Filter[0] == "none" {
fm[name] = []string{}
} else {
fm[name] = f.Filter
}
}
if len(f.Table) != 0 {
tmap[name] = f.Table
}
}
qc, err := qcode.NewCompiler(qcode.Config{
Filter: cdb.Defaults.Filter,
FilterMap: fm,
Blacklist: cdb.Defaults.Blacklist,
})
schema, err := psql.NewDBSchema(db)
if err != nil {
return nil, nil, err
}
schema, err := psql.NewDBSchema(db)
qc, err := qcode.NewCompiler(qcode.Config{
DefaultFilter: c.DB.Defaults.Filter,
FilterMap: c.getFilterMap(),
Blacklist: c.DB.Defaults.Blacklist,
})
if err != nil {
return nil, nil, err
}
pc := psql.NewCompiler(psql.Config{
Schema: schema,
Vars: cdb.Variables,
TableMap: tmap,
Vars: c.DB.Variables,
TableMap: c.getAliasMap(),
})
return qc, pc, nil
}
func InitAndListen() {
func Init() {
var err error
logger = initLog()
@ -259,16 +191,61 @@ func InitAndListen() {
log.Fatal(err)
}
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
initResolvers()
if conf.WebUI {
http.Handle("/", http.FileServer(_escFS(false)))
startHTTP()
}
func startHTTP() {
srv := &http.Server{
Addr: conf.HostPort,
Handler: routeHandler(),
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
}
fmt.Printf("Super-Graph listening on %s (%s)\n",
conf.HostPort, conf.Env)
idleConnsClosed := make(chan struct{})
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
<-sigint
logger.Fatal(http.ListenAndServe(conf.HostPort, nil))
if err := srv.Shutdown(context.Background()); err != nil {
log.Printf("http: %v", err)
}
close(idleConnsClosed)
}()
srv.RegisterOnShutdown(func() {
if err := db.Close(); err != nil {
log.Println(err)
}
})
fmt.Printf("%s listening on %s (%s)\n", serverName, conf.HostPort, conf.Env)
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
fmt.Println(err)
}
<-idleConnsClosed
}
func routeHandler() http.Handler {
mux := http.NewServeMux()
mux.Handle("/api/v1/graphql", withAuth(apiv1Http))
if conf.WebUI {
mux.Handle("/", http.FileServer(_escFS(false)))
}
fn := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Server", serverName)
mux.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func getConfigName() string {

View File

@ -1,44 +1,12 @@
package serv
import (
"context"
"encoding/json"
"net/http"
"time"
import "github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/qcode"
)
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
h.WriteString(k1)
h.WriteString(k2)
v := h.Sum64()
h.Reset()
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
du := et.Sub(et)
t := &trace{
Version: 1,
StartTime: st,
EndTime: et,
Duration: du,
Execution: execution{
[]resolver{
resolver{
Path: []string{qc.Query.Select.Table},
ParentType: "Query",
FieldName: qc.Query.Select.Table,
ReturnType: "object",
StartOffset: 1,
Duration: du,
},
},
},
}
return t
return v
}

View File

@ -1,15 +1,13 @@
package serv
import (
"context"
"fmt"
"io"
"strconv"
"github.com/valyala/fasttemplate"
)
func varMap(ctx context.Context, vars variables) variables {
func varMap(ctx *coreContext) variables {
userIDFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
@ -34,7 +32,8 @@ func varMap(ctx context.Context, vars variables) variables {
"USER_ID_PROVIDER": userIDProviderTag,
}
for k, v := range vars {
for k, v := range ctx.req.Vars {
var buf []byte
if _, ok := vm[k]; ok {
continue
}
@ -42,11 +41,11 @@ func varMap(ctx context.Context, vars variables) variables {
case string:
vm[k] = val
case int:
vm[k] = strconv.Itoa(val)
vm[k] = strconv.AppendInt(buf, int64(val), 10)
case int64:
vm[k] = strconv.FormatInt(val, 64)
vm[k] = strconv.AppendInt(buf, val, 10)
case float64:
vm[k] = fmt.Sprintf("%.0f", val)
vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64)
}
}
return vm