feat: only check preferred/fallback protocols availability
This commit is contained in:
parent
83288967e3
commit
593c062993
|
@ -30,58 +30,61 @@ func (c *Client) Protocol(ctx context.Context) (protocol.Identifier, protocol.Op
|
||||||
|
|
||||||
func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) {
|
func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) {
|
||||||
c.getProtocolOnce.Do(func() {
|
c.getProtocolOnce.Do(func() {
|
||||||
availables, err := c.opts.Protocols.Availables(
|
opts := []protocol.ProtocolOptionFunc{
|
||||||
ctx, c.addr, c.opts.AvailableTimeout,
|
|
||||||
protocol.WithProtocolLogger(c.opts.Logger),
|
protocol.WithProtocolLogger(c.opts.Logger),
|
||||||
protocol.WithProtocolDial(c.opts.Dial),
|
protocol.WithProtocolDial(c.opts.Dial),
|
||||||
)
|
}
|
||||||
|
|
||||||
|
preferred, err := c.opts.Protocols.Get(c.opts.PreferredProtocol, opts...)
|
||||||
|
if err != nil && c.opts.FallbackProtocol == "" {
|
||||||
|
c.getProtocolOnceErr = errors.WithStack(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if preferred != nil {
|
||||||
|
preferredCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ok, err := preferred.Available(preferredCtx, c.addr)
|
||||||
|
if err != nil && c.opts.FallbackProtocol == "" {
|
||||||
|
c.getProtocolOnceErr = errors.WithStack(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
c.proto = preferred.Identifier()
|
||||||
|
c.ops = preferred.Operations(c.addr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.opts.FallbackProtocol == "" {
|
||||||
|
c.getProtocolOnceErr = errors.WithStack(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fallback, err := c.opts.Protocols.Get(c.opts.FallbackProtocol, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.getProtocolOnceErr = errors.WithStack(err)
|
c.getProtocolOnceErr = errors.WithStack(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var preferred protocol.Protocol
|
fallbackCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout)
|
||||||
var fallback protocol.Protocol
|
defer cancel()
|
||||||
|
|
||||||
for _, proto := range availables {
|
ok, err := fallback.Available(fallbackCtx, c.addr)
|
||||||
if proto.Identifier() == c.opts.FallbackProtocol {
|
if err != nil && c.opts.FallbackProtocol == "" {
|
||||||
fallback = proto
|
c.getProtocolOnceErr = errors.WithStack(err)
|
||||||
}
|
|
||||||
|
|
||||||
if proto.Identifier() == c.opts.PreferredProtocol {
|
|
||||||
preferred = proto
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if preferred != nil {
|
|
||||||
c.proto = preferred.Identifier()
|
|
||||||
c.ops = preferred.Operations(c.addr)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if fallback != nil {
|
if ok {
|
||||||
c.proto = fallback.Identifier()
|
c.proto = fallback.Identifier()
|
||||||
c.ops = fallback.Operations(c.addr)
|
c.ops = fallback.Operations(c.addr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, proto := range availables {
|
|
||||||
if proto.Identifier() != c.opts.FallbackProtocol {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fallback = proto
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if fallback == nil {
|
|
||||||
c.getProtocolOnceErr = errors.Errorf("neither preferred protocol '%v' or fallback '%v' are available", c.opts.PreferredProtocol, c.opts.FallbackProtocol)
|
c.getProtocolOnceErr = errors.Errorf("neither preferred protocol '%v' or fallback '%v' are available", c.opts.PreferredProtocol, c.opts.FallbackProtocol)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.proto = fallback.Identifier()
|
|
||||||
c.ops = fallback.Operations(c.addr)
|
|
||||||
})
|
})
|
||||||
if c.getProtocolOnceErr != nil {
|
if c.getProtocolOnceErr != nil {
|
||||||
return "", nil, errors.WithStack(c.getProtocolOnceErr)
|
return "", nil, errors.WithStack(c.getProtocolOnceErr)
|
||||||
|
|
|
@ -15,6 +15,22 @@ func (r *Registry) Register(identifier Identifier, factory ProtocolFactory) {
|
||||||
r.protocols[identifier] = factory
|
r.protocols[identifier] = factory
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Registry) Get(identifier Identifier, funcs ...ProtocolOptionFunc) (Protocol, error) {
|
||||||
|
factory, exists := r.protocols[identifier]
|
||||||
|
if !exists {
|
||||||
|
return nil, errors.WithStack(ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
protocolOpts := NewProtocolOptions(funcs...)
|
||||||
|
|
||||||
|
proto, err := factory(protocolOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return proto, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Registry) Availables(ctx context.Context, addr string, timeout time.Duration, funcs ...ProtocolOptionFunc) ([]Protocol, error) {
|
func (r *Registry) Availables(ctx context.Context, addr string, timeout time.Duration, funcs ...ProtocolOptionFunc) ([]Protocol, error) {
|
||||||
availables := make([]Protocol, 0)
|
availables := make([]Protocol, 0)
|
||||||
protocolOpts := NewProtocolOptions(funcs...)
|
protocolOpts := NewProtocolOptions(funcs...)
|
||||||
|
|
Loading…
Reference in New Issue