package discovery import ( "context" "log" "net" "strings" "time" "forge.cadoles.com/cadoles/go-emlid/reach/discovery/mdns" "github.com/pkg/errors" "github.com/wlynxg/anet" ) const ReachService = "_reach._tcp" type Resolver struct { } func NewResolver() *Resolver { r := &Resolver{} return r } func (r *Resolver) Scan(ctx context.Context, interval time.Duration) (chan Service, error) { found := make(chan Service) entries := make(chan *mdns.ServiceEntry) go r.listener(ctx, entries, found) ifaces, err := findMulticastInterfaces(ctx) if err != nil { return nil, errors.WithStack(err) } for _, iface := range ifaces { err := func(iface net.Interface) error { hasIPv4, _, err := retrieveSupportedProtocols(iface) if err != nil { return errors.WithStack(err) } if !hasIPv4 { return nil } err = r.queryIface(entries, iface, interval) if err != nil && !(errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) { return errors.WithStack(err) } go func(iface net.Interface) { defer close(entries) if err := r.pollInterface(ctx, entries, iface, interval); err != nil { if err != nil && !(errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) { log.Printf("[ERROR] %+v", errors.WithStack(err)) } } }(iface) return nil }(iface) if err != nil { return nil, errors.WithStack(err) } } return found, nil } func (r *Resolver) queryIface(entries chan *mdns.ServiceEntry, iface net.Interface, timeout time.Duration) error { err := mdns.Query(&mdns.QueryParam{ Service: ReachService, Domain: "local", Timeout: timeout, Entries: entries, Interface: &iface, DisableIPv6: true, DisableIPv4: false, }) if err != nil { return errors.WithStack(err) } return nil } func (r *Resolver) pollInterface(ctx context.Context, entries chan *mdns.ServiceEntry, iface net.Interface, interval time.Duration) error { ticker := time.NewTicker(interval) for { select { case <-ticker.C: if err := r.queryIface(entries, iface, interval); err != nil { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { continue } return errors.WithStack(err) } case <-ctx.Done(): if err := ctx.Err(); err != nil { return errors.WithStack(err) } return nil } } } func (r *Resolver) listener(ctx context.Context, entries chan *mdns.ServiceEntry, found chan Service) { defer close(found) nameSeparator := "." + ReachService for { select { case entry, ok := <-entries: if !ok { return } if entry == nil { continue } name := strings.Split(entry.Name, nameSeparator) if len(name) < 2 { continue } info := decodeTxtRecord(entry.Info) srv := Service{ Name: name[0], Device: info["device"], AddrV4: entry.AddrV4, Port: entry.Port, } found <- srv case <-ctx.Done(): return } } } func decodeTxtRecord(txt string) map[string]string { m := make(map[string]string) s := strings.Split(txt, "|") for _, v := range s { s := strings.Split(v, "=") if len(s) == 2 { m[s[0]] = s[1] } } return m } func isIPv4(ip net.IP) bool { return strings.Count(ip.String(), ":") < 2 } func isIPv6(ip net.IP) bool { return strings.Count(ip.String(), ":") >= 2 } func findMulticastInterfaces(ctx context.Context) ([]net.Interface, error) { ifaces, err := anet.Interfaces() if err != nil { return nil, nil } multicastIfaces := make([]net.Interface, 0) for _, iface := range ifaces { if iface.Flags&net.FlagLoopback == net.FlagLoopback { continue } if iface.Flags&net.FlagRunning != net.FlagRunning { continue } if iface.Flags&net.FlagMulticast != net.FlagMulticast { continue } multicastIfaces = append(multicastIfaces, iface) } return multicastIfaces, nil } func retrieveSupportedProtocols(iface net.Interface) (bool, bool, error) { adresses, err := anet.InterfaceAddrsByInterface(&iface) if err != nil { return false, false, errors.WithStack(err) } hasIPv4 := false hasIPv6 := false for _, addr := range adresses { ip, _, err := net.ParseCIDR(addr.String()) if err != nil { return false, false, errors.WithStack(err) } if isIPv4(ip) { hasIPv4 = true } if isIPv6(ip) { hasIPv6 = true } if hasIPv4 && hasIPv6 { return hasIPv4, hasIPv6, nil } } return hasIPv4, hasIPv6, nil }