2020-04-10 08:27:43 +02:00
|
|
|
package serv
|
|
|
|
|
|
|
|
import (
|
2020-04-16 06:26:32 +02:00
|
|
|
"crypto/tls"
|
|
|
|
"crypto/x509"
|
2020-04-10 08:27:43 +02:00
|
|
|
"database/sql"
|
2020-04-16 06:26:32 +02:00
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
2020-04-11 08:45:06 +02:00
|
|
|
"path"
|
2020-04-19 18:54:37 +02:00
|
|
|
"path/filepath"
|
2020-04-16 06:26:32 +02:00
|
|
|
"strings"
|
2020-04-10 08:27:43 +02:00
|
|
|
"time"
|
|
|
|
|
2020-05-22 22:49:54 +02:00
|
|
|
"contrib.go.opencensus.io/integrations/ocsql"
|
2020-04-12 16:09:37 +02:00
|
|
|
"github.com/jackc/pgx/v4"
|
|
|
|
"github.com/jackc/pgx/v4/stdlib"
|
2020-04-10 08:27:43 +02:00
|
|
|
)
|
|
|
|
|
2020-04-16 06:26:32 +02:00
|
|
|
const (
|
|
|
|
PEM_SIG = "--BEGIN "
|
|
|
|
)
|
|
|
|
|
2020-04-11 08:45:06 +02:00
|
|
|
func initConf() (*Config, error) {
|
2020-04-19 18:54:37 +02:00
|
|
|
cp, err := filepath.Abs(confPath)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
c, err := ReadInConfig(path.Join(cp, GetConfigName()))
|
2020-04-11 08:45:06 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
switch c.LogLevel {
|
|
|
|
case "debug":
|
|
|
|
logLevel = LogLevelDebug
|
|
|
|
case "error":
|
|
|
|
logLevel = LogLevelError
|
|
|
|
case "warn":
|
|
|
|
logLevel = LogLevelWarn
|
|
|
|
case "info":
|
|
|
|
logLevel = LogLevelInfo
|
|
|
|
default:
|
|
|
|
logLevel = LogLevelNone
|
|
|
|
}
|
|
|
|
|
2020-05-20 06:03:05 +02:00
|
|
|
// copy over db_schema from the core config
|
|
|
|
if c.DB.Schema == "" {
|
|
|
|
c.DB.Schema = c.DBSchema
|
|
|
|
}
|
|
|
|
|
|
|
|
// set default database schema
|
|
|
|
if c.DB.Schema == "" {
|
|
|
|
c.DB.Schema = "public"
|
|
|
|
}
|
|
|
|
|
2020-04-11 08:45:06 +02:00
|
|
|
// Auths: validate and sanitize
|
|
|
|
am := make(map[string]struct{})
|
|
|
|
|
|
|
|
for i := 0; i < len(c.Auths); i++ {
|
|
|
|
a := &c.Auths[i]
|
|
|
|
a.Name = sanitize(a.Name)
|
|
|
|
|
|
|
|
if _, ok := am[a.Name]; ok {
|
|
|
|
c.Auths = append(c.Auths[:i], c.Auths[i+1:]...)
|
|
|
|
log.Printf("WRN duplicate auth found: %s", a.Name)
|
|
|
|
}
|
|
|
|
am[a.Name] = struct{}{}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Actions: validate and sanitize
|
|
|
|
axm := make(map[string]struct{})
|
|
|
|
|
|
|
|
for i := 0; i < len(c.Actions); i++ {
|
|
|
|
a := &c.Actions[i]
|
|
|
|
a.Name = sanitize(a.Name)
|
|
|
|
a.AuthName = sanitize(a.AuthName)
|
|
|
|
|
|
|
|
if _, ok := axm[a.Name]; ok {
|
|
|
|
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
|
|
|
|
log.Printf("WRN duplicate action found: %s", a.Name)
|
|
|
|
}
|
|
|
|
|
|
|
|
if _, ok := am[a.AuthName]; !ok {
|
|
|
|
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
|
|
|
|
log.Printf("WRN invalid auth_name '%s' for auth: %s", a.AuthName, a.Name)
|
|
|
|
}
|
|
|
|
axm[a.Name] = struct{}{}
|
|
|
|
}
|
|
|
|
|
|
|
|
var anonFound bool
|
|
|
|
|
|
|
|
for _, r := range c.Roles {
|
|
|
|
if sanitize(r.Name) == "anon" {
|
|
|
|
anonFound = true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if !anonFound {
|
|
|
|
log.Printf("WRN unauthenticated requests will be blocked. no role 'anon' defined")
|
|
|
|
c.AuthFailBlock = false
|
|
|
|
}
|
|
|
|
|
2020-05-17 09:11:56 +02:00
|
|
|
if c.AllowListFile == "" {
|
2020-04-19 18:54:37 +02:00
|
|
|
c.AllowListFile = c.relPath("./allow.list")
|
|
|
|
}
|
|
|
|
|
|
|
|
if c.Production {
|
|
|
|
c.UseAllowList = true
|
|
|
|
}
|
|
|
|
|
2020-04-11 08:45:06 +02:00
|
|
|
return c, nil
|
2020-04-10 08:27:43 +02:00
|
|
|
}
|
|
|
|
|
2020-05-23 22:53:39 +02:00
|
|
|
func initDB(c *Config, useDB, useTelemetry bool) (*sql.DB, error) {
|
2020-04-10 08:27:43 +02:00
|
|
|
var db *sql.DB
|
|
|
|
var err error
|
|
|
|
|
2020-04-12 16:09:37 +02:00
|
|
|
config, _ := pgx.ParseConfig("")
|
|
|
|
config.Host = c.DB.Host
|
|
|
|
config.Port = c.DB.Port
|
|
|
|
config.User = c.DB.User
|
|
|
|
config.Password = c.DB.Password
|
|
|
|
config.RuntimeParams = map[string]string{
|
|
|
|
"application_name": c.AppName,
|
|
|
|
"search_path": c.DB.Schema,
|
|
|
|
}
|
|
|
|
|
2020-04-13 06:43:18 +02:00
|
|
|
if useDB {
|
|
|
|
config.Database = c.DB.DBName
|
|
|
|
}
|
|
|
|
|
2020-04-16 06:26:32 +02:00
|
|
|
if c.DB.EnableTLS {
|
|
|
|
if len(c.DB.ServerName) == 0 {
|
|
|
|
return nil, errors.New("server_name is required")
|
|
|
|
}
|
|
|
|
if len(c.DB.ServerCert) == 0 {
|
|
|
|
return nil, errors.New("server_cert is required")
|
|
|
|
}
|
|
|
|
if len(c.DB.ClientCert) == 0 {
|
|
|
|
return nil, errors.New("client_cert is required")
|
|
|
|
}
|
|
|
|
if len(c.DB.ClientKey) == 0 {
|
|
|
|
return nil, errors.New("client_key is required")
|
|
|
|
}
|
|
|
|
|
|
|
|
rootCertPool := x509.NewCertPool()
|
|
|
|
var pem []byte
|
|
|
|
var err error
|
|
|
|
|
|
|
|
if strings.Contains(c.DB.ServerCert, PEM_SIG) {
|
|
|
|
pem = []byte(c.DB.ServerCert)
|
|
|
|
} else {
|
2020-04-17 16:56:26 +02:00
|
|
|
pem, err = ioutil.ReadFile(c.relPath(c.DB.ServerCert))
|
2020-04-16 06:26:32 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("db tls: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
|
|
return nil, errors.New("db tls: failed to append pem")
|
|
|
|
}
|
|
|
|
|
|
|
|
clientCert := make([]tls.Certificate, 0, 1)
|
|
|
|
var certs tls.Certificate
|
|
|
|
|
|
|
|
if strings.Contains(c.DB.ClientCert, PEM_SIG) {
|
|
|
|
certs, err = tls.X509KeyPair([]byte(c.DB.ClientCert), []byte(c.DB.ClientKey))
|
|
|
|
} else {
|
2020-04-17 16:56:26 +02:00
|
|
|
certs, err = tls.LoadX509KeyPair(c.relPath(c.DB.ClientCert), c.relPath(c.DB.ClientKey))
|
2020-04-16 06:26:32 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("db tls: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
clientCert = append(clientCert, certs)
|
|
|
|
config.TLSConfig = &tls.Config{
|
|
|
|
RootCAs: rootCertPool,
|
|
|
|
Certificates: clientCert,
|
|
|
|
ServerName: c.DB.ServerName,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-04-10 08:27:43 +02:00
|
|
|
// switch c.LogLevel {
|
|
|
|
// case "debug":
|
2020-04-12 16:09:37 +02:00
|
|
|
// config.LogLevel = pgx.LogLevelDebug
|
2020-04-10 08:27:43 +02:00
|
|
|
// case "info":
|
2020-04-12 16:09:37 +02:00
|
|
|
// config.LogLevel = pgx.LogLevelInfo
|
2020-04-10 08:27:43 +02:00
|
|
|
// case "warn":
|
2020-04-12 16:09:37 +02:00
|
|
|
// config.LogLevel = pgx.LogLevelWarn
|
2020-04-10 08:27:43 +02:00
|
|
|
// case "error":
|
2020-04-12 16:09:37 +02:00
|
|
|
// config.LogLevel = pgx.LogLevelError
|
2020-04-10 08:27:43 +02:00
|
|
|
// default:
|
2020-04-12 16:09:37 +02:00
|
|
|
// config.LogLevel = pgx.LogLevelNone
|
2020-04-10 08:27:43 +02:00
|
|
|
// }
|
|
|
|
|
2020-04-12 16:09:37 +02:00
|
|
|
//config.Logger = NewSQLLogger(logger)
|
2020-04-10 08:27:43 +02:00
|
|
|
|
2020-04-12 16:09:37 +02:00
|
|
|
// if c.DB.MaxRetries != 0 {
|
|
|
|
// opt.MaxRetries = c.DB.MaxRetries
|
|
|
|
// }
|
2020-04-10 08:27:43 +02:00
|
|
|
|
|
|
|
// if c.DB.PoolSize != 0 {
|
|
|
|
// config.MaxConns = conf.DB.PoolSize
|
|
|
|
// }
|
|
|
|
|
2020-05-22 22:49:54 +02:00
|
|
|
connString := stdlib.RegisterConnConfig(config)
|
|
|
|
driverName := "pgx"
|
|
|
|
// if db = stdlib.OpenDB(*config); db == nil {
|
|
|
|
// return errors.New("failed to open db")
|
|
|
|
// }
|
|
|
|
|
2020-05-23 22:37:15 +02:00
|
|
|
if useTelemetry && conf.telemetryEnabled() {
|
2020-05-24 08:24:24 +02:00
|
|
|
opts := ocsql.TraceOptions{
|
2020-05-24 16:44:00 +02:00
|
|
|
AllowRoot: true,
|
2020-05-24 08:24:24 +02:00
|
|
|
Ping: true,
|
|
|
|
RowsNext: true,
|
|
|
|
RowsClose: true,
|
|
|
|
RowsAffected: true,
|
|
|
|
LastInsertID: true,
|
|
|
|
Query: conf.Telemetry.Tracing.IncludeQuery,
|
|
|
|
QueryParams: conf.Telemetry.Tracing.IncludeParams,
|
|
|
|
}
|
|
|
|
opt := ocsql.WithOptions(opts)
|
|
|
|
name := ocsql.WithInstanceName(conf.AppName)
|
|
|
|
|
|
|
|
driverName, err = ocsql.Register(driverName, opt, name)
|
2020-05-22 22:49:54 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("unable to register ocsql driver: %v", err)
|
|
|
|
}
|
|
|
|
ocsql.RegisterAllViews()
|
2020-05-23 22:37:15 +02:00
|
|
|
|
|
|
|
var interval time.Duration
|
|
|
|
|
|
|
|
if conf.Telemetry.Interval != nil {
|
|
|
|
interval = *conf.Telemetry.Interval
|
|
|
|
} else {
|
|
|
|
interval = 5 * time.Second
|
|
|
|
}
|
|
|
|
|
|
|
|
defer ocsql.RecordStats(db, interval)()
|
2020-05-22 22:49:54 +02:00
|
|
|
|
|
|
|
log.Println("INF OpenCensus telemetry enabled")
|
|
|
|
}
|
|
|
|
|
2020-04-12 16:09:37 +02:00
|
|
|
for i := 1; i < 10; i++ {
|
2020-05-22 22:49:54 +02:00
|
|
|
db, err = sql.Open(driverName, connString)
|
|
|
|
if err != nil {
|
|
|
|
continue
|
2020-04-12 16:09:37 +02:00
|
|
|
}
|
2020-05-22 22:49:54 +02:00
|
|
|
|
2020-04-12 16:09:37 +02:00
|
|
|
time.Sleep(time.Duration(i*100) * time.Millisecond)
|
|
|
|
}
|
2020-04-10 08:27:43 +02:00
|
|
|
|
2020-04-12 16:09:37 +02:00
|
|
|
if err != nil {
|
2020-05-22 22:49:54 +02:00
|
|
|
return nil, fmt.Errorf("unable to open db connection: %v", err)
|
|
|
|
}
|
|
|
|
|
2020-04-12 16:09:37 +02:00
|
|
|
return db, nil
|
2020-04-10 08:27:43 +02:00
|
|
|
}
|