Add REST API stitching
This commit is contained in:
114
serv/config.go
Normal file
114
serv/config.go
Normal file
@ -0,0 +1,114 @@
|
||||
package serv
|
||||
|
||||
type config struct {
|
||||
AppName string `mapstructure:"app_name"`
|
||||
Env string
|
||||
HostPort string `mapstructure:"host_port"`
|
||||
WebUI bool `mapstructure:"web_ui"`
|
||||
DebugLevel int `mapstructure:"debug_level"`
|
||||
EnableTracing bool `mapstructure:"enable_tracing"`
|
||||
AuthFailBlock string `mapstructure:"auth_fail_block"`
|
||||
Inflections map[string]string
|
||||
|
||||
Auth struct {
|
||||
Type string
|
||||
Cookie string
|
||||
Header string
|
||||
|
||||
Rails struct {
|
||||
Version string
|
||||
SecretKeyBase string `mapstructure:"secret_key_base"`
|
||||
URL string
|
||||
Password string
|
||||
MaxIdle int `mapstructure:"max_idle"`
|
||||
MaxActive int `mapstructure:"max_active"`
|
||||
Salt string
|
||||
SignSalt string `mapstructure:"sign_salt"`
|
||||
AuthSalt string `mapstructure:"auth_salt"`
|
||||
}
|
||||
|
||||
JWT struct {
|
||||
Provider string
|
||||
Secret string
|
||||
PubKeyFile string `mapstructure:"public_key_file"`
|
||||
PubKeyType string `mapstructure:"public_key_type"`
|
||||
}
|
||||
}
|
||||
|
||||
DB struct {
|
||||
Type string
|
||||
Host string
|
||||
Port string
|
||||
DBName string
|
||||
User string
|
||||
Password string
|
||||
Schema string
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
|
||||
Variables map[string]string
|
||||
|
||||
Defaults struct {
|
||||
Filter []string
|
||||
Blacklist []string
|
||||
}
|
||||
|
||||
Fields []configTable
|
||||
Tables []configTable
|
||||
} `mapstructure:"database"`
|
||||
}
|
||||
|
||||
type configTable struct {
|
||||
Name string
|
||||
Filter []string
|
||||
Table string
|
||||
Blacklist []string
|
||||
Remotes []configRemote
|
||||
}
|
||||
|
||||
type configRemote struct {
|
||||
Name string
|
||||
ID string
|
||||
Path string
|
||||
URL string
|
||||
PassHeaders []string `mapstructure:"pass_headers"`
|
||||
SetHeaders []struct {
|
||||
Name string
|
||||
Value string
|
||||
} `mapstructure:"set_headers"`
|
||||
}
|
||||
|
||||
func (c *config) getAliasMap() map[string]string {
|
||||
m := make(map[string]string, len(c.DB.Tables))
|
||||
|
||||
for i := range c.DB.Tables {
|
||||
t := c.DB.Tables[i]
|
||||
|
||||
if len(t.Table) == 0 {
|
||||
continue
|
||||
}
|
||||
m[t.Name] = t.Table
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *config) getFilterMap() map[string][]string {
|
||||
m := make(map[string][]string, len(c.DB.Tables))
|
||||
|
||||
for i := range c.DB.Tables {
|
||||
t := c.DB.Tables[i]
|
||||
|
||||
if len(t.Filter) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if t.Filter[0] == "none" {
|
||||
m[t.Name] = []string{}
|
||||
} else {
|
||||
m[t.Name] = t.Filter
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
304
serv/core.go
304
serv/core.go
@ -1,100 +1,282 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/allegro/bigcache"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
var (
|
||||
cache, _ = bigcache.NewBigCache(bigcache.DefaultConfig(24 * time.Hour))
|
||||
const (
|
||||
empty = ""
|
||||
)
|
||||
|
||||
func handleReq(ctx context.Context, w io.Writer, req *gqlReq) error {
|
||||
var key, finalSQL string
|
||||
var qc *qcode.QCode
|
||||
// var (
|
||||
// cache, _ = bigcache.NewBigCache(bigcache.DefaultConfig(24 * time.Hour))
|
||||
// )
|
||||
|
||||
var entry []byte
|
||||
type coreContext struct {
|
||||
req gqlReq
|
||||
res gqlResp
|
||||
context.Context
|
||||
}
|
||||
|
||||
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||
var err error
|
||||
|
||||
cacheEnabled := (conf.EnableTracing == false)
|
||||
//cacheEnabled := (conf.EnableTracing == false)
|
||||
|
||||
if cacheEnabled {
|
||||
k := sha1.Sum([]byte(req.Query))
|
||||
key = string(k[:])
|
||||
entry, err = cache.Get(key)
|
||||
}
|
||||
|
||||
if len(entry) == 0 || err == bigcache.ErrEntryNotFound {
|
||||
qc, err = qcompile.CompileQuery(req.Query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sqlStmt strings.Builder
|
||||
|
||||
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
||||
sqlStmt.Reset()
|
||||
|
||||
_, err = t.Execute(&sqlStmt, varMap(ctx, req.Vars))
|
||||
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(ctx) == false {
|
||||
return errUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finalSQL = sqlStmt.String()
|
||||
|
||||
} else if err != nil {
|
||||
qc, err := qcompile.CompileQuery(c.req.Query)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
} else {
|
||||
finalSQL = string(entry)
|
||||
}
|
||||
|
||||
vars := varMap(c)
|
||||
|
||||
data, skipped, err := c.resolveSQL(qc, vars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(data) == 0 || skipped == 0 {
|
||||
return c.render(w, data)
|
||||
}
|
||||
|
||||
sel := qc.Query.Selects
|
||||
h := xxhash.New()
|
||||
|
||||
// fetch the field name used within the db response json
|
||||
// that are used to mark insertion points and the mapping between
|
||||
// those field names and their select objects
|
||||
fids, sfmap := parentFieldIds(h, sel, skipped)
|
||||
|
||||
// fetch the field values of the marked insertion points
|
||||
// these values contain the id to be used with fetching remote data
|
||||
from := jsn.Get(data, fids)
|
||||
|
||||
// replacement data for the marked insertion points
|
||||
// key and value will be replaced by whats below
|
||||
to := make([]jsn.Field, 0, len(from))
|
||||
|
||||
for _, id := range from {
|
||||
// use the json key to find the related Select object
|
||||
k1 := xxhash.Sum64(id.Key)
|
||||
|
||||
s, ok := sfmap[k1]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
p := sel[s.ParentID]
|
||||
|
||||
// then use the Table nme in the Select and it's parent
|
||||
// to find the resolver to use for this relationship
|
||||
k2 := mkkey(h, s.Table, p.Table)
|
||||
|
||||
r, ok := rmap[k2]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
id := jsn.Value(id.Value)
|
||||
if len(id) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
b, err := r.Fn(req, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(r.Path) != 0 {
|
||||
b = jsn.Strip(b, r.Path)
|
||||
}
|
||||
|
||||
fils := []string{}
|
||||
for i := range s.Cols {
|
||||
fils = append(fils, s.Cols[i].Name)
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
if err = jsn.Filter(&ob, b, fils); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f := jsn.Field{[]byte(s.FieldName), ob.Bytes()}
|
||||
to = append(to, f)
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
err = jsn.Replace(&ob, data, from, to)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if cacheEnabled {
|
||||
// if err = cache.Set(key, []byte(finalSQL)); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
return c.render(w, ob.Bytes())
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
|
||||
[]byte, uint32, error) {
|
||||
//var entry []byte
|
||||
//var key string
|
||||
|
||||
//cacheEnabled := (conf.EnableTracing == false)
|
||||
|
||||
// if cacheEnabled {
|
||||
// k := sha1.Sum([]byte(req.Query))
|
||||
// key = string(k[:])
|
||||
// entry, err = cache.Get(key)
|
||||
|
||||
// if err != nil && err != bigcache.ErrEntryNotFound {
|
||||
// return emtpy, err
|
||||
// }
|
||||
|
||||
// if len(entry) != 0 && err == nil {
|
||||
// return entry, nil
|
||||
// }
|
||||
// }
|
||||
|
||||
skipped, stmts, err := pcompile.Compile(qc)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
t := fasttemplate.New(stmts[0], openVar, closeVar)
|
||||
|
||||
var sqlStmt strings.Builder
|
||||
_, err = t.Execute(&sqlStmt, vars)
|
||||
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(c) == false {
|
||||
return nil, 0, errUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
finalSQL := sqlStmt.String()
|
||||
|
||||
if conf.DebugLevel > 0 {
|
||||
fmt.Println(finalSQL)
|
||||
}
|
||||
|
||||
st := time.Now()
|
||||
|
||||
var root json.RawMessage
|
||||
_, err = db.Query(pg.Scan(&root), finalSQL)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
et := time.Now()
|
||||
resp := gqlResp{Data: json.RawMessage(root)}
|
||||
|
||||
if cacheEnabled {
|
||||
if err = cache.Set(key, []byte(finalSQL)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if conf.EnableTracing {
|
||||
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
||||
c.res.Extensions = &extensions{newTrace(st, time.Now(), qc)}
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
return nil
|
||||
return []byte(root), skipped, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) render(w io.Writer, data []byte) error {
|
||||
c.res.Data = json.RawMessage(data)
|
||||
return json.NewEncoder(w).Encode(c.res)
|
||||
}
|
||||
|
||||
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
||||
[][]byte,
|
||||
map[uint64]*qcode.Select) {
|
||||
|
||||
c := 0
|
||||
for i := range sel {
|
||||
s := &sel[i]
|
||||
if isSkipped(skipped, s.ID) {
|
||||
c++
|
||||
}
|
||||
}
|
||||
|
||||
// list of keys (and it's related value) to extract from
|
||||
// the db json response
|
||||
fm := make([][]byte, c)
|
||||
|
||||
// mapping between the above extracted key and a Select
|
||||
// object
|
||||
sm := make(map[uint64]*qcode.Select, c)
|
||||
n := 0
|
||||
|
||||
for i := range sel {
|
||||
s := &sel[i]
|
||||
|
||||
if isSkipped(skipped, s.ID) == false {
|
||||
continue
|
||||
}
|
||||
|
||||
p := sel[s.ParentID]
|
||||
k := mkkey(h, s.Table, p.Table)
|
||||
|
||||
if r, ok := rmap[k]; ok {
|
||||
fm[n] = r.IDField
|
||||
n++
|
||||
|
||||
k := xxhash.Sum64(r.IDField)
|
||||
sm[k] = s
|
||||
}
|
||||
}
|
||||
|
||||
return fm, sm
|
||||
}
|
||||
|
||||
func isSkipped(n uint32, pos uint16) bool {
|
||||
return ((n & (1 << pos)) != 0)
|
||||
}
|
||||
|
||||
func authCheck(ctx *coreContext) bool {
|
||||
return (ctx.Value(userIDKey) != nil)
|
||||
}
|
||||
|
||||
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
||||
if len(qc.Query.Selects) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
du := et.Sub(et)
|
||||
sel := qc.Query.Selects[0]
|
||||
|
||||
t := &trace{
|
||||
Version: 1,
|
||||
StartTime: st,
|
||||
EndTime: et,
|
||||
Duration: du,
|
||||
Execution: execution{
|
||||
[]resolver{
|
||||
resolver{
|
||||
Path: []string{sel.Table},
|
||||
ParentType: "Query",
|
||||
FieldName: sel.Table,
|
||||
ReturnType: "object",
|
||||
StartOffset: 1,
|
||||
Duration: du,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
18
serv/http.go
18
serv/http.go
@ -65,7 +65,7 @@ type resolver struct {
|
||||
}
|
||||
|
||||
func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx := &coreContext{Context: r.Context()}
|
||||
|
||||
if authFailBlock == authFailBlockAlways && authCheck(ctx) == false {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
@ -79,13 +79,12 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
req := &gqlReq{}
|
||||
if err := json.Unmarshal(b, req); err != nil {
|
||||
if err := json.Unmarshal(b, &ctx.req); err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.EqualFold(req.OpName, introspectionQuery) {
|
||||
if strings.EqualFold(ctx.req.OpName, introspectionQuery) {
|
||||
dat, err := ioutil.ReadFile("test.schema")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
@ -95,7 +94,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = handleReq(ctx, w, req)
|
||||
err = ctx.handleReq(w, r)
|
||||
|
||||
if err == errUnauthorized {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
@ -105,3 +104,12 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
errorResp(w, err)
|
||||
}
|
||||
}
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
if conf.DebugLevel > 0 {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(gqlResp{Error: err.Error()})
|
||||
}
|
||||
|
114
serv/reso.go
Normal file
114
serv/reso.go
Normal file
@ -0,0 +1,114 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/psql"
|
||||
)
|
||||
|
||||
var (
|
||||
rmap map[uint64]*resolvFn
|
||||
)
|
||||
|
||||
type resolvFn struct {
|
||||
IDField []byte
|
||||
Path [][]byte
|
||||
Fn func(r *http.Request, id []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
func initResolvers() {
|
||||
rmap = make(map[uint64]*resolvFn)
|
||||
|
||||
for _, t := range conf.DB.Tables {
|
||||
initRemotes(t)
|
||||
}
|
||||
}
|
||||
|
||||
func initRemotes(t configTable) {
|
||||
h := xxhash.New()
|
||||
|
||||
for _, r := range t.Remotes {
|
||||
// defines the table column to be used as an id in the
|
||||
// remote request
|
||||
idcol := r.ID
|
||||
|
||||
// if no table column specified in the config then
|
||||
// use the primary key of the table as the id
|
||||
if len(idcol) == 0 {
|
||||
idcol = pcompile.IDColumn(t.Name)
|
||||
}
|
||||
idk := fmt.Sprintf("__%s_%s", t.Name, idcol)
|
||||
|
||||
// register a relationship between the remote data
|
||||
// and the database table
|
||||
key := psql.TTKey{strings.ToLower(r.Name), t.Name}
|
||||
val := &psql.DBRel{
|
||||
Type: psql.RelRemote,
|
||||
Col1: idcol,
|
||||
Col2: idk,
|
||||
}
|
||||
pcompile.AddRelationship(key, val)
|
||||
|
||||
// the function thats called to resolve this remote
|
||||
// data request
|
||||
fn := buildFn(r)
|
||||
|
||||
path := [][]byte{}
|
||||
for _, p := range strings.Split(r.Path, ".") {
|
||||
path = append(path, []byte(p))
|
||||
}
|
||||
|
||||
rf := &resolvFn{
|
||||
IDField: []byte(idk),
|
||||
Path: path,
|
||||
Fn: fn,
|
||||
}
|
||||
|
||||
// index resolver obj by parent and child names
|
||||
rmap[mkkey(h, r.Name, t.Name)] = rf
|
||||
|
||||
// index resolver obj by IDField
|
||||
rmap[xxhash.Sum64(rf.IDField)] = rf
|
||||
}
|
||||
}
|
||||
|
||||
func buildFn(r configRemote) func(*http.Request, []byte) ([]byte, error) {
|
||||
reqURL := strings.Replace(r.URL, "$id", "%s", 1)
|
||||
client := &http.Client{}
|
||||
h := make(http.Header, len(r.PassHeaders))
|
||||
|
||||
for _, v := range r.SetHeaders {
|
||||
h.Set(v.Name, v.Value)
|
||||
}
|
||||
|
||||
fn := func(inReq *http.Request, id []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf(reqURL, id), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, v := range r.PassHeaders {
|
||||
h.Set(v, inReq.Header.Get(v))
|
||||
}
|
||||
req.Header = h
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
return fn
|
||||
}
|
175
serv/serv.go
175
serv/serv.go
@ -1,13 +1,16 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
@ -20,6 +23,8 @@ import (
|
||||
//go:generate esc -o static.go -ignore \\.DS_Store -prefix ../web/build -private -pkg serv ../web/build
|
||||
|
||||
const (
|
||||
serverName = "Super Graph"
|
||||
|
||||
authFailBlockAlways = iota + 1
|
||||
authFailBlockPerQuery
|
||||
authFailBlockNever
|
||||
@ -29,74 +34,11 @@ var (
|
||||
logger *logrus.Logger
|
||||
conf *config
|
||||
db *pg.DB
|
||||
pcompile *psql.Compiler
|
||||
qcompile *qcode.Compiler
|
||||
pcompile *psql.Compiler
|
||||
authFailBlock int
|
||||
)
|
||||
|
||||
type config struct {
|
||||
AppName string `mapstructure:"app_name"`
|
||||
Env string
|
||||
HostPort string `mapstructure:"host_port"`
|
||||
WebUI bool `mapstructure:"web_ui"`
|
||||
DebugLevel int `mapstructure:"debug_level"`
|
||||
EnableTracing bool `mapstructure:"enable_tracing"`
|
||||
AuthFailBlock string `mapstructure:"auth_fail_block"`
|
||||
Inflections map[string]string
|
||||
|
||||
Auth struct {
|
||||
Type string
|
||||
Cookie string
|
||||
Header string
|
||||
|
||||
Rails struct {
|
||||
Version string
|
||||
SecretKeyBase string `mapstructure:"secret_key_base"`
|
||||
URL string
|
||||
Password string
|
||||
MaxIdle int `mapstructure:"max_idle"`
|
||||
MaxActive int `mapstructure:"max_active"`
|
||||
Salt string
|
||||
SignSalt string `mapstructure:"sign_salt"`
|
||||
AuthSalt string `mapstructure:"auth_salt"`
|
||||
}
|
||||
|
||||
JWT struct {
|
||||
Provider string
|
||||
Secret string
|
||||
PubKeyFile string `mapstructure:"public_key_file"`
|
||||
PubKeyType string `mapstructure:"public_key_type"`
|
||||
}
|
||||
}
|
||||
|
||||
DB struct {
|
||||
Type string
|
||||
Host string
|
||||
Port string
|
||||
DBName string
|
||||
User string
|
||||
Password string
|
||||
Schema string
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
|
||||
Variables map[string]string
|
||||
|
||||
Defaults struct {
|
||||
Filter []string
|
||||
Blacklist []string
|
||||
}
|
||||
|
||||
Fields []struct {
|
||||
Name string
|
||||
Filter []string
|
||||
Table string
|
||||
Blacklist []string
|
||||
}
|
||||
} `mapstructure:"database"`
|
||||
}
|
||||
|
||||
func initLog() *logrus.Logger {
|
||||
log := logrus.New()
|
||||
log.Formatter = new(logrus.TextFormatter)
|
||||
@ -153,6 +95,15 @@ func initConf() (*config, error) {
|
||||
flect.AddPlural(k, v)
|
||||
}
|
||||
|
||||
if len(c.DB.Tables) == 0 {
|
||||
c.DB.Tables = c.DB.Fields
|
||||
}
|
||||
|
||||
for i := range c.DB.Tables {
|
||||
t := c.DB.Tables[i]
|
||||
t.Name = flect.Pluralize(strings.ToLower(t.Name))
|
||||
}
|
||||
|
||||
authFailBlock = getAuthFailBlock(c)
|
||||
|
||||
//fmt.Printf("%#v", c)
|
||||
@ -196,50 +147,31 @@ func initDB(c *config) (*pg.DB, error) {
|
||||
}
|
||||
|
||||
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||
cdb := c.DB
|
||||
|
||||
fm := make(map[string][]string, len(cdb.Fields))
|
||||
tmap := make(map[string]string, len(cdb.Fields))
|
||||
|
||||
for i := range cdb.Fields {
|
||||
f := cdb.Fields[i]
|
||||
name := flect.Pluralize(strings.ToLower(f.Name))
|
||||
if len(f.Filter) != 0 {
|
||||
if f.Filter[0] == "none" {
|
||||
fm[name] = []string{}
|
||||
} else {
|
||||
fm[name] = f.Filter
|
||||
}
|
||||
}
|
||||
if len(f.Table) != 0 {
|
||||
tmap[name] = f.Table
|
||||
}
|
||||
}
|
||||
|
||||
qc, err := qcode.NewCompiler(qcode.Config{
|
||||
Filter: cdb.Defaults.Filter,
|
||||
FilterMap: fm,
|
||||
Blacklist: cdb.Defaults.Blacklist,
|
||||
})
|
||||
schema, err := psql.NewDBSchema(db)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
schema, err := psql.NewDBSchema(db)
|
||||
qc, err := qcode.NewCompiler(qcode.Config{
|
||||
DefaultFilter: c.DB.Defaults.Filter,
|
||||
FilterMap: c.getFilterMap(),
|
||||
Blacklist: c.DB.Defaults.Blacklist,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pc := psql.NewCompiler(psql.Config{
|
||||
Schema: schema,
|
||||
Vars: cdb.Variables,
|
||||
TableMap: tmap,
|
||||
Vars: c.DB.Variables,
|
||||
TableMap: c.getAliasMap(),
|
||||
})
|
||||
|
||||
return qc, pc, nil
|
||||
}
|
||||
|
||||
func InitAndListen() {
|
||||
func Init() {
|
||||
var err error
|
||||
|
||||
logger = initLog()
|
||||
@ -259,16 +191,61 @@ func InitAndListen() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
|
||||
initResolvers()
|
||||
|
||||
if conf.WebUI {
|
||||
http.Handle("/", http.FileServer(_escFS(false)))
|
||||
startHTTP()
|
||||
}
|
||||
|
||||
func startHTTP() {
|
||||
srv := &http.Server{
|
||||
Addr: conf.HostPort,
|
||||
Handler: routeHandler(),
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
}
|
||||
|
||||
fmt.Printf("Super-Graph listening on %s (%s)\n",
|
||||
conf.HostPort, conf.Env)
|
||||
idleConnsClosed := make(chan struct{})
|
||||
go func() {
|
||||
sigint := make(chan os.Signal, 1)
|
||||
signal.Notify(sigint, os.Interrupt)
|
||||
<-sigint
|
||||
|
||||
logger.Fatal(http.ListenAndServe(conf.HostPort, nil))
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
log.Printf("http: %v", err)
|
||||
}
|
||||
close(idleConnsClosed)
|
||||
}()
|
||||
|
||||
srv.RegisterOnShutdown(func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
})
|
||||
|
||||
fmt.Printf("%s listening on %s (%s)\n", serverName, conf.HostPort, conf.Env)
|
||||
|
||||
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
<-idleConnsClosed
|
||||
}
|
||||
|
||||
func routeHandler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.Handle("/api/v1/graphql", withAuth(apiv1Http))
|
||||
if conf.WebUI {
|
||||
mux.Handle("/", http.FileServer(_escFS(false)))
|
||||
}
|
||||
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Server", serverName)
|
||||
mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
func getConfigName() string {
|
||||
|
@ -1,44 +1,12 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
import "github.com/cespare/xxhash/v2"
|
||||
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
||||
h.WriteString(k1)
|
||||
h.WriteString(k2)
|
||||
v := h.Sum64()
|
||||
h.Reset()
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
||||
http.Error(w, string(b), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func authCheck(ctx context.Context) bool {
|
||||
return (ctx.Value(userIDKey) != nil)
|
||||
}
|
||||
|
||||
func newTrace(st, et time.Time, qc *qcode.QCode) *trace {
|
||||
du := et.Sub(et)
|
||||
|
||||
t := &trace{
|
||||
Version: 1,
|
||||
StartTime: st,
|
||||
EndTime: et,
|
||||
Duration: du,
|
||||
Execution: execution{
|
||||
[]resolver{
|
||||
resolver{
|
||||
Path: []string{qc.Query.Select.Table},
|
||||
ParentType: "Query",
|
||||
FieldName: qc.Query.Select.Table,
|
||||
ReturnType: "object",
|
||||
StartOffset: 1,
|
||||
Duration: du,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return t
|
||||
return v
|
||||
}
|
||||
|
13
serv/vars.go
13
serv/vars.go
@ -1,15 +1,13 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
func varMap(ctx context.Context, vars variables) variables {
|
||||
func varMap(ctx *coreContext) variables {
|
||||
userIDFn := func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
@ -34,7 +32,8 @@ func varMap(ctx context.Context, vars variables) variables {
|
||||
"USER_ID_PROVIDER": userIDProviderTag,
|
||||
}
|
||||
|
||||
for k, v := range vars {
|
||||
for k, v := range ctx.req.Vars {
|
||||
var buf []byte
|
||||
if _, ok := vm[k]; ok {
|
||||
continue
|
||||
}
|
||||
@ -42,11 +41,11 @@ func varMap(ctx context.Context, vars variables) variables {
|
||||
case string:
|
||||
vm[k] = val
|
||||
case int:
|
||||
vm[k] = strconv.Itoa(val)
|
||||
vm[k] = strconv.AppendInt(buf, int64(val), 10)
|
||||
case int64:
|
||||
vm[k] = strconv.FormatInt(val, 64)
|
||||
vm[k] = strconv.AppendInt(buf, val, 10)
|
||||
case float64:
|
||||
vm[k] = fmt.Sprintf("%.0f", val)
|
||||
vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64)
|
||||
}
|
||||
}
|
||||
return vm
|
||||
|
Reference in New Issue
Block a user