Refactor Super Graph into a library #26

This commit is contained in:
Vikram Rangnekar 2020-04-10 02:27:43 -04:00
parent e102da839e
commit 7831d27345
200 changed files with 3590 additions and 4447 deletions

1
.gitignore vendored
View File

@ -24,7 +24,6 @@
/demo/tmp /demo/tmp
.vscode .vscode
main
.DS_Store .DS_Store
.swp .swp
.release .release

View File

@ -7,7 +7,7 @@ rules:
- name: run - name: run
match: \.go$ match: \.go$
ignore: web|examples|docs|_test\.go$ ignore: web|examples|docs|_test\.go$
command: go run main.go serv command: go run cmd/main.go serv
- name: test - name: test
match: _test\.go$ match: _test\.go$
command: go test -cover {PKG} command: go test -cover {PKG}

View File

@ -1,7 +1,7 @@
# stage: 1 # stage: 1
FROM node:10 as react-build FROM node:10 as react-build
WORKDIR /web WORKDIR /web
COPY web/ ./ COPY /cmd/internal/serv/web/ ./
RUN yarn RUN yarn
RUN yarn build RUN yarn build
@ -22,8 +22,8 @@ RUN chmod 755 /usr/local/bin/sops
WORKDIR /app WORKDIR /app
COPY . /app COPY . /app
RUN mkdir -p /app/web/build RUN mkdir -p /app/cmd/internal/serv/web/build
COPY --from=react-build /web/build/ ./web/build/ COPY --from=react-build /web/build/ ./cmd/internal/serv/web/build
RUN go mod vendor RUN go mod vendor
RUN make build RUN make build
@ -41,7 +41,7 @@ RUN mkdir -p /config
COPY --from=go-build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ COPY --from=go-build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=go-build /app/config/* /config/ COPY --from=go-build /app/config/* /config/
COPY --from=go-build /app/super-graph . COPY --from=go-build /app/super-graph .
COPY --from=go-build /app/scripts/start.sh . COPY --from=go-build /app/cmd/scripts/start.sh .
COPY --from=go-build /usr/local/bin/sops . COPY --from=go-build /usr/local/bin/sops .
RUN chmod +x /super-graph RUN chmod +x /super-graph

View File

@ -28,14 +28,14 @@ BIN_DIR := $(GOPATH)/bin
GORICE := $(BIN_DIR)/rice GORICE := $(BIN_DIR)/rice
GOLANGCILINT := $(BIN_DIR)/golangci-lint GOLANGCILINT := $(BIN_DIR)/golangci-lint
GITCHGLOG := $(BIN_DIR)/git-chglog GITCHGLOG := $(BIN_DIR)/git-chglog
WEB_BUILD_DIR := ./web/build/manifest.json WEB_BUILD_DIR := ./cmd/internal/serv/web/build/manifest.json
$(GORICE): $(GORICE):
@GO111MODULE=off go get -u github.com/GeertJohan/go.rice/rice @GO111MODULE=off go get -u github.com/GeertJohan/go.rice/rice
$(WEB_BUILD_DIR): $(WEB_BUILD_DIR):
@echo "First install Yarn and create a build of the web UI found under ./web" @echo "First install Yarn and create a build of the web UI then re-run make install"
@echo "Command: cd web && yarn && yarn build" @echo "Run this command: yarn --cwd cmd/internal/serv/web/ build"
@exit 1 @exit 1
$(GITCHGLOG): $(GITCHGLOG):
@ -57,7 +57,7 @@ os = $(word 1, $@)
$(PLATFORMS): lint test $(PLATFORMS): lint test
@mkdir -p release @mkdir -p release
@GOOS=$(os) GOARCH=amd64 go build $(BUILD_FLAGS) -o release/$(BINARY)-$(BUILD_VERSION)-$(os)-amd64 @GOOS=$(os) GOARCH=amd64 go build $(BUILD_FLAGS) -o release/$(BINARY)-$(BUILD_VERSION)-$(os)-amd64 cmd/main.go
release: windows linux darwin release: windows linux darwin
@ -69,7 +69,7 @@ gen: $(GORICE) $(WEB_BUILD_DIR)
@go generate ./... @go generate ./...
$(BINARY): clean $(BINARY): clean
@go build $(BUILD_FLAGS) -o $(BINARY) @go build $(BUILD_FLAGS) -o $(BINARY) cmd/main.go
clean: clean:
@rm -f $(BINARY) @rm -f $(BINARY)
@ -81,7 +81,7 @@ install: gen
@echo $(GOPATH) @echo $(GOPATH)
@echo "Commit Hash: `git rev-parse HEAD`" @echo "Commit Hash: `git rev-parse HEAD`"
@echo "Old Hash: `shasum $(GOPATH)/bin/$(BINARY) 2>/dev/null | cut -c -32`" @echo "Old Hash: `shasum $(GOPATH)/bin/$(BINARY) 2>/dev/null | cut -c -32`"
@go install $(BUILD_FLAGS) @go install $(BUILD_FLAGS) cmd
@echo "New Hash:" `shasum $(GOPATH)/bin/$(BINARY) 2>/dev/null | cut -c -32` @echo "New Hash:" `shasum $(GOPATH)/bin/$(BINARY) 2>/dev/null | cut -c -32`
uninstall: clean uninstall: clean

View File

@ -3,11 +3,13 @@ package serv
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/dosco/super-graph/config"
) )
type actionFn func(w http.ResponseWriter, r *http.Request) error type actionFn func(w http.ResponseWriter, r *http.Request) error
func newAction(a configAction) (http.Handler, error) { func newAction(a *config.Action) (http.Handler, error) {
var fn actionFn var fn actionFn
var err error var err error
@ -23,17 +25,16 @@ func newAction(a configAction) (http.Handler, error) {
httpFn := func(w http.ResponseWriter, r *http.Request) { httpFn := func(w http.ResponseWriter, r *http.Request) {
if err := fn(w, r); err != nil { if err := fn(w, r); err != nil {
errlog.Error().Err(err).Send() renderErr(w, err, nil)
errorResp(w, err)
} }
} }
return http.HandlerFunc(httpFn), nil return http.HandlerFunc(httpFn), nil
} }
func newSQLAction(a configAction) (actionFn, error) { func newSQLAction(a *config.Action) (actionFn, error) {
fn := func(w http.ResponseWriter, r *http.Request) error { fn := func(w http.ResponseWriter, r *http.Request) error {
_, err := db.Exec(r.Context(), a.SQL) _, err := db.ExecContext(r.Context(), a.SQL)
return err return err
} }

View File

@ -1,17 +1,17 @@
package serv package serv
import ( import (
"database/sql"
"fmt" "fmt"
_log "log"
"os"
"runtime" "runtime"
"strings" "strings"
"github.com/dosco/super-graph/allow" "github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/rs/zerolog"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"go.uber.org/zap"
) )
//go:generate rice embed-go //go:generate rice embed-go
@ -29,21 +29,17 @@ var (
) )
var ( var (
logger zerolog.Logger // logger for everything but errors log *_log.Logger // logger
errlog zerolog.Logger // logger for errors includes line numbers zlog *zap.Logger // fast logger
conf *config // parsed config conf *config.Config // parsed config
confPath string // path to the config file confPath string // path to the config file
db *pgxpool.Pool // database connection pool db *sql.DB // database connection pool
schema *psql.DBSchema // database tables, columns and relationships
allowList *allow.List // allow.list is contains queries allowed in production
qcompile *qcode.Compiler // qcode compiler
pcompile *psql.Compiler // postgres sql compiler
secretKey [32]byte // encryption key secretKey [32]byte // encryption key
internalKey [32]byte // encryption key used for internal needs
) )
func Cmd() { func Cmd() {
initLog() log = _log.New(os.Stdout, "", 0)
zlog = zap.NewExample()
rootCmd := &cobra.Command{ rootCmd := &cobra.Command{
Use: "super-graph", Use: "super-graph",
@ -149,11 +145,11 @@ e.g. db:migrate -+1
Run: cmdVersion, Run: cmdVersion,
}) })
rootCmd.Flags().StringVar(&confPath, rootCmd.PersistentFlags().StringVar(&confPath,
"path", "./config", "path to config files") "path", "./config", "path to config files")
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
errlog.Fatal().Err(err).Send() log.Fatalf("ERR %s", err)
} }
} }

View File

@ -0,0 +1,29 @@
package serv
import (
"fmt"
"os"
"github.com/dosco/super-graph/config"
"github.com/spf13/cobra"
)
func cmdConfDump(cmd *cobra.Command, args []string) {
if len(args) != 1 {
cmd.Help() //nolint: errcheck
os.Exit(1)
}
fname := fmt.Sprintf("%s.%s", config.GetConfigName(), args[0])
conf, err := initConf()
if err != nil {
log.Fatalf("ERR failed to read config: %s", err)
}
if err := conf.WriteConfigAs(fname); err != nil {
log.Fatalf("ERR failed to write config: %s", err)
}
log.Printf("INF config dumped to ./%s", fname)
}

View File

@ -1,7 +1,6 @@
package serv package serv
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path" "path"
@ -10,7 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/dosco/super-graph/migrate" "github.com/dosco/super-graph/cmd/internal/serv/internal/migrate"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -27,7 +26,7 @@ func cmdDBSetup(cmd *cobra.Command, args []string) {
cmdDBCreate(cmd, []string{}) cmdDBCreate(cmd, []string{})
cmdDBMigrate(cmd, []string{"up"}) cmdDBMigrate(cmd, []string{"up"})
sfile := path.Join(confPath, conf.SeedFile) sfile := path.Join(conf.ConfigPathUsed(), conf.SeedFile)
_, err := os.Stat(sfile) _, err := os.Stat(sfile)
if err == nil { if err == nil {
@ -36,61 +35,59 @@ func cmdDBSetup(cmd *cobra.Command, args []string) {
} }
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile) log.Fatalf("ERR unable to check if '%s' exists: %s", sfile, err)
} }
logger.Warn().Msgf("failed to read seed file '%s'", sfile) log.Printf("WRN failed to read seed file '%s'", sfile)
} }
func cmdDBReset(cmd *cobra.Command, args []string) { func cmdDBReset(cmd *cobra.Command, args []string) {
initConfOnce() initConfOnce()
if conf.Production { if conf.Production {
errlog.Fatal().Msg("db:reset does not work in production") log.Fatalln("ERR db:reset does not work in production")
return
} }
cmdDBDrop(cmd, []string{}) cmdDBDrop(cmd, []string{})
cmdDBSetup(cmd, []string{}) cmdDBSetup(cmd, []string{})
} }
func cmdDBCreate(cmd *cobra.Command, args []string) { func cmdDBCreate(cmd *cobra.Command, args []string) {
initConfOnce() initConfOnce()
ctx := context.Background()
conn, err := initDB(conf, false) db, err := initDB(conf)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to connect to database") log.Fatalf("ERR failed to connect to database: %s", err)
} }
defer conn.Close(ctx) defer db.Close()
sql := fmt.Sprintf(`CREATE DATABASE "%s"`, conf.DB.DBName) sql := fmt.Sprintf(`CREATE DATABASE "%s"`, conf.DB.DBName)
_, err = conn.Exec(ctx, sql) _, err = db.Exec(sql)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to create database") log.Fatalf("ERR failed to create database: %s", err)
} }
logger.Info().Msgf("created database '%s'", conf.DB.DBName) log.Printf("INF created database '%s'", conf.DB.DBName)
} }
func cmdDBDrop(cmd *cobra.Command, args []string) { func cmdDBDrop(cmd *cobra.Command, args []string) {
initConfOnce() initConfOnce()
ctx := context.Background()
conn, err := initDB(conf, false) db, err := initDB(conf)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to connect to database") log.Fatalf("ERR failed to connect to database: %s", err)
} }
defer conn.Close(ctx) defer db.Close()
sql := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, conf.DB.DBName) sql := fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, conf.DB.DBName)
_, err = conn.Exec(ctx, sql) _, err = db.Exec(sql)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to create database") log.Fatalf("ERR failed to drop database: %s", err)
} }
logger.Info().Msgf("dropped database '%s'", conf.DB.DBName) log.Printf("INF dropped database '%s'", conf.DB.DBName)
} }
func cmdDBNew(cmd *cobra.Command, args []string) { func cmdDBNew(cmd *cobra.Command, args []string) {
@ -104,8 +101,7 @@ func cmdDBNew(cmd *cobra.Command, args []string) {
m, err := migrate.FindMigrations(conf.MigrationsPath) m, err := migrate.FindMigrations(conf.MigrationsPath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error loading migrations:\n %v\n", err) log.Fatalf("ERR error loading migrations: %s", err)
os.Exit(1)
} }
mname := fmt.Sprintf("%d_%s.sql", len(m), name) mname := fmt.Sprintf("%d_%s.sql", len(m), name)
@ -114,17 +110,16 @@ func cmdDBNew(cmd *cobra.Command, args []string) {
mpath := filepath.Join(conf.MigrationsPath, mname) mpath := filepath.Join(conf.MigrationsPath, mname)
mfile, err := os.OpenFile(mpath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0666) mfile, err := os.OpenFile(mpath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0666)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) log.Fatalf("ERR %s", err)
os.Exit(1)
} }
defer mfile.Close() defer mfile.Close()
_, err = mfile.WriteString(newMigrationText) _, err = mfile.WriteString(newMigrationText)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) log.Fatalf("ERR %s", err)
os.Exit(1)
} }
logger.Info().Msgf("created migration '%s'", mpath)
log.Printf("INR created migration '%s'", mpath)
} }
func cmdDBMigrate(cmd *cobra.Command, args []string) { func cmdDBMigrate(cmd *cobra.Command, args []string) {
@ -136,30 +131,30 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
initConfOnce() initConfOnce()
dest := args[0] dest := args[0]
conn, err := initDB(conf, true) conn, err := initDB(conf)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to connect to database") log.Fatalf("ERR failed to connect to database: %s", err)
} }
defer conn.Close(context.Background()) defer conn.Close()
m, err := migrate.NewMigrator(conn, "schema_version") m, err := migrate.NewMigrator(conn, "schema_version")
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to initializing migrator") log.Fatalf("ERR failed to initializing migrator: %s", err)
} }
m.Data = getMigrationVars() m.Data = getMigrationVars()
err = m.LoadMigrations(conf.MigrationsPath) err = m.LoadMigrations(path.Join(conf.ConfigPathUsed(), conf.MigrationsPath))
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to load migrations") log.Fatalf("ERR failed to load migrations: %s", err)
} }
if len(m.Migrations) == 0 { if len(m.Migrations) == 0 {
errlog.Fatal().Msg("No migrations found") log.Fatalf("ERR no migrations found")
} }
m.OnStart = func(sequence int32, name, direction, sql string) { m.OnStart = func(sequence int32, name, direction, sql string) {
logger.Info().Msgf("%s executing %s %s\n%s\n\n", log.Printf("INF %s executing %s %s\n%s\n\n",
time.Now().Format("2006-01-02 15:04:05"), name, direction, sql) time.Now().Format("2006-01-02 15:04:05"), name, direction, sql)
} }
@ -174,7 +169,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
var n int64 var n int64
n, err = strconv.ParseInt(d, 10, 32) n, err = strconv.ParseInt(d, 10, 32)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("invalid destination") log.Fatalf("ERR invalid destination: %s", err)
} }
return int32(n) return int32(n)
} }
@ -203,58 +198,56 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
} }
if err != nil { if err != nil {
logger.Fatal().Err(err).Send() log.Fatalf("ERR %s", err)
// if err, ok := err.(m.MigrationPgError); ok { // if err, ok := err.(m.MigrationPgError); ok {
// if err.Detail != "" { // if err.Detail != "" {
// info.Err(err).Msg(err.Detail) // log.Fatalf("ERR %s", err.Detail)
// } // }
// if err.Position != 0 { // if err.Position != 0 {
// ele, err := ExtractErrorLine(err.Sql, int(err.Position)) // ele, err := ExtractErrorLine(err.Sql, int(err.Position))
// if err != nil { // if err != nil {
// errlog.Fatal().Err(err).Send() // log.Fatalf("ERR %s", err)
// } // }
// prefix := fmt.Sprintf() // log.Fatalf("INF line %d, %s%s", ele.LineNum, ele.Text)
// logger.Info().Msgf("line %d, %s%s", ele.LineNum, prefix, ele.Text)
// } // }
// } // }
// os.Exit(1)
} }
logger.Info().Msg("migration done") log.Println("INF migration done")
} }
func cmdDBStatus(cmd *cobra.Command, args []string) { func cmdDBStatus(cmd *cobra.Command, args []string) {
initConfOnce() initConfOnce()
conn, err := initDB(conf, true) db, err := initDB(conf)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to connect to database") log.Fatalf("ERR failed to connect to database: %s", err)
} }
defer conn.Close(context.Background()) defer db.Close()
m, err := migrate.NewMigrator(conn, "schema_version") m, err := migrate.NewMigrator(db, "schema_version")
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to initialize migrator") log.Fatalf("ERR failed to initialize migrator: %s", err)
} }
m.Data = getMigrationVars() m.Data = getMigrationVars()
err = m.LoadMigrations(conf.MigrationsPath) err = m.LoadMigrations(conf.MigrationsPath)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to load migrations") log.Fatalf("ERR failed to load migrations: %s", err)
} }
if len(m.Migrations) == 0 { if len(m.Migrations) == 0 {
errlog.Fatal().Msg("no migrations found") log.Fatalf("ERR no migrations found")
} }
mver, err := m.GetCurrentVersion() mver, err := m.GetCurrentVersion()
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to retrieve migration") log.Fatalf("ERR failed to retrieve migration: %s", err)
} }
var status string var status string
@ -265,10 +258,8 @@ func cmdDBStatus(cmd *cobra.Command, args []string) {
status = "migration(s) pending" status = "migration(s) pending"
} }
fmt.Println("status: ", status) log.Printf("INF status: %s, version: %d of %d, host: %s, database: %s",
fmt.Printf("version: %d of %d\n", mver, len(m.Migrations)) status, mver, len(m.Migrations), conf.DB.Host, conf.DB.DBName)
fmt.Println("host: ", conf.DB.Host)
fmt.Println("database:", conf.DB.DBName)
} }
type ErrorLineExtract struct { type ErrorLineExtract struct {
@ -315,9 +306,12 @@ func getMigrationVars() map[string]interface{} {
func initConfOnce() { func initConfOnce() {
var err error var err error
if conf == nil { if conf != nil {
if conf, err = initConf(); err != nil { return
errlog.Fatal().Err(err).Msg("failed to read config") }
}
conf, err = initConf()
if err != nil {
log.Fatalf("ERR failed to read config: %s", err)
} }
} }

View File

@ -98,7 +98,7 @@ func cmdNew(cmd *cobra.Command, args []string) {
} }
}) })
logger.Info().Msgf("app '%s' initialized", name) log.Printf("INR app '%s' initialized", name)
} }
type Templ struct { type Templ struct {
@ -107,7 +107,7 @@ type Templ struct {
} }
func newTempl(data map[string]string) *Templ { func newTempl(data map[string]string) *Templ {
return &Templ{rice.MustFindBox("../tmpl"), data} return &Templ{rice.MustFindBox("./tmpl"), data}
} }
func (t *Templ) get(name string) ([]byte, error) { func (t *Templ) get(name string) ([]byte, error) {
@ -133,18 +133,18 @@ func ifNotExists(filePath string, doFn func(string) error) {
_, err := os.Stat(filePath) _, err := os.Stat(filePath)
if err == nil { if err == nil {
logger.Info().Err(err).Msgf("create skipped '%s' exists", filePath) log.Printf("ERR create skipped '%s' exists", filePath)
return return
} }
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath) log.Fatalf("ERR unable to check if '%s' exists", filePath)
} }
err = doFn(filePath) err = doFn(filePath)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msgf("unable to create '%s'", filePath) log.Fatalf("ERR unable to create '%s'", filePath)
} }
logger.Info().Msgf("created '%s'", filePath) log.Printf("INR created '%s'", filePath)
} }

View File

@ -1,7 +1,6 @@
package serv package serv
import ( import (
"bytes"
"context" "context"
"encoding/csv" "encoding/csv"
"encoding/json" "encoding/json"
@ -16,37 +15,43 @@ import (
"github.com/brianvoe/gofakeit" "github.com/brianvoe/gofakeit"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/jackc/pgx/v4" "github.com/dosco/super-graph/core"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/valyala/fasttemplate"
) )
func cmdDBSeed(cmd *cobra.Command, args []string) { func cmdDBSeed(cmd *cobra.Command, args []string) {
var err error var err error
if conf, err = initConf(); err != nil { if conf, err = initConf(); err != nil {
errlog.Fatal().Err(err).Msg("failed to read config") log.Fatalf("ERR failed to read config: %s", err)
} }
conf.Production = false conf.Production = false
db, err = initDBPool(conf) db, err = initDB(conf)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to connect to database") log.Fatalf("ERR failed to connect to database: %s", err)
} }
initCompiler() sfile := path.Join(conf.ConfigPathUsed(), conf.SeedFile)
sfile := path.Join(confPath, conf.SeedFile) b, err := ioutil.ReadFile(sfile)
b, err := ioutil.ReadFile(path.Join(confPath, conf.SeedFile))
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile) log.Fatalf("ERR failed to read seed file %s: %s", sfile, err)
}
sg, err = core.NewSuperGraph(conf, db)
if err != nil {
log.Fatalf("ERR failed to initialize Super Graph: %s", err)
}
graphQLFn := func(query string, data interface{}, opt map[string]string) map[string]interface{} {
return graphQLFunc(sg, query, data, opt)
} }
vm := goja.New() vm := goja.New()
vm.Set("graphql", graphQLFunc) vm.Set("graphql", graphQLFn)
vm.Set("import_csv", importCSV) //vm.Set("import_csv", importCSV)
console := vm.NewObject() console := vm.NewObject()
console.Set("log", logFunc) //nolint: errcheck console.Set("log", logFunc) //nolint: errcheck
@ -58,77 +63,44 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
_, err = vm.RunScript("seed.js", string(b)) _, err = vm.RunScript("seed.js", string(b))
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to execute script") log.Fatalf("ERR failed to execute script: %s", err)
} }
logger.Info().Msg("seed script done") log.Println("INF seed script done")
} }
// func runFunc(call goja.FunctionCall) { // func runFunc(call goja.FunctionCall) {
func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} { func graphQLFunc(sg *core.SuperGraph, query string, data interface{}, opt map[string]string) map[string]interface{} {
vars, err := json.Marshal(data) ct := context.Background()
if err != nil {
errlog.Fatal().Err(err).Send()
}
c := context.Background()
if v, ok := opt["user_id"]; ok && len(v) != 0 { if v, ok := opt["user_id"]; ok && len(v) != 0 {
c = context.WithValue(c, userIDKey, v) ct = context.WithValue(ct, core.UserIDKey, v)
} }
var role string // var role string
if v, ok := opt["role"]; ok && len(v) != 0 { // if v, ok := opt["role"]; ok && len(v) != 0 {
role = v // role = v
} else { // } else {
role = "user" // role = "user"
// }
var vars []byte
var err error
if vars, err = json.Marshal(data); err != nil {
log.Fatalf("ERR %s", err)
} }
stmts, err := buildRoleStmt([]byte(query), vars, role) res, err := sg.GraphQL(ct, query, vars)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("graphql query failed") log.Fatalf("ERR %s", err)
}
st := stmts[0]
buf := &bytes.Buffer{}
t := fasttemplate.New(st.sql, openVar, closeVar)
_, err = t.ExecuteFunc(buf, argMap(c, vars))
if err != nil {
errlog.Fatal().Err(err).Send()
}
finalSQL := buf.String()
tx, err := db.Begin(c)
if err != nil {
errlog.Fatal().Err(err).Send()
}
defer tx.Rollback(c) //nolint: errcheck
if conf.DB.SetUserID {
if err := setLocalUserID(c, tx); err != nil {
errlog.Fatal().Err(err).Send()
}
}
var root []byte
if err = tx.QueryRow(context.Background(), finalSQL).Scan(&root); err != nil {
errlog.Fatal().Err(err).Msg("sql query failed")
}
if err := tx.Commit(c); err != nil {
errlog.Fatal().Err(err).Send()
} }
val := make(map[string]interface{}) val := make(map[string]interface{})
err = json.Unmarshal(root, &val) if err = json.Unmarshal(res.Data, &val); err != nil {
if err != nil { log.Fatalf("ERR %s", err)
errlog.Fatal().Err(err).Send()
} }
return val return val
@ -203,36 +175,34 @@ func (c *csvSource) Err() error {
return nil return nil
} }
func importCSV(table, filename string) int64 { // func importCSV(table, filename string) int64 {
if filename[0] != '/' { // if filename[0] != '/' {
filename = path.Join(confPath, filename) // filename = path.Join(conf.ConfigPathUsed(), filename)
} // }
s, err := NewCSVSource(filename) // s, err := NewCSVSource(filename)
if err != nil { // if err != nil {
errlog.Fatal().Err(err).Send() // log.Fatalf("ERR %s", err)
} // }
var cols []string // var cols []string
colval, _ := s.Values() // colval, _ := s.Values()
for _, c := range colval { // for _, c := range colval {
cols = append(cols, c.(string)) // cols = append(cols, c.(string))
} // }
n, err := db.CopyFrom( // n, err := db.Exec(fmt.Sprintf("COPY %s FROM STDIN WITH "),
context.Background(), // cols,
pgx.Identifier{table}, // s)
cols,
s)
if err != nil { // if err != nil {
err = fmt.Errorf("%w (line no %d)", err, s.i) // err = fmt.Errorf("%w (line no %d)", err, s.i)
errlog.Fatal().Err(err).Send() // log.Fatalf("ERR %s", err)
} // }
return n // return n
} // }
//nolint: errcheck //nolint: errcheck
func logFunc(args ...interface{}) { func logFunc(args ...interface{}) {

View File

@ -0,0 +1,37 @@
package serv
import (
"github.com/dosco/super-graph/core"
"github.com/spf13/cobra"
)
var (
sg *core.SuperGraph
)
func cmdServ(cmd *cobra.Command, args []string) {
var err error
conf, err = initConf()
if err != nil {
fatalInProd(err, "failed to read config")
}
initWatcher()
db, err = initDB(conf)
if err != nil {
fatalInProd(err, "failed to connect to database")
}
// if conf != nil && db != nil {
// initResolvers()
// }
sg, err = core.NewSuperGraph(conf, db)
if err != nil {
fatalInProd(err, "failed to initialize Super Graph")
}
startHTTP()
}

View File

@ -0,0 +1,7 @@
package serv
// func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
// return nil
// }

View File

@ -0,0 +1,25 @@
package serv
import (
"context"
"net/http"
)
var healthyResponse = []byte("All's Well")
func health(w http.ResponseWriter, _ *http.Request) {
ct, cancel := context.WithTimeout(context.Background(), conf.DB.PingTimeout)
defer cancel()
if err := db.PingContext(ct); err != nil {
log.Printf("ERR error pinging database: %s", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if _, err := w.Write(healthyResponse); err != nil {
log.Printf("ERR error writing healthy response: %s", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

125
cmd/internal/serv/http.go Normal file
View File

@ -0,0 +1,125 @@
package serv
import (
"encoding/json"
"errors"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/dosco/super-graph/cmd/internal/serv/internal/auth"
"github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/core"
"github.com/rs/cors"
"go.uber.org/zap"
)
const (
maxReadBytes = 100000 // 100Kb
introspectionQuery = "IntrospectionQuery"
)
var (
errUnauthorized = errors.New("not authorized")
)
type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Vars json.RawMessage `json:"variables"`
}
type errorResp struct {
Error error `json:"error"`
}
func apiV1Handler() http.Handler {
h, err := auth.WithAuth(http.HandlerFunc(apiV1), &conf.Auth)
if err != nil {
log.Fatalf("ERR %s", err)
}
if len(conf.AllowedOrigins) != 0 {
c := cors.New(cors.Options{
AllowedOrigins: conf.AllowedOrigins,
AllowCredentials: true,
Debug: conf.DebugCORS,
})
h = c.Handler(h)
}
return h
}
func apiV1(w http.ResponseWriter, r *http.Request) {
ct := r.Context()
//nolint: errcheck
if conf.AuthFailBlock && !auth.IsAuth(ct) {
renderErr(w, errUnauthorized, nil)
return
}
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
if err != nil {
renderErr(w, err, nil)
return
}
defer r.Body.Close()
req := gqlReq{}
err = json.Unmarshal(b, &req)
if err != nil {
renderErr(w, err, nil)
return
}
if strings.EqualFold(req.OpName, introspectionQuery) {
introspect(w)
return
}
res, err := sg.GraphQL(ct, req.Query, req.Vars)
if conf.LogLevel() >= config.LogLevelDebug {
log.Printf("DBG query:\n%s\nsql:\n%s", req.Query, res.SQL())
}
if err != nil {
renderErr(w, err, res)
return
}
json.NewEncoder(w).Encode(res)
if conf.LogLevel() >= config.LogLevelInfo {
zlog.Info("success",
zap.String("op", res.Operation()),
zap.String("name", res.QueryName()),
zap.String("role", res.Role()),
)
}
}
//nolint: errcheck
func renderErr(w http.ResponseWriter, err error, res *core.Result) {
if err == errUnauthorized {
w.WriteHeader(http.StatusUnauthorized)
}
json.NewEncoder(w).Encode(&errorResp{err})
if conf.LogLevel() >= config.LogLevelError {
if res != nil {
zlog.Error(err.Error(),
zap.String("op", res.Operation()),
zap.String("name", res.QueryName()),
zap.String("role", res.Role()),
)
} else {
zlog.Error(err.Error())
}
}
}

88
cmd/internal/serv/init.go Normal file
View File

@ -0,0 +1,88 @@
package serv
import (
"database/sql"
"fmt"
"time"
"github.com/dosco/super-graph/config"
_ "github.com/jackc/pgx/v4/stdlib"
)
func initConf() (*config.Config, error) {
return config.NewConfigWithLogger(confPath, log)
}
func initDB(c *config.Config) (*sql.DB, error) {
var db *sql.DB
var err error
cs := fmt.Sprintf("postgres://%s:%s@%s:%d/%s",
c.DB.User, c.DB.Password,
c.DB.Host, c.DB.Port, c.DB.DBName)
for i := 1; i < 10; i++ {
db, err = sql.Open("pgx", cs)
if err == nil {
break
}
time.Sleep(time.Duration(i*100) * time.Millisecond)
}
if err != nil {
return nil, err
}
return db, nil
// config, _ := pgxpool.ParseConfig("")
// config.ConnConfig.Host = c.DB.Host
// config.ConnConfig.Port = c.DB.Port
// config.ConnConfig.Database = c.DB.DBName
// config.ConnConfig.User = c.DB.User
// config.ConnConfig.Password = c.DB.Password
// config.ConnConfig.RuntimeParams = map[string]string{
// "application_name": c.AppName,
// "search_path": c.DB.Schema,
// }
// switch c.LogLevel {
// case "debug":
// config.ConnConfig.LogLevel = pgx.LogLevelDebug
// case "info":
// config.ConnConfig.LogLevel = pgx.LogLevelInfo
// case "warn":
// config.ConnConfig.LogLevel = pgx.LogLevelWarn
// case "error":
// config.ConnConfig.LogLevel = pgx.LogLevelError
// default:
// config.ConnConfig.LogLevel = pgx.LogLevelNone
// }
// config.ConnConfig.Logger = NewSQLLogger(logger)
// // if c.DB.MaxRetries != 0 {
// // opt.MaxRetries = c.DB.MaxRetries
// // }
// if c.DB.PoolSize != 0 {
// config.MaxConns = conf.DB.PoolSize
// }
// var db *pgxpool.Pool
// var err error
// for i := 1; i < 10; i++ {
// db, err = pgxpool.ConnectConfig(context.Background(), config)
// if err == nil {
// break
// }
// time.Sleep(time.Duration(i*100) * time.Millisecond)
// }
// if err != nil {
// return nil, err
// }
// return db, nil
}

View File

@ -0,0 +1,95 @@
package auth
import (
"context"
"fmt"
"net/http"
"github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/core"
)
func SimpleHandler(ac *config.Auth, next http.Handler) (http.HandlerFunc, error) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userIDProvider := r.Header.Get("X-User-ID-Provider")
if len(userIDProvider) != 0 {
ctx = context.WithValue(ctx, core.UserIDProviderKey, userIDProvider)
}
userID := r.Header.Get("X-User-ID")
if len(userID) != 0 {
ctx = context.WithValue(ctx, core.UserIDKey, userID)
}
userRole := r.Header.Get("X-User-Role")
if len(userRole) != 0 {
ctx = context.WithValue(ctx, core.UserRoleKey, userRole)
}
next.ServeHTTP(w, r.WithContext(ctx))
}, nil
}
func HeaderHandler(ac *config.Auth, next http.Handler) (http.HandlerFunc, error) {
hdr := ac.Header
if len(hdr.Name) == 0 {
return nil, fmt.Errorf("auth '%s': no header.name defined", ac.Name)
}
if !hdr.Exists && len(hdr.Value) == 0 {
return nil, fmt.Errorf("auth '%s': no header.value defined", ac.Name)
}
return func(w http.ResponseWriter, r *http.Request) {
var fo1 bool
value := r.Header.Get(hdr.Name)
switch {
case hdr.Exists:
fo1 = (len(value) == 0)
default:
fo1 = (value != hdr.Value)
}
if fo1 {
http.Error(w, "401 unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
}, nil
}
func WithAuth(next http.Handler, ac *config.Auth) (http.Handler, error) {
var err error
if ac.CredsInHeader {
next, err = SimpleHandler(ac, next)
}
if err != nil {
return nil, err
}
switch ac.Type {
case "rails":
return RailsHandler(ac, next)
case "jwt":
return JwtHandler(ac, next)
case "header":
return HeaderHandler(ac, next)
}
return next, nil
}
func IsAuth(ct context.Context) bool {
return ct.Value(core.UserIDKey) != nil
}

View File

@ -1,4 +1,4 @@
package serv package auth
import ( import (
"context" "context"
@ -7,6 +7,8 @@ import (
"strings" "strings"
jwt "github.com/dgrijalva/jwt-go" jwt "github.com/dgrijalva/jwt-go"
"github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/core"
) )
const ( const (
@ -14,18 +16,18 @@ const (
jwtAuth0 int = iota + 1 jwtAuth0 int = iota + 1
) )
func jwtHandler(authc configAuth, next http.Handler) http.HandlerFunc { func JwtHandler(ac *config.Auth, next http.Handler) (http.HandlerFunc, error) {
var key interface{} var key interface{}
var jwtProvider int var jwtProvider int
cookie := authc.Cookie cookie := ac.Cookie
if authc.JWT.Provider == "auth0" { if ac.JWT.Provider == "auth0" {
jwtProvider = jwtAuth0 jwtProvider = jwtAuth0
} }
secret := authc.JWT.Secret secret := ac.JWT.Secret
publicKeyFile := authc.JWT.PubKeyFile publicKeyFile := ac.JWT.PubKeyFile
switch { switch {
case len(secret) != 0: case len(secret) != 0:
@ -34,10 +36,10 @@ func jwtHandler(authc configAuth, next http.Handler) http.HandlerFunc {
case len(publicKeyFile) != 0: case len(publicKeyFile) != 0:
kd, err := ioutil.ReadFile(publicKeyFile) kd, err := ioutil.ReadFile(publicKeyFile)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() return nil, err
} }
switch authc.JWT.PubKeyType { switch ac.JWT.PubKeyType {
case "ecdsa": case "ecdsa":
key, err = jwt.ParseECPublicKeyFromPEM(kd) key, err = jwt.ParseECPublicKeyFromPEM(kd)
@ -50,7 +52,7 @@ func jwtHandler(authc configAuth, next http.Handler) http.HandlerFunc {
} }
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() return nil, err
} }
} }
@ -88,11 +90,11 @@ func jwtHandler(authc configAuth, next http.Handler) http.HandlerFunc {
if jwtProvider == jwtAuth0 { if jwtProvider == jwtAuth0 {
sub := strings.Split(claims.Subject, "|") sub := strings.Split(claims.Subject, "|")
if len(sub) != 2 { if len(sub) != 2 {
ctx = context.WithValue(ctx, userIDProviderKey, sub[0]) ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0])
ctx = context.WithValue(ctx, userIDKey, sub[1]) ctx = context.WithValue(ctx, core.UserIDKey, sub[1])
} }
} else { } else {
ctx = context.WithValue(ctx, userIDKey, claims.Subject) ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
@ -100,5 +102,5 @@ func jwtHandler(authc configAuth, next http.Handler) http.HandlerFunc {
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }, nil
} }

View File

@ -1,4 +1,4 @@
package serv package auth
import ( import (
"context" "context"
@ -9,50 +9,54 @@ import (
"strings" "strings"
"github.com/bradfitz/gomemcache/memcache" "github.com/bradfitz/gomemcache/memcache"
"github.com/dosco/super-graph/rails" "github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/core"
"github.com/dosco/super-graph/cmd/internal/serv/internal/rails"
"github.com/garyburd/redigo/redis" "github.com/garyburd/redigo/redis"
) )
func railsHandler(authc configAuth, next http.Handler) http.HandlerFunc { func RailsHandler(ac *config.Auth, next http.Handler) (http.HandlerFunc, error) {
ru := authc.Rails.URL ru := ac.Rails.URL
if strings.HasPrefix(ru, "memcache:") { if strings.HasPrefix(ru, "memcache:") {
return railsMemcacheHandler(authc, next) return RailsMemcacheHandler(ac, next)
} }
if strings.HasPrefix(ru, "redis:") { if strings.HasPrefix(ru, "redis:") {
return railsRedisHandler(authc, next) return RailsRedisHandler(ac, next)
} }
return railsCookieHandler(authc, next) return RailsCookieHandler(ac, next)
} }
func railsRedisHandler(authc configAuth, next http.Handler) http.HandlerFunc { func RailsRedisHandler(ac *config.Auth, next http.Handler) (http.HandlerFunc, error) {
cookie := authc.Cookie cookie := ac.Cookie
if len(cookie) == 0 { if len(cookie) == 0 {
errlog.Fatal().Msg("no auth.cookie defined") return nil, fmt.Errorf("no auth.cookie defined")
} }
if len(authc.Rails.URL) == 0 { if len(ac.Rails.URL) == 0 {
errlog.Fatal().Msg("no auth.rails.url defined") return nil, fmt.Errorf("no auth.rails.url defined")
} }
rp := &redis.Pool{ rp := &redis.Pool{
MaxIdle: authc.Rails.MaxIdle, MaxIdle: ac.Rails.MaxIdle,
MaxActive: authc.Rails.MaxActive, MaxActive: ac.Rails.MaxActive,
Dial: func() (redis.Conn, error) { Dial: func() (redis.Conn, error) {
c, err := redis.DialURL(authc.Rails.URL) c, err := redis.DialURL(ac.Rails.URL)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() return nil, err
} }
pwd := authc.Rails.Password pwd := ac.Rails.Password
if len(pwd) != 0 { if len(pwd) != 0 {
if _, err := c.Do("AUTH", pwd); err != nil { if _, err := c.Do("AUTH", pwd); err != nil {
errlog.Fatal().Err(err).Send() return nil, err
} }
} }
return c, err
return c, nil
}, },
} }
@ -76,24 +80,25 @@ func railsRedisHandler(authc configAuth, next http.Handler) http.HandlerFunc {
return return
} }
ctx := context.WithValue(r.Context(), userIDKey, userID) ctx := context.WithValue(r.Context(), core.UserIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} }, nil
} }
func railsMemcacheHandler(authc configAuth, next http.Handler) http.HandlerFunc { func RailsMemcacheHandler(ac *config.Auth, next http.Handler) (http.HandlerFunc, error) {
cookie := authc.Cookie cookie := ac.Cookie
if len(cookie) == 0 { if len(cookie) == 0 {
errlog.Fatal().Msg("no auth.cookie defined") return nil, fmt.Errorf("no auth.cookie defined")
} }
if len(authc.Rails.URL) == 0 { if len(ac.Rails.URL) == 0 {
errlog.Fatal().Msg("no auth.rails.url defined") return nil, fmt.Errorf("no auth.rails.url defined")
} }
rURL, err := url.Parse(authc.Rails.URL) rURL, err := url.Parse(ac.Rails.URL)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() return nil, err
} }
mc := memcache.New(rURL.Host) mc := memcache.New(rURL.Host)
@ -118,49 +123,49 @@ func railsMemcacheHandler(authc configAuth, next http.Handler) http.HandlerFunc
return return
} }
ctx := context.WithValue(r.Context(), userIDKey, userID) ctx := context.WithValue(r.Context(), core.UserIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} }, nil
} }
func railsCookieHandler(authc configAuth, next http.Handler) http.HandlerFunc { func RailsCookieHandler(ac *config.Auth, next http.Handler) (http.HandlerFunc, error) {
cookie := authc.Cookie cookie := ac.Cookie
if len(cookie) == 0 { if len(cookie) == 0 {
errlog.Fatal().Msg("no auth.cookie defined") return nil, fmt.Errorf("no auth.cookie defined")
} }
ra, err := railsAuth(authc) ra, err := railsAuth(ac)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() return nil, err
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil || len(ck.Value) == 0 { if err != nil || len(ck.Value) == 0 {
logger.Warn().Err(err).Msg("rails cookie missing") // logger.Warn().Err(err).Msg("rails cookie missing")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
userID, err := ra.ParseCookie(ck.Value) userID, err := ra.ParseCookie(ck.Value)
if err != nil { if err != nil {
logger.Warn().Err(err).Msg("failed to parse rails cookie") // logger.Warn().Err(err).Msg("failed to parse rails cookie")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
ctx := context.WithValue(r.Context(), userIDKey, userID) ctx := context.WithValue(r.Context(), core.UserIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} }, nil
} }
func railsAuth(authc configAuth) (*rails.Auth, error) { func railsAuth(ac *config.Auth) (*rails.Auth, error) {
secret := authc.Rails.SecretKeyBase secret := ac.Rails.SecretKeyBase
if len(secret) == 0 { if len(secret) == 0 {
return nil, errors.New("no auth.rails.secret_key_base defined") return nil, errors.New("no auth.rails.secret_key_base defined")
} }
version := authc.Rails.Version version := ac.Rails.Version
if len(version) == 0 { if len(version) == 0 {
return nil, errors.New("no auth.rails.version defined") return nil, errors.New("no auth.rails.version defined")
} }
@ -170,16 +175,16 @@ func railsAuth(authc configAuth) (*rails.Auth, error) {
return nil, err return nil, err
} }
if len(authc.Rails.Salt) != 0 { if len(ac.Rails.Salt) != 0 {
ra.Salt = authc.Rails.Salt ra.Salt = ac.Rails.Salt
} }
if len(authc.Rails.SignSalt) != 0 { if len(ac.Rails.SignSalt) != 0 {
ra.SignSalt = authc.Rails.SignSalt ra.SignSalt = ac.Rails.SignSalt
} }
if len(authc.Rails.AuthSalt) != 0 { if len(ac.Rails.AuthSalt) != 0 {
ra.AuthSalt = authc.Rails.AuthSalt ra.AuthSalt = ac.Rails.AuthSalt
} }
return ra, nil return ra, nil

View File

@ -3,6 +3,7 @@ package migrate
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -12,7 +13,6 @@ import (
"strings" "strings"
"text/template" "text/template"
"github.com/jackc/pgx/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -62,7 +62,7 @@ type MigratorOptions struct {
} }
type Migrator struct { type Migrator struct {
conn *pgx.Conn db *sql.DB
versionTable string versionTable string
options *MigratorOptions options *MigratorOptions
Migrations []*Migration Migrations []*Migration
@ -70,12 +70,12 @@ type Migrator struct {
Data map[string]interface{} // Data available to use in migrations Data map[string]interface{} // Data available to use in migrations
} }
func NewMigrator(conn *pgx.Conn, versionTable string) (m *Migrator, err error) { func NewMigrator(db *sql.DB, versionTable string) (m *Migrator, err error) {
return NewMigratorEx(conn, versionTable, &MigratorOptions{MigratorFS: defaultMigratorFS{}}) return NewMigratorEx(db, versionTable, &MigratorOptions{MigratorFS: defaultMigratorFS{}})
} }
func NewMigratorEx(conn *pgx.Conn, versionTable string, opts *MigratorOptions) (m *Migrator, err error) { func NewMigratorEx(db *sql.DB, versionTable string, opts *MigratorOptions) (m *Migrator, err error) {
m = &Migrator{conn: conn, versionTable: versionTable, options: opts} m = &Migrator{db: db, versionTable: versionTable, options: opts}
err = m.ensureSchemaVersionTableExists() err = m.ensureSchemaVersionTableExists()
m.Migrations = make([]*Migration, 0) m.Migrations = make([]*Migration, 0)
m.Data = make(map[string]interface{}) m.Data = make(map[string]interface{})
@ -254,14 +254,13 @@ func (m *Migrator) Migrate() error {
// MigrateTo migrates to targetVersion // MigrateTo migrates to targetVersion
func (m *Migrator) MigrateTo(targetVersion int32) (err error) { func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
ctx := context.Background()
// Lock to ensure multiple migrations cannot occur simultaneously // Lock to ensure multiple migrations cannot occur simultaneously
lockNum := int64(9628173550095224) // arbitrary random number lockNum := int64(9628173550095224) // arbitrary random number
if _, lockErr := m.conn.Exec(ctx, "select pg_try_advisory_lock($1)", lockNum); lockErr != nil { if _, lockErr := m.db.Exec("select pg_try_advisory_lock($1)", lockNum); lockErr != nil {
return lockErr return lockErr
} }
defer func() { defer func() {
_, unlockErr := m.conn.Exec(ctx, "select pg_advisory_unlock($1)", lockNum) _, unlockErr := m.db.Exec("select pg_advisory_unlock($1)", lockNum)
if err == nil && unlockErr != nil { if err == nil && unlockErr != nil {
err = unlockErr err = unlockErr
} }
@ -310,11 +309,11 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
ctx := context.Background() ctx := context.Background()
tx, err := m.conn.Begin(ctx) tx, err := m.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback(ctx) //nolint: errcheck defer tx.Rollback() //nolint: errcheck
// Fire on start callback // Fire on start callback
if m.OnStart != nil { if m.OnStart != nil {
@ -322,7 +321,7 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
} }
// Execute the migration // Execute the migration
_, err = tx.Exec(ctx, sql) _, err = tx.Exec(sql)
if err != nil { if err != nil {
// if err, ok := err.(pgx.PgError); ok { // if err, ok := err.(pgx.PgError); ok {
// return MigrationPgError{Sql: sql, PgError: err} // return MigrationPgError{Sql: sql, PgError: err}
@ -336,12 +335,12 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
// } // }
// Add one to the version // Add one to the version
_, err = tx.Exec(ctx, "update "+m.versionTable+" set version=$1", sequence) _, err = tx.Exec("update "+m.versionTable+" set version=$1", sequence)
if err != nil { if err != nil {
return err return err
} }
err = tx.Commit(ctx) err = tx.Commit()
if err != nil { if err != nil {
return err return err
} }
@ -353,14 +352,13 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
} }
func (m *Migrator) GetCurrentVersion() (v int32, err error) { func (m *Migrator) GetCurrentVersion() (v int32, err error) {
err = m.conn.QueryRow(context.Background(), err = m.db.QueryRow("select version from " + m.versionTable).Scan(&v)
"select version from "+m.versionTable).Scan(&v)
return v, err return v, err
} }
func (m *Migrator) ensureSchemaVersionTableExists() (err error) { func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
_, err = m.conn.Exec(context.Background(), fmt.Sprintf(` _, err = m.db.Exec(fmt.Sprintf(`
create table if not exists %s(version int4 not null); create table if not exists %s(version int4 not null);
insert into %s(version) insert into %s(version)

View File

@ -116,7 +116,7 @@ func Do(log func(string, ...interface{}), additional ...dir) error {
continue continue
} }
logger.Info().Msgf("Reloading, file changed detected '%s'", event) log("INF Reloading, file changed detected: %s", event)
var trigger bool var trigger bool
switch runtime.GOOS { switch runtime.GOOS {
@ -172,7 +172,7 @@ func Do(log func(string, ...interface{}), additional ...dir) error {
func ReExec() { func ReExec() {
err := syscall.Exec(binSelf, append([]string{binSelf}, os.Args[1:]...), os.Environ()) err := syscall.Exec(binSelf, append([]string{binSelf}, os.Args[1:]...), os.Environ())
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("cannot restart") log.Fatalf("ERR cannot restart: %s", err)
} }
} }

File diff suppressed because one or more lines are too long

View File

@ -11,49 +11,12 @@ import (
rice "github.com/GeertJohan/go.rice" rice "github.com/GeertJohan/go.rice"
"github.com/NYTimes/gziphandler" "github.com/NYTimes/gziphandler"
"github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/cmd/internal/serv/internal/auth"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/config"
) )
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { func initWatcher() {
di, err := psql.GetDBInfo(db) cpath := conf.ConfigPathUsed()
if err != nil {
return nil, nil, err
}
if err = addTables(c, di); err != nil {
return nil, nil, err
}
if err = addForeignKeys(c, di); err != nil {
return nil, nil, err
}
schema, err = psql.NewDBSchema(di, c.getAliasMap())
if err != nil {
return nil, nil, err
}
qc, err := qcode.NewCompiler(qcode.Config{
Blocklist: c.DB.Blocklist,
})
if err != nil {
return nil, nil, err
}
if err := addRoles(c, qc); err != nil {
return nil, nil, err
}
pc := psql.NewCompiler(psql.Config{
Schema: schema,
Vars: c.DB.Vars,
})
return qc, pc, nil
}
func initWatcher(cpath string) {
if conf != nil && !conf.WatchAndReload { if conf != nil && !conf.WatchAndReload {
return return
} }
@ -66,9 +29,9 @@ func initWatcher(cpath string) {
} }
go func() { go func() {
err := Do(logger.Printf, d) err := Do(log.Printf, d)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() log.Fatalf("ERR %s", err)
} }
}() }()
} }
@ -103,7 +66,7 @@ func startHTTP() {
routes, err := routeHandler() routes, err := routeHandler()
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() log.Fatalf("ERR %s", err)
} }
srv := &http.Server{ srv := &http.Server{
@ -121,7 +84,7 @@ func startHTTP() {
<-sigint <-sigint
if err := srv.Shutdown(context.Background()); err != nil { if err := srv.Shutdown(context.Background()); err != nil {
errlog.Error().Err(err).Msg("shutdown signal received") log.Fatalln("INF shutdown signal received")
} }
close(idleConnsClosed) close(idleConnsClosed)
}() }()
@ -130,16 +93,13 @@ func startHTTP() {
db.Close() db.Close()
}) })
logger.Info(). log.Printf("INF version: %s, git-branch: %s, host-port: %s, app-name: %s, env: %s\n",
Str("version", version). version, gitBranch, hostPort, appName, env)
Str("git_branch", gitBranch).
Str("host_post", hostPort). log.Printf("INF %s started\n", serverName)
Str("app_name", appName).
Str("env", env).
Msgf("%s listening", serverName)
if err := srv.ListenAndServe(); err != http.ErrServerClosed { if err := srv.ListenAndServe(); err != http.ErrServerClosed {
errlog.Error().Err(err).Msg("server closed") log.Fatalln("INF server closed")
} }
<-idleConnsClosed <-idleConnsClosed
@ -162,7 +122,7 @@ func routeHandler() (http.Handler, error) {
} }
if conf.WebUI { if conf.WebUI {
routes["/"] = http.FileServer(rice.MustFindBox("../web/build").HTTPBox()) routes["/"] = http.FileServer(rice.MustFindBox("./web/build").HTTPBox())
} }
if conf.HTTPGZip { if conf.HTTPGZip {
@ -190,29 +150,31 @@ func setActionRoutes(routes map[string]http.Handler) error {
for _, a := range conf.Actions { for _, a := range conf.Actions {
var fn http.Handler var fn http.Handler
fn, err = newAction(a) fn, err = newAction(&a)
if err != nil { if err != nil {
break break
} }
p := fmt.Sprintf("/api/v1/actions/%s", strings.ToLower(a.Name)) p := fmt.Sprintf("/api/v1/actions/%s", strings.ToLower(a.Name))
if authc, ok := findAuth(a.AuthName); ok { if ac := findAuth(a.AuthName); ac != nil {
routes[p] = withAuth(fn, authc) routes[p], err = auth.WithAuth(fn, ac)
} else { } else {
routes[p] = fn routes[p] = fn
} }
if err != nil {
return err
}
} }
return nil return nil
} }
func findAuth(name string) (configAuth, bool) { func findAuth(name string) *config.Auth {
var authc configAuth
for _, a := range conf.Auths { for _, a := range conf.Auths {
if strings.EqualFold(a.Name, name) { if strings.EqualFold(a.Name, name) {
return a, true return &a
} }
} }
return authc, false return nil
} }

View File

@ -0,0 +1,43 @@
package serv
// import (
// "context"
// "github.com/jackc/pgx/v4"
// "github.com/rs/zerolog"
// )
// type Logger struct {
// logger zerolog.Logger
// }
// // NewLogger accepts a zerolog.Logger as input and returns a new custom pgx
// // logging fascade as output.
// func NewSQLLogger(logger zerolog.Logger) *Logger {
// return &Logger{
// logger: // logger.With().Logger(),
// }
// }
// func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
// var zlevel zerolog.Level
// switch level {
// case pgx.LogLevelNone:
// zlevel = zerolog.NoLevel
// case pgx.LogLevelError:
// zlevel = zerolog.ErrorLevel
// case pgx.LogLevelWarn:
// zlevel = zerolog.WarnLevel
// case pgx.LogLevelDebug, pgx.LogLevelInfo:
// zlevel = zerolog.DebugLevel
// default:
// zlevel = zerolog.DebugLevel
// }
// if sql, ok := data["sql"]; ok {
// delete(data, "sql")
// pl.// logger.WithLevel(zlevel).Fields(data).Msg(sql.(string))
// } else {
// pl.// logger.WithLevel(zlevel).Fields(data).Msg(msg)
// }
// }

View File

@ -2,7 +2,7 @@ app_name: "{% app_name %} Development"
host_port: 0.0.0.0:8080 host_port: 0.0.0.0:8080
web_ui: true web_ui: true
# debug, info, warn, error, fatal, panic # debug, error, warn, info
log_level: "info" log_level: "info"
# enable or disable http compression (uses gzip) # enable or disable http compression (uses gzip)
@ -30,7 +30,8 @@ reload_on_config_change: true
# seed_file: seed.js # seed_file: seed.js
# Path pointing to where the migrations can be found # Path pointing to where the migrations can be found
migrations_path: ./config/migrations # this must be a relative path under the config path
migrations_path: ./migrations
# Secret key for general encryption operations like # Secret key for general encryption operations like
# encrypting the cursor data # encrypting the cursor data

View File

@ -6,7 +6,7 @@ app_name: "{% app_name %} Production"
host_port: 0.0.0.0:8080 host_port: 0.0.0.0:8080
web_ui: false web_ui: false
# debug, info, warn, error, fatal, panic, disable # debug, error, warn, info
log_level: "warn" log_level: "warn"
# enable or disable http compression (uses gzip) # enable or disable http compression (uses gzip)

View File

@ -5,6 +5,7 @@ import (
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"io" "io"
"os"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -23,14 +24,6 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v return v
} }
// nolint: errcheck
func stmtHash(name string, role string) string {
h := sha1.New()
io.WriteString(h, strings.ToLower(name))
io.WriteString(h, role)
return hex.EncodeToString(h.Sum(nil))
}
// nolint: errcheck // nolint: errcheck
func gqlHash(b string, vars []byte, role string) string { func gqlHash(b string, vars []byte, role string) string {
b = strings.TrimSpace(b) b = strings.TrimSpace(b)
@ -117,25 +110,19 @@ func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
} }
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {
continue
}
return &stmts[i]
}
return nil
}
func fatalInProd(err error, msg string) { func fatalInProd(err error, msg string) {
var wg sync.WaitGroup var wg sync.WaitGroup
if !isDev() { if !isDev() {
errlog.Fatal().Err(err).Msg(msg) log.Fatalf("ERR %s: %s", msg, err)
} }
errlog.Error().Err(err).Msg(msg) log.Printf("ERR %s: %s", msg, err)
wg.Add(1) wg.Add(1)
wg.Wait() wg.Wait()
} }
func isDev() bool {
return strings.HasPrefix(os.Getenv("GO_ENV"), "dev")
}

View File

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

View File

Before

Width:  |  Height:  |  Size: 2.6 KiB

After

Width:  |  Height:  |  Size: 2.6 KiB

45
cmd/main.go Normal file
View File

@ -0,0 +1,45 @@
// Main package for the Super Graph service and command line tooling
/*
Super Graph
For documentation, visit https://supergraph.dev
Commit SHA-1 :
Commit timestamp :
Branch :
Go version : go1.14
Licensed under the Apache Public License 2.0
Copyright 2020, Vikram Rangnekar.
Usage:
super-graph [command]
Available Commands:
conf:dump Dump config to file
db:create Create database
db:drop Drop database
db:migrate Migrate the database
db:new Generate a new migration
db:reset Reset database
db:seed Run the seed script to seed the database
db:setup Setup database
db:status Print current migration status
help Help about any command
new Create a new application
serv Run the super-graph service
version Super Graph binary version information
Flags:
-h, --help help for super-graph
--path string path to config files (default "./config")
Use "super-graph [command] --help" for more information about a command.
*/
package main
import "github.com/dosco/super-graph/cmd/internal/serv"
func main() {
serv.Cmd()
}

