package cast import ( "context" "net" "regexp" "strconv" "strings" "sync" "time" "gitlab.com/wpetit/goweb/logger" "github.com/barnybug/go-cast" "github.com/barnybug/go-cast/log" "github.com/hashicorp/mdns" "github.com/pkg/errors" ) type Service struct { found chan *cast.Client entriesCh chan *mdns.ServiceEntry stopPeriodic chan struct{} } func NewService(ctx context.Context) *Service { s := &Service{ found: make(chan *cast.Client), entriesCh: make(chan *mdns.ServiceEntry, 10), } go s.listener(ctx) return s } func (d *Service) Run(ctx context.Context, interval time.Duration) error { ifaces, err := findMulticastInterfaces(ctx) if err != nil { return errors.WithStack(err) } var wg sync.WaitGroup for _, iface := range ifaces { hasIPv4, hasIPv6, err := retrieveSupportedProtocols(iface) if err != nil { return errors.WithStack(err) } if !hasIPv4 && !hasIPv6 { continue } if err := d.queryIface(iface, !hasIPv4, !hasIPv6); err != nil { return errors.WithStack(err) } pollCtx, cancel := context.WithCancel(ctx) defer cancel() wg.Add(1) go func(ctx context.Context, iface net.Interface) { defer wg.Done() if err := d.pollInterface(ctx, iface, interval, !hasIPv4, !hasIPv6); err != nil { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { return } logger.Error( ctx, "could not poll interface", logger.E(errors.WithStack(err)), logger.F("iface", iface.Name), ) } }(pollCtx, iface) } wg.Wait() return nil } func (d *Service) queryIface(iface net.Interface, disableIPv4, disableIPv6 bool) error { err := mdns.Query(&mdns.QueryParam{ Service: "_googlecast._tcp", Domain: "local", Timeout: 3 * time.Second, Entries: d.entriesCh, Interface: &iface, DisableIPv6: disableIPv6, DisableIPv4: disableIPv4, }) if err != nil { return errors.WithStack(err) } return nil } func (d *Service) pollInterface(ctx context.Context, iface net.Interface, interval time.Duration, disableIPv4, disableIPv6 bool) error { ticker := time.NewTicker(interval) for { select { case <-ticker.C: if err := d.queryIface(iface, disableIPv4, disableIPv6); err != nil { return errors.WithStack(err) } case <-ctx.Done(): if err := ctx.Err(); err != nil { return errors.WithStack(err) } return nil } } } func (d *Service) Stop() { if d.stopPeriodic != nil { close(d.stopPeriodic) d.stopPeriodic = nil } } func (d *Service) Found() chan *cast.Client { return d.found } func (d *Service) listener(ctx context.Context) { for entry := range d.entriesCh { name := strings.Split(entry.Name, "._googlecast") // Skip everything that doesn't have googlecast in the fdqn if len(name) < 2 { continue } log.Printf("New entry: %#v\n", entry) client := cast.NewClient(entry.AddrV4, entry.Port) info := decodeTxtRecord(entry.Info) client.SetName(info["fn"]) client.SetInfo(info) select { case d.found <- client: case <-time.After(time.Second): case <-ctx.Done(): break } } } func decodeDnsEntry(text string) string { text = strings.Replace(text, `\.`, ".", -1) text = strings.Replace(text, `\ `, " ", -1) re := regexp.MustCompile(`([\\][0-9][0-9][0-9])`) text = re.ReplaceAllStringFunc(text, func(source string) string { i, err := strconv.Atoi(source[1:]) if err != nil { return "" } return string([]byte{byte(i)}) }) return text } func decodeTxtRecord(txt string) map[string]string { m := make(map[string]string) s := strings.Split(txt, "|") for _, v := range s { s := strings.Split(v, "=") if len(s) == 2 { m[s[0]] = s[1] } } return m } func isIPv4(ip net.IP) bool { return strings.Count(ip.String(), ":") < 2 } func isIPv6(ip net.IP) bool { return strings.Count(ip.String(), ":") >= 2 } func findMulticastInterfaces(ctx context.Context) ([]net.Interface, error) { ifaces, err := net.Interfaces() if err != nil { return nil, nil } multicastIfaces := make([]net.Interface, 0) for _, iface := range ifaces { if iface.Flags&net.FlagLoopback == net.FlagLoopback { continue } if iface.Flags&net.FlagRunning != net.FlagRunning { continue } if iface.Flags&net.FlagMulticast != net.FlagMulticast { continue } multicastIfaces = append(multicastIfaces, iface) } return multicastIfaces, nil } func retrieveSupportedProtocols(iface net.Interface) (bool, bool, error) { adresses, err := iface.Addrs() if err != nil { return false, false, errors.WithStack(err) } hasIPv4 := false hasIPv6 := false for _, addr := range adresses { ip, _, err := net.ParseCIDR(addr.String()) if err != nil { return false, false, errors.WithStack(err) } if isIPv4(ip) { hasIPv4 = true } if isIPv6(ip) { hasIPv6 = true } if hasIPv4 && hasIPv6 { return hasIPv4, hasIPv6, nil } } return hasIPv4, hasIPv6, nil }