super-graph/internal/serv/init.go
2020-04-23 21:25:09 -04:00

217 lines
4.4 KiB
Go

package serv
import (
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"path"
"path/filepath"
"strings"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
//_ "github.com/jackc/pgx/v4/stdlib"
)
const (
PEM_SIG = "--BEGIN "
)
func initConf() (*Config, error) {
cp, err := filepath.Abs(confPath)
if err != nil {
return nil, err
}
c, err := ReadInConfig(path.Join(cp, GetConfigName()))
if err != nil {
return nil, err
}
switch c.LogLevel {
case "debug":
logLevel = LogLevelDebug
case "error":
logLevel = LogLevelError
case "warn":
logLevel = LogLevelWarn
case "info":
logLevel = LogLevelInfo
default:
logLevel = LogLevelNone
}
// Auths: validate and sanitize
am := make(map[string]struct{})
for i := 0; i < len(c.Auths); i++ {
a := &c.Auths[i]
a.Name = sanitize(a.Name)
if _, ok := am[a.Name]; ok {
c.Auths = append(c.Auths[:i], c.Auths[i+1:]...)
log.Printf("WRN duplicate auth found: %s", a.Name)
}
am[a.Name] = struct{}{}
}
// Actions: validate and sanitize
axm := make(map[string]struct{})
for i := 0; i < len(c.Actions); i++ {
a := &c.Actions[i]
a.Name = sanitize(a.Name)
a.AuthName = sanitize(a.AuthName)
if _, ok := axm[a.Name]; ok {
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
log.Printf("WRN duplicate action found: %s", a.Name)
}
if _, ok := am[a.AuthName]; !ok {
c.Actions = append(c.Actions[:i], c.Actions[i+1:]...)
log.Printf("WRN invalid auth_name '%s' for auth: %s", a.AuthName, a.Name)
}
axm[a.Name] = struct{}{}
}
var anonFound bool
for _, r := range c.Roles {
if sanitize(r.Name) == "anon" {
anonFound = true
}
}
if !anonFound {
log.Printf("WRN unauthenticated requests will be blocked. no role 'anon' defined")
c.AuthFailBlock = false
}
if len(c.AllowListFile) == 0 {
c.AllowListFile = c.relPath("./allow.list")
}
if c.Production {
c.UseAllowList = true
}
// In anon role block all tables that are not defined in the role
c.DefaultBlock = true
return c, nil
}
func initDB(c *Config, useDB bool) (*sql.DB, error) {
var db *sql.DB
var err error
config, _ := pgx.ParseConfig("")
config.Host = c.DB.Host
config.Port = c.DB.Port
config.User = c.DB.User
config.Password = c.DB.Password
config.RuntimeParams = map[string]string{
"application_name": c.AppName,
"search_path": c.DB.Schema,
}
if useDB {
config.Database = c.DB.DBName
}
if c.DB.EnableTLS {
if len(c.DB.ServerName) == 0 {
return nil, errors.New("server_name is required")
}
if len(c.DB.ServerCert) == 0 {
return nil, errors.New("server_cert is required")
}
if len(c.DB.ClientCert) == 0 {
return nil, errors.New("client_cert is required")
}
if len(c.DB.ClientKey) == 0 {
return nil, errors.New("client_key is required")
}
rootCertPool := x509.NewCertPool()
var pem []byte
var err error
if strings.Contains(c.DB.ServerCert, PEM_SIG) {
pem = []byte(c.DB.ServerCert)
} else {
pem, err = ioutil.ReadFile(c.relPath(c.DB.ServerCert))
}
if err != nil {
return nil, fmt.Errorf("db tls: %w", err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return nil, errors.New("db tls: failed to append pem")
}
clientCert := make([]tls.Certificate, 0, 1)
var certs tls.Certificate
if strings.Contains(c.DB.ClientCert, PEM_SIG) {
certs, err = tls.X509KeyPair([]byte(c.DB.ClientCert), []byte(c.DB.ClientKey))
} else {
certs, err = tls.LoadX509KeyPair(c.relPath(c.DB.ClientCert), c.relPath(c.DB.ClientKey))
}
if err != nil {
return nil, fmt.Errorf("db tls: %w", err)
}
clientCert = append(clientCert, certs)
config.TLSConfig = &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
ServerName: c.DB.ServerName,
}
}
// switch c.LogLevel {
// case "debug":
// config.LogLevel = pgx.LogLevelDebug
// case "info":
// config.LogLevel = pgx.LogLevelInfo
// case "warn":
// config.LogLevel = pgx.LogLevelWarn
// case "error":
// config.LogLevel = pgx.LogLevelError
// default:
// config.LogLevel = pgx.LogLevelNone
// }
//config.Logger = NewSQLLogger(logger)
// if c.DB.MaxRetries != 0 {
// opt.MaxRetries = c.DB.MaxRetries
// }
// if c.DB.PoolSize != 0 {
// config.MaxConns = conf.DB.PoolSize
// }
for i := 1; i < 10; i++ {
db = stdlib.OpenDB(*config)
if db == nil {
break
}
time.Sleep(time.Duration(i*100) * time.Millisecond)
}
if err != nil {
return nil, err
}
return db, nil
}