emissary/internal/agent/controller/openwrt/sysupgrade_controller.go

179 lines
4.2 KiB
Go

package openwrt
import (
"context"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"forge.cadoles.com/Cadoles/emissary/internal/agent"
"forge.cadoles.com/Cadoles/emissary/internal/agent/controller/openwrt/spec/sysupgrade"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type SysUpgradeController struct {
client *http.Client
command string
args []string
firmwareVersion FirmwareVersion
}
// Name implements agent.Controller
func (*SysUpgradeController) Name() string {
return "sysupgrade-controller"
}
// Reconcile implements agent.Controller
func (c *SysUpgradeController) Reconcile(ctx context.Context, state *agent.State) error {
sysSpec := sysupgrade.NewSpec()
if err := state.GetSpec(sysupgrade.Name, sysSpec); err != nil {
if errors.Is(err, agent.ErrSpecNotFound) {
logger.Info(ctx, "could not find sysupgrade spec, doing nothing")
return nil
}
return errors.WithStack(err)
}
firmwareVersion, err := c.firmwareVersion.FirmwareVersion(ctx)
if err != nil {
return errors.WithStack(err)
}
ctx = logger.With(ctx,
logger.F("currentFirmwareVersion", firmwareVersion),
logger.F("newFirmwareVersion", sysSpec.Version),
)
if firmwareVersion == sysSpec.Version {
logger.Info(ctx, "firmware version did not change, doing nothing")
return nil
}
downloadDir, err := os.MkdirTemp(os.TempDir(), "emissary_sysupgrade_*")
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := os.RemoveAll(downloadDir); err != nil {
err = errors.WithStack(err)
logger.Error(
ctx, "could not remove download direction",
logger.CapturedE(err),
logger.F("downloadDir", downloadDir),
)
}
}()
firmwareFile := filepath.Join(downloadDir, "firmware.bin")
logger.Info(
ctx, "downloading firmware",
logger.F("url", sysSpec.URL), logger.F("sha256sum", sysSpec.SHA256Sum),
)
if err := c.downloadFile(ctx, sysSpec.URL, sysSpec.SHA256Sum, firmwareFile); err != nil {
return errors.WithStack(err)
}
logger.Info(ctx, "upgrading firmware")
if err := c.upgradeFirmware(ctx, firmwareFile); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *SysUpgradeController) downloadFile(ctx context.Context, url string, sha256sum string, dest string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return errors.WithStack(err)
}
res, err := c.client.Do(req)
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := res.Body.Close(); err != nil && !errors.Is(err, os.ErrClosed) {
panic(errors.WithStack(err))
}
}()
tmp, err := os.CreateTemp(filepath.Dir(dest), "download_")
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := os.Remove(tmp.Name()); err != nil && !os.IsNotExist(err) {
panic(errors.WithStack(err))
}
}()
if _, err := io.Copy(tmp, res.Body); err != nil {
return errors.WithStack(err)
}
tmpFileHash, err := hash(tmp.Name())
if err != nil {
return errors.WithStack(err)
}
if tmpFileHash != sha256sum {
return errors.Errorf("sha256 sum mismatch: expected '%s', got '%s'", sha256sum, tmpFileHash)
}
if err := os.Rename(tmp.Name(), dest); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *SysUpgradeController) upgradeFirmware(ctx context.Context, firmwareFile string) error {
templatizedArgs := make([]string, len(c.args))
for i, a := range c.args {
templatizedArgs[i] = strings.Replace(a, FirmwareFileTemplate, firmwareFile, 1)
}
command := exec.CommandContext(ctx, c.command, templatizedArgs...)
command.Stdout = os.Stdout
command.Stderr = os.Stderr
logger.Debug(ctx, "executing command", logger.F("command", c.command), logger.F("args", templatizedArgs))
if err := command.Run(); err != nil {
return errors.WithStack(err)
}
return nil
}
func NewSysUpgradeController(funcs ...SysUpgradeOptionFunc) *SysUpgradeController {
opts := defaultSysUpgradeOptions()
for _, fn := range funcs {
fn(opts)
}
return &SysUpgradeController{
command: opts.Command,
args: opts.Args,
client: opts.Client,
firmwareVersion: opts.FirmwareVersion,
}
}
var _ agent.Controller = &SysUpgradeController{}