237 lines
4.6 KiB
Go
237 lines
4.6 KiB
Go
package chromecast
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
|
|
"github.com/barnybug/go-cast"
|
|
"github.com/barnybug/go-cast/log"
|
|
"github.com/hashicorp/mdns"
|
|
"github.com/pkg/errors"
|
|
"github.com/wlynxg/anet"
|
|
)
|
|
|
|
const (
|
|
serviceDiscoveryPollingInterval time.Duration = 500 * time.Millisecond
|
|
)
|
|
|
|
type Discovery struct {
|
|
found chan *cast.Client
|
|
entriesCh chan *mdns.ServiceEntry
|
|
|
|
stopPeriodic chan struct{}
|
|
}
|
|
|
|
func NewDiscovery(ctx context.Context) *Discovery {
|
|
d := &Discovery{
|
|
found: make(chan *cast.Client),
|
|
entriesCh: make(chan *mdns.ServiceEntry, 10),
|
|
}
|
|
|
|
go d.listener(ctx)
|
|
return d
|
|
}
|
|
|
|
func (d *Discovery) Run(ctx context.Context, interval time.Duration) error {
|
|
ifaces, err := findMulticastInterfaces(ctx)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
for _, iface := range ifaces {
|
|
hasIPv4, hasIPv6, err := retrieveSupportedProtocols(iface)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
if !hasIPv4 && !hasIPv6 {
|
|
continue
|
|
}
|
|
|
|
if err := d.queryIface(iface, !hasIPv4, !hasIPv6); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
pollCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
wg.Add(1)
|
|
|
|
go func(ctx context.Context, iface net.Interface) {
|
|
defer wg.Done()
|
|
|
|
if err := d.pollInterface(ctx, iface, interval, !hasIPv4, !hasIPv6); err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
|
return
|
|
}
|
|
|
|
logger.Error(
|
|
ctx, "could not poll interface",
|
|
logger.CapturedE(errors.WithStack(err)), logger.F("iface", iface.Name),
|
|
)
|
|
}
|
|
}(pollCtx, iface)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *Discovery) queryIface(iface net.Interface, disableIPv4, disableIPv6 bool) error {
|
|
err := mdns.Query(&mdns.QueryParam{
|
|
Service: "_googlecast._tcp",
|
|
Domain: "local",
|
|
Timeout: 3 * time.Second,
|
|
Entries: d.entriesCh,
|
|
Interface: &iface,
|
|
DisableIPv6: disableIPv6,
|
|
DisableIPv4: disableIPv4,
|
|
})
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *Discovery) pollInterface(ctx context.Context, iface net.Interface, interval time.Duration, disableIPv4, disableIPv6 bool) error {
|
|
ticker := time.NewTicker(interval)
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if err := d.queryIface(iface, disableIPv4, disableIPv6); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
case <-ctx.Done():
|
|
if err := ctx.Err(); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Discovery) Stop() {
|
|
if d.stopPeriodic != nil {
|
|
close(d.stopPeriodic)
|
|
d.stopPeriodic = nil
|
|
}
|
|
}
|
|
|
|
func (d *Discovery) Found() chan *cast.Client {
|
|
return d.found
|
|
}
|
|
|
|
func (d *Discovery) listener(ctx context.Context) {
|
|
for entry := range d.entriesCh {
|
|
name := strings.Split(entry.Name, "._googlecast")
|
|
// Skip everything that doesn't have googlecast in the fdqn
|
|
if len(name) < 2 {
|
|
continue
|
|
}
|
|
|
|
log.Printf("New entry: %#v\n", entry)
|
|
client := cast.NewClient(entry.AddrV4, entry.Port)
|
|
info := decodeTxtRecord(entry.Info)
|
|
client.SetName(info["fn"])
|
|
client.SetInfo(info)
|
|
|
|
select {
|
|
case d.found <- client:
|
|
case <-time.After(time.Second):
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func decodeTxtRecord(txt string) map[string]string {
|
|
m := make(map[string]string)
|
|
|
|
s := strings.Split(txt, "|")
|
|
for _, v := range s {
|
|
s := strings.Split(v, "=")
|
|
if len(s) == 2 {
|
|
m[s[0]] = s[1]
|
|
}
|
|
}
|
|
|
|
return m
|
|
}
|
|
|
|
func isIPv4(ip net.IP) bool {
|
|
return strings.Count(ip.String(), ":") < 2
|
|
}
|
|
|
|
func isIPv6(ip net.IP) bool {
|
|
return strings.Count(ip.String(), ":") >= 2
|
|
}
|
|
|
|
func findMulticastInterfaces(ctx context.Context) ([]net.Interface, error) {
|
|
ifaces, err := anet.Interfaces()
|
|
if err != nil {
|
|
return nil, nil
|
|
}
|
|
|
|
multicastIfaces := make([]net.Interface, 0)
|
|
|
|
for _, iface := range ifaces {
|
|
if iface.Flags&net.FlagLoopback == net.FlagLoopback {
|
|
continue
|
|
}
|
|
|
|
if iface.Flags&net.FlagRunning != net.FlagRunning {
|
|
continue
|
|
}
|
|
|
|
if iface.Flags&net.FlagMulticast != net.FlagMulticast {
|
|
continue
|
|
}
|
|
|
|
multicastIfaces = append(multicastIfaces, iface)
|
|
}
|
|
|
|
return multicastIfaces, nil
|
|
}
|
|
|
|
func retrieveSupportedProtocols(iface net.Interface) (bool, bool, error) {
|
|
adresses, err := anet.InterfaceAddrsByInterface(&iface)
|
|
if err != nil {
|
|
return false, false, errors.WithStack(err)
|
|
}
|
|
|
|
hasIPv4 := false
|
|
hasIPv6 := false
|
|
|
|
for _, addr := range adresses {
|
|
ip, _, err := net.ParseCIDR(addr.String())
|
|
if err != nil {
|
|
return false, false, errors.WithStack(err)
|
|
}
|
|
|
|
if isIPv4(ip) {
|
|
hasIPv4 = true
|
|
}
|
|
|
|
if isIPv6(ip) {
|
|
hasIPv6 = true
|
|
}
|
|
|
|
if hasIPv4 && hasIPv6 {
|
|
return hasIPv4, hasIPv6, nil
|
|
}
|
|
}
|
|
|
|
return hasIPv4, hasIPv6, nil
|
|
}
|