edge/pkg/module/cast/chromecast/discovery.go

237 lines
4.6 KiB
Go
Raw Permalink Normal View History

package chromecast
import (
"context"
"net"
"strings"
"sync"
"time"
"gitlab.com/wpetit/goweb/logger"
"github.com/barnybug/go-cast"
"github.com/barnybug/go-cast/log"
"github.com/hashicorp/mdns"
"github.com/pkg/errors"
"github.com/wlynxg/anet"
)
const (
serviceDiscoveryPollingInterval time.Duration = 500 * time.Millisecond
)
type Discovery struct {
found chan *cast.Client
entriesCh chan *mdns.ServiceEntry
stopPeriodic chan struct{}
}
func NewDiscovery(ctx context.Context) *Discovery {
d := &Discovery{
found: make(chan *cast.Client),
entriesCh: make(chan *mdns.ServiceEntry, 10),
}
go d.listener(ctx)
return d
}
func (d *Discovery) Run(ctx context.Context, interval time.Duration) error {
ifaces, err := findMulticastInterfaces(ctx)
if err != nil {
return errors.WithStack(err)
}
var wg sync.WaitGroup
for _, iface := range ifaces {
hasIPv4, hasIPv6, err := retrieveSupportedProtocols(iface)
if err != nil {
return errors.WithStack(err)
}
if !hasIPv4 && !hasIPv6 {
continue
}
if err := d.queryIface(iface, !hasIPv4, !hasIPv6); err != nil {
return errors.WithStack(err)
}
pollCtx, cancel := context.WithCancel(ctx)
defer cancel()
wg.Add(1)
go func(ctx context.Context, iface net.Interface) {
defer wg.Done()
if err := d.pollInterface(ctx, iface, interval, !hasIPv4, !hasIPv6); err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return
}
logger.Error(
ctx, "could not poll interface",
logger.CapturedE(errors.WithStack(err)), logger.F("iface", iface.Name),
)
}
}(pollCtx, iface)
}
wg.Wait()
return nil
}
func (d *Discovery) queryIface(iface net.Interface, disableIPv4, disableIPv6 bool) error {
err := mdns.Query(&mdns.QueryParam{
Service: "_googlecast._tcp",
Domain: "local",
Timeout: 3 * time.Second,
Entries: d.entriesCh,
Interface: &iface,
DisableIPv6: disableIPv6,
DisableIPv4: disableIPv4,
})
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (d *Discovery) pollInterface(ctx context.Context, iface net.Interface, interval time.Duration, disableIPv4, disableIPv6 bool) error {
ticker := time.NewTicker(interval)
for {
select {
case <-ticker.C:
if err := d.queryIface(iface, disableIPv4, disableIPv6); err != nil {
return errors.WithStack(err)
}
case <-ctx.Done():
if err := ctx.Err(); err != nil {
return errors.WithStack(err)
}
return nil
}
}
}
func (d *Discovery) Stop() {
if d.stopPeriodic != nil {
close(d.stopPeriodic)
d.stopPeriodic = nil
}
}
func (d *Discovery) Found() chan *cast.Client {
return d.found
}
func (d *Discovery) listener(ctx context.Context) {
for entry := range d.entriesCh {
name := strings.Split(entry.Name, "._googlecast")
// Skip everything that doesn't have googlecast in the fdqn
if len(name) < 2 {
continue
}
log.Printf("New entry: %#v\n", entry)
client := cast.NewClient(entry.AddrV4, entry.Port)
info := decodeTxtRecord(entry.Info)
client.SetName(info["fn"])
client.SetInfo(info)
select {
case d.found <- client:
case <-time.After(time.Second):
case <-ctx.Done():
return
}
}
}
func decodeTxtRecord(txt string) map[string]string {
m := make(map[string]string)
s := strings.Split(txt, "|")
for _, v := range s {
s := strings.Split(v, "=")
if len(s) == 2 {
m[s[0]] = s[1]
}
}
return m
}
func isIPv4(ip net.IP) bool {
return strings.Count(ip.String(), ":") < 2
}
func isIPv6(ip net.IP) bool {
return strings.Count(ip.String(), ":") >= 2
}
func findMulticastInterfaces(ctx context.Context) ([]net.Interface, error) {
ifaces, err := anet.Interfaces()
if err != nil {
return nil, nil
}
multicastIfaces := make([]net.Interface, 0)
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback == net.FlagLoopback {
continue
}
if iface.Flags&net.FlagRunning != net.FlagRunning {
continue
}
if iface.Flags&net.FlagMulticast != net.FlagMulticast {
continue
}
multicastIfaces = append(multicastIfaces, iface)
}
return multicastIfaces, nil
}
func retrieveSupportedProtocols(iface net.Interface) (bool, bool, error) {
adresses, err := anet.InterfaceAddrsByInterface(&iface)
if err != nil {
return false, false, errors.WithStack(err)
}
hasIPv4 := false
hasIPv6 := false
for _, addr := range adresses {
ip, _, err := net.ParseCIDR(addr.String())
if err != nil {
return false, false, errors.WithStack(err)
}
if isIPv4(ip) {
hasIPv4 = true
}
if isIPv6(ip) {
hasIPv6 = true
}
if hasIPv4 && hasIPv6 {
return hasIPv4, hasIPv6, nil
}
}
return hasIPv4, hasIPv6, nil
}