feat(v2): allow override of dial func
This commit is contained in:
parent
b976bde363
commit
ebb516b02c
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module forge.cadoles.com/cadoles/go-emlid
|
||||||
go 1.22.5
|
go 1.22.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95
|
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46
|
||||||
github.com/Masterminds/semver/v3 v3.2.1
|
github.com/Masterminds/semver/v3 v3.2.1
|
||||||
github.com/davecgh/go-spew v1.1.1
|
github.com/davecgh/go-spew v1.1.1
|
||||||
github.com/grandcat/zeroconf v1.0.0
|
github.com/grandcat/zeroconf v1.0.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -1,5 +1,7 @@
|
||||||
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95 h1:o3G5+9RjczCK1xAYFaRMknk1kY9Ule6PNfiW6N6hEpg=
|
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95 h1:o3G5+9RjczCK1xAYFaRMknk1kY9Ule6PNfiW6N6hEpg=
|
||||||
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95/go.mod h1:I6kYOFWNkFlNeQLI7ZqfTRz4NdPHZxX0Bzizmzgchs0=
|
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20180919100209-bb857ced6b95/go.mod h1:I6kYOFWNkFlNeQLI7ZqfTRz4NdPHZxX0Bzizmzgchs0=
|
||||||
|
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46 h1:vLTYHA4+pYeI9mZvCMrc29AmnNjeGEpEG1mTwtCOoDI=
|
||||||
|
forge.cadoles.com/Pyxis/golang-socketio v0.0.0-20240805155359-f54949ba3a46/go.mod h1:bT+HWia42VRX1TzTUlEM645tPJEOtsEdzlKBiEqVchY=
|
||||||
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
|
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
|
||||||
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
||||||
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
|
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
|
||||||
|
|
|
@ -30,7 +30,11 @@ func (c *Client) Protocol(ctx context.Context) (protocol.Identifier, protocol.Op
|
||||||
|
|
||||||
func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) {
|
func (c *Client) getProtocol(ctx context.Context) (protocol.Identifier, protocol.Operations, error) {
|
||||||
c.getProtocolOnce.Do(func() {
|
c.getProtocolOnce.Do(func() {
|
||||||
availables, err := c.opts.Protocols.Availables(ctx, c.addr, c.opts.AvailableTimeout, protocol.WithProtocolLogger(c.opts.Logger))
|
availables, err := c.opts.Protocols.Availables(
|
||||||
|
ctx, c.addr, c.opts.AvailableTimeout,
|
||||||
|
protocol.WithProtocolLogger(c.opts.Logger),
|
||||||
|
protocol.WithProtocolDial(c.opts.Dial),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.getProtocolOnceErr = errors.WithStack(err)
|
c.getProtocolOnceErr = errors.WithStack(err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -16,6 +16,7 @@ type Options struct {
|
||||||
FallbackProtocol protocol.Identifier
|
FallbackProtocol protocol.Identifier
|
||||||
AvailableTimeout time.Duration
|
AvailableTimeout time.Duration
|
||||||
Logger logger.Logger
|
Logger logger.Logger
|
||||||
|
Dial protocol.DialFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type OptionFunc func(opts *Options)
|
type OptionFunc func(opts *Options)
|
||||||
|
@ -27,6 +28,7 @@ func NewOptions(funcs ...OptionFunc) *Options {
|
||||||
Protocols: protocol.DefaultRegistry(),
|
Protocols: protocol.DefaultRegistry(),
|
||||||
AvailableTimeout: 5 * time.Second,
|
AvailableTimeout: 5 * time.Second,
|
||||||
Logger: slog.Default(),
|
Logger: slog.Default(),
|
||||||
|
Dial: protocol.DefaultDialFunc,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, fn := range funcs {
|
for _, fn := range funcs {
|
||||||
|
@ -65,3 +67,9 @@ func WithAvailableTimeout(timeout time.Duration) OptionFunc {
|
||||||
opts.AvailableTimeout = timeout
|
opts.AvailableTimeout = timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithDial(dial protocol.DialFunc) OptionFunc {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.Dial = dial
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,12 +3,16 @@ package protocol
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
|
||||||
"forge.cadoles.com/cadoles/go-emlid/reach/client/logger"
|
"forge.cadoles.com/cadoles/go-emlid/reach/client/logger"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Identifier string
|
type Identifier string
|
||||||
|
|
||||||
|
type DialFunc func(network string, addr string) (net.Conn, error)
|
||||||
|
|
||||||
type Protocol interface {
|
type Protocol interface {
|
||||||
Identifier() Identifier
|
Identifier() Identifier
|
||||||
Available(ctx context.Context, addr string) (bool, error)
|
Available(ctx context.Context, addr string) (bool, error)
|
||||||
|
@ -17,6 +21,16 @@ type Protocol interface {
|
||||||
|
|
||||||
type ProtocolOptions struct {
|
type ProtocolOptions struct {
|
||||||
Logger logger.Logger
|
Logger logger.Logger
|
||||||
|
Dial DialFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultDialFunc = func(network, addr string) (net.Conn, error) {
|
||||||
|
conn, err := net.Dial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProtocolFactory func(opts *ProtocolOptions) (Protocol, error)
|
type ProtocolFactory func(opts *ProtocolOptions) (Protocol, error)
|
||||||
|
@ -40,3 +54,9 @@ func WithProtocolLogger(logger logger.Logger) ProtocolOptionFunc {
|
||||||
opts.Logger = logger
|
opts.Logger = logger
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithProtocolDial(dial DialFunc) ProtocolOptionFunc {
|
||||||
|
return func(opts *ProtocolOptions) {
|
||||||
|
opts.Dial = dial
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ func (o *Operations) Connect(ctx context.Context) error {
|
||||||
o.client.Close()
|
o.client.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint, err := socketio.EndpointFromHAddr(o.addr)
|
endpoint, err := socketio.EndpointFromAddr(o.addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,9 @@ func (o *Operations) GetJSON(path string, dst any) error {
|
||||||
var res *http.Response
|
var res *http.Response
|
||||||
|
|
||||||
url := o.getURL(path)
|
url := o.getURL(path)
|
||||||
res, err := http.Get(url)
|
client := o.getHTTPClient()
|
||||||
|
|
||||||
|
res, err := client.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
@ -50,6 +52,7 @@ func (o *Operations) GetJSON(path string, dst any) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Operations) PostJSON(path string, data any, dst any) error {
|
func (o *Operations) PostJSON(path string, data any, dst any) error {
|
||||||
|
|
||||||
var res *http.Response
|
var res *http.Response
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
@ -60,7 +63,9 @@ func (o *Operations) PostJSON(path string, data any, dst any) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
url := o.getURL(path)
|
url := o.getURL(path)
|
||||||
res, err := http.Post(url, "application/json", &buf)
|
client := o.getHTTPClient()
|
||||||
|
|
||||||
|
res, err := client.Post(url, "application/json", &buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
@ -89,6 +94,18 @@ func (o *Operations) PostJSON(path string, data any, dst any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Operations) getHTTPClient() *http.Client {
|
||||||
|
o.getClientOnce.Do(func() {
|
||||||
|
o.httpClient = &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Dial: o.dial,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return o.httpClient
|
||||||
|
}
|
||||||
|
|
||||||
func (o *Operations) PostBaseCoordinates(ctx context.Context, base *model.Base) (*model.Base, error) {
|
func (o *Operations) PostBaseCoordinates(ctx context.Context, base *model.Base) (*model.Base, error) {
|
||||||
var updated model.Base
|
var updated model.Base
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package v2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"forge.cadoles.com/cadoles/go-emlid/reach/client/logger"
|
"forge.cadoles.com/cadoles/go-emlid/reach/client/logger"
|
||||||
|
@ -17,6 +18,10 @@ type Operations struct {
|
||||||
client *socketio.Client
|
client *socketio.Client
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
dial protocol.DialFunc
|
||||||
|
|
||||||
|
getClientOnce sync.Once
|
||||||
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reboot implements protocol.Operations.
|
// Reboot implements protocol.Operations.
|
||||||
|
@ -161,12 +166,12 @@ func (o *Operations) Connect(ctx context.Context) error {
|
||||||
o.client.Close()
|
o.client.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint, err := socketio.EndpointFromHAddr(o.addr)
|
endpoint, err := socketio.EndpointFromAddr(o.addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := socketio.NewClient(endpoint)
|
client := socketio.NewClient(endpoint, socketio.WithDialFunc(socketio.DialFunc(o.dial)))
|
||||||
|
|
||||||
if err := client.Connect(); err != nil {
|
if err := client.Connect(); err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
|
|
|
@ -15,6 +15,7 @@ const compatibleVersionConstraint = ">= 32"
|
||||||
|
|
||||||
type Protocol struct {
|
type Protocol struct {
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
|
dial protocol.DialFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Available implements protocol.Protocol.
|
// Available implements protocol.Protocol.
|
||||||
|
@ -50,7 +51,11 @@ func (p *Protocol) Identifier() protocol.Identifier {
|
||||||
|
|
||||||
// Operations implements protocol.Protocol.
|
// Operations implements protocol.Protocol.
|
||||||
func (p *Protocol) Operations(addr string) protocol.Operations {
|
func (p *Protocol) Operations(addr string) protocol.Operations {
|
||||||
return &Operations{addr: addr, logger: p.logger}
|
return &Operations{
|
||||||
|
dial: p.dial,
|
||||||
|
addr: addr,
|
||||||
|
logger: p.logger,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ protocol.Protocol = &Protocol{}
|
var _ protocol.Protocol = &Protocol{}
|
||||||
|
|
|
@ -32,12 +32,15 @@ func (c *Client) Connect() error {
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
transport := &transport.WebsocketTransport{
|
transport := &Transport{
|
||||||
PingInterval: c.opts.PingInterval,
|
dial: c.opts.DialFunc,
|
||||||
PingTimeout: c.opts.PingTimeout,
|
ws: &transport.WebsocketTransport{
|
||||||
ReceiveTimeout: c.opts.ReceiveTimeout,
|
PingInterval: c.opts.PingInterval,
|
||||||
SendTimeout: c.opts.SendTimeout,
|
PingTimeout: c.opts.PingTimeout,
|
||||||
BufferSize: c.opts.BufferSize,
|
ReceiveTimeout: c.opts.ReceiveTimeout,
|
||||||
|
SendTimeout: c.opts.SendTimeout,
|
||||||
|
BufferSize: c.opts.BufferSize,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := gosocketio.Dial(c.endpoint, transport)
|
conn, err := gosocketio.Dial(c.endpoint, transport)
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EndpointFromHAddr(addr string) (string, error) {
|
func EndpointFromAddr(addr string) (string, error) {
|
||||||
host, rawPort, err := net.SplitHostPort(addr)
|
host, rawPort, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var addrErr *net.AddrError
|
var addrErr *net.AddrError
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
package socketio
|
package socketio
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DialFunc func(network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
PingInterval time.Duration
|
PingInterval time.Duration
|
||||||
|
@ -8,6 +15,7 @@ type Options struct {
|
||||||
ReceiveTimeout time.Duration
|
ReceiveTimeout time.Duration
|
||||||
SendTimeout time.Duration
|
SendTimeout time.Duration
|
||||||
BufferSize int
|
BufferSize int
|
||||||
|
DialFunc DialFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type OptionFunc func(opts *Options)
|
type OptionFunc func(opts *Options)
|
||||||
|
@ -19,6 +27,7 @@ func NewOptions(funcs ...OptionFunc) *Options {
|
||||||
ReceiveTimeout: 60 * time.Second,
|
ReceiveTimeout: 60 * time.Second,
|
||||||
SendTimeout: 60 * time.Second,
|
SendTimeout: 60 * time.Second,
|
||||||
BufferSize: 1024 * 32,
|
BufferSize: 1024 * 32,
|
||||||
|
DialFunc: DefaultDialFunc,
|
||||||
}
|
}
|
||||||
for _, fn := range funcs {
|
for _, fn := range funcs {
|
||||||
fn(opts)
|
fn(opts)
|
||||||
|
@ -60,3 +69,19 @@ func WithBufferSize(size int) OptionFunc {
|
||||||
opts.BufferSize = size
|
opts.BufferSize = size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var DefaultDialFunc = func(network, addr string) (net.Conn, error) {
|
||||||
|
conn, err := net.Dial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDialFunc configures the client to use the given dial func
|
||||||
|
func WithDialFunc(dial DialFunc) OptionFunc {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.DialFunc = dial
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
package socketio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"forge.cadoles.com/Pyxis/golang-socketio/transport"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Transport struct {
|
||||||
|
dial DialFunc
|
||||||
|
ws *transport.WebsocketTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect implements transport.Transport.
|
||||||
|
func (t *Transport) Connect(url string) (conn transport.Connection, err error) {
|
||||||
|
if t.dial == nil {
|
||||||
|
return t.ws.Connect(url)
|
||||||
|
} else {
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
NetDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return t.dial(network, addr)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
socket, _, err := dialer.Dial(url, t.ws.RequestHeader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transport.NewWebsocketConnection(socket, t.ws), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleConnection implements transport.Transport.
|
||||||
|
func (t *Transport) HandleConnection(w http.ResponseWriter, r *http.Request) (conn transport.Connection, err error) {
|
||||||
|
return t.ws.HandleConnection(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve implements transport.Transport.
|
||||||
|
func (t *Transport) Serve(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.ws.Serve(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ transport.Transport = &Transport{}
|
Loading…
Reference in New Issue