package mdns import ( "fmt" "log" "net" "strings" "sync/atomic" "time" "github.com/miekg/dns" "github.com/pkg/errors" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) const ( ipv4mdns = "224.0.0.251" ipv6mdns = "ff02::fb" mdnsPort = 5353 forceUnicastResponses = false ) var ( ipv4Addr = &net.UDPAddr{ IP: net.ParseIP(ipv4mdns), Port: mdnsPort, } ipv6Addr = &net.UDPAddr{ IP: net.ParseIP(ipv6mdns), Port: mdnsPort, } ) // ServiceEntry is returned after we query for a service type ServiceEntry struct { Name string Host string AddrV4 net.IP AddrV6 net.IP Port int Info string InfoFields []string Addr net.IP // @Deprecated hasTXT bool sent bool } // complete is used to check if we have all the info we need func (s *ServiceEntry) complete() bool { return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT } // QueryParam is used to customize how a Lookup is performed type QueryParam struct { Service string // Service to lookup Domain string // Lookup domain, default "local" Timeout time.Duration // Lookup timeout, default 1 second Interface *net.Interface // Multicast interface to use Entries chan<- *ServiceEntry // Entries Channel WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC DisableIPv4 bool // Whether to disable usage of IPv4 for MDNS operations. Does not affect discovered addresses. DisableIPv6 bool // Whether to disable usage of IPv6 for MDNS operations. Does not affect discovered addresses. } // DefaultParams is used to return a default set of QueryParam's func DefaultParams(service string) *QueryParam { return &QueryParam{ Service: service, Domain: "local", Timeout: time.Second, Entries: make(chan *ServiceEntry), WantUnicastResponse: false, // TODO(reddaly): Change this default. DisableIPv4: false, DisableIPv6: false, } } // Query looks up a given service, in a domain, waiting at most // for a timeout before finishing the query. The results are streamed // to a channel. Sends will not block, so clients should make sure to // either read or buffer. func Query(params *QueryParam) error { // Create a new client client, err := newClient(!params.DisableIPv4, !params.DisableIPv6) if err != nil { return err } defer client.Close() // Set the multicast interface if params.Interface != nil { if err := client.setInterface(params.Interface); err != nil { return err } } // Ensure defaults are set if params.Domain == "" { params.Domain = "local" } if params.Timeout == 0 { params.Timeout = time.Second } // Run the query return client.query(params) } // Lookup is the same as Query, however it uses all the default parameters func Lookup(service string, entries chan<- *ServiceEntry) error { params := DefaultParams(service) params.Entries = entries return Query(params) } // Client provides a query interface that can be used to // search for service providers using mDNS type client struct { use_ipv4 bool use_ipv6 bool ipv4UnicastConn *net.UDPConn ipv6UnicastConn *net.UDPConn ipv4MulticastConn *net.UDPConn ipv6MulticastConn *net.UDPConn closed int32 closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. } // NewClient creates a new mdns Client that can be used to query // for records func newClient(v4 bool, v6 bool) (*client, error) { if !v4 && !v6 { return nil, fmt.Errorf("Must enable at least one of IPv4 and IPv6 querying") } // TODO(reddaly): At least attempt to bind to the port required in the spec. // Create a IPv4 listener var uconn4 *net.UDPConn var uconn6 *net.UDPConn var mconn4 *net.UDPConn var mconn6 *net.UDPConn var err error if v4 { uconn4, err = net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp4 port: %+v", errors.WithStack(err)) } } if v6 { uconn6, err = net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp6 port: %+v", errors.WithStack(err)) } } if uconn4 == nil && uconn6 == nil { return nil, fmt.Errorf("failed to bind to any unicast udp port") } if v4 { mconn4, err = net.ListenMulticastUDP("udp4", nil, ipv4Addr) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp4 port: %+v", errors.WithStack(err)) } } if v6 { mconn6, err = net.ListenMulticastUDP("udp6", nil, ipv6Addr) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp6 port: %+v", errors.WithStack(err)) } } if mconn4 == nil && mconn6 == nil { return nil, fmt.Errorf("failed to bind to any multicast udp port") } c := &client{ use_ipv4: v4, use_ipv6: v6, ipv4MulticastConn: mconn4, ipv6MulticastConn: mconn6, ipv4UnicastConn: uconn4, ipv6UnicastConn: uconn6, closedCh: make(chan struct{}), } return c, nil } // Close is used to cleanup the client func (c *client) Close() error { if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) { // something else already closed it return nil } close(c.closedCh) if c.ipv4UnicastConn != nil { c.ipv4UnicastConn.Close() } if c.ipv6UnicastConn != nil { c.ipv6UnicastConn.Close() } if c.ipv4MulticastConn != nil { c.ipv4MulticastConn.Close() } if c.ipv6MulticastConn != nil { c.ipv6MulticastConn.Close() } return nil } // setInterface is used to set the query interface, uses system // default if not provided func (c *client) setInterface(iface *net.Interface) error { if c.use_ipv4 { p := ipv4.NewPacketConn(c.ipv4UnicastConn) if err := p.SetMulticastInterface(iface); err != nil { return err } p = ipv4.NewPacketConn(c.ipv4MulticastConn) if err := p.SetMulticastInterface(iface); err != nil { return err } } if c.use_ipv6 { p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) if err := p2.SetMulticastInterface(iface); err != nil { return err } p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) if err := p2.SetMulticastInterface(iface); err != nil { return err } } return nil } // query is used to perform a lookup and stream results func (c *client) query(params *QueryParam) error { // Create the service name serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) // Start listening for response packets msgCh := make(chan *dns.Msg, 32) if c.use_ipv4 { go c.recv(c.ipv4UnicastConn, msgCh) go c.recv(c.ipv4MulticastConn, msgCh) } if c.use_ipv6 { go c.recv(c.ipv6UnicastConn, msgCh) go c.recv(c.ipv6MulticastConn, msgCh) } // Send the query m := new(dns.Msg) m.SetQuestion(serviceAddr, dns.TypePTR) // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question // Section // // In the Question Section of a Multicast DNS query, the top bit of the qclass // field is used to indicate that unicast responses are preferred for this // particular question. (See Section 5.4.) if params.WantUnicastResponse { m.Question[0].Qclass |= 1 << 15 } m.RecursionDesired = false if err := c.sendQuery(m); err != nil { return err } // Map the in-progress responses inprogress := make(map[string]*ServiceEntry) // Listen until we reach the timeout finish := time.After(params.Timeout) for { select { case resp := <-msgCh: var inp *ServiceEntry for _, answer := range append(resp.Answer, resp.Extra...) { // TODO(reddaly): Check that response corresponds to serviceAddr? switch rr := answer.(type) { case *dns.PTR: // Create new entry for this inp = ensureName(inprogress, rr.Ptr) case *dns.SRV: // Check for a target mismatch if rr.Target != rr.Hdr.Name { alias(inprogress, rr.Hdr.Name, rr.Target) } // Get the port inp = ensureName(inprogress, rr.Hdr.Name) inp.Host = rr.Target inp.Port = int(rr.Port) case *dns.TXT: // Pull out the txt inp = ensureName(inprogress, rr.Hdr.Name) inp.Info = strings.Join(rr.Txt, "|") inp.InfoFields = rr.Txt inp.hasTXT = true case *dns.A: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) inp.Addr = rr.A // @Deprecated inp.AddrV4 = rr.A case *dns.AAAA: // Pull out the IP inp = ensureName(inprogress, rr.Hdr.Name) inp.Addr = rr.AAAA // @Deprecated inp.AddrV6 = rr.AAAA } } if inp == nil { continue } // Check if this entry is complete if inp.complete() { if inp.sent { continue } inp.sent = true select { case params.Entries <- inp: default: } } else { // Fire off a node specific query m := new(dns.Msg) m.SetQuestion(inp.Name, dns.TypePTR) m.RecursionDesired = false if err := c.sendQuery(m); err != nil { log.Printf("[ERR] mdns: Failed to query instance %s: %+v", inp.Name, errors.WithStack(err)) } } case <-finish: return nil } } } // sendQuery is used to multicast a query out func (c *client) sendQuery(q *dns.Msg) error { buf, err := q.Pack() if err != nil { return err } if c.ipv4UnicastConn != nil { _, err = c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) if err != nil { return err } } if c.ipv6UnicastConn != nil { _, err = c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) if err != nil { return err } } return nil } // recv is used to receive until we get a shutdown func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { if l == nil { return } buf := make([]byte, 65536) for atomic.LoadInt32(&c.closed) == 0 { n, err := l.Read(buf) if atomic.LoadInt32(&c.closed) == 1 { return } if err != nil { log.Printf("[ERR] mdns: Failed to read packet: %+v", errors.WithStack(err)) continue } msg := new(dns.Msg) if err := msg.Unpack(buf[:n]); err != nil { log.Printf("[ERR] mdns: Failed to unpack packet: %+v", errors.WithStack(err)) continue } select { case msgCh <- msg: case <-c.closedCh: return } } } // ensureName is used to ensure the named node is in progress func ensureName(inprogress map[string]*ServiceEntry, name string) *ServiceEntry { if inp, ok := inprogress[name]; ok { return inp } inp := &ServiceEntry{ Name: name, } inprogress[name] = inp return inp } // alias is used to setup an alias between two entries func alias(inprogress map[string]*ServiceEntry, src, dst string) { srcEntry := ensureName(inprogress, src) inprogress[dst] = srcEntry } // trimDot is used to trim the dots from the start or end of a string func trimDot(s string) string { return strings.Trim(s, ".") }