From 593c06299311f56efdbb5c668bcb788a70e20ec2 Mon Sep 17 00:00:00 2001 From: William Petit Date: Tue, 13 Aug 2024 15:21:53 +0200 Subject: [PATCH] feat: only check preferred/fallback protocols availability --- reach/client/client.go | 75 ++++++++++++++++--------------- reach/client/protocol/registry.go | 16 +++++++ 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/reach/client/client.go b/reach/client/client.go index 1ce01cc..ee667e1 100644 --- a/reach/client/client.go +++ b/reach/client/client.go @@ -30,58 +30,61 @@ func (c *Client) Protocol(ctx context.Context) (protocol.Identifier, protocol.Op func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) { c.getProtocolOnce.Do(func() { - availables, err := c.opts.Protocols.Availables( - ctx, c.addr, c.opts.AvailableTimeout, + opts := []protocol.ProtocolOptionFunc{ protocol.WithProtocolLogger(c.opts.Logger), protocol.WithProtocolDial(c.opts.Dial), - ) + } + + preferred, err := c.opts.Protocols.Get(c.opts.PreferredProtocol, opts...) + if err != nil && c.opts.FallbackProtocol == "" { + c.getProtocolOnceErr = errors.WithStack(err) + return + } + + if preferred != nil { + preferredCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout) + defer cancel() + + ok, err := preferred.Available(preferredCtx, c.addr) + if err != nil && c.opts.FallbackProtocol == "" { + c.getProtocolOnceErr = errors.WithStack(err) + return + } + + if ok { + c.proto = preferred.Identifier() + c.ops = preferred.Operations(c.addr) + return + } + } + + if c.opts.FallbackProtocol == "" { + c.getProtocolOnceErr = errors.WithStack(err) + return + } + + fallback, err := c.opts.Protocols.Get(c.opts.FallbackProtocol, opts...) if err != nil { c.getProtocolOnceErr = errors.WithStack(err) return } - var preferred protocol.Protocol - var fallback protocol.Protocol + fallbackCtx, cancel := context.WithTimeout(ctx, c.opts.AvailableTimeout) + defer cancel() - for _, proto := range availables { - if proto.Identifier() == c.opts.FallbackProtocol { - fallback = proto - } - - if proto.Identifier() == c.opts.PreferredProtocol { - preferred = proto - break - } - } - - if preferred != nil { - c.proto = preferred.Identifier() - c.ops = preferred.Operations(c.addr) + ok, err := fallback.Available(fallbackCtx, c.addr) + if err != nil && c.opts.FallbackProtocol == "" { + c.getProtocolOnceErr = errors.WithStack(err) return } - if fallback != nil { + if ok { c.proto = fallback.Identifier() c.ops = fallback.Operations(c.addr) return } - for _, proto := range availables { - if proto.Identifier() != c.opts.FallbackProtocol { - continue - } - - fallback = proto - break - } - - if fallback == nil { - c.getProtocolOnceErr = errors.Errorf("neither preferred protocol '%v' or fallback '%v' are available", c.opts.PreferredProtocol, c.opts.FallbackProtocol) - return - } - - c.proto = fallback.Identifier() - c.ops = fallback.Operations(c.addr) + c.getProtocolOnceErr = errors.Errorf("neither preferred protocol '%v' or fallback '%v' are available", c.opts.PreferredProtocol, c.opts.FallbackProtocol) }) if c.getProtocolOnceErr != nil { return "", nil, errors.WithStack(c.getProtocolOnceErr) diff --git a/reach/client/protocol/registry.go b/reach/client/protocol/registry.go index 7edab13..d8afc48 100644 --- a/reach/client/protocol/registry.go +++ b/reach/client/protocol/registry.go @@ -15,6 +15,22 @@ func (r *Registry) Register(identifier Identifier, factory ProtocolFactory) { r.protocols[identifier] = factory } +func (r *Registry) Get(identifier Identifier, funcs ...ProtocolOptionFunc) (Protocol, error) { + factory, exists := r.protocols[identifier] + if !exists { + return nil, errors.WithStack(ErrNotFound) + } + + protocolOpts := NewProtocolOptions(funcs...) + + proto, err := factory(protocolOpts) + if err != nil { + return nil, errors.WithStack(err) + } + + return proto, nil +} + func (r *Registry) Availables(ctx context.Context, addr string, timeout time.Duration, funcs ...ProtocolOptionFunc) ([]Protocol, error) { availables := make([]Protocol, 0) protocolOpts := NewProtocolOptions(funcs...)