217 lines
4.4 KiB
Go
217 lines
4.4 KiB
Go
package serv
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/jackc/pgx/v4/stdlib"
|
|
//_ "github.com/jackc/pgx/v4/stdlib"
|
|
)
|
|
|
|
const (
|
|
PEM_SIG = "--BEGIN "
|
|
)
|
|
|
|
func initConf() (*Config, error) {
|
|
cp, err := filepath.Abs(confPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c, err := ReadInConfig(path.Join(cp, GetConfigName()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch c.LogLevel {
|
|
case "debug":
|
|
logLevel = LogLevelDebug
|
|
case "error":
|
|
logLevel = LogLevelError
|
|
case "warn":
|
|
logLevel = LogLevelWarn
|
|
case "info":
|
|
logLevel = LogLevelInfo
|
|
default:
|
|
logLevel = LogLevelNone
|
|
}
|
|
|
|
// Auths: validate and sanitize
|
|
am := make(map[string]struct{})
|
|
|
|
for i := 0; i < len(c.Auths); i++ {
|
|
a := &c.Auths[i]
|
|
a.Name = sanitize(a.Name)
|
|
|
|
if _, ok := am[a.Name]; ok {
|
|
c.Auths = append(c.Auths[:i], c.Auths[i+1:]...)
|
|
log.Printf("WRN duplicate auth found: %s", a.Name)
|
|
}
|
|
am[a.Name] = struct{}{}
|
|
}
|
|
|
|
// Actions: validate and sanitize
|
|
axm := make(map[string]struct{})
|
|
|
|
for i := 0; i < len(c.Actions); i++ {
|
|
a := &c.Actions[i]
|
|
a.Name = sanitize(a.Name)
|
|
a.AuthName = sanitize(a.AuthName)
|
|
|
|
if _, ok := axm[a.Name]; ok {
|
|
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
|
|
log.Printf("WRN duplicate action found: %s", a.Name)
|
|
}
|
|
|
|
if _, ok := am[a.AuthName]; !ok {
|
|
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
|
|
log.Printf("WRN invalid auth_name '%s' for auth: %s", a.AuthName, a.Name)
|
|
}
|
|
axm[a.Name] = struct{}{}
|
|
}
|
|
|
|
var anonFound bool
|
|
|
|
for _, r := range c.Roles {
|
|
if sanitize(r.Name) == "anon" {
|
|
anonFound = true
|
|
}
|
|
}
|
|
|
|
if !anonFound {
|
|
log.Printf("WRN unauthenticated requests will be blocked. no role 'anon' defined")
|
|
c.AuthFailBlock = false
|
|
}
|
|
|
|
if c.AllowListFile == "" {
|
|
c.AllowListFile = c.relPath("./allow.list")
|
|
}
|
|
|
|
if c.Production {
|
|
c.UseAllowList = true
|
|
}
|
|
|
|
// In anon role block all tables that are not defined in the role
|
|
c.DefaultBlock = true
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func initDB(c *Config, useDB bool) (*sql.DB, error) {
|
|
var db *sql.DB
|
|
var err error
|
|
|
|
config, _ := pgx.ParseConfig("")
|
|
config.Host = c.DB.Host
|
|
config.Port = c.DB.Port
|
|
config.User = c.DB.User
|
|
config.Password = c.DB.Password
|
|
config.RuntimeParams = map[string]string{
|
|
"application_name": c.AppName,
|
|
"search_path": c.DB.Schema,
|
|
}
|
|
|
|
if useDB {
|
|
config.Database = c.DB.DBName
|
|
}
|
|
|
|
if c.DB.EnableTLS {
|
|
if len(c.DB.ServerName) == 0 {
|
|
return nil, errors.New("server_name is required")
|
|
}
|
|
if len(c.DB.ServerCert) == 0 {
|
|
return nil, errors.New("server_cert is required")
|
|
}
|
|
if len(c.DB.ClientCert) == 0 {
|
|
return nil, errors.New("client_cert is required")
|
|
}
|
|
if len(c.DB.ClientKey) == 0 {
|
|
return nil, errors.New("client_key is required")
|
|
}
|
|
|
|
rootCertPool := x509.NewCertPool()
|
|
var pem []byte
|
|
var err error
|
|
|
|
if strings.Contains(c.DB.ServerCert, PEM_SIG) {
|
|
pem = []byte(c.DB.ServerCert)
|
|
} else {
|
|
pem, err = ioutil.ReadFile(c.relPath(c.DB.ServerCert))
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db tls: %w", err)
|
|
}
|
|
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
return nil, errors.New("db tls: failed to append pem")
|
|
}
|
|
|
|
clientCert := make([]tls.Certificate, 0, 1)
|
|
var certs tls.Certificate
|
|
|
|
if strings.Contains(c.DB.ClientCert, PEM_SIG) {
|
|
certs, err = tls.X509KeyPair([]byte(c.DB.ClientCert), []byte(c.DB.ClientKey))
|
|
} else {
|
|
certs, err = tls.LoadX509KeyPair(c.relPath(c.DB.ClientCert), c.relPath(c.DB.ClientKey))
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db tls: %w", err)
|
|
}
|
|
|
|
clientCert = append(clientCert, certs)
|
|
config.TLSConfig = &tls.Config{
|
|
RootCAs: rootCertPool,
|
|
Certificates: clientCert,
|
|
ServerName: c.DB.ServerName,
|
|
}
|
|
}
|
|
|
|
// switch c.LogLevel {
|
|
// case "debug":
|
|
// config.LogLevel = pgx.LogLevelDebug
|
|
// case "info":
|
|
// config.LogLevel = pgx.LogLevelInfo
|
|
// case "warn":
|
|
// config.LogLevel = pgx.LogLevelWarn
|
|
// case "error":
|
|
// config.LogLevel = pgx.LogLevelError
|
|
// default:
|
|
// config.LogLevel = pgx.LogLevelNone
|
|
// }
|
|
|
|
//config.Logger = NewSQLLogger(logger)
|
|
|
|
// if c.DB.MaxRetries != 0 {
|
|
// opt.MaxRetries = c.DB.MaxRetries
|
|
// }
|
|
|
|
// if c.DB.PoolSize != 0 {
|
|
// config.MaxConns = conf.DB.PoolSize
|
|
// }
|
|
|
|
for i := 1; i < 10; i++ {
|
|
db = stdlib.OpenDB(*config)
|
|
if db == nil {
|
|
break
|
|
}
|
|
time.Sleep(time.Duration(i*100) * time.Millisecond)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return db, nil
|
|
}
|