package openwrt import ( "context" "io" "net/http" "os" "os/exec" "path/filepath" "strings" "forge.cadoles.com/Cadoles/emissary/internal/agent" "forge.cadoles.com/Cadoles/emissary/internal/agent/controller/openwrt/spec/sysupgrade" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type SysUpgradeController struct { client *http.Client command string args []string firmwareVersion FirmwareVersion } // Name implements agent.Controller func (*SysUpgradeController) Name() string { return "sysupgrade-controller" } // Reconcile implements agent.Controller func (c *SysUpgradeController) Reconcile(ctx context.Context, state *agent.State) error { sysSpec := sysupgrade.NewSpec() if err := state.GetSpec(sysupgrade.Name, sysSpec); err != nil { if errors.Is(err, agent.ErrSpecNotFound) { logger.Info(ctx, "could not find sysupgrade spec, doing nothing") return nil } return errors.WithStack(err) } firmwareVersion, err := c.firmwareVersion.FirmwareVersion(ctx) if err != nil { return errors.WithStack(err) } ctx = logger.With(ctx, logger.F("currentFirmwareVersion", firmwareVersion), logger.F("newFirmwareVersion", sysSpec.Version), ) if firmwareVersion == sysSpec.Version { logger.Info(ctx, "firmware version did not change, doing nothing") return nil } downloadDir, err := os.MkdirTemp(os.TempDir(), "emissary_sysupgrade_*") if err != nil { return errors.WithStack(err) } defer func() { if err := os.RemoveAll(downloadDir); err != nil { logger.Error( ctx, "could not remove download direction", logger.E(errors.WithStack(err)), logger.F("downloadDir", downloadDir), ) } }() firmwareFile := filepath.Join(downloadDir, "firmware.bin") logger.Info( ctx, "downloading firmware", logger.F("url", sysSpec.URL), logger.F("sha256sum", sysSpec.SHA256Sum), ) if err := c.downloadFile(ctx, sysSpec.URL, sysSpec.SHA256Sum, firmwareFile); err != nil { return errors.WithStack(err) } logger.Info(ctx, "upgrading firmware") if err := c.upgradeFirmware(ctx, firmwareFile); err != nil { return errors.WithStack(err) } return nil } func (c *SysUpgradeController) downloadFile(ctx context.Context, url string, sha256sum string, dest string) error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return errors.WithStack(err) } res, err := c.client.Do(req) if err != nil { return errors.WithStack(err) } defer func() { if err := res.Body.Close(); err != nil && !errors.Is(err, os.ErrClosed) { panic(errors.WithStack(err)) } }() tmp, err := os.CreateTemp(filepath.Dir(dest), "download_") if err != nil { return errors.WithStack(err) } defer func() { if err := os.Remove(tmp.Name()); err != nil && !os.IsNotExist(err) { panic(errors.WithStack(err)) } }() if _, err := io.Copy(tmp, res.Body); err != nil { return errors.WithStack(err) } tmpFileHash, err := hash(tmp.Name()) if err != nil { return errors.WithStack(err) } if tmpFileHash != sha256sum { return errors.Errorf("sha256 sum mismatch: expected '%s', got '%s'", sha256sum, tmpFileHash) } if err := os.Rename(tmp.Name(), dest); err != nil { return errors.WithStack(err) } return nil } func (c *SysUpgradeController) upgradeFirmware(ctx context.Context, firmwareFile string) error { templatizedArgs := make([]string, len(c.args)) for i, a := range c.args { templatizedArgs[i] = strings.Replace(a, FirmwareFileTemplate, firmwareFile, 1) } command := exec.CommandContext(ctx, c.command, templatizedArgs...) command.Stdout = os.Stdout command.Stderr = os.Stderr logger.Debug(ctx, "executing command", logger.F("command", c.command), logger.F("args", templatizedArgs)) if err := command.Run(); err != nil { return errors.WithStack(err) } return nil } func NewSysUpgradeController(funcs ...SysUpgradeOptionFunc) *SysUpgradeController { opts := defaultSysUpgradeOptions() for _, fn := range funcs { fn(opts) } return &SysUpgradeController{ command: opts.Command, args: opts.Args, client: opts.Client, firmwareVersion: opts.FirmwareVersion, } } var _ agent.Controller = &SysUpgradeController{}