feat: only check preferred/fallback protocols availability

This commit is contained in:
wpetit 2024-08-13 15:21:53 +02:00
parent 83288967e3
commit 593c062993
2 changed files with 55 additions and 36 deletions

View File

@ -30,58 +30,61 @@ func (c *Client) Protocol(ctx context.Context) (protocol.Identifier, protocol.Op
func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) { func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) {
c.getProtocolOnce.Do(func() { c.getProtocolOnce.Do(func() {
availables, err := c.opts.Protocols.Availables( opts := []protocol.ProtocolOptionFunc{
ctx, c.addr, c.opts.AvailableTimeout,
protocol.WithProtocolLogger(c.opts.Logger), protocol.WithProtocolLogger(c.opts.Logger),
protocol.WithProtocolDial(c.opts.Dial), protocol.WithProtocolDial(c.opts.Dial),
) }
preferred, err := c.opts.Protocols.Get(c.opts.PreferredProtocol, opts...)
if err != nil && c.opts.FallbackProtocol == "" {
c.getProtocolOnceErr = errors.WithStack(err)
return
}
if preferred != nil {
preferredCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout)
defer cancel()
ok, err := preferred.Available(preferredCtx, c.addr)
if err != nil && c.opts.FallbackProtocol == "" {
c.getProtocolOnceErr = errors.WithStack(err)
return
}
if ok {
c.proto = preferred.Identifier()
c.ops = preferred.Operations(c.addr)
return
}
}
if c.opts.FallbackProtocol == "" {
c.getProtocolOnceErr = errors.WithStack(err)
return
}
fallback, err := c.opts.Protocols.Get(c.opts.FallbackProtocol, opts...)
if err != nil { if err != nil {
c.getProtocolOnceErr = errors.WithStack(err) c.getProtocolOnceErr = errors.WithStack(err)
return return
} }
var preferred protocol.Protocol fallbackCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout)
var fallback protocol.Protocol defer cancel()
for _, proto := range availables { ok, err := fallback.Available(fallbackCtx, c.addr)
if proto.Identifier() == c.opts.FallbackProtocol { if err != nil && c.opts.FallbackProtocol == "" {
fallback = proto c.getProtocolOnceErr = errors.WithStack(err)
}
if proto.Identifier() == c.opts.PreferredProtocol {
preferred = proto
break
}
}
if preferred != nil {
c.proto = preferred.Identifier()
c.ops = preferred.Operations(c.addr)
return return
} }
if fallback != nil { if ok {
c.proto = fallback.Identifier() c.proto = fallback.Identifier()
c.ops = fallback.Operations(c.addr) c.ops = fallback.Operations(c.addr)
return return
} }
for _, proto := range availables {
if proto.Identifier() != c.opts.FallbackProtocol {
continue
}
fallback = proto
break
}
if fallback == nil {
c.getProtocolOnceErr = errors.Errorf("neither preferred protocol '%v' or fallback '%v' are available", c.opts.PreferredProtocol, c.opts.FallbackProtocol) c.getProtocolOnceErr = errors.Errorf("neither preferred protocol '%v' or fallback '%v' are available", c.opts.PreferredProtocol, c.opts.FallbackProtocol)
return
}
c.proto = fallback.Identifier()
c.ops = fallback.Operations(c.addr)
}) })
if c.getProtocolOnceErr != nil { if c.getProtocolOnceErr != nil {
return "", nil, errors.WithStack(c.getProtocolOnceErr) return "", nil, errors.WithStack(c.getProtocolOnceErr)

View File

@ -15,6 +15,22 @@ func (r *Registry) Register(identifier Identifier, factory ProtocolFactory) {
r.protocols[identifier] = factory r.protocols[identifier] = factory
} }
func (r *Registry) Get(identifier Identifier, funcs ...ProtocolOptionFunc) (Protocol, error) {
factory, exists := r.protocols[identifier]
if !exists {
return nil, errors.WithStack(ErrNotFound)
}
protocolOpts := NewProtocolOptions(funcs...)
proto, err := factory(protocolOpts)
if err != nil {
return nil, errors.WithStack(err)
}
return proto, nil
}
func (r *Registry) Availables(ctx context.Context, addr string, timeout time.Duration, funcs ...ProtocolOptionFunc) ([]Protocol, error) { func (r *Registry) Availables(ctx context.Context, addr string, timeout time.Duration, funcs ...ProtocolOptionFunc) ([]Protocol, error) {
availables := make([]Protocol, 0) availables := make([]Protocol, 0)
protocolOpts := NewProtocolOptions(funcs...) protocolOpts := NewProtocolOptions(funcs...)