go-emlid/reach/discovery/resolver.go

230 lines
4.4 KiB
Go
Raw Normal View History

2024-08-06 09:32:02 +02:00
package discovery
import (
"context"
"log"
"net"
"strings"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/hashicorp/mdns"
"github.com/pkg/errors"
"github.com/wlynxg/anet"
)
const ReachService = "_reach._tcp"
type Resolver struct {
}
func NewResolver() *Resolver {
r := &Resolver{}
return r
}
func (r *Resolver) Scan(ctx context.Context, interval time.Duration) (chan Service, error) {
found := make(chan Service)
entries := make(chan *mdns.ServiceEntry)
go r.listener(ctx, entries, found)
ifaces, err := findMulticastInterfaces(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
for _, iface := range ifaces {
err := func(iface net.Interface) error {
hasIPv4, hasIPv6, err := retrieveSupportedProtocols(iface)
if err != nil {
return errors.WithStack(err)
}
if !hasIPv4 && !hasIPv6 {
return nil
}
err = r.queryIface(entries, iface, !hasIPv4, !hasIPv6, interval)
if err != nil && !(errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) {
return errors.WithStack(err)
}
go func(iface net.Interface) {
defer close(entries)
if err := r.pollInterface(ctx, entries, iface, interval, !hasIPv4, !hasIPv6); err != nil {
log.Printf("[ERROR] %+v", errors.WithStack(err))
}
}(iface)
return nil
}(iface)
if err != nil {
return nil, errors.WithStack(err)
}
}
return found, nil
}
func (r *Resolver) queryIface(entries chan *mdns.ServiceEntry, iface net.Interface, disableIPv4, disableIPv6 bool, timeout time.Duration) error {
err := mdns.Query(&mdns.QueryParam{
Service: ReachService,
Domain: "local",
Timeout: timeout,
Entries: entries,
Interface: &iface,
DisableIPv6: disableIPv6,
DisableIPv4: disableIPv4,
})
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (r *Resolver) pollInterface(ctx context.Context, entries chan *mdns.ServiceEntry, iface net.Interface, interval time.Duration, disableIPv4, disableIPv6 bool) error {
ticker := time.NewTicker(interval)
for {
select {
case <-ticker.C:
if err := r.queryIface(entries, iface, disableIPv4, disableIPv6, interval); err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
continue
}
return errors.WithStack(err)
}
case <-ctx.Done():
if err := ctx.Err(); err != nil {
return errors.WithStack(err)
}
return nil
}
}
}
func (r *Resolver) listener(ctx context.Context, entries chan *mdns.ServiceEntry, found chan Service) {
defer close(found)
nameSeparator := "." + ReachService
for {
select {
case entry, ok := <-entries:
if !ok {
return
}
if entry == nil {
continue
}
name := strings.Split(entry.Name, nameSeparator)
if len(name) < 2 {
continue
}
info := decodeTxtRecord(entry.Info)
srv := Service{
Name: name[0],
Device: info["device"],
AddrV4: entry.AddrV4,
Port: entry.Port,
}
spew.Sdump(srv)
found <- srv
case <-ctx.Done():
return
}
}
}
func decodeTxtRecord(txt string) map[string]string {
m := make(map[string]string)
s := strings.Split(txt, "|")
for _, v := range s {
s := strings.Split(v, "=")
if len(s) == 2 {
m[s[0]] = s[1]
}
}
return m
}
func isIPv4(ip net.IP) bool {
return strings.Count(ip.String(), ":") < 2
}
func isIPv6(ip net.IP) bool {
return strings.Count(ip.String(), ":") >= 2
}
func findMulticastInterfaces(ctx context.Context) ([]net.Interface, error) {
ifaces, err := anet.Interfaces()
if err != nil {
return nil, nil
}
multicastIfaces := make([]net.Interface, 0)
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback == net.FlagLoopback {
continue
}
if iface.Flags&net.FlagRunning != net.FlagRunning {
continue
}
if iface.Flags&net.FlagMulticast != net.FlagMulticast {
continue
}
multicastIfaces = append(multicastIfaces, iface)
}
return multicastIfaces, nil
}
func retrieveSupportedProtocols(iface net.Interface) (bool, bool, error) {
adresses, err := anet.InterfaceAddrsByInterface(&iface)
if err != nil {
return false, false, errors.WithStack(err)
}
hasIPv4 := false
hasIPv6 := false
for _, addr := range adresses {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
return false, false, errors.WithStack(err)
}
if isIPv4(ip) {
hasIPv4 = true
}
if isIPv6(ip) {
hasIPv6 = true
}
if hasIPv4 && hasIPv6 {
return hasIPv4, hasIPv6, nil
}
}
return hasIPv4, hasIPv6, nil
}