diff --git a/config/prod.yml b/config/prod.yml index 86b5563..9d5df4c 100644 --- a/config/prod.yml +++ b/config/prod.yml @@ -1,6 +1,6 @@ # Inherit config from this other config file # so I only need to overwrite some values -inherit: dev +inherits: dev app_name: "Super Graph Production" host_port: 0.0.0.0:8080 diff --git a/docs/guide.md b/docs/guide.md index 29b9b2f..a2c084f 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -1042,7 +1042,7 @@ We're tried to ensure that the config file is self documenting and easy to work ```yaml # Inherit config from this other config file # so I only need to overwrite some values -inherit: base +inherits: base app_name: "Super Graph Development" host_port: 0.0.0.0:8080 diff --git a/serv/cmd.go b/serv/cmd.go index c271280..77730b1 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -8,7 +8,6 @@ import ( "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" - "github.com/gobuffalo/flect" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/rs/zerolog" @@ -156,37 +155,30 @@ func initConf() (*config, error) { return nil, err } - inherit := vi.GetString("inherit") - if len(inherit) != 0 { - vi = newConfig(inherit) + inherits := vi.GetString("inherits") + if len(inherits) != 0 { + vi = newConfig(inherits) if err := vi.ReadInConfig(); err != nil { return nil, err } + if vi.IsSet("inherits") { + logger.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)", + inherits, + vi.GetString("inherits")) + } + vi.SetConfigName(getConfigName()) vi.MergeInConfig() } - c := &config{Viper: vi} + c := &config{} - if err := vi.Unmarshal(c); err != nil { + if err := c.Init(vi); err != nil { return nil, fmt.Errorf("unable to decode config, %v", err) } - if len(c.Tables) == 0 { - c.Tables = c.DB.Tables - } - - for k, v := range c.Inflections { - flect.AddPlural(k, v) - } - - for i := range c.Tables { - t := c.Tables[i] - t.Name = flect.Pluralize(strings.ToLower(t.Name)) - } - authFailBlock = getAuthFailBlock(c) logLevel, err := zerolog.ParseLevel(c.LogLevel) @@ -195,35 +187,6 @@ func initConf() (*config, error) { } zerolog.SetGlobalLevel(logLevel) - for k, v := range c.DB.Vars { - c.DB.Vars[k] = sanitize(v) - } - - c.RolesQuery = sanitize(c.RolesQuery) - - rolesMap := make(map[string]struct{}) - - for i := range c.Roles { - role := &c.Roles[i] - - if _, ok := rolesMap[role.Name]; ok { - logger.Fatal().Msgf("duplicate role '%s' found", role.Name) - } - role.Name = sanitize(role.Name) - role.Match = sanitize(role.Match) - rolesMap[role.Name] = struct{}{} - } - - if _, ok := rolesMap["user"]; !ok { - c.Roles = append(c.Roles, configRole{Name: "user"}) - } - - if _, ok := rolesMap["anon"]; !ok { - c.Roles = append(c.Roles, configRole{Name: "anon"}) - } - - c.Validate() - return c, nil } diff --git a/serv/config.go b/serv/config.go index f67e74a..ba64b8d 100644 --- a/serv/config.go +++ b/serv/config.go @@ -1,10 +1,13 @@ package serv import ( + "fmt" + "os" "regexp" "strings" "unicode" + "github.com/gobuffalo/flect" "github.com/spf13/viper" ) @@ -100,40 +103,42 @@ type configRemote struct { } `mapstructure:"set_headers"` } +type configRoleTable struct { + Name string + + Query struct { + Limit int + Filters []string + Columns []string + DisableFunctions bool `mapstructure:"disable_functions"` + Block bool + } + + Insert struct { + Filters []string + Columns []string + Presets map[string]string + Block bool + } + + Update struct { + Filters []string + Columns []string + Presets map[string]string + Block bool + } + + Delete struct { + Filters []string + Columns []string + Block bool + } +} + type configRole struct { Name string Match string - Tables []struct { - Name string - - Query struct { - Limit int - Filters []string - Columns []string - DisableFunctions bool `mapstructure:"disable_functions"` - Block bool - } - - Insert struct { - Filters []string - Columns []string - Presets map[string]string - Block bool - } - - Update struct { - Filters []string - Columns []string - Presets map[string]string - Block bool - } - - Delete struct { - Filters []string - Columns []string - Block bool - } - } + Tables []configRoleTable } func newConfig(name string) *viper.Viper { @@ -147,6 +152,10 @@ func newConfig(name string) *viper.Viper { vi.AddConfigPath(confPath) vi.AddConfigPath("./config") + if dir, _ := os.Getwd(); strings.HasSuffix(dir, "/serv") { + vi.AddConfigPath("../config") + } + vi.SetDefault("host_port", "0.0.0.0:8080") vi.SetDefault("web_ui", false) vi.SetDefault("enable_tracing", false) @@ -170,11 +179,69 @@ func newConfig(name string) *viper.Viper { return vi } -func (c *config) Validate() { +func (c *config) Init(vi *viper.Viper) error { + if err := vi.Unmarshal(c); err != nil { + return fmt.Errorf("unable to decode config, %v", err) + } + c.Viper = vi + + if len(c.Tables) == 0 { + c.Tables = c.DB.Tables + } + + for k, v := range c.Inflections { + flect.AddPlural(k, v) + } + + for i := range c.Tables { + t := c.Tables[i] + t.Name = flect.Pluralize(strings.ToLower(t.Name)) + t.Table = flect.Pluralize(strings.ToLower(t.Table)) + } + + for i := range c.Roles { + r := c.Roles[i] + r.Name = strings.ToLower(r.Name) + } + + for k, v := range c.DB.Vars { + c.DB.Vars[k] = sanitize(v) + } + + c.RolesQuery = sanitize(c.RolesQuery) + + rolesMap := make(map[string]struct{}) + + for i := range c.Roles { + role := &c.Roles[i] + + if _, ok := rolesMap[role.Name]; ok { + logger.Fatal().Msgf("duplicate role '%s' found", role.Name) + } + role.Name = sanitize(role.Name) + role.Match = sanitize(role.Match) + rolesMap[role.Name] = struct{}{} + } + + if _, ok := rolesMap["user"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "user"}) + } + + if _, ok := rolesMap["anon"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "anon"}) + } + + c.validate() + + return nil +} + +func (c *config) validate() { rm := make(map[string]struct{}) for i := range c.Roles { - name := strings.ToLower(c.Roles[i].Name) + name := c.Roles[i].Name + if _, ok := rm[name]; ok { logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name) } @@ -184,7 +251,8 @@ func (c *config) Validate() { tm := make(map[string]struct{}) for i := range c.Tables { - name := strings.ToLower(c.Tables[i].Name) + name := c.Tables[i].Name + if _, ok := tm[name]; ok { logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name) } @@ -206,8 +274,7 @@ func (c *config) getAliasMap() map[string][]string { continue } - k := strings.ToLower(t.Table) - m[k] = append(m[k], strings.ToLower(t.Name)) + m[t.Table] = append(m[t.Table], t.Name) } return m } diff --git a/serv/config_test.go b/serv/config_test.go new file mode 100644 index 0000000..f12e194 --- /dev/null +++ b/serv/config_test.go @@ -0,0 +1,13 @@ +package serv + +import ( + "testing" +) + +func TestInitConf(t *testing.T) { + _, err := initConf() + + if err != nil { + t.Fatal(err.Error()) + } +} diff --git a/tmpl/prod.yml b/tmpl/prod.yml index a63d146..d474995 100644 --- a/tmpl/prod.yml +++ b/tmpl/prod.yml @@ -1,6 +1,6 @@ # Inherit config from this other config file # so I only need to overwrite some values -inherit: dev +inherits: dev app_name: "{{app_name}} Production" host_port: 0.0.0.0:8080