Basic plugin system with centralized registry

This commit is contained in:
2020-11-17 09:19:15 +01:00
parent fe90d81b5d
commit ad5c02bd80
13 changed files with 333 additions and 1 deletions

16
plugin/error.go Normal file
View File

@ -0,0 +1,16 @@
package plugin
import "errors"
var (
// ErrInvalidRegisterFunc is returned when the plugin package
// could not find the expected RegisterPlugin func in the loaded
// plugin.
ErrInvalidRegisterFunc = errors.New("invalid register func")
// ErrInvalidPlugin is returned when a loaded plugin does
// not match the expected interface.
ErrInvalidPlugin = errors.New("invalid plugin")
// ErrPluginNotFound is returned when the given plugin could
// not be found in the registry.
ErrPluginNotFound = errors.New("plugin not found")
)

6
plugin/plugin.go Normal file
View File

@ -0,0 +1,6 @@
package plugin
type Plugin interface {
PluginName() string
PluginVersion() string
}

117
plugin/registry.go Normal file
View File

@ -0,0 +1,117 @@
package plugin
import (
"context"
"path/filepath"
"plugin"
"sync"
"github.com/pkg/errors"
)
type Registry struct {
plugins map[string]Plugin
mutex sync.RWMutex
}
func (r *Registry) Add(plg Plugin) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.plugins[plg.PluginName()] = plg
}
func (r *Registry) Get(name string) (Plugin, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
plg, exists := r.plugins[name]
if !exists {
return nil, errors.WithStack(ErrPluginNotFound)
}
return plg, nil
}
func (r *Registry) Load(ctx context.Context, path string) (Plugin, error) {
p, err := plugin.Open(path)
if err != nil {
return nil, errors.WithStack(err)
}
registerFuncSymbol, err := p.Lookup("RegisterPlugin")
if err != nil {
return nil, errors.WithStack(err)
}
register, ok := registerFuncSymbol.(func(context.Context) (Plugin, error))
if !ok {
return nil, errors.WithStack(ErrInvalidRegisterFunc)
}
plg, err := register(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
if plg == nil {
return nil, errors.WithStack(ErrInvalidPlugin)
}
r.Add(plg)
return plg, nil
}
func (r *Registry) LoadAll(ctx context.Context, pattern string) ([]Plugin, error) {
extensions := make([]Plugin, 0)
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, errors.WithStack(err)
}
for _, m := range matches {
ext, err := r.Load(ctx, m)
if err != nil {
return nil, errors.WithStack(err)
}
extensions = append(extensions, ext)
}
return extensions, nil
}
func (r *Registry) Plugins() []Plugin {
r.mutex.RLock()
defer r.mutex.RUnlock()
plugins := make([]Plugin, 0, len(r.plugins))
for _, p := range r.plugins {
plugins = append(plugins, p)
}
return plugins
}
type IteratorFunc func(plg Plugin) error
func (r *Registry) Each(iterator IteratorFunc) error {
r.mutex.RLock()
defer r.mutex.RUnlock()
for _, p := range r.plugins {
if err := iterator(p); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func NewRegistry() *Registry {
return &Registry{
plugins: make(map[string]Plugin),
}
}

94
plugin/registry_test.go Normal file
View File

@ -0,0 +1,94 @@
package plugin_test
import (
"testing"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/plugin"
)
type testPlugin struct {
name string
version string
}
func (p *testPlugin) PluginName() string {
return p.name
}
func (p *testPlugin) PluginVersion() string {
return p.version
}
func (p *testPlugin) Foo() string {
return "bar"
}
func TestRegistryEach(t *testing.T) {
t.Parallel()
reg := plugin.NewRegistry()
plugins := []*testPlugin{
{"plugin.a", "0.0.0"},
{"plugin.b", "0.0.1"},
}
for _, p := range plugins {
reg.Add(p)
}
total := 0
err := reg.Each(func(p plugin.Plugin) error {
total++
return nil
})
if err != nil {
t.Error(errors.WithStack(err))
}
if e, g := len(plugins), total; e != g {
t.Errorf("total: expected '%v', got '%v'", e, g)
}
}
func TestRegistryGet(t *testing.T) {
t.Parallel()
reg := plugin.NewRegistry()
plugins := []*testPlugin{
{"plugin.a", "0.0.0"},
{"plugin.b", "0.0.1"},
}
for _, p := range plugins {
reg.Add(p)
}
for _, p := range plugins {
plugin, err := reg.Get(p.name)
if err != nil {
t.Error(errors.WithStack(err))
}
if e, g := p.name, plugin.PluginName(); e != g {
t.Errorf("plugin.PluginName(): expected '%v', got '%v'", e, g)
}
if e, g := p.version, plugin.PluginVersion(); e != g {
t.Errorf("plugin.PluginVersion(): expected '%v', got '%v'", e, g)
}
}
p, err := reg.Get("plugin.c")
if !errors.Is(err, plugin.ErrPluginNotFound) {
t.Errorf("err: expected '%v', got '%v'", plugin.ErrPluginNotFound, err)
}
if p != nil {
t.Errorf("reg.Get(\"plugin.c\"): expected '%v', got '%v'", nil, p)
}
}