Block unauthorized requests when 'anon' role is not defined

This commit is contained in:
Vikram Rangnekar 2019-11-02 17:13:17 -04:00
parent 0deb3596c5
commit 77a51924a7
14 changed files with 80 additions and 82 deletions

View File

@ -159,7 +159,6 @@ query {
} }
} }
variables { variables {
"data": { "data": {
"email": "gfk@myspace.com", "email": "gfk@myspace.com",
@ -272,4 +271,20 @@ query {
} }
} }
query {
users {
id
email
picture: avatar
password
full_name
products(limit: 2, where: {price: {gt: 10}}) {
id
name
description
price
}
}
}

View File

@ -12,8 +12,7 @@ log_level: "debug"
use_allow_list: false use_allow_list: false
# Throw a 401 on auth failure for queries that need auth # Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never auth_fail_block: false
auth_fail_block: never
# Latency tracing for database queries and remote joins # Latency tracing for database queries and remote joins
# the resulting latency information is returned with the # the resulting latency information is returned with the

View File

@ -16,8 +16,7 @@ log_level: "info"
use_allow_list: true use_allow_list: true
# Throw a 401 on auth failure for queries that need auth # Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never auth_fail_block: true
auth_fail_block: always
# Latency tracing for database queries and remote joins # Latency tracing for database queries and remote joins
# the resulting latency information is returned with the # the resulting latency information is returned with the

View File

@ -2,7 +2,6 @@ package psql
import ( import (
"encoding/json" "encoding/json"
"fmt"
"testing" "testing"
) )
@ -261,8 +260,6 @@ func simpleUpdateWithPresets(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
fmt.Println(string(resSQL))
if string(resSQL) != sql { if string(resSQL) != sql {
t.Fatal(errNotExpected) t.Fatal(errNotExpected)
} }

View File

@ -211,12 +211,12 @@ func (al *allowList) save(item *allowItem) {
key := gqlHash(item.gql, item.vars, "") key := gqlHash(item.gql, item.vars, "")
if idx, ok := al.index[key]; ok { if _, ok := al.index[key]; ok {
al.list[idx] = item return
} else { }
al.list = append(al.list, item) al.list = append(al.list, item)
al.index[key] = len(al.list) - 1 al.index[key] = len(al.list) - 1
}
f, err := os.Create(al.filepath) f, err := os.Create(al.filepath)
if err != nil { if err != nil {

View File

@ -121,7 +121,7 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil || len(ck.Value) == 0 {
logger.Warn().Err(err).Msg("rails cookie missing") logger.Warn().Err(err).Msg("rails cookie missing")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return

View File

@ -19,10 +19,6 @@ import (
const ( const (
serverName = "Super Graph" serverName = "Super Graph"
authFailBlockAlways = iota + 1
authFailBlockPerQuery
authFailBlockNever
) )
var ( var (
@ -32,7 +28,6 @@ var (
db *pgxpool.Pool db *pgxpool.Pool
qcompile *qcode.Compiler qcompile *qcode.Compiler
pcompile *psql.Compiler pcompile *psql.Compiler
authFailBlock int
) )
func Init() { func Init() {
@ -179,8 +174,6 @@ func initConf() (*config, error) {
return nil, fmt.Errorf("unable to decode config, %v", err) return nil, fmt.Errorf("unable to decode config, %v", err)
} }
authFailBlock = getAuthFailBlock(c)
logLevel, err := zerolog.ParseLevel(c.LogLevel) logLevel, err := zerolog.ParseLevel(c.LogLevel)
if err != nil { if err != nil {
logger.Error().Err(err).Msg("error setting log_level") logger.Error().Err(err).Msg("error setting log_level")

View File

@ -24,7 +24,7 @@ type config struct {
EnableTracing bool `mapstructure:"enable_tracing"` EnableTracing bool `mapstructure:"enable_tracing"`
UseAllowList bool `mapstructure:"use_allow_list"` UseAllowList bool `mapstructure:"use_allow_list"`
WatchAndReload bool `mapstructure:"reload_on_config_change"` WatchAndReload bool `mapstructure:"reload_on_config_change"`
AuthFailBlock string `mapstructure:"auth_fail_block"` AuthFailBlock bool `mapstructure:"auth_fail_block"`
SeedFile string `mapstructure:"seed_file"` SeedFile string `mapstructure:"seed_file"`
MigrationsPath string `mapstructure:"migrations_path"` MigrationsPath string `mapstructure:"migrations_path"`
@ -103,36 +103,41 @@ type configRemote struct {
} `mapstructure:"set_headers"` } `mapstructure:"set_headers"`
} }
type configRoleTable struct { type configQuery struct {
Name string
Query struct {
Limit int Limit int
Filters []string Filters []string
Columns []string Columns []string
DisableFunctions bool `mapstructure:"disable_functions"` DisableFunctions bool `mapstructure:"disable_functions"`
Block bool Block bool
} }
Insert struct { type configInsert struct {
Filters []string Filters []string
Columns []string Columns []string
Presets map[string]string Presets map[string]string
Block bool Block bool
} }
Update struct { type configUpdate struct {
Filters []string Filters []string
Columns []string Columns []string
Presets map[string]string Presets map[string]string
Block bool Block bool
} }
Delete struct { type configDelete struct {
Filters []string Filters []string
Columns []string Columns []string
Block bool Block bool
} }
type configRoleTable struct {
Name string
Query configQuery
Insert configInsert
Update configUpdate
Delete configDelete
} }
type configRole struct { type configRole struct {
@ -213,7 +218,7 @@ func (c *config) Init(vi *viper.Viper) error {
rolesMap := make(map[string]struct{}) rolesMap := make(map[string]struct{})
for i := range c.Roles { for i := range c.Roles {
role := &c.Roles[i] role := c.Roles[i]
if _, ok := rolesMap[role.Name]; ok { if _, ok := rolesMap[role.Name]; ok {
logger.Fatal().Msgf("duplicate role '%s' found", role.Name) logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
@ -228,7 +233,8 @@ func (c *config) Init(vi *viper.Viper) error {
} }
if _, ok := rolesMap["anon"]; !ok { if _, ok := rolesMap["anon"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "anon"}) logger.Warn().Msg("unauthenticated requests will be blocked. no role 'anon' defined")
c.AuthFailBlock = true
} }
c.validate() c.validate()

View File

@ -208,10 +208,8 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c)) _, err = t.ExecuteFunc(buf, varMap(c))
if err == errNoUserID && if err == errNoUserID {
authFailBlock == authFailBlockPerQuery && logger.Warn().Msg("no user id found. query requires an authenicated request")
authCheck(c) == false {
return nil, 0, errUnauthorized
} }
if err != nil { if err != nil {

View File

@ -70,10 +70,9 @@ type resolver struct {
func apiv1Http(w http.ResponseWriter, r *http.Request) { func apiv1Http(w http.ResponseWriter, r *http.Request) {
ctx := &coreContext{Context: r.Context()} ctx := &coreContext{Context: r.Context()}
if authFailBlock == authFailBlockAlways && authCheck(ctx) == false { if conf.AuthFailBlock && authCheck(ctx) == false {
err := "Not authorized" w.WriteHeader(http.StatusUnauthorized)
logger.Debug().Msg(err) json.NewEncoder(w).Encode(gqlResp{Error: errUnauthorized.Error()})
http.Error(w, err, 401)
return return
} }

View File

@ -107,6 +107,13 @@ func Do(log func(string, ...interface{}), additional ...dir) error {
case event := <-watcher.Events: case event := <-watcher.Events:
// Ensure that we use the correct events, as they are not uniform across // Ensure that we use the correct events, as they are not uniform across
// platforms. See https://github.com/fsnotify/fsnotify/issues/74 // platforms. See https://github.com/fsnotify/fsnotify/issues/74
if conf.UseAllowList == false && strings.HasSuffix(event.Name, "/allow.list") {
continue
}
logger.Info().Msgf("Reloading, file changed detected '%s'", event)
var trigger bool var trigger bool
switch runtime.GOOS { switch runtime.GOOS {
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly": case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":

View File

@ -204,16 +204,3 @@ func getConfigName() string {
return ge return ge
} }
func getAuthFailBlock(c *config) int {
switch c.AuthFailBlock {
case "always":
return authFailBlockAlways
case "per_query", "perquery", "query":
return authFailBlockPerQuery
case "never", "false":
return authFailBlockNever
}
return authFailBlockAlways
}

View File

@ -12,8 +12,7 @@ log_level: "debug"
use_allow_list: false use_allow_list: false
# Throw a 401 on auth failure for queries that need auth # Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never auth_fail_block: false
auth_fail_block: never
# Latency tracing for database queries and remote joins # Latency tracing for database queries and remote joins
# the resulting latency information is returned with the # the resulting latency information is returned with the

View File

@ -16,8 +16,7 @@ log_level: "info"
use_allow_list: true use_allow_list: true
# Throw a 401 on auth failure for queries that need auth # Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never auth_fail_block: true
auth_fail_block: always
# Latency tracing for database queries and remote joins # Latency tracing for database queries and remote joins
# the resulting latency information is returned with the # the resulting latency information is returned with the