View File

@ -1,755 +0,0 @@
# http://localhost:8080/
variables {
"data": [
{
"name": "Protect Ya Neck",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Enter the Wu-Tang",
"created_at": "now",
"updated_at": "now"
}
]
}
mutation {
products(insert: $data) {
id
name
}
}
variables {
"update": {
"name": "Wu-Tang",
"description": "No description needed"
},
"product_id": 1
}
mutation {
products(id: $product_id, update: $update) {
id
name
description
}
}
query {
users {
id
email
picture: avatar
products(limit: 2, where: {price: {gt: 10}}) {
id
name
description
}
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
mutation {
products(id: 199, delete: true) {
id
name
}
}
query {
products {
id
name
user {
email
}
}
}
variables {
"data": {
"product_id": 5
}
}
mutation {
products(id: $product_id, delete: true) {
id
name
}
}
query {
products {
id
name
price
users {
email
}
}
}
variables {
"data": {
"email": "gfk@myspace.com",
"full_name": "Ghostface Killah",
"created_at": "now",
"updated_at": "now"
}
}
mutation {
user(insert: $data) {
id
}
}
variables {
"update": {
"name": "Helloo",
"description": "World \u003c\u003e"
},
"user": 123
}
mutation {
products(id: 5, update: $update) {
id
name
description
}
}
variables {
"data": {
"name": "WOOO",
"price": 50.5
}
}
mutation {
products(insert: $data) {
id
name
}
}
query getProducts {
products {
id
name
price
description
}
}
query {
deals {
id
name
price
}
}
variables {
"beer": "smoke"
}
query beerSearch {
products(search: $beer) {
id
name
search_rank
search_headline_description
}
}
query {
user {
id
full_name
}
}
variables {
"data": {
"email": "goo1@rug.com",
"full_name": "The Dude",
"created_at": "now",
"updated_at": "now",
"product": {
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now"
}
}
}
mutation {
user(insert: $data) {
id
full_name
email
product {
id
name
price
}
}
}
variables {
"data": {
"email": "goo12@rug.com",
"full_name": "The Dude",
"created_at": "now",
"updated_at": "now",
"product": [
{
"name": "Banana 1",
"price": 1.1,
"created_at": "now",
"updated_at": "now"
},
{
"name": "Banana 2",
"price": 2.2,
"created_at": "now",
"updated_at": "now"
}
]
}
}
mutation {
user(insert: $data) {
id
full_name
email
products {
id
name
price
}
}
}
variables {
"data": {
"name": "Banana 3",
"price": 1.1,
"created_at": "now",
"updated_at": "now",
"user": {
"email": "a2@a.com",
"full_name": "The Dude",
"created_at": "now",
"updated_at": "now"
}
}
}
mutation {
products(insert: $data) {
id
name
price
user {
id
full_name
email
}
}
}
variables {
"update": {
"name": "my_name",
"description": "my_desc"
}
}
mutation {
product(id: 15, update: $update, where: {id: {eq: 1}}) {
id
name
}
}
variables {
"update": {
"name": "my_name",
"description": "my_desc"
}
}
mutation {
product(update: $update, where: {id: {eq: 1}}) {
id
name
}
}
variables {
"update": {
"name": "my_name 2",
"description": "my_desc 2"
}
}
mutation {
product(update: $update, where: {id: {eq: 1}}) {
id
name
description
}
}
variables {
"data": {
"sale_type": "tuutuu",
"quantity": 5,
"due_date": "now",
"customer": {
"email": "thedude1@rug.com",
"full_name": "The Dude"
},
"product": {
"name": "Apple",
"price": 1.25
}
}
}
mutation {
purchase(update: $data, id: 5) {
sale_type
quantity
due_date
customer {
id
full_name
email
}
product {
id
name
price
}
}
}
variables {
"data": {
"email": "thedude@rug.com",
"full_name": "The Dude",
"created_at": "now",
"updated_at": "now",
"product": {
"where": {
"id": 2
},
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now"
}
}
}
mutation {
user(update: $data, where: {id: {eq: 8}}) {
id
full_name
email
product {
id
name
price
}
}
}
variables {
"data": {
"email": "thedude@rug.com",
"full_name": "The Dude",
"created_at": "now",
"updated_at": "now",
"product": {
"where": {
"id": 2
},
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now"
}
}
}
query {
user(where: {id: {eq: 8}}) {
id
product {
id
name
price
}
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now",
"user": {
"email": "thedude@rug.com"
}
}
}
query {
user {
email
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now",
"user": {
"email": "booboo@demo.com"
}
}
}
mutation {
product(update: $data, id: 6) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now",
"user": {
"email": "booboo@demo.com"
}
}
}
query {
product(id: 6) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": {
"email": "thedude123@rug.com",
"full_name": "The Dude",
"created_at": "now",
"updated_at": "now",
"product": {
"connect": {
"id": 7
},
"disconnect": {
"id": 8
}
}
}
}
mutation {
user(update: $data, id: 6) {
id
full_name
email
product {
id
name
price
}
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now",
"user": {
"connect": {
"id": 5,
"email": "test@test.com"
}
}
}
}
mutation {
product(update: $data, id: 9) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": {
"email": "thed44ude@rug.com",
"full_name": "The Dude",
"created_at": "now",
"updated_at": "now",
"product": {
"connect": {
"id": 5
}
}
}
}
mutation {
user(insert: $data) {
id
full_name
email
product {
id
name
price
}
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now",
"user": {
"connect": {
"id": 5
}
}
}
}
mutation {
product(insert: $data) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": [
{
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now",
"user": {
"connect": {
"id": 6
}
}
},
{
"name": "Coconut",
"price": 2.25,
"created_at": "now",
"updated_at": "now",
"user": {
"connect": {
"id": 3
}
}
}
]
}
mutation {
products(insert: $data) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": [
{
"name": "Apple",
"price": 1.25,
"created_at": "now",
"updated_at": "now"
},
{
"name": "Coconut",
"price": 2.25,
"created_at": "now",
"updated_at": "now"
}
]
}
mutation {
products(insert: $data) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"user": {
"connect": {
"id": 5,
"email": "test@test.com"
}
}
}
}
mutation {
product(update: $data, id: 9) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"user": {
"connect": {
"id": 5
}
}
}
}
mutation {
product(update: $data, id: 9) {
id
name
user {
id
full_name
email
}
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"user": {
"disconnect": {
"id": 5
}
}
}
}
mutation {
product(update: $data, id: 9) {
id
name
user_id
}
}
variables {
"data": {
"name": "Apple",
"price": 1.25,
"user": {
"disconnect": {
"id": 5
}
}
}
}
mutation {
product(update: $data, id: 2) {
id
name
user_id
}
}

