feat(controller,sysupgrade): openwrt upgrade controller

This commit is contained in:
2023-03-24 23:17:55 +01:00
parent 97a60e2617
commit 0b783c374a
28 changed files with 531 additions and 28 deletions

View File

@ -28,8 +28,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/pkg/errors"
_ "forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd/argon2id"
_ "forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd/plain"
_ "forge.cadoles.com/Cadoles/emissary/internal/imports/passwd"
)
type Server struct {

View File

@ -0,0 +1,31 @@
package openwrt
import (
"crypto/sha256"
"encoding/hex"
"io"
"os"
"github.com/pkg/errors"
)
func hash(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", errors.WithStack(err)
}
hasher := sha256.New()
defer func() {
if err := file.Close(); err != nil {
panic(errors.WithStack(err))
}
}()
if _, err := io.Copy(hasher, file); err != nil {
return "", errors.WithStack(err)
}
return hex.EncodeToString(hasher.Sum(nil)), nil
}

View File

@ -0,0 +1,17 @@
package sysupgrade
import (
_ "embed"
"forge.cadoles.com/Cadoles/emissary/internal/spec"
"github.com/pkg/errors"
)
//go:embed schema.json
var schema []byte
func init() {
if err := spec.Register(Name, schema); err != nil {
panic(errors.WithStack(err))
}
}

View File

@ -0,0 +1,20 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "https://sysupgrade.openwrt.emissary.cadoles.com/spec.json",
"title": "SysUpgradeSpec",
"description": "Emissary 'SysUpgrade' specification",
"type": "object",
"properties": {
"url": {
"type": "string"
},
"sha256sum": {
"type": "string"
},
"version": {
"type": "string"
}
},
"required": ["url", "sha256sum", "version"],
"additionalProperties": false
}

View File

@ -0,0 +1,38 @@
package sysupgrade
import (
"forge.cadoles.com/Cadoles/emissary/internal/spec"
)
const Name spec.Name = "sysupgrade.openwrt.emissary.cadoles.com"
type Spec struct {
Revision int `json:"revision"`
URL string `json:"url"`
SHA256Sum string `json:"sha256sum"`
Version string `json:"version"`
}
func (s *Spec) SpecName() spec.Name {
return Name
}
func (s *Spec) SpecRevision() int {
return s.Revision
}
func (s *Spec) SpecData() map[string]any {
return map[string]any{
"url": s.URL,
"version": s.Version,
"sha256sum": s.SHA256Sum,
}
}
func NewSpec() *Spec {
return &Spec{
Revision: -1,
}
}
var _ spec.Spec = &Spec{}

View File

@ -0,0 +1,9 @@
{
"name": "sysupgrade.openwrt.emissary.cadoles.com",
"data": {
"url": "http://example.com/firmware.img",
"sha256sum": "58019192dacdae17755707719707db007e26dac856102280583fbd18427dd352",
"version": "0.0.0"
},
"revision": 0
}

View File

@ -0,0 +1,65 @@
package sysupgrade
import (
"context"
"encoding/json"
"io/ioutil"
"testing"
"forge.cadoles.com/Cadoles/emissary/internal/spec"
"github.com/pkg/errors"
)
type validatorTestCase struct {
Name string
Source string
ShouldFail bool
}
var validatorTestCases = []validatorTestCase{
{
Name: "SpecOK",
Source: "testdata/spec-ok.json",
ShouldFail: false,
},
}
func TestValidator(t *testing.T) {
t.Parallel()
validator := spec.NewValidator()
if err := validator.Register(Name, schema); err != nil {
t.Fatalf("+%v", errors.WithStack(err))
}
for _, tc := range validatorTestCases {
func(tc validatorTestCase) {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
rawSpec, err := ioutil.ReadFile(tc.Source)
if err != nil {
t.Fatalf("+%v", errors.WithStack(err))
}
var spec spec.RawSpec
if err := json.Unmarshal(rawSpec, &spec); err != nil {
t.Fatalf("+%v", errors.WithStack(err))
}
ctx := context.Background()
err = validator.Validate(ctx, &spec)
if !tc.ShouldFail && err != nil {
t.Errorf("+%v", errors.WithStack(err))
}
if tc.ShouldFail && err == nil {
t.Error("validation should have failed")
}
})
}(tc)
}
}

View File

