package cast import ( "context" "sync" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type Service interface { Scan(ctx context.Context) ([]Device, error) Find(ctx context.Context, deviceID string) (Device, error) NewClient(ctx context.Context, device Device) (Client, error) } type Registry struct { index map[DeviceType]Service } func (r *Registry) NewClient(ctx context.Context, device Device) (Client, error) { deviceType := device.DeviceType() srv, exists := r.index[deviceType] if !exists { return nil, errors.Wrapf(ErrUnknownDeviceType, "device type '%s' is not registered", deviceType) } client, err := srv.NewClient(ctx, device) if err != nil { return nil, errors.WithStack(err) } return client, nil } func (r *Registry) Find(ctx context.Context, deviceID string) (Device, error) { for _, srv := range r.index { device, err := srv.Find(ctx, deviceID) if err != nil { logger.Error(ctx, "could not get device", logger.CapturedE(errors.WithStack(err))) continue } if device != nil { return device, nil } } return nil, errors.WithStack(ErrDeviceNotFound) } func (r *Registry) Scan(ctx context.Context) ([]Device, error) { results := make([]Device, 0) errs := make([]error, 0) var ( lock sync.Mutex wg sync.WaitGroup ) wg.Add(len(r.index)) for _, srv := range r.index { go func() { defer wg.Done() devices, err := srv.Scan(ctx) if err != nil { lock.Lock() errs = append(errs, errors.WithStack(err)) lock.Unlock() } lock.Lock() results = append(results, devices...) lock.Unlock() }() } wg.Wait() for _, err := range errs { logger.Error(ctx, "error occured while scanning", logger.CapturedE(errors.WithStack(err))) } return results, nil } func (r *Registry) Register(deviceType DeviceType, service Service) { r.index[deviceType] = service } func NewRegistry() *Registry { return &Registry{ index: make(map[DeviceType]Service), } } var defaultRegistry = NewRegistry() func NewClient(ctx context.Context, device Device) (Client, error) { client, err := defaultRegistry.NewClient(ctx, device) if err != nil { return nil, errors.WithStack(err) } return client, nil } func Scan(ctx context.Context) ([]Device, error) { devices, err := defaultRegistry.Scan(ctx) if err != nil { return nil, errors.WithStack(err) } return devices, nil } func Find(ctx context.Context, deviceID string) (Device, error) { device, err := defaultRegistry.Find(ctx, deviceID) if err != nil { return nil, errors.WithStack(err) } return device, nil } func Register(deviceType DeviceType, service Service) { defaultRegistry.Register(deviceType, service) }