feat(v2): allow override of dial func

This commit is contained in:
wpetit 2024-08-05 18:10:19 +02:00
parent b976bde363
commit ebb516b02c
13 changed files with 151 additions and 16 deletions

2
go.mod
View File

@ -3,7 +3,7 @@ module forge.cadoles.com/cadoles/go-emlid
go 1.22.5 go 1.22.5
require ( require (
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95 forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46
github.com/Masterminds/semver/v3 v3.2.1 github.com/Masterminds/semver/v3 v3.2.1
github.com/davecgh/go-spew v1.1.1 github.com/davecgh/go-spew v1.1.1
github.com/grandcat/zeroconf v1.0.0 github.com/grandcat/zeroconf v1.0.0

2
go.sum
View File

@ -1,5 +1,7 @@
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95 h1:o3G5+9RjczCK1xAYFaRMknk1kY9Ule6PNfiW6N6hEpg= forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95 h1:o3G5+9RjczCK1xAYFaRMknk1kY9Ule6PNfiW6N6hEpg=
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95/go.mod h1:I6kYOFWNkFlNeQLI7ZqfTRz4NdPHZxX0Bzizmzgchs0= forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95/go.mod h1:I6kYOFWNkFlNeQLI7ZqfTRz4NdPHZxX0Bzizmzgchs0=
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46 h1:vLTYHA4+pYeI9mZvCMrc29AmnNjeGEpEG1mTwtCOoDI=
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46/go.mod h1:bT+HWia42VRX1TzTUlEM645tPJEOtsEdzlKBiEqVchY=
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=

View File

@ -30,7 +30,11 @@ func (c *Client) Protocol(ctx context.Context) (protocol.Identifier, protocol.Op
func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) { func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) {
c.getProtocolOnce.Do(func() { c.getProtocolOnce.Do(func() {
availables, err := c.opts.Protocols.Availables(ctx, c.addr, c.opts.AvailableTimeout, protocol.WithProtocolLogger(c.opts.Logger)) availables, err := c.opts.Protocols.Availables(
ctx, c.addr, c.opts.AvailableTimeout,
protocol.WithProtocolLogger(c.opts.Logger),
protocol.WithProtocolDial(c.opts.Dial),
)
if err != nil { if err != nil {
c.getProtocolOnceErr = errors.WithStack(err) c.getProtocolOnceErr = errors.WithStack(err)
return return

View File

@ -16,6 +16,7 @@ type Options struct {
FallbackProtocol protocol.Identifier FallbackProtocol protocol.Identifier
AvailableTimeout time.Duration AvailableTimeout time.Duration
Logger logger.Logger Logger logger.Logger
Dial protocol.DialFunc
} }
type OptionFunc func(opts *Options) type OptionFunc func(opts *Options)
@ -27,6 +28,7 @@ func NewOptions(funcs ...OptionFunc) *Options {
Protocols: protocol.DefaultRegistry(), Protocols: protocol.DefaultRegistry(),
AvailableTimeout: 5 * time.Second, AvailableTimeout: 5 * time.Second,
Logger: slog.Default(), Logger: slog.Default(),
Dial: protocol.DefaultDialFunc,
} }
for _, fn := range funcs { for _, fn := range funcs {
@ -65,3 +67,9 @@ func WithAvailableTimeout(timeout time.Duration) OptionFunc {
opts.AvailableTimeout = timeout opts.AvailableTimeout = timeout
} }
} }
func WithDial(dial protocol.DialFunc) OptionFunc {
return func(opts *Options) {
opts.Dial = dial
}
}

View File

@ -3,12 +3,16 @@ package protocol
import ( import (
"context" "context"
"log/slog" "log/slog"
"net"
"forge.cadoles.com/cadoles/go-emlid/reach/client/logger" "forge.cadoles.com/cadoles/go-emlid/reach/client/logger"
"github.com/pkg/errors"
) )
type Identifier string type Identifier string
type DialFunc func(network string, addr string) (net.Conn, error)
type Protocol interface { type Protocol interface {
Identifier() Identifier Identifier() Identifier
Available(ctx context.Context, addr string) (bool, error) Available(ctx context.Context, addr string) (bool, error)
@ -17,6 +21,16 @@ type Protocol interface {
type ProtocolOptions struct { type ProtocolOptions struct {
Logger logger.Logger Logger logger.Logger
Dial DialFunc
}
var DefaultDialFunc = func(network, addr string) (net.Conn, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, errors.WithStack(err)
}
return conn, nil
} }
type ProtocolFactory func(opts *ProtocolOptions) (Protocol, error) type ProtocolFactory func(opts *ProtocolOptions) (Protocol, error)
@ -40,3 +54,9 @@ func WithProtocolLogger(logger logger.Logger) ProtocolOptionFunc {
opts.Logger = logger opts.Logger = logger
} }
} }
func WithProtocolDial(dial DialFunc) ProtocolOptionFunc {
return func(opts *ProtocolOptions) {
opts.Dial = dial
}
}

View File

@ -49,7 +49,7 @@ func (o *Operations) Connect(ctx context.Context) error {
o.client.Close() o.client.Close()
} }
endpoint, err := socketio.EndpointFromHAddr(o.addr) endpoint, err := socketio.EndpointFromAddr(o.addr)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }

View File

@ -20,7 +20,9 @@ func (o *Operations) GetJSON(path string, dst any) error {
var res *http.Response var res *http.Response
url := o.getURL(path) url := o.getURL(path)
res, err := http.Get(url) client := o.getHTTPClient()
res, err := client.Get(url)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
@ -50,6 +52,7 @@ func (o *Operations) GetJSON(path string, dst any) error {
} }
func (o *Operations) PostJSON(path string, data any, dst any) error { func (o *Operations) PostJSON(path string, data any, dst any) error {
var res *http.Response var res *http.Response
var buf bytes.Buffer var buf bytes.Buffer
@ -60,7 +63,9 @@ func (o *Operations) PostJSON(path string, data any, dst any) error {
} }
url := o.getURL(path) url := o.getURL(path)
res, err := http.Post(url, "application/json", &buf) client := o.getHTTPClient()
res, err := client.Post(url, "application/json", &buf)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
@ -89,6 +94,18 @@ func (o *Operations) PostJSON(path string, data any, dst any) error {
return nil return nil
} }
func (o *Operations) getHTTPClient() *http.Client {
o.getClientOnce.Do(func() {
o.httpClient = &http.Client{
Transport: &http.Transport{
Dial: o.dial,
},
}
})
return o.httpClient
}
func (o *Operations) PostBaseCoordinates(ctx context.Context, base *model.Base) (*model.Base, error) { func (o *Operations) PostBaseCoordinates(ctx context.Context, base *model.Base) (*model.Base, error) {
var updated model.Base var updated model.Base

View File

@ -2,6 +2,7 @@ package v2
import ( import (
"context" "context"
"net/http"
"sync" "sync"
"forge.cadoles.com/cadoles/go-emlid/reach/client/logger" "forge.cadoles.com/cadoles/go-emlid/reach/client/logger"
@ -17,6 +18,10 @@ type Operations struct {
client *socketio.Client client *socketio.Client
mutex sync.RWMutex mutex sync.RWMutex
logger logger.Logger logger logger.Logger
dial protocol.DialFunc
getClientOnce sync.Once
httpClient *http.Client
} }
// Reboot implements protocol.Operations. // Reboot implements protocol.Operations.
@ -161,12 +166,12 @@ func (o *Operations) Connect(ctx context.Context) error {
o.client.Close() o.client.Close()
} }
endpoint, err := socketio.EndpointFromHAddr(o.addr) endpoint, err := socketio.EndpointFromAddr(o.addr)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
client := socketio.NewClient(endpoint) client := socketio.NewClient(endpoint, socketio.WithDialFunc(socketio.DialFunc(o.dial)))
if err := client.Connect(); err != nil { if err := client.Connect(); err != nil {
return errors.WithStack(err) return errors.WithStack(err)

View File

@ -15,6 +15,7 @@ const compatibleVersionConstraint = ">= 32"
type Protocol struct { type Protocol struct {
logger logger.Logger logger logger.Logger
dial protocol.DialFunc
} }
// Available implements protocol.Protocol. // Available implements protocol.Protocol.
@ -50,7 +51,11 @@ func (p *Protocol) Identifier() protocol.Identifier {
// Operations implements protocol.Protocol. // Operations implements protocol.Protocol.
func (p *Protocol) Operations(addr string) protocol.Operations { func (p *Protocol) Operations(addr string) protocol.Operations {
return &Operations{addr: addr, logger: p.logger} return &Operations{
dial: p.dial,
addr: addr,
logger: p.logger,
}
} }
var _ protocol.Protocol = &Protocol{} var _ protocol.Protocol = &Protocol{}

View File

@ -32,12 +32,15 @@ func (c *Client) Connect() error {
wg.Add(1) wg.Add(1)
transport := &transport.WebsocketTransport{ transport := &Transport{
dial: c.opts.DialFunc,
ws: &transport.WebsocketTransport{
PingInterval: c.opts.PingInterval, PingInterval: c.opts.PingInterval,
PingTimeout: c.opts.PingTimeout, PingTimeout: c.opts.PingTimeout,
ReceiveTimeout: c.opts.ReceiveTimeout, ReceiveTimeout: c.opts.ReceiveTimeout,
SendTimeout: c.opts.SendTimeout, SendTimeout: c.opts.SendTimeout,
BufferSize: c.opts.BufferSize, BufferSize: c.opts.BufferSize,
},
} }
conn, err := gosocketio.Dial(c.endpoint, transport) conn, err := gosocketio.Dial(c.endpoint, transport)

View File

@ -9,7 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func EndpointFromHAddr(addr string) (string, error) { func EndpointFromAddr(addr string) (string, error) {
host, rawPort, err := net.SplitHostPort(addr) host, rawPort, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
var addrErr *net.AddrError var addrErr *net.AddrError

View File

@ -1,6 +1,13 @@
package socketio package socketio
import "time" import (
"net"
"time"
"github.com/pkg/errors"
)
type DialFunc func(network, addr string) (net.Conn, error)
type Options struct { type Options struct {
PingInterval time.Duration PingInterval time.Duration
@ -8,6 +15,7 @@ type Options struct {
ReceiveTimeout time.Duration ReceiveTimeout time.Duration
SendTimeout time.Duration SendTimeout time.Duration
BufferSize int BufferSize int
DialFunc DialFunc
} }
type OptionFunc func(opts *Options) type OptionFunc func(opts *Options)
@ -19,6 +27,7 @@ func NewOptions(funcs ...OptionFunc) *Options {
ReceiveTimeout: 60 * time.Second, ReceiveTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second, SendTimeout: 60 * time.Second,
BufferSize: 1024 * 32, BufferSize: 1024 * 32,
DialFunc: DefaultDialFunc,
} }
for _, fn := range funcs { for _, fn := range funcs {
fn(opts) fn(opts)
@ -60,3 +69,19 @@ func WithBufferSize(size int) OptionFunc {
opts.BufferSize = size opts.BufferSize = size
} }
} }
var DefaultDialFunc = func(network, addr string) (net.Conn, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, errors.WithStack(err)
}
return conn, nil
}
// WithDialFunc configures the client to use the given dial func
func WithDialFunc(dial DialFunc) OptionFunc {
return func(opts *Options) {
opts.DialFunc = dial
}
}

View File

@ -0,0 +1,46 @@
package socketio
import (
"net"
"net/http"
"forge.cadoles.com/Pyxis/golang-socketio/transport"
"github.com/gorilla/websocket"
)
type Transport struct {
dial DialFunc
ws *transport.WebsocketTransport
}
// Connect implements transport.Transport.
func (t *Transport) Connect(url string) (conn transport.Connection, err error) {
if t.dial == nil {
return t.ws.Connect(url)
} else {
dialer := websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
return t.dial(network, addr)
},
}
socket, _, err := dialer.Dial(url, t.ws.RequestHeader)
if err != nil {
return nil, err
}
return transport.NewWebsocketConnection(socket, t.ws), nil
}
}
// HandleConnection implements transport.Transport.
func (t *Transport) HandleConnection(w http.ResponseWriter, r *http.Request) (conn transport.Connection, err error) {
return t.ws.HandleConnection(w, r)
}
// Serve implements transport.Transport.
func (t *Transport) Serve(w http.ResponseWriter, r *http.Request) {
t.ws.Serve(w, r)
}
var _ transport.Transport = &Transport{}