feat(protocol): allow override of dial func
This commit is contained in:
@ -6,6 +6,7 @@ func init() {
|
||||
protocol.Register(Identifier, func(opts *protocol.ProtocolOptions) (protocol.Protocol, error) {
|
||||
return &Protocol{
|
||||
logger: opts.Logger,
|
||||
dial: opts.Dial,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
@ -20,7 +20,9 @@ func (o *Operations) GetJSON(path string, dst any) error {
|
||||
var res *http.Response
|
||||
|
||||
url := o.getURL(path)
|
||||
res, err := http.Get(url)
|
||||
client := o.getHTTPClient()
|
||||
|
||||
res, err := client.Get(url)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
@ -50,6 +52,7 @@ func (o *Operations) GetJSON(path string, dst any) error {
|
||||
}
|
||||
|
||||
func (o *Operations) PostJSON(path string, data any, dst any) error {
|
||||
|
||||
var res *http.Response
|
||||
|
||||
var buf bytes.Buffer
|
||||
@ -60,7 +63,9 @@ func (o *Operations) PostJSON(path string, data any, dst any) error {
|
||||
}
|
||||
|
||||
url := o.getURL(path)
|
||||
res, err := http.Post(url, "application/json", &buf)
|
||||
client := o.getHTTPClient()
|
||||
|
||||
res, err := client.Post(url, "application/json", &buf)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
@ -89,6 +94,18 @@ func (o *Operations) PostJSON(path string, data any, dst any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Operations) getHTTPClient() *http.Client {
|
||||
o.getClientOnce.Do(func() {
|
||||
o.httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Dial: o.dial,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
return o.httpClient
|
||||
}
|
||||
|
||||
func (o *Operations) PostBaseCoordinates(ctx context.Context, base *model.Base) (*model.Base, error) {
|
||||
var updated model.Base
|
||||
|
||||
|
@ -2,6 +2,7 @@ package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"forge.cadoles.com/cadoles/go-emlid/reach/client/logger"
|
||||
@ -17,6 +18,10 @@ type Operations struct {
|
||||
client *socketio.Client
|
||||
mutex sync.RWMutex
|
||||
logger logger.Logger
|
||||
dial protocol.DialFunc
|
||||
|
||||
getClientOnce sync.Once
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Reboot implements protocol.Operations.
|
||||
@ -161,12 +166,12 @@ func (o *Operations) Connect(ctx context.Context) error {
|
||||
o.client.Close()
|
||||
}
|
||||
|
||||
endpoint, err := socketio.EndpointFromHAddr(o.addr)
|
||||
endpoint, err := socketio.EndpointFromAddr(o.addr)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
client := socketio.NewClient(endpoint)
|
||||
client := socketio.NewClient(endpoint, socketio.WithDialFunc(socketio.DialFunc(o.dial)))
|
||||
|
||||
if err := client.Connect(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
|
@ -15,6 +15,7 @@ const compatibleVersionConstraint = ">= 32"
|
||||
|
||||
type Protocol struct {
|
||||
logger logger.Logger
|
||||
dial protocol.DialFunc
|
||||
}
|
||||
|
||||
// Available implements protocol.Protocol.
|
||||
@ -50,7 +51,11 @@ func (p *Protocol) Identifier() protocol.Identifier {
|
||||
|
||||
// Operations implements protocol.Protocol.
|
||||
func (p *Protocol) Operations(addr string) protocol.Operations {
|
||||
return &Operations{addr: addr, logger: p.logger}
|
||||
return &Operations{
|
||||
dial: p.dial,
|
||||
addr: addr,
|
||||
logger: p.logger,
|
||||
}
|
||||
}
|
||||
|
||||
var _ protocol.Protocol = &Protocol{}
|
||||
|
Reference in New Issue
Block a user