From 33e16c8615c0b1902b43abc6276cb5228b0c5a82 Mon Sep 17 00:00:00 2001 From: William Petit Date: Tue, 6 Aug 2024 09:32:02 +0200 Subject: [PATCH] feat(discovery): android compatibility --- cmd/discover/main.go | 17 +- go.mod | 8 +- go.sum | 21 + reach/discovery/discovery.go | 68 --- reach/discovery/mdns/client.go | 431 ++++++++++++++++++ reach/discovery/resolver.go | 228 +++++++++ .../{discovery_test.go => resolver_test.go} | 22 +- reach/discovery/service.go | 10 + 8 files changed, 725 insertions(+), 80 deletions(-) delete mode 100644 reach/discovery/discovery.go create mode 100644 reach/discovery/mdns/client.go create mode 100644 reach/discovery/resolver.go rename reach/discovery/{discovery_test.go => resolver_test.go} (68%) create mode 100644 reach/discovery/service.go diff --git a/cmd/discover/main.go b/cmd/discover/main.go index 3a31ce5..c235c29 100644 --- a/cmd/discover/main.go +++ b/cmd/discover/main.go @@ -5,25 +5,30 @@ import ( "encoding/json" "fmt" "os" + "time" "forge.cadoles.com/cadoles/go-emlid/reach/discovery" "github.com/pkg/errors" ) func main() { - services, err := discovery.Watch(context.Background()) + resolver := discovery.NewResolver() + + found, err := resolver.Scan(context.Background(), 1*time.Second) if err != nil { fmt.Printf("[FATAL] %+v", errors.WithStack(err)) os.Exit(1) } - for srv := range services { + for srv := range found { data, err := json.MarshalIndent(struct { - Addr string `json:"addr"` - Name string `json:"name"` + Addr string `json:"addr"` + Name string `json:"name"` + Device string `json:"device"` }{ - Name: srv.Name, - Addr: fmt.Sprintf("%s:%d", srv.AddrV4.String(), srv.Port), + Name: srv.Name, + Addr: fmt.Sprintf("%s:%d", srv.AddrV4.String(), srv.Port), + Device: srv.Device, }, "", " ") if err != nil { fmt.Printf("[FATAL] %+v", errors.WithStack(err)) diff --git a/go.mod b/go.mod index 631b4df..58550c2 100644 --- a/go.mod +++ b/go.mod @@ -14,8 +14,10 @@ require ( require ( github.com/cenkalti/backoff v2.2.1+incompatible // indirect github.com/gorilla/websocket v1.5.3 // indirect - github.com/miekg/dns v1.1.27 // indirect + github.com/hashicorp/mdns v1.0.5 // indirect + github.com/miekg/dns v1.1.41 // indirect + github.com/wlynxg/anet v0.0.4-0.20240806025826-e684438fc7c6 // indirect golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 // indirect - golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa // indirect - golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe // indirect + golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1 // indirect + golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 // indirect ) diff --git a/go.sum b/go.sum index 0c5f228..9edb620 100644 --- a/go.sum +++ b/go.sum @@ -12,12 +12,20 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE= github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs= +github.com/hashicorp/mdns v1.0.5 h1:1M5hW1cunYeoXOqHwEb/GBDDHAFo0Yqb/uz/beC6LbE= +github.com/hashicorp/mdns v1.0.5/go.mod h1:mtBihi+LeNXGtG8L9dX59gAEa12BDtBQSp4v/YAJqrc= github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM= github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= +github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY= +github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/wlynxg/anet v0.0.4-0.20240806025826-e684438fc7c6 h1:c/wkXIJvpg2oot7iFqPESTBAO9UvhWTBnW97y9aPgyU= +github.com/wlynxg/anet v0.0.4-0.20240806025826-e684438fc7c6/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -27,12 +35,25 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1 h1:4qWs8cYYH6PoEFy4dfhDFgoMGkwAcETd+MmPdCPMzUc= +golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe h1:6fAMxZRR6sl1Uq8U61gxU+kPTs2tR8uOySCbBP7BN/M= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/reach/discovery/discovery.go b/reach/discovery/discovery.go deleted file mode 100644 index 7f2a723..0000000 --- a/reach/discovery/discovery.go +++ /dev/null @@ -1,68 +0,0 @@ -package discovery - -import ( - "context" - "net" - - "github.com/grandcat/zeroconf" - "github.com/pkg/errors" -) - -// Service is a ReachRS service discovered via MDNS-SD -type Service struct { - Name string - AddrV4 *net.IP - Port int -} - -// Discover tries to discover ReachRS services on the local network via mDNS-SD -func Discover(ctx context.Context) ([]Service, error) { - services := make([]Service, 0) - - watch, err := Watch(ctx) - if err != nil { - return nil, errors.WithStack(err) - } - - for srv := range watch { - services = append(services, srv) - } - - return services, nil -} - -// Watch watches ReachRS services on the local network via mDNS-SD -func Watch(ctx context.Context) (chan Service, error) { - out := make(chan Service, 0) - - resolver, err := zeroconf.NewResolver() - if err != nil { - return nil, errors.WithStack(err) - } - - entries := make(chan *zeroconf.ServiceEntry) - - go func() { - defer close(out) - - for e := range entries { - var addr *net.IP - if len(e.AddrIPv4) > 0 { - addr = &e.AddrIPv4[0] - } - srv := Service{ - Name: e.Instance, - AddrV4: addr, - Port: e.Port, - } - out <- srv - } - - }() - - if err = resolver.Browse(ctx, "_reach._tcp", ".local", entries); err != nil { - return nil, err - } - - return out, nil -} diff --git a/reach/discovery/mdns/client.go b/reach/discovery/mdns/client.go new file mode 100644 index 0000000..bbb2c96 --- /dev/null +++ b/reach/discovery/mdns/client.go @@ -0,0 +1,431 @@ +package mdns + +import ( + "fmt" + "log" + "net" + "strings" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +const ( + ipv4mdns = "224.0.0.251" + ipv6mdns = "ff02::fb" + mdnsPort = 5353 + forceUnicastResponses = false +) + +var ( + ipv4Addr = &net.UDPAddr{ + IP: net.ParseIP(ipv4mdns), + Port: mdnsPort, + } + ipv6Addr = &net.UDPAddr{ + IP: net.ParseIP(ipv6mdns), + Port: mdnsPort, + } +) + +// ServiceEntry is returned after we query for a service +type ServiceEntry struct { + Name string + Host string + AddrV4 net.IP + AddrV6 net.IP + Port int + Info string + InfoFields []string + + Addr net.IP // @Deprecated + + hasTXT bool + sent bool +} + +// complete is used to check if we have all the info we need +func (s *ServiceEntry) complete() bool { + return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT +} + +// QueryParam is used to customize how a Lookup is performed +type QueryParam struct { + Service string // Service to lookup + Domain string // Lookup domain, default "local" + Timeout time.Duration // Lookup timeout, default 1 second + Interface *net.Interface // Multicast interface to use + Entries chan<- *ServiceEntry // Entries Channel + WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC + DisableIPv4 bool // Whether to disable usage of IPv4 for MDNS operations. Does not affect discovered addresses. + DisableIPv6 bool // Whether to disable usage of IPv6 for MDNS operations. Does not affect discovered addresses. +} + +// DefaultParams is used to return a default set of QueryParam's +func DefaultParams(service string) *QueryParam { + return &QueryParam{ + Service: service, + Domain: "local", + Timeout: time.Second, + Entries: make(chan *ServiceEntry), + WantUnicastResponse: false, // TODO(reddaly): Change this default. + DisableIPv4: false, + DisableIPv6: false, + } +} + +// Query looks up a given service, in a domain, waiting at most +// for a timeout before finishing the query. The results are streamed +// to a channel. Sends will not block, so clients should make sure to +// either read or buffer. +func Query(params *QueryParam) error { + // Create a new client + client, err := newClient(!params.DisableIPv4, !params.DisableIPv6) + if err != nil { + return err + } + defer client.Close() + + // Set the multicast interface + if params.Interface != nil { + if err := client.setInterface(params.Interface); err != nil { + return err + } + } + + // Ensure defaults are set + if params.Domain == "" { + params.Domain = "local" + } + if params.Timeout == 0 { + params.Timeout = time.Second + } + + // Run the query + return client.query(params) +} + +// Lookup is the same as Query, however it uses all the default parameters +func Lookup(service string, entries chan<- *ServiceEntry) error { + params := DefaultParams(service) + params.Entries = entries + return Query(params) +} + +// Client provides a query interface that can be used to +// search for service providers using mDNS +type client struct { + use_ipv4 bool + use_ipv6 bool + + ipv4UnicastConn *net.UDPConn + ipv6UnicastConn *net.UDPConn + + ipv4MulticastConn *net.UDPConn + ipv6MulticastConn *net.UDPConn + + closed int32 + closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. +} + +// NewClient creates a new mdns Client that can be used to query +// for records +func newClient(v4 bool, v6 bool) (*client, error) { + if !v4 && !v6 { + return nil, fmt.Errorf("Must enable at least one of IPv4 and IPv6 querying") + } + + // TODO(reddaly): At least attempt to bind to the port required in the spec. + // Create a IPv4 listener + var uconn4 *net.UDPConn + var uconn6 *net.UDPConn + var mconn4 *net.UDPConn + var mconn6 *net.UDPConn + var err error + + if v4 { + uconn4, err = net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + log.Printf("[ERR] mdns: Failed to bind to udp4 port: %+v", errors.WithStack(err)) + } + } + + if v6 { + uconn6, err = net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + log.Printf("[ERR] mdns: Failed to bind to udp6 port: %+v", errors.WithStack(err)) + } + } + + if uconn4 == nil && uconn6 == nil { + return nil, fmt.Errorf("failed to bind to any unicast udp port") + } + + if v4 { + mconn4, err = net.ListenMulticastUDP("udp4", nil, ipv4Addr) + if err != nil { + log.Printf("[ERR] mdns: Failed to bind to udp4 port: %+v", errors.WithStack(err)) + } + } + if v6 { + mconn6, err = net.ListenMulticastUDP("udp6", nil, ipv6Addr) + if err != nil { + log.Printf("[ERR] mdns: Failed to bind to udp6 port: %+v", errors.WithStack(err)) + } + } + + if mconn4 == nil && mconn6 == nil { + return nil, fmt.Errorf("failed to bind to any multicast udp port") + } + + c := &client{ + use_ipv4: v4, + use_ipv6: v6, + ipv4MulticastConn: mconn4, + ipv6MulticastConn: mconn6, + ipv4UnicastConn: uconn4, + ipv6UnicastConn: uconn6, + closedCh: make(chan struct{}), + } + + return c, nil +} + +// Close is used to cleanup the client +func (c *client) Close() error { + if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) { + // something else already closed it + return nil + } + + close(c.closedCh) + + if c.ipv4UnicastConn != nil { + c.ipv4UnicastConn.Close() + } + if c.ipv6UnicastConn != nil { + c.ipv6UnicastConn.Close() + } + if c.ipv4MulticastConn != nil { + c.ipv4MulticastConn.Close() + } + if c.ipv6MulticastConn != nil { + c.ipv6MulticastConn.Close() + } + + return nil +} + +// setInterface is used to set the query interface, uses system +// default if not provided +func (c *client) setInterface(iface *net.Interface) error { + if c.use_ipv4 { + p := ipv4.NewPacketConn(c.ipv4UnicastConn) + if err := p.SetMulticastInterface(iface); err != nil { + return err + } + p = ipv4.NewPacketConn(c.ipv4MulticastConn) + if err := p.SetMulticastInterface(iface); err != nil { + return err + } + } + if c.use_ipv6 { + p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) + if err := p2.SetMulticastInterface(iface); err != nil { + return err + } + p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) + if err := p2.SetMulticastInterface(iface); err != nil { + return err + } + } + return nil +} + +// query is used to perform a lookup and stream results +func (c *client) query(params *QueryParam) error { + // Create the service name + serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) + + // Start listening for response packets + msgCh := make(chan *dns.Msg, 32) + if c.use_ipv4 { + go c.recv(c.ipv4UnicastConn, msgCh) + go c.recv(c.ipv4MulticastConn, msgCh) + } + if c.use_ipv6 { + go c.recv(c.ipv6UnicastConn, msgCh) + go c.recv(c.ipv6MulticastConn, msgCh) + } + + // Send the query + m := new(dns.Msg) + m.SetQuestion(serviceAddr, dns.TypePTR) + // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question + // Section + // + // In the Question Section of a Multicast DNS query, the top bit of the qclass + // field is used to indicate that unicast responses are preferred for this + // particular question. (See Section 5.4.) + if params.WantUnicastResponse { + m.Question[0].Qclass |= 1 << 15 + } + m.RecursionDesired = false + if err := c.sendQuery(m); err != nil { + return err + } + + // Map the in-progress responses + inprogress := make(map[string]*ServiceEntry) + + // Listen until we reach the timeout + finish := time.After(params.Timeout) + for { + select { + case resp := <-msgCh: + var inp *ServiceEntry + for _, answer := range append(resp.Answer, resp.Extra...) { + // TODO(reddaly): Check that response corresponds to serviceAddr? + switch rr := answer.(type) { + case *dns.PTR: + // Create new entry for this + inp = ensureName(inprogress, rr.Ptr) + + case *dns.SRV: + // Check for a target mismatch + if rr.Target != rr.Hdr.Name { + alias(inprogress, rr.Hdr.Name, rr.Target) + } + + // Get the port + inp = ensureName(inprogress, rr.Hdr.Name) + inp.Host = rr.Target + inp.Port = int(rr.Port) + + case *dns.TXT: + // Pull out the txt + inp = ensureName(inprogress, rr.Hdr.Name) + inp.Info = strings.Join(rr.Txt, "|") + inp.InfoFields = rr.Txt + inp.hasTXT = true + + case *dns.A: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name) + inp.Addr = rr.A // @Deprecated + inp.AddrV4 = rr.A + + case *dns.AAAA: + // Pull out the IP + inp = ensureName(inprogress, rr.Hdr.Name) + inp.Addr = rr.AAAA // @Deprecated + inp.AddrV6 = rr.AAAA + } + } + + if inp == nil { + continue + } + + // Check if this entry is complete + if inp.complete() { + if inp.sent { + continue + } + inp.sent = true + select { + case params.Entries <- inp: + default: + } + } else { + // Fire off a node specific query + m := new(dns.Msg) + m.SetQuestion(inp.Name, dns.TypePTR) + m.RecursionDesired = false + if err := c.sendQuery(m); err != nil { + log.Printf("[ERR] mdns: Failed to query instance %s: %+v", inp.Name, errors.WithStack(err)) + } + } + case <-finish: + return nil + } + } +} + +// sendQuery is used to multicast a query out +func (c *client) sendQuery(q *dns.Msg) error { + buf, err := q.Pack() + if err != nil { + return err + } + if c.ipv4UnicastConn != nil { + _, err = c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) + if err != nil { + return err + } + } + if c.ipv6UnicastConn != nil { + _, err = c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) + if err != nil { + return err + } + } + return nil +} + +// recv is used to receive until we get a shutdown +func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { + if l == nil { + return + } + buf := make([]byte, 65536) + for atomic.LoadInt32(&c.closed) == 0 { + n, err := l.Read(buf) + + if atomic.LoadInt32(&c.closed) == 1 { + return + } + + if err != nil { + log.Printf("[ERR] mdns: Failed to read packet: %+v", errors.WithStack(err)) + continue + } + msg := new(dns.Msg) + if err := msg.Unpack(buf[:n]); err != nil { + log.Printf("[ERR] mdns: Failed to unpack packet: %+v", errors.WithStack(err)) + continue + } + select { + case msgCh <- msg: + case <-c.closedCh: + return + } + } +} + +// ensureName is used to ensure the named node is in progress +func ensureName(inprogress map[string]*ServiceEntry, name string) *ServiceEntry { + if inp, ok := inprogress[name]; ok { + return inp + } + inp := &ServiceEntry{ + Name: name, + } + inprogress[name] = inp + return inp +} + +// alias is used to setup an alias between two entries +func alias(inprogress map[string]*ServiceEntry, src, dst string) { + srcEntry := ensureName(inprogress, src) + inprogress[dst] = srcEntry +} + +// trimDot is used to trim the dots from the start or end of a string +func trimDot(s string) string { + return strings.Trim(s, ".") +} diff --git a/reach/discovery/resolver.go b/reach/discovery/resolver.go new file mode 100644 index 0000000..1259695 --- /dev/null +++ b/reach/discovery/resolver.go @@ -0,0 +1,228 @@ +package discovery + +import ( + "context" + "log" + "net" + "strings" + "time" + + "forge.cadoles.com/cadoles/go-emlid/reach/discovery/mdns" + "github.com/pkg/errors" + "github.com/wlynxg/anet" +) + +const ReachService = "_reach._tcp" + +type Resolver struct { +} + +func NewResolver() *Resolver { + r := &Resolver{} + + return r +} + +func (r *Resolver) Scan(ctx context.Context, interval time.Duration) (chan Service, error) { + found := make(chan Service) + entries := make(chan *mdns.ServiceEntry) + + go r.listener(ctx, entries, found) + + ifaces, err := findMulticastInterfaces(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + for _, iface := range ifaces { + err := func(iface net.Interface) error { + hasIPv4, _, err := retrieveSupportedProtocols(iface) + if err != nil { + return errors.WithStack(err) + } + + if !hasIPv4 { + return nil + } + + err = r.queryIface(entries, iface, interval) + if err != nil && !(errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) { + return errors.WithStack(err) + } + + go func(iface net.Interface) { + defer close(entries) + + if err := r.pollInterface(ctx, entries, iface, interval); err != nil { + if err != nil && !(errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) { + log.Printf("[ERROR] %+v", errors.WithStack(err)) + } + } + }(iface) + + return nil + }(iface) + if err != nil { + return nil, errors.WithStack(err) + } + } + + return found, nil +} + +func (r *Resolver) queryIface(entries chan *mdns.ServiceEntry, iface net.Interface, timeout time.Duration) error { + err := mdns.Query(&mdns.QueryParam{ + Service: ReachService, + Domain: "local", + Timeout: timeout, + Entries: entries, + Interface: &iface, + DisableIPv6: true, + DisableIPv4: false, + }) + if err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (r *Resolver) pollInterface(ctx context.Context, entries chan *mdns.ServiceEntry, iface net.Interface, interval time.Duration) error { + ticker := time.NewTicker(interval) + for { + select { + case <-ticker.C: + if err := r.queryIface(entries, iface, interval); err != nil { + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + continue + } + + return errors.WithStack(err) + } + + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return errors.WithStack(err) + } + + return nil + } + } +} + +func (r *Resolver) listener(ctx context.Context, entries chan *mdns.ServiceEntry, found chan Service) { + defer close(found) + + nameSeparator := "." + ReachService + + for { + select { + case entry, ok := <-entries: + if !ok { + return + } + + if entry == nil { + continue + } + + name := strings.Split(entry.Name, nameSeparator) + if len(name) < 2 { + continue + } + + info := decodeTxtRecord(entry.Info) + + srv := Service{ + Name: name[0], + Device: info["device"], + AddrV4: entry.AddrV4, + Port: entry.Port, + } + + found <- srv + case <-ctx.Done(): + return + } + } +} + +func decodeTxtRecord(txt string) map[string]string { + m := make(map[string]string) + + s := strings.Split(txt, "|") + for _, v := range s { + s := strings.Split(v, "=") + if len(s) == 2 { + m[s[0]] = s[1] + } + } + + return m +} + +func isIPv4(ip net.IP) bool { + return strings.Count(ip.String(), ":") < 2 +} + +func isIPv6(ip net.IP) bool { + return strings.Count(ip.String(), ":") >= 2 +} + +func findMulticastInterfaces(ctx context.Context) ([]net.Interface, error) { + ifaces, err := anet.Interfaces() + if err != nil { + return nil, nil + } + + multicastIfaces := make([]net.Interface, 0) + + for _, iface := range ifaces { + if iface.Flags&net.FlagLoopback == net.FlagLoopback { + continue + } + + if iface.Flags&net.FlagRunning != net.FlagRunning { + continue + } + + if iface.Flags&net.FlagMulticast != net.FlagMulticast { + continue + } + + multicastIfaces = append(multicastIfaces, iface) + } + + return multicastIfaces, nil +} + +func retrieveSupportedProtocols(iface net.Interface) (bool, bool, error) { + adresses, err := anet.InterfaceAddrsByInterface(&iface) + if err != nil { + return false, false, errors.WithStack(err) + } + + hasIPv4 := false + hasIPv6 := false + + for _, addr := range adresses { + ip, _, err := net.ParseCIDR(addr.String()) + if err != nil { + return false, false, errors.WithStack(err) + } + + if isIPv4(ip) { + hasIPv4 = true + } + + if isIPv6(ip) { + hasIPv6 = true + } + + if hasIPv4 && hasIPv6 { + return hasIPv4, hasIPv6, nil + } + } + + return hasIPv4, hasIPv6, nil +} diff --git a/reach/discovery/discovery_test.go b/reach/discovery/resolver_test.go similarity index 68% rename from reach/discovery/discovery_test.go rename to reach/discovery/resolver_test.go index acc400a..4d9082e 100644 --- a/reach/discovery/discovery_test.go +++ b/reach/discovery/resolver_test.go @@ -7,17 +7,33 @@ import ( "time" "forge.cadoles.com/cadoles/go-emlid/reach" + "github.com/davecgh/go-spew/spew" + "github.com/pkg/errors" ) -func TestDiscovery(t *testing.T) { +func TestResolver(t *testing.T) { reach.AssertIntegrationTests(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - services, err := Discover(ctx) + resolver := NewResolver() + + found, err := resolver.Scan(ctx, 500*time.Millisecond) if err != nil { - t.Fatal(err) + t.Fatalf("%+v", errors.WithStack(err)) + } + + services := make([]Service, 0) +OUTER: + for { + select { + case s := <-found: + t.Logf("%s", spew.Sdump(s)) + services = append(services, s) + case <-ctx.Done(): + break OUTER + } } if g, e := len(services), 1; g < e { diff --git a/reach/discovery/service.go b/reach/discovery/service.go new file mode 100644 index 0000000..a34e7a8 --- /dev/null +++ b/reach/discovery/service.go @@ -0,0 +1,10 @@ +package discovery + +import "net" + +type Service struct { + Name string + Device string + AddrV4 net.IP + Port int +}