Compare commits

...

7 Commits

25 changed files with 617 additions and 185 deletions

View File

@ -125,23 +125,34 @@ func copyDir(writer *zip.Writer, baseDir string, zipBasePath string) error {
} }
func copyFile(writer *zip.Writer, srcPath string, zipPath string) error { func copyFile(writer *zip.Writer, srcPath string, zipPath string) error {
r, err := os.Open(srcPath) srcFile, err := os.Open(srcPath)
if err != nil {
return errors.WithStack(err)
}
srcStat, err := os.Stat(srcPath)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
defer func() { defer func() {
if err := r.Close(); err != nil { if err := srcFile.Close(); err != nil {
panic(errors.WithStack(err)) panic(errors.WithStack(err))
} }
}() }()
f, err := writer.Create(zipPath) fileHeader := &zip.FileHeader{
Name: zipPath,
Modified: srcStat.ModTime().UTC(),
Method: zip.Deflate,
}
file, err := writer.CreateHeader(fileHeader)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
if _, err = io.Copy(f, r); err != nil { if _, err = io.Copy(file, srcFile); err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }

1
go.mod
View File

@ -6,6 +6,7 @@ require modernc.org/sqlite v1.20.4
require ( require (
github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e // indirect github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/hashicorp/go.net v0.0.0-20151006203346-104dcad90073 // indirect github.com/hashicorp/go.net v0.0.0-20151006203346-104dcad90073 // indirect
github.com/hashicorp/mdns v0.0.0-20151206042412-9d85cf22f9f8 // indirect github.com/hashicorp/mdns v0.0.0-20151206042412-9d85cf22f9f8 // indirect
github.com/miekg/dns v0.0.0-20161006100029-fc4e1e2843d8 // indirect github.com/miekg/dns v0.0.0-20161006100029-fc4e1e2843d8 // indirect

2
go.sum
View File

@ -110,6 +110,8 @@ github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyL
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e h1:eeyMpoxANuWNQ9O2auv4wXxJsrXzLUhdHaOmNWEGkRY= github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e h1:eeyMpoxANuWNQ9O2auv4wXxJsrXzLUhdHaOmNWEGkRY=
github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=

View File

@ -26,7 +26,7 @@ func defaultHandlerOptions() *HandlerOptions {
Bus: memory.NewBus(), Bus: memory.NewBus(),
SockJS: sockjsOptions, SockJS: sockjsOptions,
ServerModuleFactories: make([]app.ServerModuleFactory, 0), ServerModuleFactories: make([]app.ServerModuleFactory, 0),
UploadMaxFileSize: 1024 * 10, // 10Mb UploadMaxFileSize: 10 << (10 * 2), // 10Mb
} }
} }

View File

@ -1,28 +0,0 @@
package module
import (
"context"
"fmt"
"github.com/dop251/goja"
)
func assertType[T any](v goja.Value, rt *goja.Runtime) T {
if c, ok := v.Export().(T); ok {
return c
}
panic(rt.NewTypeError(fmt.Sprintf("expected value to be a '%T', got '%T'", new(T), v.Export())))
}
func assertContext(v goja.Value, r *goja.Runtime) context.Context {
return assertType[context.Context](v, r)
}
func assertObject(v goja.Value, r *goja.Runtime) map[string]any {
return assertType[map[string]any](v, r)
}
func assertString(v goja.Value, r *goja.Runtime) string {
return assertType[string](v, r)
}

84
pkg/module/auth/module.go Normal file
View File

@ -0,0 +1,84 @@
package auth
import (
"net/http"
"strings"
"forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/module"
"forge.cadoles.com/arcad/edge/pkg/module/util"
"github.com/dop251/goja"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
)
const (
AnonymousSubject = "anonymous"
)
type Module struct {
server *app.Server
keyFunc jwt.Keyfunc
}
func (m *Module) Name() string {
return "auth"
}
func (m *Module) Export(export *goja.Object) {
if err := export.Set("getSubject", m.getSubject); err != nil {
panic(errors.Wrap(err, "could not set 'getSubject' function"))
}
if err := export.Set("ANONYMOUS", AnonymousSubject); err != nil {
panic(errors.Wrap(err, "could not set 'ANONYMOUS_USER' property"))
}
}
func (m *Module) getSubject(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ctx := util.AssertContext(call.Argument(0), rt)
req, ok := ctx.Value(module.ContextKeyOriginRequest).(*http.Request)
if !ok {
panic(errors.New("could not find http request in context"))
}
rawToken := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
if rawToken == "" {
rawToken = req.URL.Query().Get("token")
}
if rawToken == "" {
return rt.ToValue(AnonymousSubject)
}
token, err := jwt.Parse(rawToken, m.keyFunc)
if err != nil {
panic(errors.WithStack(err))
}
if !token.Valid {
panic(errors.Errorf("invalid jwt token: '%v'", token.Raw))
}
mapClaims, ok := token.Claims.(jwt.MapClaims)
if !ok {
panic(errors.Errorf("unexpected claims type '%T'", token.Claims))
}
subject, exists := mapClaims["sub"]
if !exists {
return rt.ToValue(AnonymousSubject)
}
return rt.ToValue(subject)
}
func ModuleFactory(keyFunc jwt.Keyfunc) app.ServerModuleFactory {
return func(server *app.Server) app.ServerModule {
return &Module{
server: server,
keyFunc: keyFunc,
}
}
}

View File

@ -0,0 +1,121 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"testing"
"time"
"cdr.dev/slog"
"forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/module"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func TestAuthModule(t *testing.T) {
t.Parallel()
logger.SetLevel(slog.LevelDebug)
keyFunc, secret := getKeyFunc()
server := app.NewServer(
module.ConsoleModuleFactory(),
ModuleFactory(keyFunc),
)
data, err := ioutil.ReadFile("testdata/auth.js")
if err != nil {
t.Fatal(err)
}
if err := server.Load("testdata/auth.js", string(data)); err != nil {
t.Fatal(err)
}
if err := server.Start(); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
defer server.Stop()
req, err := http.NewRequest("GET", "/foo", nil)
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": "jdoe",
"nbf": time.Now().UTC().Unix(),
})
rawToken, err := token.SignedString(secret)
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
req.Header.Add("Authorization", "Bearer "+rawToken)
ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req)
if _, err := server.ExecFuncByName("testAuth", ctx); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
}
func TestAuthAnonymousModule(t *testing.T) {
t.Parallel()
logger.SetLevel(slog.LevelDebug)
keyFunc, _ := getKeyFunc()
server := app.NewServer(
module.ConsoleModuleFactory(),
ModuleFactory(keyFunc),
)
data, err := ioutil.ReadFile("testdata/auth_anonymous.js")
if err != nil {
t.Fatal(err)
}
if err := server.Load("testdata/auth_anonymous.js", string(data)); err != nil {
t.Fatal(err)
}
if err := server.Start(); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
defer server.Stop()
req, err := http.NewRequest("GET", "/foo", nil)
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req)
if _, err := server.ExecFuncByName("testAuth", ctx); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
}
func getKeyFunc() (jwt.Keyfunc, []byte) {
secret := []byte("not_so_secret")
keyFunc := func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
}
return keyFunc, secret
}