505
config/config.go Normal file
View File

@ -0,0 +1,505 @@
// Package config provides the config values needed for Super Graph
// For detailed documentation visit https://supergraph.dev
package config
import (
"fmt"
"log"
"os"
"path"
"strings"
"time"
"github.com/gobuffalo/flect"
"github.com/spf13/viper"
)
const (
LogLevelNone int = iota
LogLevelInfo
LogLevelWarn
LogLevelError
LogLevelDebug
)
// Config struct holds the Super Graph config values
type Config struct {
Core `mapstructure:",squash"`
Serv `mapstructure:",squash"`
vi *viper.Viper
log *log.Logger
logLevel int
roles map[string]*Role
abacEnabled bool
valid bool
}
// Core struct contains core specific config value
type Core struct {
Env string
Production bool
LogLevel string `mapstructure:"log_level"`
SecretKey string `mapstructure:"secret_key"`
SetUserID bool `mapstructure:"set_user_id"`
Vars map[string]string `mapstructure:"variables"`
Blocklist []string
Tables []Table
RolesQuery string `mapstructure:"roles_query"`
Roles []Role
}
// Serv struct contains config values used by the Super Graph service
type Serv struct {
AppName string `mapstructure:"app_name"`
HostPort string `mapstructure:"host_port"`
Host string
Port string
HTTPGZip bool `mapstructure:"http_compress"`
WebUI bool `mapstructure:"web_ui"`
EnableTracing bool `mapstructure:"enable_tracing"`
UseAllowList bool `mapstructure:"use_allow_list"`
WatchAndReload bool `mapstructure:"reload_on_config_change"`
AuthFailBlock bool `mapstructure:"auth_fail_block"`
SeedFile string `mapstructure:"seed_file"`
MigrationsPath string `mapstructure:"migrations_path"`
AllowedOrigins []string `mapstructure:"cors_allowed_origins"`
DebugCORS bool `mapstructure:"cors_debug"`
Inflections map[string]string
Auth Auth
Auths []Auth
DB struct {
Type string
Host string
Port uint16
DBName string
User string
Password string
Schema string
PoolSize int32 `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
PingTimeout time.Duration `mapstructure:"ping_timeout"`
} `mapstructure:"database"`
Actions []Action
}
// Auth struct contains authentication related config values used by the Super Graph service
type Auth struct {
Name string
Type string
Cookie string
CredsInHeader bool `mapstructure:"creds_in_header"`
Rails struct {
Version string
SecretKeyBase string `mapstructure:"secret_key_base"`
URL string
Password string
MaxIdle int `mapstructure:"max_idle"`
MaxActive int `mapstructure:"max_active"`
Salt string
SignSalt string `mapstructure:"sign_salt"`
AuthSalt string `mapstructure:"auth_salt"`
}
JWT struct {
Provider string
Secret string
PubKeyFile string `mapstructure:"public_key_file"`
PubKeyType string `mapstructure:"public_key_type"`
}
Header struct {
Name string
Value string
Exists bool
}
}
// Column struct defines a database column
type Column struct {
Name string
Type string
ForeignKey string `mapstructure:"related_to"`
}
// Table struct defines a database table
type Table struct {
Name string
Table string
Blocklist []string
Remotes []Remote
Columns []Column
}
// Remote struct defines a remote API endpoint
type Remote struct {
Name string
ID string
Path string
URL string
Debug bool
PassHeaders []string `mapstructure:"pass_headers"`
SetHeaders []struct {
Name string
Value string
} `mapstructure:"set_headers"`
}
// Query struct contains access control values for query operations
type Query struct {
Limit int
Filters []string
Columns []string
DisableFunctions bool `mapstructure:"disable_functions"`
Block bool
}
// Insert struct contains access control values for insert operations
type Insert struct {
Filters []string
Columns []string
Presets map[string]string
Block bool
}
// Insert struct contains access control values for update operations
type Update struct {
Filters []string
Columns []string
Presets map[string]string
Block bool
}
// Delete struct contains access control values for delete operations
type Delete struct {
Filters []string
Columns []string
Block bool
}
// RoleTable struct contains role specific access control values for a database table
type RoleTable struct {
Name string
Query Query
Insert Insert
Update Update
Delete Delete
}
// Role struct contains role specific access control values for for all database tables
type Role struct {
Name string
Match string
Tables []RoleTable
tablesMap map[string]*RoleTable
}
// Action struct contains config values for a Super Graph service action
type Action struct {
Name string
SQL string
AuthName string `mapstructure:"auth_name"`
}
// NewConfig function reads in the config file for the environment specified in the GO_ENV
// environment variable. This is the best way to create a new Super Graph config.
func NewConfig(path string) (*Config, error) {
return NewConfigWithLogger(path, log.New(os.Stdout, "", 0))
}
// NewConfigWithLogger function reads in the config file for the environment specified in the GO_ENV
// environment variable. This is the best way to create a new Super Graph config.
func NewConfigWithLogger(path string, logger *log.Logger) (*Config, error) {
vi := newViper(path, GetConfigName())
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
inherits := vi.GetString("inherits")
if len(inherits) != 0 {
vi = newViper(path, inherits)
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
if vi.IsSet("inherits") {
return nil, fmt.Errorf("inherited config (%s) cannot itself inherit (%s)",
inherits,
vi.GetString("inherits"))
}
vi.SetConfigName(GetConfigName())
if err := vi.MergeInConfig(); err != nil {
return nil, err
}
}
c := &Config{log: logger, vi: vi}
if err := vi.Unmarshal(&c); err != nil {
return nil, fmt.Errorf("failed to decode config, %v", err)
}
if err := c.init(); err != nil {
return nil, fmt.Errorf("failed to initialize config: %w", err)
}
return c, nil
}
// NewConfigFrom function initializes a Config struct that you manually created
// so it can be used by Super Graph
func NewConfigFrom(c *Config, configPath string, logger *log.Logger) (*Config, error) {
c.vi = newViper(configPath, GetConfigName())
c.log = logger
if err := c.init(); err != nil {
return nil, err
}
return c, nil
}
func newViper(configPath, filename string) *viper.Viper {
vi := viper.New()
vi.SetEnvPrefix("SG")
vi.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
vi.AutomaticEnv()
vi.SetConfigName(filename)
vi.AddConfigPath(configPath)
vi.AddConfigPath("./config")
vi.SetDefault("host_port", "0.0.0.0:8080")
vi.SetDefault("web_ui", false)
vi.SetDefault("enable_tracing", false)
vi.SetDefault("auth_fail_block", "always")
vi.SetDefault("seed_file", "seed.js")
vi.SetDefault("database.type", "postgres")
vi.SetDefault("database.host", "localhost")
vi.SetDefault("database.port", 5432)
vi.SetDefault("database.user", "postgres")
vi.SetDefault("database.schema", "public")
vi.SetDefault("env", "development")
vi.BindEnv("env", "GO_ENV") //nolint: errcheck
vi.BindEnv("host", "HOST") //nolint: errcheck
vi.BindEnv("port", "PORT") //nolint: errcheck
vi.SetDefault("auth.rails.max_idle", 80)
vi.SetDefault("auth.rails.max_active", 12000)
return vi
}
func (c *Config) init() error {
switch c.Core.LogLevel {
case "debug":
c.logLevel = LogLevelDebug
case "error":
c.logLevel = LogLevelError
case "warn":
c.logLevel = LogLevelWarn
case "info":
c.logLevel = LogLevelInfo
default:
c.logLevel = LogLevelNone
}
if c.UseAllowList {
c.Production = true
}
for k, v := range c.Inflections {
flect.AddPlural(k, v)
}
// Tables: Validate and sanitize
tm := make(map[string]struct{})
for i := 0; i < len(c.Tables); i++ {
t := &c.Tables[i]
t.Name = flect.Pluralize(strings.ToLower(t.Name))
if _, ok := tm[t.Name]; ok {
c.Tables = append(c.Tables[:i], c.Tables[i+1:]...)
c.log.Printf("WRN duplicate table found: %s", t.Name)
}
tm[t.Name] = struct{}{}
t.Table = flect.Pluralize(strings.ToLower(t.Table))
}
// Variables: Validate and sanitize
for k, v := range c.Vars {
c.Vars[k] = sanitize(v)
}
// Roles: validate and sanitize
c.RolesQuery = sanitize(c.RolesQuery)
c.roles = make(map[string]*Role)
for i := 0; i < len(c.Roles); i++ {
r := &c.Roles[i]
r.Name = strings.ToLower(r.Name)
if _, ok := c.roles[r.Name]; ok {
c.Roles = append(c.Roles[:i], c.Roles[i+1:]...)
c.log.Printf("WRN duplicate role found: %s", r.Name)
}
r.Match = sanitize(r.Match)
r.tablesMap = make(map[string]*RoleTable)
for n, table := range r.Tables {
r.tablesMap[table.Name] = &r.Tables[n]
}
c.roles[r.Name] = r
}
if _, ok := c.roles["user"]; !ok {
u := Role{Name: "user"}
c.Roles = append(c.Roles, u)
c.roles["user"] = &u
}
if _, ok := c.roles["anon"]; !ok {
c.log.Printf("WRN unauthenticated requests will be blocked. no role 'anon' defined")
c.AuthFailBlock = true
}
if len(c.RolesQuery) == 0 {
c.log.Printf("WRN roles_query not defined: attribute based access control disabled")
}
if len(c.RolesQuery) == 0 {
c.abacEnabled = false
} else {
switch len(c.Roles) {
case 0, 1:
c.abacEnabled = false
case 2:
_, ok1 := c.roles["anon"]
_, ok2 := c.roles["user"]
c.abacEnabled = !(ok1 && ok2)
default:
c.abacEnabled = true
}
}
// Auths: validate and sanitize
am := make(map[string]struct{})
for i := 0; i < len(c.Auths); i++ {
a := &c.Auths[i]
a.Name = strings.ToLower(a.Name)
if _, ok := am[a.Name]; ok {
c.Auths = append(c.Auths[:i], c.Auths[i+1:]...)
c.log.Printf("WRN duplicate auth found: %s", a.Name)
}
am[a.Name] = struct{}{}
}
// Actions: validate and sanitize
axm := make(map[string]struct{})
for i := 0; i < len(c.Actions); i++ {
a := &c.Actions[i]
a.Name = strings.ToLower(a.Name)
a.AuthName = strings.ToLower(a.AuthName)
if _, ok := axm[a.Name]; ok {
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
c.log.Printf("WRN duplicate action found: %s", a.Name)
}
if _, ok := am[a.AuthName]; !ok {
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
c.log.Printf("WRN invalid auth_name '%s' for auth: %s", a.AuthName, a.Name)
}
axm[a.Name] = struct{}{}
}
c.valid = true
return nil
}
// GetDBTableAliases function returns a map with database tables as keys
// and a list of aliases as values
func (c *Config) GetDBTableAliases() map[string][]string {
m := make(map[string][]string, len(c.Tables))
for i := range c.Tables {
t := c.Tables[i]
if len(t.Table) == 0 || len(t.Columns) != 0 {
continue
}
m[t.Table] = append(m[t.Table], t.Name)
}
return m
}
// IsABACEnabled function returns true if attribute based access control is enabled
func (c *Config) IsABACEnabled() bool {
return c.abacEnabled
}
// IsAnonRoleDefined function returns true if the config has configuration for the `anon` role
func (c *Config) IsAnonRoleDefined() bool {
_, ok := c.roles["anon"]
return ok
}
// GetRole function returns returns the Role struct by name
func (c *Config) GetRole(name string) *Role {
role := c.roles[name]
return role
}
// ConfigPathUsed function returns the path to the current config file (excluding filename)
func (c *Config) ConfigPathUsed() string {
return path.Dir(c.vi.ConfigFileUsed())
}
// WriteConfigAs function writes the config to a file
// Format defined by extension (eg: .yml, .json)
func (c *Config) WriteConfigAs(fname string) error {
return c.vi.WriteConfigAs(fname)
}
// Log function returns the logger
func (c *Config) Log() *log.Logger {
return c.log
}
// LogLevel function returns the log level
func (c *Config) LogLevel() int {
return c.logLevel
}
// IsValid function returns true if the Config struct is initialized and valid
func (c *Config) IsValid() bool {
return c.valid
}
// GetTable function returns the RoleTable struct for a Role by table name
func (r *Role) GetTable(name string) *RoleTable {
table := r.tablesMap[name]
return table
}

View File

@ -1,11 +1,11 @@
package serv package config
import ( import (
"testing" "testing"
) )
func TestInitConf(t *testing.T) { func TestInitConf(t *testing.T) {
_, err := initConf() _, err := NewConfig("../examples/rails-app/config/supergraph")
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())

View File

@ -1,226 +0,0 @@
app_name: "Super Graph Development"
host_port: 0.0.0.0:8080
web_ui: true
# debug, info, warn, error, fatal, panic
log_level: "debug"
# enable or disable http compression (uses gzip)
http_compress: true
# When production mode is 'true' only queries
# from the allow list are permitted.
# When it's 'false' all queries are saved to the
# the allow list in ./config/allow.list
production: false
# Throw a 401 on auth failure for queries that need auth
auth_fail_block: false
# Latency tracing for database queries and remote joins
# the resulting latency information is returned with the
# response
enable_tracing: true
# Watch the config folder and reload Super Graph
# with the new configs when a change is detected
reload_on_config_change: true
# File that points to the database seeding script
# seed_file: seed.js
# Path pointing to where the migrations can be found
migrations_path: ./config/migrations
# Secret key for general encryption operations like
# encrypting the cursor data
secret_key: supercalifajalistics
# CORS: A list of origins a cross-domain request can be executed from.
# If the special * value is present in the list, all origins will be allowed.
# An origin may contain a wildcard (*) to replace 0 or more
# characters (i.e.: http://*.domain.com).
cors_allowed_origins: ["*"]
# Debug Cross Origin Resource Sharing requests
cors_debug: true
# Postgres related environment Variables
# SG_DATABASE_HOST
# SG_DATABASE_PORT
# SG_DATABASE_USER
# SG_DATABASE_PASSWORD
# Auth related environment Variables
# SG_AUTH_RAILS_COOKIE_SECRET_KEY_BASE
# SG_AUTH_RAILS_REDIS_URL
# SG_AUTH_RAILS_REDIS_PASSWORD
# SG_AUTH_JWT_PUBLIC_KEY_FILE
# inflections:
# person: people
# sheep: sheep
auth:
# Can be 'rails' or 'jwt'
type: rails
cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header for testing.
# Disable in production
creds_in_header: true
rails:
# Rails version this is used for reading the
# various cookies formats.
version: 5.2
# Found in 'Rails.application.config.secret_key_base'
secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566
# Remote cookie store. (memcache or redis)
# url: redis://redis:6379
# password: ""
# max_idle: 80
# max_active: 12000
# In most cases you don't need these
# salt: "encrypted cookie"
# sign_salt: "signed encrypted cookie"
# auth_salt: "authenticated encrypted cookie"
# jwt:
# provider: auth0
# secret: abc335bfcfdb04e50db5bb0a4d67ab9
# public_key_file: /secrets/public_key.pem
# public_key_type: ecdsa #rsa
database:
type: postgres
host: db
port: 5432
dbname: app_development
user: postgres
password: postgres
#schema: "public"
#pool_size: 10
#max_retries: 0
#log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false
# database ping timeout is used for db health checking
ping_timeout: 1m
# Define additional variables here to be used with filters
variables:
admin_account_id: "5"
# Field and table names that you wish to block
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
- password
- encrypted
- token
tables:
- name: customers
remotes:
- name: payments
id: stripe_id
url: http://rails_app:3000/stripe/$id
path: data
# debug: true
pass_headers:
- cookie
set_headers:
- name: Host
value: 0.0.0.0
# - name: Authorization
# value: Bearer <stripe_api_key>
- # You can create new fields that have a
# real db table backing them
name: me
table: users
- name: deals
table: products
- name: users
columns:
- name: email
related_to: products.name
roles_query: "SELECT * FROM users WHERE id = $user_id"
roles:
- name: anon
tables:
- name: products
query:
limit: 10
columns: ["id", "name", "description" ]
aggregation: false
insert:
block: false
update:
block: false
delete:
block: false
- name: deals
query:
limit: 3
aggregation: false
- name: purchases
query:
limit: 3
aggregation: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
disable_functions: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
presets:
- user_id: "$user_id"
- created_at: "now"
- updated_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
presets:
- updated_at: "now"
delete:
block: true
- name: admin
match: id = 1000
tables:
- name: users
filters: []

View File

@ -1,67 +0,0 @@
# Inherit config from this other config file
# so I only need to overwrite some values
inherits: dev
app_name: "Super Graph Production"
host_port: 0.0.0.0:8080
web_ui: false
# debug, info, warn, error, fatal, panic, disable
log_level: "info"
# enable or disable http compression (uses gzip)
http_compress: true
# When production mode is 'true' only queries
# from the allow list are permitted.
# When it's 'false' all queries are saved to the
# the allow list in ./config/allow.list
production: true
# Throw a 401 on auth failure for queries that need auth
auth_fail_block: true
# Latency tracing for database queries and remote joins
# the resulting latency information is returned with the
# response
enable_tracing: true
# File that points to the database seeding script
# seed_file: seed.js
# Path pointing to where the migrations can be found
# migrations_path: migrations
# Secret key for general encryption operations like
# encrypting the cursor data
# secret_key: supercalifajalistics
# Postgres related environment Variables
# SG_DATABASE_HOST
# SG_DATABASE_PORT
# SG_DATABASE_USER
# SG_DATABASE_PASSWORD
# Auth related environment Variables
# SG_AUTH_RAILS_COOKIE_SECRET_KEY_BASE
# SG_AUTH_RAILS_REDIS_URL
# SG_AUTH_RAILS_REDIS_PASSWORD
# SG_AUTH_JWT_PUBLIC_KEY_FILE
database:
type: postgres
host: db
port: 5432
dbname: app_production
user: postgres
password: postgres
#pool_size: 10
#max_retries: 0
#log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false
# database ping timeout is used for db health checking
ping_timeout: 5m

View File

@ -1,116 +0,0 @@
var user_count = 10
customer_count = 100
product_count = 50
purchase_count = 100
var users = []
customers = []
products = []
for (i = 0; i < user_count; i++) {
var pwd = fake.password()
var data = {
full_name: fake.name(),
avatar: fake.avatar_url(200),
phone: fake.phone(),
email: fake.email(),
password: pwd,
password_confirmation: pwd,
created_at: "now",
updated_at: "now"
}
var res = graphql(" \
mutation { \
user(insert: $data) { \
id \
} \
}", { data: data })
users.push(res.user)
}
for (i = 0; i < product_count; i++) {
var n = Math.floor(Math.random() * users.length)
var user = users[n]
var desc = [
fake.beer_style(),
fake.beer_hop(),
fake.beer_yeast(),
fake.beer_ibu(),
fake.beer_alcohol(),
fake.beer_blg(),
].join(", ")
var data = {
name: fake.beer_name(),
description: desc,
price: fake.price()
//user_id: user.id,
//created_at: "now",
//updated_at: "now"
}
var res = graphql(" \
mutation { \
product(insert: $data) { \
id \
} \
}", { data: data }, {
user_id: 5
})
products.push(res.product)
}
for (i = 0; i < customer_count; i++) {
var pwd = fake.password()
var data = {
stripe_id: "CUS-" + fake.uuid(),
full_name: fake.name(),
phone: fake.phone(),
email: fake.email(),
password: pwd,
password_confirmation: pwd,
created_at: "now",
updated_at: "now"
}
var res = graphql(" \
mutation { \
customer(insert: $data) { \
id \
} \
}", { data: data })
customers.push(res.customer)
}
for (i = 0; i < purchase_count; i++) {
var sale_type = fake.rand_string(["rented", "bought"])
if (sale_type === "rented") {
var due_date = fake.date()
var returned = fake.date()
}
var data = {
customer_id: customers[Math.floor(Math.random() * customer_count)].id,
product_id: products[Math.floor(Math.random() * product_count)].id,
sale_type: sale_type,
quantity: Math.floor(Math.random() * 10),
due_date: due_date,
returned: returned,
created_at: "now",
updated_at: "now"
}
var res = graphql(" \
mutation { \
purchase(insert: $data) { \
id \
} \
}", { data: data })
console.log(res)
}

52
config/utils.go Normal file
View File

@ -0,0 +1,52 @@
package config
import (
"os"
"regexp"
"strings"
"unicode"
)
var (
varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`)
)
func sanitize(s string) string {
s0 := varRe1.ReplaceAllString(s, `{{$1}}`)
s1 := strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return ' '
}
return r
}, s0)
return varRe2.ReplaceAllStringFunc(s1, func(m string) string {
return strings.ToLower(m)
})
}
func GetConfigName() string {
if len(os.Getenv("GO_ENV")) == 0 {
return "dev"
}
ge := strings.ToLower(os.Getenv("GO_ENV"))
switch {
case strings.HasPrefix(ge, "pro"):
return "prod"
case strings.HasPrefix(ge, "sta"):
return "stage"
case strings.HasPrefix(ge, "tes"):
return "test"
case strings.HasPrefix(ge, "dev"):
return "dev"
}
return ge
}

