edge/pkg/storage/driver/sqlite/sql.go

186 lines
3.4 KiB
Go
Raw Permalink Normal View History

2023-02-09 12:16:36 +01:00
package sqlite
import (
"context"
"database/sql"
"strings"
2023-02-09 12:16:36 +01:00
"sync"
"time"
2023-02-09 12:16:36 +01:00
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
2023-03-03 16:37:19 +01:00
"modernc.org/sqlite"
2023-03-03 16:37:19 +01:00
_ "modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
2023-02-09 12:16:36 +01:00
)
2023-03-03 16:37:19 +01:00
func Open(path string) (*sql.DB, error) {
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, errors.Wrapf(err, "could not open database with path '%s'", path)
}
return db, nil
}
func WithRetry(ctx context.Context, db *sql.DB, max int, fn func(*sql.Tx) error) error {
attempts := 0
ctx = logger.With(ctx, logger.F("max", max))
var err error
for {
ctx = logger.With(ctx)
if attempts >= max {
logger.Debug(ctx, "transaction retrying failed", logger.F("attempts", attempts))
return errors.Wrapf(err, "transaction failed after %d attempts", max)
}
err = WithTx(ctx, db, fn)
if err != nil {
if !strings.Contains(err.Error(), "(5) (SQLITE_BUSY)") {
return errors.WithStack(err)
}
err = errors.WithStack(err)
logger.Warn(ctx, "database is busy", logger.E(err))
wait := time.Duration(8<<(attempts+1)) * time.Millisecond
logger.Debug(
ctx, "database is busy, waiting before retrying transaction",
logger.F("wait", wait.String()),
logger.F("attempts", attempts),
)
timer := time.NewTimer(wait)
select {
case <-timer.C:
attempts++
continue
case <-ctx.Done():
if err := ctx.Err(); err != nil {
return errors.WithStack(err)
}
return nil
}
}
return nil
}
}
func WithTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
2023-02-09 12:16:36 +01:00
var tx *sql.Tx
tx, err := db.BeginTx(ctx, nil)
2023-02-09 12:16:36 +01:00
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := tx.Rollback(); err != nil {
if errors.Is(err, sql.ErrTxDone) {
return
}
panic(errors.WithStack(err))
}
}()
for {
if err = fn(tx); err != nil {
var sqlErr *sqlite.Error
if errors.As(err, &sqlErr) {
if sqlErr.Code() == sqlite3.SQLITE_BUSY {
logger.Warn(ctx, "database busy, retrying transaction")
if err := ctx.Err(); err != nil {
logger.Error(ctx, "could not execute transaction", logger.CapturedE(errors.WithStack(err)))
return errors.WithStack(err)
}
continue
}
}
return errors.WithStack(err)
}
break
2023-02-09 12:16:36 +01:00
}
if err = tx.Commit(); err != nil {
return errors.WithStack(err)
}
return nil
}
type GetDBFunc func(ctx context.Context) (*sql.DB, error)
2023-02-09 12:16:36 +01:00
func NewGetDBFunc(dsn string, initFunc func(ctx context.Context, db *sql.DB) error) GetDBFunc {
2023-02-09 12:16:36 +01:00
var (
db *sql.DB
mutex sync.RWMutex
)
return func(ctx context.Context) (*sql.DB, error) {
mutex.RLock()
if db != nil {
defer mutex.RUnlock()
return db, nil
}
mutex.RUnlock()
mutex.Lock()
defer mutex.Unlock()
logger.Debug(ctx, "opening database", logger.F("dsn", dsn))
newDB, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, errors.WithStack(err)
}
logger.Debug(ctx, "initializing database")
if err = initFunc(ctx, newDB); err != nil {
return nil, errors.WithStack(err)
}
db = newDB
return db, nil
}
}
func NewGetDBFuncFromDB(db *sql.DB, initFunc func(ctx context.Context, db *sql.DB) error) GetDBFunc {
2023-02-09 12:16:36 +01:00
var err error
initOnce := &sync.Once{}
return func(ctx context.Context) (*sql.DB, error) {
initOnce.Do(func() {
logger.Debug(ctx, "initializing database")
err = initFunc(ctx, db)
})
if err != nil {
return nil, errors.WithStack(err)
}
return db, nil
}
}