package persistence import ( "context" "encoding/json" "fmt" "io/ioutil" "os" "path/filepath" "forge.cadoles.com/Cadoles/emissary/internal/agent" "forge.cadoles.com/Cadoles/emissary/internal/spec" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type Controller struct { trackedSpecRevisions map[spec.Name]int filename string loaded bool } // Name implements node.Controller. func (c *Controller) Name() string { return "persistence-controller" } // Reconcile implements node.Controller. func (c *Controller) Reconcile(ctx context.Context, state *agent.State) error { specs := state.Specs() changed := c.specChanged(state.Specs()) switch { // If first cycle, load state from file system case !c.loaded: logger.Info(ctx, "first cycle, loading state", logger.F("stateFile", c.filename)) if err := c.loadState(ctx, state); err != nil { if errors.Is(err, os.ErrNotExist) { logger.Info(ctx, "state file not found", logger.F("stateFile", c.filename)) c.loaded = true return nil } return errors.WithStack(err) } c.trackSpecsRevisions(specs) c.loaded = true return nil // If specs did not change, return case !changed: logger.Info(ctx, "no changes detected, doing nothing") return nil // If specs has changed, save it case changed: logger.Info(ctx, "saving state", logger.F("stateFile", c.filename)) if err := c.writeState(ctx, state); err != nil { return errors.WithStack(err) } c.trackSpecsRevisions(specs) } return nil } func (c *Controller) specChanged(specs agent.Specs) bool { if len(specs) != len(c.trackedSpecRevisions) { return true } for name, spec := range specs { trackedRevision, exists := c.trackedSpecRevisions[name] if !exists { return true } if trackedRevision != spec.SpecRevision() { return true } } for trackedSpecName, trackedRevision := range c.trackedSpecRevisions { spec, exists := specs[trackedSpecName] if !exists { return true } if trackedRevision != spec.SpecRevision() { return true } } return false } func (c *Controller) trackSpecsRevisions(specs agent.Specs) { c.trackedSpecRevisions = make(map[spec.Name]int) for name, spec := range specs { c.trackedSpecRevisions[name] = spec.SpecRevision() } } func (c *Controller) loadState(ctx context.Context, state *agent.State) error { data, err := ioutil.ReadFile(c.filename) if err != nil { return errors.WithStack(err) } if err := json.Unmarshal(data, state); err != nil { return errors.WithStack(err) } return nil } func (c *Controller) writeState(ctx context.Context, state *agent.State) error { dir, file := filepath.Split(c.filename) if dir == "" { dir = "." } f, err := ioutil.TempFile(dir, file) if err != nil { return errors.Errorf("cannot create temp file: %v", err) } defer func() { if err == nil { return } if err := os.Remove(f.Name()); err != nil { if errors.Is(err, os.ErrNotExist) { return } err = errors.WithStack(err) logger.Error(ctx, "could not remove temporary file", logger.CapturedE(err)) } }() defer func() { if err := f.Close(); err != nil { if errors.Is(err, os.ErrClosed) { return } err = errors.WithStack(err) logger.Error(ctx, "could not close temporary file", logger.CapturedE(err)) } }() data, err := json.Marshal(state) if err != nil { return errors.WithStack(err) } name := f.Name() if err := ioutil.WriteFile(name, data, os.ModePerm); err != nil { return errors.Errorf("cannot write data to temporary file %q: %v", name, err) } if err := f.Sync(); err != nil { return errors.Errorf("can't flush temporary file %q: %v", name, err) } if err := f.Close(); err != nil { return errors.Errorf("can't close temporary file %q: %v", name, err) } // get the file mode from the original file and use that for the replacement // file, too. destInfo, err := os.Stat(c.filename) switch { case os.IsNotExist(err): // Do nothing case err != nil: return errors.WithStack(err) default: sourceInfo, err := os.Stat(name) if err != nil { return errors.WithStack(err) } if sourceInfo.Mode() != destInfo.Mode() { if err := os.Chmod(name, destInfo.Mode()); err != nil { return fmt.Errorf("can't set filemode on temporary file %q: %v", name, err) } } } if err := os.Rename(name, c.filename); err != nil { return fmt.Errorf("cannot replace %q with temporary file %q: %v", c.filename, name, err) } return nil } func NewController(filename string) *Controller { return &Controller{ filename: filename, trackedSpecRevisions: make(map[spec.Name]int), } } var _ agent.Controller = &Controller{}