emissary/internal/agent/controller/app/controller.go

303 lines
6.8 KiB
Go

package app
import (
"context"
"io"
"net/http"
"os"
"path/filepath"
"forge.cadoles.com/Cadoles/emissary/internal/agent"
"forge.cadoles.com/Cadoles/emissary/internal/spec/app"
"forge.cadoles.com/arcad/edge/pkg/bundle"
"forge.cadoles.com/arcad/edge/pkg/storage/sqlite"
"github.com/mitchellh/hashstructure/v2"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type serverEntry struct {
SpecHash uint64
Server *Server
}
type Controller struct {
client *http.Client
downloadDir string
dataDir string
servers map[string]*serverEntry
}
// Name implements node.Controller.
func (c *Controller) Name() string {
return "app-controller"
}
// Reconcile implements node.Controller.
func (c *Controller) Reconcile(ctx context.Context, state *agent.State) error {
appSpec := app.NewSpec()
if err := state.GetSpec(app.NameApp, appSpec); err != nil {
if errors.Is(err, agent.ErrSpecNotFound) {
logger.Info(ctx, "could not find app spec")
c.stopAllApps(ctx, appSpec)
return nil
}
return errors.WithStack(err)
}
logger.Info(ctx, "retrieved spec", logger.F("spec", appSpec.SpecName()), logger.F("revision", appSpec.SpecRevision()))
c.updateApps(ctx, appSpec)
return nil
}
func (c *Controller) stopAllApps(ctx context.Context, spec *app.Spec) {
if len(c.servers) > 0 {
logger.Info(ctx, "stopping all apps")
}
for appID, entry := range c.servers {
logger.Info(ctx, "stopping app", logger.F("appID", appID))
if err := entry.Server.Stop(); err != nil {
logger.Error(
ctx, "error while stopping app",
logger.F("appID", appID),
logger.E(errors.WithStack(err)),
)
delete(c.servers, appID)
}
}
}
func (c *Controller) updateApps(ctx context.Context, spec *app.Spec) {
// Stop and remove obsolete apps
for appID, entry := range c.servers {
if _, exists := spec.Apps[appID]; exists {
continue
}
logger.Info(ctx, "stopping app", logger.F("appID", appID))
if err := entry.Server.Stop(); err != nil {
logger.Error(
ctx, "error while stopping app",
logger.F("gatewayID", appID),
logger.E(errors.WithStack(err)),
)
delete(c.servers, appID)
}
}
// (Re)start apps
for appID, appSpec := range spec.Apps {
appCtx := logger.With(ctx, logger.F("appID", appID))
if err := c.updateApp(ctx, appID, appSpec, spec.Auth); err != nil {
logger.Error(appCtx, "could not update app", logger.E(errors.WithStack(err)))
continue
}
}
}
func (c *Controller) updateApp(ctx context.Context, appID string, appSpec app.AppEntry, auth *app.Auth) (err error) {
newAppSpecHash, err := hashstructure.Hash(appSpec, hashstructure.FormatV2, nil)
if err != nil {
return errors.WithStack(err)
}
bundle, sha256sum, err := c.ensureAppBundle(ctx, appID, appSpec)
if err != nil {
return errors.Wrap(err, "could not download app bundle")
}
dataDir, err := c.ensureAppDataDir(ctx, appID)
if err != nil {
return errors.Wrap(err, "could not retrieve app data dir")
}
var entry *serverEntry
entry, exists := c.servers[appID]
if !exists {
logger.Info(ctx, "app currently not running")
} else if sha256sum != appSpec.SHA256Sum {
logger.Info(
ctx, "bundle hash mismatch, stopping app",
logger.F("currentHash", sha256sum),
logger.F("specHash", appSpec.SHA256Sum),
)
if err := entry.Server.Stop(); err != nil {
return errors.Wrap(err, "could not stop app")
}
entry = nil
}
if entry == nil {
dbFile := filepath.Join(dataDir, appID+".sqlite")
db, err := sqlite.Open(dbFile)
if err != nil {
return errors.Wrapf(err, "could not opend database file '%s'", dbFile)
}
entry = &serverEntry{
Server: NewServer(bundle, db, auth),
SpecHash: 0,
}
c.servers[appID] = entry
}
specChanged := newAppSpecHash != entry.SpecHash
if entry.Server.Running() && !specChanged {
return nil
}
if specChanged && entry.SpecHash != 0 {
logger.Info(
ctx, "restarting app",
logger.F("address", appSpec.Address),
)
} else {
logger.Info(
ctx, "starting app",
logger.F("address", appSpec.Address),
)
}
if err := entry.Server.Start(ctx, appSpec.Address); err != nil {
delete(c.servers, appID)
return errors.Wrap(err, "could not start app")
}
entry.SpecHash = newAppSpecHash
return nil
}
func (c *Controller) ensureAppBundle(ctx context.Context, appID string, spec app.AppEntry) (bundle.Bundle, string, error) {
if err := os.MkdirAll(c.downloadDir, os.ModePerm); err != nil {
return nil, "", errors.WithStack(err)
}
bundlePath := filepath.Join(c.downloadDir, appID+"."+spec.Format)
_, err := os.Stat(bundlePath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, "", errors.WithStack(err)
}
if errors.Is(err, os.ErrNotExist) {
if err := c.downloadFile(spec.URL, spec.SHA256Sum, bundlePath); err != nil {
return nil, "", errors.WithStack(err)
}
}
sha256sum, err := hash(bundlePath)
if err != nil {
return nil, "", errors.WithStack(err)
}
if sha256sum == spec.SHA256Sum {
bdle, err := bundle.FromPath(bundlePath)
if err != nil {
return nil, "", errors.WithStack(err)
}
return bdle, sha256sum, nil
}
logger.Info(ctx, "bundle hash mismatch, downloading app", logger.F("url", spec.URL))
if err := c.downloadFile(spec.URL, spec.SHA256Sum, bundlePath); err != nil {
return nil, "", errors.WithStack(err)
}
bdle, err := bundle.FromPath(bundlePath)
if err != nil {
return nil, "", errors.WithStack(err)
}
return bdle, "", nil
}
func (c *Controller) downloadFile(url string, sha256sum string, dest string) error {
res, err := c.client.Get(url)
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := res.Body.Close(); err != nil && !errors.Is(err, os.ErrClosed) {
panic(errors.WithStack(err))
}
}()
tmp, err := os.CreateTemp(filepath.Dir(dest), "download_")
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := os.Remove(tmp.Name()); err != nil && !os.IsNotExist(err) {
panic(errors.WithStack(err))
}
}()
if _, err := io.Copy(tmp, res.Body); err != nil {
return errors.WithStack(err)
}
tmpFileHash, err := hash(tmp.Name())
if err != nil {
return errors.WithStack(err)
}
if tmpFileHash != sha256sum {
return errors.Errorf("sha256 sum mismatch: expected '%s', got '%s'", sha256sum, tmpFileHash)
}
if err := os.Rename(tmp.Name(), dest); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Controller) ensureAppDataDir(ctx context.Context, appID string) (string, error) {
dataDir := filepath.Join(c.dataDir, appID)
if err := os.MkdirAll(dataDir, os.ModePerm); err != nil {
return "", errors.WithStack(err)
}
return dataDir, nil
}
func NewController(funcs ...OptionFunc) *Controller {
opts := defaultOptions()
for _, fn := range funcs {
fn(opts)
}
return &Controller{
client: opts.Client,
downloadDir: opts.DownloadDir,
dataDir: opts.DataDir,
servers: make(map[string]*serverEntry),
}
}
var _ agent.Controller = &Controller{}