feat: only check preferred/fallback protocols availability
This commit is contained in:
parent
83288967e3
commit
593c062993
|
@ -30,58 +30,61 @@ func (c *Client) Protocol(ctx context.Context) (protocol.Identifier, protocol.Op
|
|||
|
||||
func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) {
|
||||
c.getProtocolOnce.Do(func() {
|
||||
availables, err := c.opts.Protocols.Availables(
|
||||
ctx, c.addr, c.opts.AvailableTimeout,
|
||||
opts := []protocol.ProtocolOptionFunc{
|
||||
protocol.WithProtocolLogger(c.opts.Logger),
|
||||
protocol.WithProtocolDial(c.opts.Dial),
|
||||
)
|
||||
}
|
||||
|
||||
preferred, err := c.opts.Protocols.Get(c.opts.PreferredProtocol, opts...)
|
||||
if err != nil && c.opts.FallbackProtocol == "" {
|
||||
c.getProtocolOnceErr = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
if preferred != nil {
|
||||
preferredCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout)
|
||||
defer cancel()
|
||||
|
||||
ok, err := preferred.Available(preferredCtx, c.addr)
|
||||
if err != nil && c.opts.FallbackProtocol == "" {
|
||||
c.getProtocolOnceErr = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
if ok {
|
||||
c.proto = preferred.Identifier()
|
||||
c.ops = preferred.Operations(c.addr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if c.opts.FallbackProtocol == "" {
|
||||
c.getProtocolOnceErr = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
fallback, err := c.opts.Protocols.Get(c.opts.FallbackProtocol, opts...)
|
||||
if err != nil {
|
||||
c.getProtocolOnceErr = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
var preferred protocol.Protocol
|
||||
var fallback protocol.Protocol
|
||||
fallbackCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout)
|
||||
defer cancel()
|
||||
|
||||
for _, proto := range availables {
|
||||
if proto.Identifier() == c.opts.FallbackProtocol {
|
||||
fallback = proto
|
||||
}
|
||||
|
||||
if proto.Identifier() == c.opts.PreferredProtocol {
|
||||
preferred = proto
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if preferred != nil {
|
||||
c.proto = preferred.Identifier()
|
||||
c.ops = preferred.Operations(c.addr)
|
||||
ok, err := fallback.Available(fallbackCtx, c.addr)
|
||||
if err != nil && c.opts.FallbackProtocol == "" {
|
||||
c.getProtocolOnceErr = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
if fallback != nil {
|
||||
if ok {
|
||||
c.proto = fallback.Identifier()
|
||||
c.ops = fallback.Operations(c.addr)
|
||||
return
|
||||
}
|
||||
|
||||
for _, proto := range availables {
|
||||
if proto.Identifier() != c.opts.FallbackProtocol {
|
||||
continue
|
||||
}
|
||||
|
||||
fallback = proto
|
||||
break
|
||||
}
|
||||
|
||||
if fallback == nil {
|
||||
c.getProtocolOnceErr = errors.Errorf("neither preferred protocol '%v' or fallback '%v' are available", c.opts.PreferredProtocol, c.opts.FallbackProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.proto = fallback.Identifier()
|
||||
c.ops = fallback.Operations(c.addr)
|
||||
})
|
||||
if c.getProtocolOnceErr != nil {
|
||||
return "", nil, errors.WithStack(c.getProtocolOnceErr)
|
||||
|
|
|
@ -15,6 +15,22 @@ func (r *Registry) Register(identifier Identifier, factory ProtocolFactory) {
|
|||
r.protocols[identifier] = factory
|
||||
}
|
||||
|
||||
func (r *Registry) Get(identifier Identifier, funcs ...ProtocolOptionFunc) (Protocol, error) {
|
||||
factory, exists := r.protocols[identifier]
|
||||
if !exists {
|
||||
return nil, errors.WithStack(ErrNotFound)
|
||||
}
|
||||
|
||||
protocolOpts := NewProtocolOptions(funcs...)
|
||||
|
||||
proto, err := factory(protocolOpts)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return proto, nil
|
||||
}
|
||||
|
||||
func (r *Registry) Availables(ctx context.Context, addr string, timeout time.Duration, funcs ...ProtocolOptionFunc) ([]Protocol, error) {
|
||||
availables := make([]Protocol, 0)
|
||||
protocolOpts := NewProtocolOptions(funcs...)
|
||||
|
|
Loading…
Reference in New Issue