9
pkg/module/auth/testdata/auth.js vendored Normal file
View File

@ -0,0 +1,9 @@
function testAuth(ctx) {
var subject = auth.getSubject(ctx);
if (subject !== "jdoe") {
throw new Error("subject: expected 'jdoe', got '"+subject+"'");
}
}

View File

@ -0,0 +1,9 @@
function testAuth(ctx) {
var subject = auth.getSubject(ctx);
if (subject !== auth.ANONYMOUS) {
throw new Error("subject: expected '"+auth.ANONYMOUS+"', got '"+subject+"'");
}
}

View File

@ -12,26 +12,26 @@ import (
) )
type Device struct { type Device struct {
UUID string `goja:"uuid"` UUID string `goja:"uuid" json:"uuid"`
Host net.IP `goja:"host"` Host net.IP `goja:"host" json:"host"`
Port int `goja:"port"` Port int `goja:"port" json:"port"`
Name string `goja:"name"` Name string `goja:"name" json:"name"`
} }
type DeviceStatus struct { type DeviceStatus struct {
CurrentApp DeviceStatusCurrentApp `goja:"currentApp"` CurrentApp DeviceStatusCurrentApp `goja:"currentApp" json:"currentApp"`
Volume DeviceStatusVolume `goja:"volume"` Volume DeviceStatusVolume `goja:"volume" json:"volume"`
} }
type DeviceStatusCurrentApp struct { type DeviceStatusCurrentApp struct {
ID string `goja:"id"` ID string `goja:"id" json:"id"`
DisplayName string `goja:"displayName"` DisplayName string `goja:"displayName" json:"displayName"`
StatusText string `goja:"statusText"` StatusText string `goja:"statusText" json:"statusText"`
} }
type DeviceStatusVolume struct { type DeviceStatusVolume struct {
Level float64 `goja:"level"` Level float64 `goja:"level" json:"level"`
Muted bool `goja:"muted"` Muted bool `goja:"muted" json:"muted"`
} }
const ( const (

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/module/util"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -26,8 +27,8 @@ func (m *ContextModule) new(call goja.FunctionCall, rt *goja.Runtime) goja.Value
} }
func (m *ContextModule) with(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *ContextModule) with(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ctx := assertContext(call.Argument(0), rt) ctx := util.AssertContext(call.Argument(0), rt)
rawValues := assertObject(call.Argument(1), rt) rawValues := util.AssertObject(call.Argument(1), rt)
values := make(map[ContextKey]any) values := make(map[ContextKey]any)
for k, v := range rawValues { for k, v := range rawValues {
@ -40,8 +41,8 @@ func (m *ContextModule) with(call goja.FunctionCall, rt *goja.Runtime) goja.Valu
} }
func (m *ContextModule) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *ContextModule) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ctx := assertContext(call.Argument(0), rt) ctx := util.AssertContext(call.Argument(0), rt)
rawKey := assertString(call.Argument(1), rt) rawKey := util.AssertString(call.Argument(1), rt)
value := ctx.Value(ContextKey(rawKey)) value := ctx.Value(ContextKey(rawKey))