@ -0,0 +1,177 @@
package openwrt
import (
"context"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"forge.cadoles.com/Cadoles/emissary/internal/agent"
"forge.cadoles.com/Cadoles/emissary/internal/agent/controller/openwrt/spec/sysupgrade"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type SysUpgradeController struct {
client *http.Client
command string
args []string
firmwareVersion FirmwareVersion
}
// Name implements agent.Controller
func (*SysUpgradeController) Name() string {
return "sysupgrade-controller"
}
// Reconcile implements agent.Controller
func (c *SysUpgradeController) Reconcile(ctx context.Context, state *agent.State) error {
sysSpec := sysupgrade.NewSpec()
if err := state.GetSpec(sysupgrade.Name, sysSpec); err != nil {
if errors.Is(err, agent.ErrSpecNotFound) {
logger.Info(ctx, "could not find sysupgrade spec, doing nothing")
return nil
}
return errors.WithStack(err)
}
firmwareVersion, err := c.firmwareVersion.FirmwareVersion(ctx)
if err != nil {
return errors.WithStack(err)
}
ctx = logger.With(ctx,
logger.F("currentFirmwareVersion", firmwareVersion),
logger.F("newFirmwareVersion", sysSpec.Version),
)
if firmwareVersion == sysSpec.Version {
logger.Info(ctx, "firmware version did not change, doing nothing")
return nil
}
downloadDir, err := os.MkdirTemp(os.TempDir(), "emissary_sysupgrade_*")
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := os.RemoveAll(downloadDir); err != nil {
logger.Error(
ctx, "could not remove download direction",
logger.E(errors.WithStack(err)),
logger.F("downloadDir", downloadDir),
)
}
}()
firmwareFile := filepath.Join(downloadDir, "firmware.bin")
logger.Info(
ctx, "downloading firmware",
logger.F("url", sysSpec.URL), logger.F("sha256sum", sysSpec.SHA256Sum),
)
if err := c.downloadFile(ctx, sysSpec.URL, sysSpec.SHA256Sum, firmwareFile); err != nil {
return errors.WithStack(err)
}
logger.Info(ctx, "upgrading firmware")
if err := c.upgradeFirmware(ctx, firmwareFile); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *SysUpgradeController) downloadFile(ctx context.Context, url string, sha256sum string, dest string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return errors.WithStack(err)
}
res, err := c.client.Do(req)
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := res.Body.Close(); err != nil && !errors.Is(err, os.ErrClosed) {
panic(errors.WithStack(err))
}
}()
tmp, err := os.CreateTemp(filepath.Dir(dest), "download_")
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := os.Remove(tmp.Name()); err != nil && !os.IsNotExist(err) {
panic(errors.WithStack(err))
}
}()
if _, err := io.Copy(tmp, res.Body); err != nil {
return errors.WithStack(err)
}
tmpFileHash, err := hash(tmp.Name())
if err != nil {
return errors.WithStack(err)
}
if tmpFileHash != sha256sum {
return errors.Errorf("sha256 sum mismatch: expected '%s', got '%s'", sha256sum, tmpFileHash)
}
if err := os.Rename(tmp.Name(), dest); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *SysUpgradeController) upgradeFirmware(ctx context.Context, firmwareFile string) error {
templatizedArgs := make([]string, len(c.args))
for i, a := range c.args {
templatizedArgs[i] = strings.Replace(a, FirmwareFileTemplate, firmwareFile, 1)
}
command := exec.CommandContext(ctx, c.command, templatizedArgs...)
command.Stdout = os.Stdout
command.Stderr = os.Stderr
logger.Debug(ctx, "executing command", logger.F("command", c.command), logger.F("args", templatizedArgs))
if err := command.Run(); err != nil {
return errors.WithStack(err)
}
return nil
}
func NewSysUpgradeController(funcs ...SysUpgradeOptionFunc) *SysUpgradeController {
opts := defaultSysUpgradeOptions()
for _, fn := range funcs {
fn(opts)
}
return &SysUpgradeController{
command: opts.Command,
args: opts.Args,
client: opts.Client,
firmwareVersion: opts.FirmwareVersion,
}
}
var _ agent.Controller = &SysUpgradeController{}

View File

@ -0,0 +1,46 @@
package openwrt
import (
"bytes"
"context"
"os"
"os/exec"
"strings"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type ShellFirmwareVersion struct {
command string
args []string
}
// FirmwareVersion implements FirmwareVersion
func (fv *ShellFirmwareVersion) FirmwareVersion(ctx context.Context) (string, error) {
command := exec.CommandContext(ctx, fv.command, fv.args...)
var buf bytes.Buffer
command.Stdout = &buf
command.Stderr = os.Stderr
logger.Debug(ctx, "executing command", logger.F("command", fv.command), logger.F("args", fv.args))
if err := command.Run(); err != nil {
return "", errors.WithStack(err)
}
version := strings.TrimSpace(buf.String())
return version, nil
}
func NewShellFirmwareVersion(command string, args ...string) *ShellFirmwareVersion {
return &ShellFirmwareVersion{
command: command,
args: args,
}
}
var _ FirmwareVersion = &ShellFirmwareVersion{}

View File

@ -0,0 +1,58 @@
package openwrt
import (
"context"
"net/http"
"time"
)
const FirmwareFileTemplate = "%FIRMWARE%"
type FirmwareVersion interface {
FirmwareVersion(context.Context) (string, error)
}
type SysUpgradeOptions struct {
Command string
Args []string
FirmwareVersion FirmwareVersion
Client *http.Client
}
func defaultSysUpgradeOptions() *SysUpgradeOptions {
return &SysUpgradeOptions{
Command: `echo`,
Args: []string{`[DUMMY UPGRADE]`, FirmwareFileTemplate},
Client: &http.Client{
Timeout: 30 * time.Second,
},
FirmwareVersion: NewShellFirmwareVersion(`echo`, "0.0.0-dummy"),
}
}
type SysUpgradeOptionFunc func(*SysUpgradeOptions)
func WithSysUpgradeCommand(command string, args ...string) SysUpgradeOptionFunc {
return func(opts *SysUpgradeOptions) {
opts.Command = command
opts.Args = args
}
}
func WithSysUpgradeFirmwareVersion(firmwareVersion FirmwareVersion) SysUpgradeOptionFunc {
return func(opts *SysUpgradeOptions) {
opts.FirmwareVersion = firmwareVersion
}
}
func WithSysUpgradeShellFirmwareVersion(command string, args ...string) SysUpgradeOptionFunc {
return func(opts *SysUpgradeOptions) {
opts.FirmwareVersion = NewShellFirmwareVersion(command, args...)
}
}
func WithSysUpgradeClient(client *http.Client) SysUpgradeOptionFunc {
return func(opts *SysUpgradeOptions) {
opts.Client = client
}
}