181
core/api.go Normal file
View File

@ -0,0 +1,181 @@
// Package core provides the primary API to include and use Super Graph with your own code.
// For detailed documentation visit https://supergraph.dev
//
// Example usage:
/*
package main
import (
"database/sql"
"fmt"
"time"
"github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/core"
_ "github.com/jackc/pgx/v4/stdlib"
)
func main() {
db, err := sql.Open("pgx", "postgres://postgrs:@localhost:5432/example_db")
if err != nil {
log.Fatalf(err)
}
conf, err := config.NewConfig("./config")
if err != nil {
log.Fatalf(err)
}
sg, err = core.NewSuperGraph(conf, db)
if err != nil {
log.Fatalf(err)
}
query := `
query {
posts {
id
title
}
}`
res, err := sg.GraphQL(context.Background(), query, nil)
if err != nil {
log.Fatalf(err)
}
fmt.Println(string(res.Data))
}
*/
package core
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/json"
"fmt"
"log"
"github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/core/internal/allow"
"github.com/dosco/super-graph/core/internal/crypto"
"github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode"
)
type contextkey int
// Constants to set values on the context passed to the NewSuperGraph function
const (
// Name of the authentication provider. Eg. google, github, etc
UserIDProviderKey contextkey = iota
// User ID value for authenticated users
UserIDKey
// User role if pre-defined
UserRoleKey
)
// SuperGraph struct is an instance of the Super Graph engine it holds all the required information like
// datase schemas, relationships, etc that the GraphQL to SQL compiler would need to do it's job.
type SuperGraph struct {
conf *config.Config
db *sql.DB
schema *psql.DBSchema
allowList *allow.List
encKey [32]byte
prepared map[string]*preparedItem
getRole *sql.Stmt
qc *qcode.Compiler
pc *psql.Compiler
}
// NewConfig functions initializes config using a config.Core struct
func NewConfig(core config.Core, configPath string, logger *log.Logger) (*config.Config, error) {
c, err := config.NewConfigFrom(&config.Config{Core: core}, configPath, logger)
if err != nil {
return nil, err
}
return c, nil
}
// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
// schemas and relationships
func NewSuperGraph(conf *config.Config, db *sql.DB) (*SuperGraph, error) {
if !conf.IsValid() {
return nil, fmt.Errorf("invalid config")
}
sg := &SuperGraph{
conf: conf,
db: db,
}
if err := sg.initCompilers(); err != nil {
return nil, err
}
if err := sg.initAllowList(); err != nil {
return nil, err
}
if err := sg.initPrepared(); err != nil {
return nil, err
}
if len(conf.SecretKey) != 0 {
sk := sha256.Sum256([]byte(conf.SecretKey))
conf.SecretKey = ""
sg.encKey = sk
} else {
sg.encKey = crypto.NewEncryptionKey()
}
return sg, nil
}
// Result struct contains the output of the GraphQL function this includes resulting json from the
// database query and any error information
type Result struct {
op qcode.QType
name string
sql string
role string
Error string `json:"message,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Extensions *extensions `json:"extensions,omitempty"`
}
// GraphQL function is called on the SuperGraph struct to convert the provided GraphQL query into an
// SQL query and execute it on the database. In production mode prepared statements are directly used
// and no query compiling takes places.
//
// In developer mode all names queries are saved into a file `allow.list` and in production mode only
// queries from this file can be run.
func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMessage) (*Result, error) {
ct := scontext{Context: c, sg: sg, query: query, vars: vars}
if len(vars) <= 2 {
ct.vars = nil
}
if keyExists(c, UserIDKey) {
ct.role = "user"
} else {
ct.role = "anon"
}
ct.res.op = qcode.GetQType(query)
ct.res.name = allow.QueryName(query)
data, err := ct.execQuery()
if err != nil {
return &ct.res, err
}
ct.res.Data = json.RawMessage(data)
return &ct.res, nil
}

View File

@ -1,8 +1,7 @@
package serv package core
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -10,29 +9,29 @@ import (
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
) )
func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int, error) { func (c *scontext) argMap() func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "user_id_provider": case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil { if v := c.Value(UserIDProviderKey); v != nil {
return io.WriteString(w, v.(string)) return io.WriteString(w, v.(string))
} }
return 0, argErr("user_id_provider") return 0, argErr("user_id_provider")
case "user_id": case "user_id":
if v := ctx.Value(userIDKey); v != nil { if v := c.Value(UserIDKey); v != nil {
return io.WriteString(w, v.(string)) return io.WriteString(w, v.(string))
} }
return 0, argErr("user_id") return 0, argErr("user_id")
case "user_role": case "user_role":
if v := ctx.Value(userRoleKey); v != nil { if v := c.Value(UserRoleKey); v != nil {
return io.WriteString(w, v.(string)) return io.WriteString(w, v.(string))
} }
return 0, argErr("user_role") return 0, argErr("user_role")
} }
fields := jsn.Get(vars, [][]byte{[]byte(tag)}) fields := jsn.Get(c.vars, [][]byte{[]byte(tag)})
if len(fields) == 0 { if len(fields) == 0 {
return 0, argErr(tag) return 0, argErr(tag)
@ -49,7 +48,7 @@ func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int
if bytes.EqualFold(v, []byte("null")) { if bytes.EqualFold(v, []byte("null")) {
return io.WriteString(w, ``) return io.WriteString(w, ``)
} }
v1, err := decrypt(string(fields[0].Value)) v1, err := c.sg.decrypt(string(fields[0].Value))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -61,14 +60,14 @@ func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int
} }
} }
func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) { func (c *scontext) argList(args [][]byte) ([]interface{}, error) {
vars := make([]interface{}, len(args)) vars := make([]interface{}, len(args))
var fields map[string]json.RawMessage var fields map[string]json.RawMessage
var err error var err error
if len(ctx.req.Vars) != 0 { if len(c.vars) != 0 {
fields, _, err = jsn.Tree(ctx.req.Vars) fields, _, err = jsn.Tree(c.vars)
if err != nil { if err != nil {
return nil, err return nil, err
@ -79,21 +78,21 @@ func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
av := args[i] av := args[i]
switch { switch {
case bytes.Equal(av, []byte("user_id")): case bytes.Equal(av, []byte("user_id")):
if v := ctx.Value(userIDKey); v != nil { if v := c.Value(UserIDKey); v != nil {
vars[i] = v.(string) vars[i] = v.(string)
} else { } else {
return nil, argErr("user_id") return nil, argErr("user_id")
} }
case bytes.Equal(av, []byte("user_id_provider")): case bytes.Equal(av, []byte("user_id_provider")):
if v := ctx.Value(userIDProviderKey); v != nil { if v := c.Value(UserIDProviderKey); v != nil {
vars[i] = v.(string) vars[i] = v.(string)
} else { } else {
return nil, argErr("user_id_provider") return nil, argErr("user_id_provider")
} }
case bytes.Equal(av, []byte("user_role")): case bytes.Equal(av, []byte("user_role")):
if v := ctx.Value(userRoleKey); v != nil { if v := c.Value(UserRoleKey); v != nil {
vars[i] = v.(string) vars[i] = v.(string)
} else { } else {
return nil, argErr("user_role") return nil, argErr("user_role")
@ -101,7 +100,7 @@ func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
case bytes.Equal(av, []byte("cursor")): case bytes.Equal(av, []byte("cursor")):
if v, ok := fields["cursor"]; ok && v[0] == '"' { if v, ok := fields["cursor"]; ok && v[0] == '"' {
v1, err := decrypt(string(v[1 : len(v)-1])) v1, err := c.sg.decrypt(string(v[1 : len(v)-1]))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,4 +1,4 @@
package serv package core
import ( import (
"bytes" "bytes"
@ -7,42 +7,43 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode"
) )
type stmt struct { type stmt struct {
role *configRole role *config.Role
qc *qcode.QCode qc *qcode.QCode
skipped uint32 skipped uint32
sql string sql string
} }
func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) { func (sg *SuperGraph) buildStmt(qt qcode.QType, query, vars []byte, role string) ([]stmt, error) {
switch qt { switch qt {
case qcode.QTMutation: case qcode.QTMutation:
return buildRoleStmt(gql, vars, role) return sg.buildRoleStmt(query, vars, role)
case qcode.QTQuery: case qcode.QTQuery:
if role == "anon" { if role == "anon" {
return buildRoleStmt(gql, vars, "anon") return sg.buildRoleStmt(query, vars, "anon")
} }
if conf.isABACEnabled() { if sg.conf.IsABACEnabled() {
return buildMultiStmt(gql, vars) return sg.buildMultiStmt(query, vars)
} }
return buildRoleStmt(gql, vars, "user") return sg.buildRoleStmt(query, vars, "user")
default: default:
return nil, fmt.Errorf("unknown query type '%d'", qt) return nil, fmt.Errorf("unknown query type '%d'", qt)
} }
} }
func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) { func (sg *SuperGraph) buildRoleStmt(query, vars []byte, role string) ([]stmt, error) {
ro, ok := conf.roles[role] ro := sg.conf.GetRole(role)
if !ok { if ro == nil {
return nil, fmt.Errorf(`roles '%s' not defined in config`, role) return nil, fmt.Errorf(`roles '%s' not defined in c.sg.config`, role)
} }
var vm map[string]json.RawMessage var vm map[string]json.RawMessage
@ -54,7 +55,7 @@ func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
} }
} }
qc, err := qcompile.Compile(gql, ro.Name) qc, err := sg.qc.Compile(query, ro.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,7 +63,7 @@ func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
stmts := []stmt{stmt{role: ro, qc: qc}} stmts := []stmt{stmt{role: ro, qc: qc}}
w := &bytes.Buffer{} w := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm)) skipped, err := sg.pc.Compile(qc, w, psql.Variables(vm))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -73,7 +74,7 @@ func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
return stmts, nil return stmts, nil
} }
func buildMultiStmt(gql, vars []byte) ([]stmt, error) { func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
var vm map[string]json.RawMessage var vm map[string]json.RawMessage
var err error var err error
@ -83,29 +84,29 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
} }
} }
if len(conf.RolesQuery) == 0 { if len(sg.conf.RolesQuery) == 0 {
return nil, errors.New("roles_query not defined") return nil, errors.New("roles_query not defined")
} }
stmts := make([]stmt, 0, len(conf.Roles)) stmts := make([]stmt, 0, len(sg.conf.Roles))
w := &bytes.Buffer{} w := &bytes.Buffer{}
for i := 0; i < len(conf.Roles); i++ { for i := 0; i < len(sg.conf.Roles); i++ {
role := &conf.Roles[i] role := &sg.conf.Roles[i]
// skip anon as it's not included in the combined multi-statement // skip anon as it's not included in the combined multi-statement
if role.Name == "anon" { if role.Name == "anon" {
continue continue
} }
qc, err := qcompile.Compile(gql, role.Name) qc, err := sg.qc.Compile(query, role.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
stmts = append(stmts, stmt{role: role, qc: qc}) stmts = append(stmts, stmt{role: role, qc: qc})
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm)) skipped, err := sg.pc.Compile(qc, w, psql.Variables(vm))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -116,7 +117,7 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
w.Reset() w.Reset()
} }
sql, err := renderUserQuery(stmts, vm) sql, err := sg.renderUserQuery(stmts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -126,8 +127,7 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
} }
//nolint: errcheck //nolint: errcheck
func renderUserQuery( func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
@ -145,7 +145,7 @@ func renderUserQuery(
} }
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`) io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery) io.WriteString(w, sg.conf.RolesQuery)
io.WriteString(w, `) THEN `) io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`) io.WriteString(w, `(SELECT (CASE`)
@ -161,20 +161,21 @@ func renderUserQuery(
} }
io.WriteString(w, ` ELSE 'user' END) FROM (`) io.WriteString(w, ` ELSE 'user' END) FROM (`)
io.WriteString(w, conf.RolesQuery) io.WriteString(w, sg.conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `) io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
return w.String(), nil return w.String(), nil
} }
func hasTablesWithConfig(qc *qcode.QCode, role *configRole) bool { func (sg *SuperGraph) hasTablesWithConfig(qc *qcode.QCode, role *config.Role) bool {
for _, id := range qc.Roots { for _, id := range qc.Roots {
t, err := schema.GetTable(qc.Selects[id].Name) t, err := sg.schema.GetTable(qc.Selects[id].Name)
if err != nil { if err != nil {
return false return false
} }
if _, ok := role.tablesMap[t.Name]; !ok {
if r := role.GetTable(t.Name); r == nil {
return false return false
} }
} }

View File

@ -1,14 +1,15 @@
package serv package core
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/config"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode"
) )
func addTables(c *config, di *psql.DBInfo) error { func addTables(c *config.Config, di *psql.DBInfo) error {
for _, t := range c.Tables { for _, t := range c.Tables {
if len(t.Table) == 0 || len(t.Columns) == 0 { if len(t.Table) == 0 || len(t.Columns) == 0 {
continue continue
@ -20,7 +21,7 @@ func addTables(c *config, di *psql.DBInfo) error {
return nil return nil
} }
func addTable(di *psql.DBInfo, cols []configColumn, t configTable) error { func addTable(di *psql.DBInfo, cols []config.Column, t config.Table) error {
bc, ok := di.GetColumn(t.Table, t.Name) bc, ok := di.GetColumn(t.Table, t.Name)
if !ok { if !ok {
return fmt.Errorf( return fmt.Errorf(
@ -57,7 +58,7 @@ func addTable(di *psql.DBInfo, cols []configColumn, t configTable) error {
return nil return nil
} }
func addForeignKeys(c *config, di *psql.DBInfo) error { func addForeignKeys(c *config.Config, di *psql.DBInfo) error {
for _, t := range c.Tables { for _, t := range c.Tables {
for _, c := range t.Columns { for _, c := range t.Columns {
if len(c.ForeignKey) == 0 { if len(c.ForeignKey) == 0 {
@ -71,18 +72,18 @@ func addForeignKeys(c *config, di *psql.DBInfo) error {
return nil return nil
} }
func addForeignKey(di *psql.DBInfo, c configColumn, t configTable) error { func addForeignKey(di *psql.DBInfo, c config.Column, t config.Table) error {
c1, ok := di.GetColumn(t.Name, c.Name) c1, ok := di.GetColumn(t.Name, c.Name)
if !ok { if !ok {
return fmt.Errorf( return fmt.Errorf(
"Invalid table '%s' or column '%s' in config", "Invalid table '%s' or column '%s' in config.Config",
t.Name, c.Name) t.Name, c.Name)
} }
v := strings.SplitN(c.ForeignKey, ".", 2) v := strings.SplitN(c.ForeignKey, ".", 2)
if len(v) != 2 { if len(v) != 2 {
return fmt.Errorf( return fmt.Errorf(
"Invalid foreign_key in config for table '%s' and column '%s", "Invalid foreign_key in config.Config for table '%s' and column '%s",
t.Name, c.Name) t.Name, c.Name)
} }
@ -90,7 +91,7 @@ func addForeignKey(di *psql.DBInfo, c configColumn, t configTable) error {
c2, ok := di.GetColumn(fkt, fkc) c2, ok := di.GetColumn(fkt, fkc)
if !ok { if !ok {
return fmt.Errorf( return fmt.Errorf(
"Invalid foreign_key in config for table '%s' and column '%s", "Invalid foreign_key in config.Config for table '%s' and column '%s",
t.Name, c.Name) t.Name, c.Name)
} }
@ -100,7 +101,7 @@ func addForeignKey(di *psql.DBInfo, c configColumn, t configTable) error {
return nil return nil
} }
func addRoles(c *config, qc *qcode.Compiler) error { func addRoles(c *config.Config, qc *qcode.Compiler) error {
for _, r := range c.Roles { for _, r := range c.Roles {
for _, t := range r.Tables { for _, t := range r.Tables {
if err := addRole(qc, r, t); err != nil { if err := addRole(qc, r, t); err != nil {
@ -112,7 +113,7 @@ func addRoles(c *config, qc *qcode.Compiler) error {
return nil return nil
} }
func addRole(qc *qcode.Compiler, r configRole, t configRoleTable) error { func addRole(qc *qcode.Compiler, r config.Role, t config.RoleTable) error {
blockFilter := []string{"false"} blockFilter := []string{"false"}
query := qcode.QueryConfig{ query := qcode.QueryConfig{

19
core/consts.go Normal file
View File

@ -0,0 +1,19 @@
package core
import (
"context"
"errors"
)
const (
openVar = "{{"
closeVar = "}}"
)
var (
errNotFound = errors.New("not found in prepared statements")
)
func keyExists(ct context.Context, key contextkey) bool {
return ct.Value(key) != nil
}

392
core/core.go Normal file
View File

@ -0,0 +1,392 @@
package core
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode"
"github.com/valyala/fasttemplate"
)
type extensions struct {
Tracing *trace `json:"tracing,omitempty"`
}
type trace struct {
Version int `json:"version"`
StartTime time.Time `json:"startTime"`
EndTime time.Time `json:"endTime"`
Duration time.Duration `json:"duration"`
Execution execution `json:"execution"`
}
type execution struct {
Resolvers []resolver `json:"resolvers"`
}
type resolver struct {
Path []string `json:"path"`
ParentType string `json:"parentType"`
FieldName string `json:"fieldName"`
ReturnType string `json:"returnType"`
StartOffset int `json:"startOffset"`
Duration time.Duration `json:"duration"`
}
type scontext struct {
context.Context
sg *SuperGraph
query string
vars json.RawMessage
role string
res Result
}
func (sg *SuperGraph) initCompilers() error {
di, err := psql.GetDBInfo(sg.db)
if err != nil {
return err
}
if err = addTables(sg.conf, di); err != nil {
return err
}
if err = addForeignKeys(sg.conf, di); err != nil {
return err
}
sg.schema, err = psql.NewDBSchema(di, sg.conf.GetDBTableAliases())
if err != nil {
return err
}
sg.qc, err = qcode.NewCompiler(qcode.Config{
Blocklist: sg.conf.Blocklist,
})
if err != nil {
return err
}
if err := addRoles(sg.conf, sg.qc); err != nil {
return err
}
sg.pc = psql.NewCompiler(psql.Config{
Schema: sg.schema,
Vars: sg.conf.Vars,
})
return nil
}
func (c *scontext) execQuery() ([]byte, error) {
var data []byte
// var st *stmt
var err error
if c.sg.conf.Production {
data, _, err = c.resolvePreparedSQL()
if err != nil {
return nil, err
}
} else {
data, _, err = c.resolveSQL()
if err != nil {
return nil, err
}
}
return data, nil
//return execRemoteJoin(st, data, c.req.hdr)
}
func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
var tx *sql.Tx
var err error
mutation := (c.res.op == qcode.QTMutation)
useRoleQuery := c.sg.conf.IsABACEnabled() && mutation
useTx := useRoleQuery || c.sg.conf.SetUserID
if useTx {
if tx, err = c.sg.db.BeginTx(c, nil); err != nil {
return nil, nil, err
}
defer tx.Rollback() //nolint: errcheck
}
if c.sg.conf.SetUserID {
if err := setLocalUserID(c, tx); err != nil {
return nil, nil, err
}
}
var role string
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(UserRoleKey); v != nil {
role = v.(string)
} else {
role = c.role
}
c.res.role = role
ps, ok := prepared[stmtHash(c.res.name, role)]
if !ok {
return nil, nil, errNotFound
}
c.res.sql = ps.st.sql
var root []byte
var row *sql.Row
varsList, err := c.argList(ps.args)
if err != nil {
return nil, nil, err
}
if useTx {
row = tx.Stmt(ps.sd).QueryRow(varsList...)
} else {
row = ps.sd.QueryRow(varsList...)
}
if ps.roleArg {
err = row.Scan(&role, &root)
} else {
err = row.Scan(&root)
}
if err != nil {
return nil, nil, err
}
c.role = role
if useTx {
if err := tx.Commit(); err != nil {
return nil, nil, err
}
}
if root, err = c.sg.encryptCursor(ps.st.qc, root); err != nil {
return nil, nil, err
}
return root, &ps.st, nil
}
func (c *scontext) resolveSQL() ([]byte, *stmt, error) {
var tx *sql.Tx
var err error
mutation := (c.res.op == qcode.QTMutation)
useRoleQuery := c.sg.conf.IsABACEnabled() && mutation
useTx := useRoleQuery || c.sg.conf.SetUserID
if useTx {
if tx, err = c.sg.db.BeginTx(c, nil); err != nil {
return nil, nil, err
}
defer tx.Rollback() //nolint: errcheck
}
if c.sg.conf.SetUserID {
if err := setLocalUserID(c, tx); err != nil {
return nil, nil, err
}
}
if useRoleQuery {
if c.role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(UserRoleKey); v != nil {
c.role = v.(string)
}
stmts, err := c.sg.buildStmt(c.res.op, []byte(c.query), c.vars, c.role)
if err != nil {
return nil, nil, err
}
st := &stmts[0]
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, c.argMap())
if err != nil {
return nil, nil, err
}
finalSQL := buf.String()
// var stime time.Time
// if c.sg.conf.EnableTracing {
// stime = time.Now()
// }
var root []byte
var role string
var row *sql.Row
// defaultRole := c.role
if useTx {
row = tx.QueryRow(finalSQL)
} else {
row = c.sg.db.QueryRow(finalSQL)
}
if len(stmts) > 1 {
err = row.Scan(&role, &root)
} else {
err = row.Scan(&root)
}
c.res.sql = finalSQL
if len(role) == 0 {
c.res.role = c.role
} else {
c.res.role = role
}
if err != nil {
return nil, nil, err
}
if useTx {
if err := tx.Commit(); err != nil {
return nil, nil, err
}
}
if root, err = c.sg.encryptCursor(st.qc, root); err != nil {
return nil, nil, err
}
if c.sg.allowList.IsPersist() {
if err := c.sg.allowList.Set(c.vars, c.query, ""); err != nil {
return nil, nil, err
}
}
if len(stmts) > 1 {
if st = findStmt(role, stmts); st == nil {
return nil, nil, fmt.Errorf("invalid role '%s' returned", role)
}
}
// if c.sg.conf.EnableTracing {
// for _, id := range st.qc.Roots {
// c.addTrace(st.qc.Selects, id, stime)
// }
// }
return root, st, nil
}
func (c *scontext) executeRoleQuery(tx *sql.Tx) (string, error) {
userID := c.Value(UserIDKey)
if userID == nil {
return "anon", nil
}
var role string
row := c.sg.getRole.QueryRow(userID, c.role)
if err := row.Scan(&role); err != nil {
return "", err
}
return role, nil
}
func (r *Result) Operation() string {
return r.op.String()
}
func (r *Result) QueryName() string {
return r.name
}
func (r *Result) Role() string {
return r.role
}
func (r *Result) SQL() string {
return r.sql
}
// func (c *scontext) addTrace(sel []qcode.Select, id int32, st time.Time) {
// et := time.Now()
// du := et.Sub(st)
// if c.res.Extensions == nil {
// c.res.Extensions = &extensions{&trace{
// Version: 1,
// StartTime: st,
// Execution: execution{},
// }}
// }
// c.res.Extensions.Tracing.EndTime = et
// c.res.Extensions.Tracing.Duration = du
// n := 1
// for i := id; i != -1; i = sel[i].ParentID {
// n++
// }
// path := make([]string, n)
// n--
// for i := id; ; i = sel[i].ParentID {
// path[n] = sel[i].Name
// if sel[i].ParentID == -1 {
// break
// }
// n--
// }
// tr := resolver{
// Path: path,
// ParentType: "Query",
// FieldName: sel[id].Name,
// ReturnType: "object",
// StartOffset: 1,
// Duration: du,
// }
// c.res.Extensions.Tracing.Execution.Resolvers =
// append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
// }
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {
continue
}
return &stmts[i]
}
return nil
}

View File

@ -1,4 +1,4 @@
package serv package core
/* /*

View File

@ -1,15 +1,15 @@
package serv package core
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"github.com/dosco/super-graph/crypto" "github.com/dosco/super-graph/core/internal/crypto"
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/core/internal/qcode"
) )
func encryptCursor(qc *qcode.QCode, data []byte) ([]byte, error) { func (sg *SuperGraph) encryptCursor(qc *qcode.QCode, data []byte) ([]byte, error) {
var keys [][]byte var keys [][]byte
for _, s := range qc.Selects { for _, s := range qc.Selects {
@ -39,7 +39,7 @@ func encryptCursor(qc *qcode.QCode, data []byte) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
if len(f.Value) > 2 { if len(f.Value) > 2 {
v, err := crypto.Encrypt(f.Value[1:len(f.Value)-1], &internalKey) v, err := crypto.Encrypt(f.Value[1:len(f.Value)-1], &sg.encKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,10 +63,10 @@ func encryptCursor(qc *qcode.QCode, data []byte) ([]byte, error) {
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func decrypt(data string) ([]byte, error) { func (sg *SuperGraph) decrypt(data string) ([]byte, error) {
v, err := base64.StdEncoding.DecodeString(data) v, err := base64.StdEncoding.DecodeString(data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return crypto.Decrypt(v, &internalKey) return crypto.Decrypt(v, &sg.encKey)
} }

15
core/db.go Normal file
View File

@ -0,0 +1,15 @@
package core
import (
"context"
"database/sql"
)
func setLocalUserID(c context.Context, tx *sql.Tx) error {
var err error
if v := c.Value(UserIDKey); v != nil {
_, err = tx.Exec(`SET LOCAL "user.id" = ?`, v)
}
return err
}

Some files were not shown because too many files have changed in this diff Show More