View File

@ -5,6 +5,7 @@ import (
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/bus"
"forge.cadoles.com/arcad/edge/pkg/module/util"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -58,7 +59,7 @@ func (m *NetModule) send(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ContextKeySessionID: sessionID, ContextKeySessionID: sessionID,
}) })
} else { } else {
ctx = assertContext(firstArg, rt) ctx = util.AssertContext(firstArg, rt)
} }
data := call.Argument(1).Export() data := call.Argument(1).Export()

View File

@ -7,6 +7,7 @@ import (
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/bus"
"forge.cadoles.com/arcad/edge/pkg/module/util"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
@ -51,7 +52,7 @@ func (m *RPCModule) Export(export *goja.Object) {
} }
func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
fnName := assertString(call.Argument(0), rt) fnName := util.AssertString(call.Argument(0), rt)
var ( var (
callable goja.Callable callable goja.Callable
@ -78,7 +79,7 @@ func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Valu
} }
func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
fnName := assertString(call.Argument(0), rt) fnName := util.AssertString(call.Argument(0), rt)
m.callbacks.Delete(fnName) m.callbacks.Delete(fnName)

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/module/util"
"forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage"
"forge.cadoles.com/arcad/edge/pkg/storage/filter" "forge.cadoles.com/arcad/edge/pkg/storage/filter"
"github.com/dop251/goja" "github.com/dop251/goja"
@ -47,7 +48,7 @@ func (m *StoreModule) Export(export *goja.Object) {
} }
func (m *StoreModule) upsert(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *StoreModule) upsert(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ctx := assertContext(call.Argument(0), rt) ctx := util.AssertContext(call.Argument(0), rt)
collection := m.assertCollection(call.Argument(1), rt) collection := m.assertCollection(call.Argument(1), rt)
document := m.assertDocument(call.Argument(2), rt) document := m.assertDocument(call.Argument(2), rt)
@ -60,7 +61,7 @@ func (m *StoreModule) upsert(call goja.FunctionCall, rt *goja.Runtime) goja.Valu
} }
func (m *StoreModule) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *StoreModule) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ctx := assertContext(call.Argument(0), rt) ctx := util.AssertContext(call.Argument(0), rt)
collection := m.assertCollection(call.Argument(1), rt) collection := m.assertCollection(call.Argument(1), rt)
documentID := m.assertDocumentID(call.Argument(2), rt) documentID := m.assertDocumentID(call.Argument(2), rt)
@ -84,29 +85,31 @@ type queryOptions struct {
} }
func (m *StoreModule) query(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *StoreModule) query(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ctx := assertContext(call.Argument(0), rt) ctx := util.AssertContext(call.Argument(0), rt)
collection := m.assertCollection(call.Argument(1), rt) collection := m.assertCollection(call.Argument(1), rt)
filter := m.assertFilter(call.Argument(2), rt) filter := m.assertFilter(call.Argument(2), rt)
queryOptions := m.assertQueryOptions(call.Argument(3), rt) queryOptions := m.assertQueryOptions(call.Argument(3), rt)
queryOptionsFuncs := make([]storage.QueryOptionFunc, 0) queryOptionsFuncs := make([]storage.QueryOptionFunc, 0)
if queryOptions.Limit != nil { if queryOptions != nil {
queryOptionsFuncs = append(queryOptionsFuncs, storage.WithLimit(*queryOptions.Limit)) if queryOptions.Limit != nil {
} queryOptionsFuncs = append(queryOptionsFuncs, storage.WithLimit(*queryOptions.Limit))
}
if queryOptions.OrderBy != nil { if queryOptions.OrderBy != nil {
queryOptionsFuncs = append(queryOptionsFuncs, storage.WithOrderBy(*queryOptions.OrderBy)) queryOptionsFuncs = append(queryOptionsFuncs, storage.WithOrderBy(*queryOptions.OrderBy))
} }
if queryOptions.Offset != nil { if queryOptions.Offset != nil {
queryOptionsFuncs = append(queryOptionsFuncs, storage.WithOffset(*queryOptions.Limit)) queryOptionsFuncs = append(queryOptionsFuncs, storage.WithOffset(*queryOptions.Limit))
} }
if queryOptions.OrderDirection != nil { if queryOptions.OrderDirection != nil {
queryOptionsFuncs = append(queryOptionsFuncs, storage.WithOrderDirection( queryOptionsFuncs = append(queryOptionsFuncs, storage.WithOrderDirection(
storage.OrderDirection(*queryOptions.OrderDirection), storage.OrderDirection(*queryOptions.OrderDirection),
)) ))
}
} }
documents, err := m.store.Query(ctx, collection, filter, queryOptionsFuncs...) documents, err := m.store.Query(ctx, collection, filter, queryOptionsFuncs...)
@ -123,7 +126,7 @@ func (m *StoreModule) query(call goja.FunctionCall, rt *goja.Runtime) goja.Value
} }
func (m *StoreModule) delete(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *StoreModule) delete(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
ctx := assertContext(call.Argument(0), rt) ctx := util.AssertContext(call.Argument(0), rt)
collection := m.assertCollection(call.Argument(1), rt) collection := m.assertCollection(call.Argument(1), rt)
documentID := m.assertDocumentID(call.Argument(2), rt) documentID := m.assertDocumentID(call.Argument(2), rt)
@ -144,6 +147,10 @@ func (m *StoreModule) assertCollection(value goja.Value, rt *goja.Runtime) strin
} }
func (m *StoreModule) assertFilter(value goja.Value, rt *goja.Runtime) *filter.Filter { func (m *StoreModule) assertFilter(value goja.Value, rt *goja.Runtime) *filter.Filter {
if value.Export() == nil {
return nil
}
rawFilter, ok := value.Export().(map[string]interface{}) rawFilter, ok := value.Export().(map[string]interface{})
if !ok { if !ok {
panic(rt.NewTypeError(fmt.Sprintf("filter must be an object, got '%T'", value.Export()))) panic(rt.NewTypeError(fmt.Sprintf("filter must be an object, got '%T'", value.Export())))
@ -172,6 +179,10 @@ func (m *StoreModule) assertDocumentID(value goja.Value, rt *goja.Runtime) stora
} }
func (m *StoreModule) assertQueryOptions(value goja.Value, rt *goja.Runtime) *queryOptions { func (m *StoreModule) assertQueryOptions(value goja.Value, rt *goja.Runtime) *queryOptions {
if value.Export() == nil {
return nil
}
rawQueryOptions, ok := value.Export().(map[string]interface{}) rawQueryOptions, ok := value.Export().(map[string]interface{})
if !ok { if !ok {
panic(rt.NewTypeError(fmt.Sprintf("query options must be an object, got '%T'", value.Export()))) panic(rt.NewTypeError(fmt.Sprintf("query options must be an object, got '%T'", value.Export())))

28
pkg/module/util/assert.go Normal file
View File

@ -0,0 +1,28 @@
package util
import (
"context"
"fmt"
"github.com/dop251/goja"
)
func AssertType[T any](v goja.Value, rt *goja.Runtime) T {
if c, ok := v.Export().(T); ok {
return c
}
panic(rt.NewTypeError(fmt.Sprintf("expected value to be a '%T', got '%T'", new(T), v.Export())))
}
func AssertContext(v goja.Value, r *goja.Runtime) context.Context {
return AssertType[context.Context](v, r)
}
func AssertObject(v goja.Value, r *goja.Runtime) map[string]any {
return AssertType[map[string]any](v, r)
}
func AssertString(v goja.Value, r *goja.Runtime) string {
return AssertType[string](v, r)
}

View File

@ -31,12 +31,17 @@ func (d Document) ID() (DocumentID, bool) {
return "", false return "", false
} }
id, ok := rawID.(string) strID, ok := rawID.(string)
if ok { if ok {
return "", false return DocumentID(strID), true
} }
return DocumentID(id), true docID, ok := rawID.(DocumentID)
if ok {
return docID, true
}
return "", false
} }
func (d Document) CreatedAt() (time.Time, bool) { func (d Document) CreatedAt() (time.Time, bool) {
@ -54,7 +59,7 @@ func (d Document) timeAttr(attr string) (time.Time, bool) {
} }
t, ok := rawTime.(time.Time) t, ok := rawTime.(time.Time)
if ok { if !ok {
return time.Time{}, false return time.Time{}, false
} }

View File

@ -0,0 +1,15 @@
package sql
const (
OpIn = "IN"
OpLesserThan = "<"
OpLesserThanEqual = "<="
OpEqual = "="
OpNotEqual = "!="
OpSuperiorThan = ">"
OpSuperiorThanEqual = ">="
OpAnd = "AND"
OpOr = "OR"
OpLike = "LIKE"
OpNot = "NOT"
)

View File

@ -71,6 +71,12 @@ func WithDefaultTransform() OptionFunc {
} }
} }
func WithTransform(transform TransformFunc) OptionFunc {
return func(opt *Option) {
opt.Transform = transform
}
}
func WithNoOpValueTransform() OptionFunc { func WithNoOpValueTransform() OptionFunc {
return WithValueTransform(func(value interface{}) interface{} { return WithValueTransform(func(value interface{}) interface{} {
return value return value

View File

@ -60,7 +60,7 @@ func transformAndOperator(op filter.Operator, option *Option) (string, []interfa
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenAnd, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenAnd, op.Token())
} }
return aggregatorToSQL("AND", option, andOp.Children()...) return aggregatorToSQL(OpAnd, option, andOp.Children()...)
} }
func transformOrOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformOrOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -69,7 +69,7 @@ func transformOrOperator(op filter.Operator, option *Option) (string, []interfac
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenOr, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenOr, op.Token())
} }
return aggregatorToSQL("OR", option, orOp.Children()...) return aggregatorToSQL(OpOr, option, orOp.Children()...)
} }
func transformEqOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformEqOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -78,7 +78,7 @@ func transformEqOperator(op filter.Operator, option *Option) (string, []interfac
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenEq, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenEq, op.Token())
} }
return fieldsToSQL("=", false, eqOp.Fields(), option) return fieldsToSQL(OpEqual, false, eqOp.Fields(), option)
} }
func transformNeqOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformNeqOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -87,7 +87,7 @@ func transformNeqOperator(op filter.Operator, option *Option) (string, []interfa
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenNeq, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenNeq, op.Token())
} }
return fieldsToSQL("!=", false, eqOp.Fields(), option) return fieldsToSQL(OpNotEqual, false, eqOp.Fields(), option)
} }
func transformGtOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformGtOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -96,7 +96,7 @@ func transformGtOperator(op filter.Operator, option *Option) (string, []interfac
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenGt, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenGt, op.Token())
} }
return fieldsToSQL(">", false, gtOp.Fields(), option) return fieldsToSQL(OpSuperiorThan, false, gtOp.Fields(), option)
} }
func transformGteOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformGteOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -105,7 +105,7 @@ func transformGteOperator(op filter.Operator, option *Option) (string, []interfa
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenGte, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenGte, op.Token())
} }
return fieldsToSQL(">=", false, gteOp.Fields(), option) return fieldsToSQL(OpSuperiorThanEqual, false, gteOp.Fields(), option)
} }
func transformLtOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformLtOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -114,7 +114,7 @@ func transformLtOperator(op filter.Operator, option *Option) (string, []interfac
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenLt, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenLt, op.Token())
} }
return fieldsToSQL("<", false, ltOp.Fields(), option) return fieldsToSQL(OpLesserThan, false, ltOp.Fields(), option)
} }
func transformLteOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformLteOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -123,7 +123,7 @@ func transformLteOperator(op filter.Operator, option *Option) (string, []interfa
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenLte, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenLte, op.Token())
} }
return fieldsToSQL("<=", false, lteOp.Fields(), option) return fieldsToSQL(OpLesserThanEqual, false, lteOp.Fields(), option)
} }
func transformInOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformInOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -132,7 +132,7 @@ func transformInOperator(op filter.Operator, option *Option) (string, []interfac
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenIn, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenIn, op.Token())
} }
return fieldsToSQL("IN", true, inOp.Fields(), option) return fieldsToSQL(OpIn, true, inOp.Fields(), option)
} }
func transformLikeOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformLikeOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -141,7 +141,7 @@ func transformLikeOperator(op filter.Operator, option *Option) (string, []interf
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenLike, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenLike, op.Token())
} }
return fieldsToSQL("LIKE", false, likeOp.Fields(), option) return fieldsToSQL(OpLike, false, likeOp.Fields(), option)
} }
func transformNotOperator(op filter.Operator, option *Option) (string, []interface{}, error) { func transformNotOperator(op filter.Operator, option *Option) (string, []interface{}, error) {
@ -150,10 +150,10 @@ func transformNotOperator(op filter.Operator, option *Option) (string, []interfa
return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenNot, op.Token()) return "", nil, errors.Wrapf(filter.ErrUnexpectedOperator, "expected '%s', got '%s'", filter.TokenNot, op.Token())
} }
sql, args, err := aggregatorToSQL("AND", option, notOp.Children()...) sql, args, err := aggregatorToSQL(OpAnd, option, notOp.Children()...)
if err != nil { if err != nil {
return "", nil, errors.WithStack(err) return "", nil, errors.WithStack(err)
} }
return "NOT " + sql, args, nil return OpNot + " " + sql, args, nil
} }

