178 lines
4.1 KiB
Go
178 lines
4.1 KiB
Go
|
package openwrt
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"os/exec"
|
||
|
"path/filepath"
|
||
|
"strings"
|
||
|
|
||
|
"forge.cadoles.com/Cadoles/emissary/internal/agent"
|
||
|
"forge.cadoles.com/Cadoles/emissary/internal/agent/controller/openwrt/spec/sysupgrade"
|
||
|
"github.com/pkg/errors"
|
||
|
"gitlab.com/wpetit/goweb/logger"
|
||
|
)
|
||
|
|
||
|
type SysUpgradeController struct {
|
||
|
client *http.Client
|
||
|
command string
|
||
|
args []string
|
||
|
firmwareVersion FirmwareVersion
|
||
|
}
|
||
|
|
||
|
// Name implements agent.Controller
|
||
|
func (*SysUpgradeController) Name() string {
|
||
|
return "sysupgrade-controller"
|
||
|
}
|
||
|
|
||
|
// Reconcile implements agent.Controller
|
||
|
func (c *SysUpgradeController) Reconcile(ctx context.Context, state *agent.State) error {
|
||
|
sysSpec := sysupgrade.NewSpec()
|
||
|
|
||
|
if err := state.GetSpec(sysupgrade.Name, sysSpec); err != nil {
|
||
|
if errors.Is(err, agent.ErrSpecNotFound) {
|
||
|
logger.Info(ctx, "could not find sysupgrade spec, doing nothing")
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
firmwareVersion, err := c.firmwareVersion.FirmwareVersion(ctx)
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
ctx = logger.With(ctx,
|
||
|
logger.F("currentFirmwareVersion", firmwareVersion),
|
||
|
logger.F("newFirmwareVersion", sysSpec.Version),
|
||
|
)
|
||
|
|
||
|
if firmwareVersion == sysSpec.Version {
|
||
|
logger.Info(ctx, "firmware version did not change, doing nothing")
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
downloadDir, err := os.MkdirTemp(os.TempDir(), "emissary_sysupgrade_*")
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err := os.RemoveAll(downloadDir); err != nil {
|
||
|
logger.Error(
|
||
|
ctx, "could not remove download direction",
|
||
|
logger.E(errors.WithStack(err)),
|
||
|
logger.F("downloadDir", downloadDir),
|
||
|
)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
firmwareFile := filepath.Join(downloadDir, "firmware.bin")
|
||
|
|
||
|
logger.Info(
|
||
|
ctx, "downloading firmware",
|
||
|
logger.F("url", sysSpec.URL), logger.F("sha256sum", sysSpec.SHA256Sum),
|
||
|
)
|
||
|
|
||
|
if err := c.downloadFile(ctx, sysSpec.URL, sysSpec.SHA256Sum, firmwareFile); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
logger.Info(ctx, "upgrading firmware")
|
||
|
|
||
|
if err := c.upgradeFirmware(ctx, firmwareFile); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *SysUpgradeController) downloadFile(ctx context.Context, url string, sha256sum string, dest string) error {
|
||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
res, err := c.client.Do(req)
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err := res.Body.Close(); err != nil && !errors.Is(err, os.ErrClosed) {
|
||
|
panic(errors.WithStack(err))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
tmp, err := os.CreateTemp(filepath.Dir(dest), "download_")
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err := os.Remove(tmp.Name()); err != nil && !os.IsNotExist(err) {
|
||
|
panic(errors.WithStack(err))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
if _, err := io.Copy(tmp, res.Body); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
tmpFileHash, err := hash(tmp.Name())
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
if tmpFileHash != sha256sum {
|
||
|
return errors.Errorf("sha256 sum mismatch: expected '%s', got '%s'", sha256sum, tmpFileHash)
|
||
|
}
|
||
|
|
||
|
if err := os.Rename(tmp.Name(), dest); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *SysUpgradeController) upgradeFirmware(ctx context.Context, firmwareFile string) error {
|
||
|
templatizedArgs := make([]string, len(c.args))
|
||
|
for i, a := range c.args {
|
||
|
templatizedArgs[i] = strings.Replace(a, FirmwareFileTemplate, firmwareFile, 1)
|
||
|
}
|
||
|
|
||
|
command := exec.CommandContext(ctx, c.command, templatizedArgs...)
|
||
|
|
||
|
command.Stdout = os.Stdout
|
||
|
command.Stderr = os.Stderr
|
||
|
|
||
|
logger.Debug(ctx, "executing command", logger.F("command", c.command), logger.F("args", templatizedArgs))
|
||
|
|
||
|
if err := command.Run(); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func NewSysUpgradeController(funcs ...SysUpgradeOptionFunc) *SysUpgradeController {
|
||
|
opts := defaultSysUpgradeOptions()
|
||
|
for _, fn := range funcs {
|
||
|
fn(opts)
|
||
|
}
|
||
|
|
||
|
return &SysUpgradeController{
|
||
|
command: opts.Command,
|
||
|
args: opts.Args,
|
||
|
client: opts.Client,
|
||
|
firmwareVersion: opts.FirmwareVersion,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var _ agent.Controller = &SysUpgradeController{}
|