From 6842d4d88ac684e0b35db1e6f56e5e09342fb8cc Mon Sep 17 00:00:00 2001 From: William Petit Date: Mon, 5 Aug 2024 18:10:19 +0200 Subject: [PATCH] feat(protocol): allow override of dial func --- go.mod | 2 +- go.sum | 2 ++ reach/client/client.go | 6 +++- reach/client/options.go | 8 +++++ reach/client/protocol/protocol.go | 20 +++++++++++ reach/client/protocol/v1/init.go | 1 + reach/client/protocol/v1/operations.go | 6 ++-- reach/client/protocol/v1/protocol.go | 3 +- reach/client/protocol/v2/init.go | 1 + reach/client/protocol/v2/internal.go | 21 ++++++++++-- reach/client/protocol/v2/operations.go | 9 +++-- reach/client/protocol/v2/protocol.go | 7 +++- reach/client/socketio/client.go | 15 +++++---- reach/client/socketio/endpoint.go | 2 +- reach/client/socketio/options.go | 27 ++++++++++++++- reach/client/socketio/transport.go | 46 ++++++++++++++++++++++++++ 16 files changed, 158 insertions(+), 18 deletions(-) create mode 100644 reach/client/socketio/transport.go diff --git a/go.mod b/go.mod index 5038c2f..631b4df 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module forge.cadoles.com/cadoles/go-emlid go 1.22.5 require ( - forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95 + forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46 github.com/Masterminds/semver/v3 v3.2.1 github.com/davecgh/go-spew v1.1.1 github.com/grandcat/zeroconf v1.0.0 diff --git a/go.sum b/go.sum index ca67ce8..0c5f228 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95 h1:o3G5+9RjczCK1xAYFaRMknk1kY9Ule6PNfiW6N6hEpg= forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95/go.mod h1:I6kYOFWNkFlNeQLI7ZqfTRz4NdPHZxX0Bzizmzgchs0= +forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46 h1:vLTYHA4+pYeI9mZvCMrc29AmnNjeGEpEG1mTwtCOoDI= +forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46/go.mod h1:bT+HWia42VRX1TzTUlEM645tPJEOtsEdzlKBiEqVchY= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= diff --git a/reach/client/client.go b/reach/client/client.go index 33af6d3..1ce01cc 100644 --- a/reach/client/client.go +++ b/reach/client/client.go @@ -30,7 +30,11 @@ func (c *Client) Protocol(ctx context.Context) (protocol.Identifier, protocol.Op func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) { c.getProtocolOnce.Do(func() { - availables, err := c.opts.Protocols.Availables(ctx, c.addr, c.opts.AvailableTimeout, protocol.WithProtocolLogger(c.opts.Logger)) + availables, err := c.opts.Protocols.Availables( + ctx, c.addr, c.opts.AvailableTimeout, + protocol.WithProtocolLogger(c.opts.Logger), + protocol.WithProtocolDial(c.opts.Dial), + ) if err != nil { c.getProtocolOnceErr = errors.WithStack(err) return diff --git a/reach/client/options.go b/reach/client/options.go index 6d46539..6878bef 100644 --- a/reach/client/options.go +++ b/reach/client/options.go @@ -16,6 +16,7 @@ type Options struct { FallbackProtocol protocol.Identifier AvailableTimeout time.Duration Logger logger.Logger + Dial protocol.DialFunc } type OptionFunc func(opts *Options) @@ -27,6 +28,7 @@ func NewOptions(funcs ...OptionFunc) *Options { Protocols: protocol.DefaultRegistry(), AvailableTimeout: 5 * time.Second, Logger: slog.Default(), + Dial: protocol.DefaultDialFunc, } for _, fn := range funcs { @@ -65,3 +67,9 @@ func WithAvailableTimeout(timeout time.Duration) OptionFunc { opts.AvailableTimeout = timeout } } + +func WithDial(dial protocol.DialFunc) OptionFunc { + return func(opts *Options) { + opts.Dial = dial + } +} diff --git a/reach/client/protocol/protocol.go b/reach/client/protocol/protocol.go index 20f8aa1..0d413a3 100644 --- a/reach/client/protocol/protocol.go +++ b/reach/client/protocol/protocol.go @@ -3,12 +3,16 @@ package protocol import ( "context" "log/slog" + "net" "forge.cadoles.com/cadoles/go-emlid/reach/client/logger" + "github.com/pkg/errors" ) type Identifier string +type DialFunc func(network string, addr string) (net.Conn, error) + type Protocol interface { Identifier() Identifier Available(ctx context.Context, addr string) (bool, error) @@ -17,6 +21,16 @@ type Protocol interface { type ProtocolOptions struct { Logger logger.Logger + Dial DialFunc +} + +var DefaultDialFunc = func(network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, errors.WithStack(err) + } + + return conn, nil } type ProtocolFactory func(opts *ProtocolOptions) (Protocol, error) @@ -40,3 +54,9 @@ func WithProtocolLogger(logger logger.Logger) ProtocolOptionFunc { opts.Logger = logger } } + +func WithProtocolDial(dial DialFunc) ProtocolOptionFunc { + return func(opts *ProtocolOptions) { + opts.Dial = dial + } +} diff --git a/reach/client/protocol/v1/init.go b/reach/client/protocol/v1/init.go index c2f1488..522bb43 100644 --- a/reach/client/protocol/v1/init.go +++ b/reach/client/protocol/v1/init.go @@ -6,6 +6,7 @@ func init() { protocol.Register(Identifier, func(opts *protocol.ProtocolOptions) (protocol.Protocol, error) { return &Protocol{ logger: opts.Logger, + dial: opts.Dial, }, nil }) } diff --git a/reach/client/protocol/v1/operations.go b/reach/client/protocol/v1/operations.go index 7face3e..2e3689e 100644 --- a/reach/client/protocol/v1/operations.go +++ b/reach/client/protocol/v1/operations.go @@ -18,6 +18,8 @@ type Operations struct { client *socketio.Client mutex sync.RWMutex logger logger.Logger + + dial protocol.DialFunc } // Close implements protocol.Operations. @@ -49,12 +51,12 @@ func (o *Operations) Connect(ctx context.Context) error { o.client.Close() } - endpoint, err := socketio.EndpointFromHAddr(o.addr) + endpoint, err := socketio.EndpointFromAddr(o.addr) if err != nil { return errors.WithStack(err) } - client := socketio.NewClient(endpoint) + client := socketio.NewClient(endpoint, socketio.WithDialFunc(socketio.DialFunc(o.dial))) o.logger.Debug("connecting", logger.Attr("endpoint", endpoint)) diff --git a/reach/client/protocol/v1/protocol.go b/reach/client/protocol/v1/protocol.go index 683c303..c78dcf0 100644 --- a/reach/client/protocol/v1/protocol.go +++ b/reach/client/protocol/v1/protocol.go @@ -16,6 +16,7 @@ const compatibleVersionConstraint = "^2.24" type Protocol struct { logger logger.Logger + dial protocol.DialFunc } // Available implements protocol.Protocol. @@ -59,7 +60,7 @@ func (p *Protocol) Identifier() protocol.Identifier { // Operations implements protocol.Protocol. func (p *Protocol) Operations(addr string) protocol.Operations { - return &Operations{addr: addr, logger: p.logger} + return &Operations{addr: addr, logger: p.logger, dial: p.dial} } var _ protocol.Protocol = &Protocol{} diff --git a/reach/client/protocol/v2/init.go b/reach/client/protocol/v2/init.go index 0569767..fef579d 100644 --- a/reach/client/protocol/v2/init.go +++ b/reach/client/protocol/v2/init.go @@ -6,6 +6,7 @@ func init() { protocol.Register(Identifier, func(opts *protocol.ProtocolOptions) (protocol.Protocol, error) { return &Protocol{ logger: opts.Logger, + dial: opts.Dial, }, nil }) } diff --git a/reach/client/protocol/v2/internal.go b/reach/client/protocol/v2/internal.go index 8c9dc04..8367bb3 100644 --- a/reach/client/protocol/v2/internal.go +++ b/reach/client/protocol/v2/internal.go @@ -20,7 +20,9 @@ func (o *Operations) GetJSON(path string, dst any) error { var res *http.Response url := o.getURL(path) - res, err := http.Get(url) + client := o.getHTTPClient() + + res, err := client.Get(url) if err != nil { return errors.WithStack(err) } @@ -50,6 +52,7 @@ func (o *Operations) GetJSON(path string, dst any) error { } func (o *Operations) PostJSON(path string, data any, dst any) error { + var res *http.Response var buf bytes.Buffer @@ -60,7 +63,9 @@ func (o *Operations) PostJSON(path string, data any, dst any) error { } url := o.getURL(path) - res, err := http.Post(url, "application/json", &buf) + client := o.getHTTPClient() + + res, err := client.Post(url, "application/json", &buf) if err != nil { return errors.WithStack(err) } @@ -89,6 +94,18 @@ func (o *Operations) PostJSON(path string, data any, dst any) error { return nil } +func (o *Operations) getHTTPClient() *http.Client { + o.getClientOnce.Do(func() { + o.httpClient = &http.Client{ + Transport: &http.Transport{ + Dial: o.dial, + }, + } + }) + + return o.httpClient +} + func (o *Operations) PostBaseCoordinates(ctx context.Context, base *model.Base) (*model.Base, error) { var updated model.Base diff --git a/reach/client/protocol/v2/operations.go b/reach/client/protocol/v2/operations.go index 38cb9a2..18ea4c5 100644 --- a/reach/client/protocol/v2/operations.go +++ b/reach/client/protocol/v2/operations.go @@ -2,6 +2,7 @@ package v2 import ( "context" + "net/http" "sync" "forge.cadoles.com/cadoles/go-emlid/reach/client/logger" @@ -17,6 +18,10 @@ type Operations struct { client *socketio.Client mutex sync.RWMutex logger logger.Logger + dial protocol.DialFunc + + getClientOnce sync.Once + httpClient *http.Client } // Reboot implements protocol.Operations. @@ -161,12 +166,12 @@ func (o *Operations) Connect(ctx context.Context) error { o.client.Close() } - endpoint, err := socketio.EndpointFromHAddr(o.addr) + endpoint, err := socketio.EndpointFromAddr(o.addr) if err != nil { return errors.WithStack(err) } - client := socketio.NewClient(endpoint) + client := socketio.NewClient(endpoint, socketio.WithDialFunc(socketio.DialFunc(o.dial))) if err := client.Connect(); err != nil { return errors.WithStack(err) diff --git a/reach/client/protocol/v2/protocol.go b/reach/client/protocol/v2/protocol.go index 033c7c2..ecfb91d 100644 --- a/reach/client/protocol/v2/protocol.go +++ b/reach/client/protocol/v2/protocol.go @@ -15,6 +15,7 @@ const compatibleVersionConstraint = ">= 32" type Protocol struct { logger logger.Logger + dial protocol.DialFunc } // Available implements protocol.Protocol. @@ -50,7 +51,11 @@ func (p *Protocol) Identifier() protocol.Identifier { // Operations implements protocol.Protocol. func (p *Protocol) Operations(addr string) protocol.Operations { - return &Operations{addr: addr, logger: p.logger} + return &Operations{ + dial: p.dial, + addr: addr, + logger: p.logger, + } } var _ protocol.Protocol = &Protocol{} diff --git a/reach/client/socketio/client.go b/reach/client/socketio/client.go index 43bc6eb..96a9bb5 100644 --- a/reach/client/socketio/client.go +++ b/reach/client/socketio/client.go @@ -32,12 +32,15 @@ func (c *Client) Connect() error { wg.Add(1) - transport := &transport.WebsocketTransport{ - PingInterval: c.opts.PingInterval, - PingTimeout: c.opts.PingTimeout, - ReceiveTimeout: c.opts.ReceiveTimeout, - SendTimeout: c.opts.SendTimeout, - BufferSize: c.opts.BufferSize, + transport := &Transport{ + dial: c.opts.DialFunc, + ws: &transport.WebsocketTransport{ + PingInterval: c.opts.PingInterval, + PingTimeout: c.opts.PingTimeout, + ReceiveTimeout: c.opts.ReceiveTimeout, + SendTimeout: c.opts.SendTimeout, + BufferSize: c.opts.BufferSize, + }, } conn, err := gosocketio.Dial(c.endpoint, transport) diff --git a/reach/client/socketio/endpoint.go b/reach/client/socketio/endpoint.go index 33d8bec..f03fd68 100644 --- a/reach/client/socketio/endpoint.go +++ b/reach/client/socketio/endpoint.go @@ -9,7 +9,7 @@ import ( "github.com/pkg/errors" ) -func EndpointFromHAddr(addr string) (string, error) { +func EndpointFromAddr(addr string) (string, error) { host, rawPort, err := net.SplitHostPort(addr) if err != nil { var addrErr *net.AddrError diff --git a/reach/client/socketio/options.go b/reach/client/socketio/options.go index c7d4bb2..97d0d7e 100644 --- a/reach/client/socketio/options.go +++ b/reach/client/socketio/options.go @@ -1,6 +1,13 @@ package socketio -import "time" +import ( + "net" + "time" + + "github.com/pkg/errors" +) + +type DialFunc func(network, addr string) (net.Conn, error) type Options struct { PingInterval time.Duration @@ -8,6 +15,7 @@ type Options struct { ReceiveTimeout time.Duration SendTimeout time.Duration BufferSize int + DialFunc DialFunc } type OptionFunc func(opts *Options) @@ -19,6 +27,7 @@ func NewOptions(funcs ...OptionFunc) *Options { ReceiveTimeout: 60 * time.Second, SendTimeout: 60 * time.Second, BufferSize: 1024 * 32, + DialFunc: DefaultDialFunc, } for _, fn := range funcs { fn(opts) @@ -60,3 +69,19 @@ func WithBufferSize(size int) OptionFunc { opts.BufferSize = size } } + +var DefaultDialFunc = func(network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, errors.WithStack(err) + } + + return conn, nil +} + +// WithDialFunc configures the client to use the given dial func +func WithDialFunc(dial DialFunc) OptionFunc { + return func(opts *Options) { + opts.DialFunc = dial + } +} diff --git a/reach/client/socketio/transport.go b/reach/client/socketio/transport.go new file mode 100644 index 0000000..f868c10 --- /dev/null +++ b/reach/client/socketio/transport.go @@ -0,0 +1,46 @@ +package socketio + +import ( + "net" + "net/http" + + "forge.cadoles.com/Pyxis/golang-socketio/transport" + "github.com/gorilla/websocket" +) + +type Transport struct { + dial DialFunc + ws *transport.WebsocketTransport +} + +// Connect implements transport.Transport. +func (t *Transport) Connect(url string) (conn transport.Connection, err error) { + if t.dial == nil { + return t.ws.Connect(url) + } else { + dialer := websocket.Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + return t.dial(network, addr) + }, + } + + socket, _, err := dialer.Dial(url, t.ws.RequestHeader) + if err != nil { + return nil, err + } + + return transport.NewWebsocketConnection(socket, t.ws), nil + } +} + +// HandleConnection implements transport.Transport. +func (t *Transport) HandleConnection(w http.ResponseWriter, r *http.Request) (conn transport.Connection, err error) { + return t.ws.HandleConnection(w, r) +} + +// Serve implements transport.Transport. +func (t *Transport) Serve(w http.ResponseWriter, r *http.Request) { + t.ws.Serve(w, r) +} + +var _ transport.Transport = &Transport{}