View File

@ -32,7 +32,7 @@ func DefaultTransform(operator string, invert bool, key string, value interface{
return "", nil, errors.WithStack(err) return "", nil, errors.WithStack(err)
} }
if _, err := sb.WriteString(key); err != nil { if _, err := sb.WriteString(option.KeyTransform(key)); err != nil {
return "", nil, errors.WithStack(err) return "", nil, errors.WithStack(err)
} }
} else { } else {

View File

@ -93,15 +93,23 @@ func (s *DocumentStore) Query(ctx context.Context, collection string, filter *fi
var documents []storage.Document var documents []storage.Document
err := s.withTx(ctx, func(tx *sql.Tx) error { err := s.withTx(ctx, func(tx *sql.Tx) error {
criteria, args, err := filterSQL.ToSQL( criteria := "1 = 1"
filter.Root(), args := make([]any, 0)
filterSQL.WithPreparedParameter("$", 2),
filterSQL.WithKeyTransform(func(key string) string { var err error
return fmt.Sprintf("json_extract(data, '$.%s')", key)
}), if filter != nil {
) criteria, args, err = filterSQL.ToSQL(
if err != nil { filter.Root(),
return errors.WithStack(err) filterSQL.WithPreparedParameter("$", 2),
filterSQL.WithTransform(transformOperator),
filterSQL.WithKeyTransform(func(key string) string {
return fmt.Sprintf("json_extract(data, '$.%s')", key)
}),
)
if err != nil {
return errors.WithStack(err)
}
} }
query := ` query := `
@ -180,10 +188,6 @@ func (s *DocumentStore) Upsert(ctx context.Context, collection string, document
id = storage.NewDocumentID() id = storage.NewDocumentID()
} }
delete(document, storage.DocumentAttrID)
delete(document, storage.DocumentAttrCreatedAt)
delete(document, storage.DocumentAttrUpdatedAt)
args := []any{id, collection, JSONMap(document), now, now} args := []any{id, collection, JSONMap(document), now, now}
row := tx.QueryRowContext(ctx, query, args...) row := tx.QueryRowContext(ctx, query, args...)

View File

@ -0,0 +1,24 @@
package sqlite
import (
"fmt"
"forge.cadoles.com/arcad/edge/pkg/storage/filter/sql"
)
func transformOperator(operator string, invert bool, key string, value any, option *sql.Option) (string, any, error) {
switch operator {
case sql.OpIn:
return transformInOperator(key, value, option)
default:
return sql.DefaultTransform(operator, invert, key, value, option)
}
}
func transformInOperator(key string, value any, option *sql.Option) (string, any, error) {
return fmt.Sprintf(
"EXISTS (SELECT 1 FROM json_each(json_extract(data, \"$.%v\")) WHERE value = %v)",
key,
option.PreparedParameter(),
), option.ValueTransform(value), nil
}

View File

@ -7,8 +7,8 @@ import (
) )
func TestDocumentStore(t *testing.T, store storage.DocumentStore) { func TestDocumentStore(t *testing.T, store storage.DocumentStore) {
t.Run("Query", func(t *testing.T) { t.Run("Ops", func(t *testing.T) {
// t.Parallel() // t.Parallel()
testDocumentStoreQuery(t, store) testDocumentStoreOps(t, store)
}) })
} }

View File

@ -0,0 +1,212 @@
package testsuite
import (
"context"
"testing"
"forge.cadoles.com/arcad/edge/pkg/storage"
"forge.cadoles.com/arcad/edge/pkg/storage/filter"
"github.com/davecgh/go-spew/spew"
"github.com/pkg/errors"
)
type documentStoreOpsTestCase struct {
Name string
Run func(ctx context.Context, store storage.DocumentStore) error
}
var documentStoreOpsTestCases = []documentStoreOpsTestCase{
{
Name: "Basic query",
Run: func(ctx context.Context, store storage.DocumentStore) error {
collection := "simple_select"
docs := []storage.Document{
{
"attr1": "Foo",
},
{
"attr1": "Bar",
},
}
for _, d := range docs {
if _, err := store.Upsert(ctx, collection, d); err != nil {
return errors.WithStack(err)
}
}
filter := filter.New(
filter.NewEqOperator(map[string]interface{}{
"attr1": "Foo",
}),
)
results, err := store.Query(ctx, collection, filter, nil)
if err != nil {
return errors.WithStack(err)
}
if e, g := 1, len(results); e != g {
return errors.Errorf("len(results): expected '%v', got '%v'", e, g)
}
if e, g := "Foo", results[0]["attr1"]; e != g {
return errors.Errorf("results[0][\"Attr1\"]: expected '%v', got '%v'", e, g)
}
return nil
},
},
{
Name: "Query with 'IN' operator",
Run: func(ctx context.Context, store storage.DocumentStore) error {
docs := []storage.Document{
{
"counter": 1,
"tags": []string{"foo", "bar"},
},
{
"counter": 1,
"tags": []string{"nope"},
},
}
collection := "in_operator"
for _, doc := range docs {
if _, err := store.Upsert(ctx, collection, doc); err != nil {
return errors.WithStack(err)
}
}
filter := filter.New(
filter.NewAndOperator(
filter.NewEqOperator(map[string]any{
"counter": 1,
}),
filter.NewInOperator(map[string]any{
"tags": "foo",
}),
),
)
results, err := store.Query(ctx, collection, filter, nil)
if err != nil {
return errors.WithStack(err)
}
if e, g := 1, len(results); e != g {
return errors.Errorf("len(results): expected '%v', got '%v'", e, g)
}
return nil
},
},
{
Name: "Double upsert",
Run: func(ctx context.Context, store storage.DocumentStore) error {
collection := "double_upsert"
oriDoc := storage.Document{
"attr1": "Foo",
}
// Upsert document for the first time
upsertedDoc, err := store.Upsert(ctx, collection, oriDoc)
if err != nil {
return errors.WithStack(err)
}
id, exists := upsertedDoc.ID()
if !exists {
return errors.New("id, exists := upsertedDoc.ID(): 'exists' should be true")
}
if id == storage.DocumentID("") {
return errors.New("id, exists := upsertedDoc.ID(): 'id' should not be an empty string")
}
createdAt, exists := upsertedDoc.CreatedAt()
if !exists {
return errors.New("createdAt, exists := upsertedDoc.CreatedAt(): 'exists' should be true")
}
if createdAt.IsZero() {
return errors.New("createdAt, exists := upsertedDoc.CreatedAt(): 'createdAt' should not be zero time")
}
updatedAt, exists := upsertedDoc.UpdatedAt()
if !exists {
return errors.New("updatedAt, exists := upsertedDoc.UpdatedAt(): 'exists' should be true")
}
if updatedAt.IsZero() {
return errors.New("updatedAt, exists := upsertedDoc.UpdatedAt(): 'updatedAt' should not be zero time")
}
if e, g := oriDoc["attr1"], upsertedDoc["attr1"]; e != g {
return errors.Errorf("upsertedDoc[\"attr1\"]: expected '%v', got '%v'", e, g)
}
// Check that document does not have unexpected properties
if e, g := 4, len(upsertedDoc); e != g {
return errors.Errorf("len(upsertedDoc): expected '%v', got '%v'", e, g)
}
// Upsert document for the second time
upsertedDoc2, err := store.Upsert(ctx, collection, upsertedDoc)
if err != nil {
return errors.WithStack(err)
}
spew.Dump(upsertedDoc, upsertedDoc2)
prevID, _ := upsertedDoc.ID()
newID, _ := upsertedDoc2.ID()
if e, g := prevID, newID; e != g {
return errors.Errorf("newID: expected '%v', got '%v'", e, g)
}
createdAt1, _ := upsertedDoc.CreatedAt()
createdAt2, _ := upsertedDoc2.CreatedAt()
if e, g := createdAt1, createdAt2; e != g {
return errors.Errorf("upsertedDoc2.CreatedAt(): expected '%v', got '%v'", e, g)
}
updatedAt1, _ := upsertedDoc.UpdatedAt()
updatedAt2, _ := upsertedDoc2.UpdatedAt()
if e, g := updatedAt1, updatedAt2; e == g {
return errors.New("upsertedDoc2.UpdatedAt() should have been different than upsertedDoc.UpdatedAt()")
}
// Verify that there is no additional created document in the collection
results, err := store.Query(ctx, collection, nil, nil)
if err != nil {
return errors.WithStack(err)
}
if e, g := 1, len(results); e != g {
return errors.Errorf("len(results): expected '%v', got '%v'", e, g)
}
return nil
},
},
}
func testDocumentStoreOps(t *testing.T, store storage.DocumentStore) {
for _, tc := range documentStoreOpsTestCases {
func(tc documentStoreOpsTestCase) {
t.Run(tc.Name, func(t *testing.T) {
if err := tc.Run(context.Background(), store); err != nil {
t.Errorf("%+v", errors.WithStack(err))
}
})
}(tc)
}
}

View File

@ -1,85 +0,0 @@
package testsuite
import (
"context"
"testing"
"forge.cadoles.com/arcad/edge/pkg/storage"
"forge.cadoles.com/arcad/edge/pkg/storage/filter"
"github.com/pkg/errors"
)
type documentStoreQueryTestCase struct {
Name string
Before func(ctx context.Context, store storage.DocumentStore) error
Collection string
Filter *filter.Filter
QueryOptionsFuncs []storage.QueryOptionFunc
After func(t *testing.T, results []storage.Document, err error)
}
var documentStoreQueryTestCases = []documentStoreQueryTestCase{
{
Name: "Simple select",
Before: func(ctx context.Context, store storage.DocumentStore) error {
doc1 := storage.Document{
"attr1": "Foo",
}
if _, err := store.Upsert(ctx, "simple_select", doc1); err != nil {
return errors.WithStack(err)
}
doc2 := storage.Document{
"attr1": "Bar",
}
if _, err := store.Upsert(ctx, "simple_select", doc2); err != nil {
return errors.WithStack(err)
}
return nil
},
Collection: "simple_select",
Filter: filter.New(
filter.NewEqOperator(map[string]interface{}{
"attr1": "Foo",
}),
),
After: func(t *testing.T, results []storage.Document, err error) {
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
if e, g := 1, len(results); e != g {
t.Errorf("len(results): expected '%v', got '%v'", e, g)
}
if e, g := "Foo", results[0]["attr1"]; e != g {
t.Errorf("results[0][\"Attr1\"]: expected '%v', got '%v'", e, g)
}
},
},
}
func testDocumentStoreQuery(t *testing.T, store storage.DocumentStore) {
for _, tc := range documentStoreQueryTestCases {
func(tc documentStoreQueryTestCase) {
t.Run(tc.Name, func(t *testing.T) {
// t.Parallel()
ctx := context.Background()
if tc.Before != nil {
if err := tc.Before(ctx, store); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
}
documents, err := store.Query(ctx, tc.Collection, tc.Filter, tc.QueryOptionsFuncs...)
tc.After(t, documents, err)
})
}(tc)
}
}