go-captiveportal/arp/table.go

155 lines
2.9 KiB
Go

package arp
import (
"context"
"net"
"sync"
"time"
"github.com/irai/arp"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type Table struct {
mutex sync.RWMutex
onOffline func(id string)
entries map[string]string
}
type WatchConfig struct {
RouterIP string
RouterNetwork string
HostIP string
NIC string
OfflineDeadline time.Duration
ProbeInterval time.Duration
PurgeDeadline time.Duration
}
func (t *Table) Watch(ctx context.Context, config WatchConfig) error {
ipRouter := net.ParseIP(config.RouterIP).To4()
ipHost := net.ParseIP(config.HostIP).To4()
_, lanNetwork, err := net.ParseCIDR(config.RouterNetwork)
if err != nil {
return errors.WithStack(err)
}
hostMac, err := t.getMACAddress(config.NIC)
if err != nil {
return errors.WithStack(err)
}
h, err := arp.New(arp.Config{
HostMAC: hostMac,
HostIP: ipHost,
HomeLAN: *lanNetwork,
NIC: config.NIC,
RouterIP: ipRouter,
OfflineDeadline: config.OfflineDeadline,
ProbeInterval: config.ProbeInterval,
PurgeDeadline: config.PurgeDeadline,
})
if err != nil {
return errors.WithStack(err)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
notifications := make(chan arp.MACEntry)
h.AddNotificationChannel(notifications)
go t.handleNotifications(ctx, notifications)
if err := h.ListenAndServe(ctx); err != nil {
return errors.WithStack(err)
}
return nil
}
func (t *Table) FindMACByIP(ip string) (string, error) {
t.mutex.RLock()
defer t.mutex.RUnlock()
mac, exists := t.entries[ip]
if !exists {
return "", errors.WithStack(ErrNotFound)
}
return mac, nil
}
func (t *Table) Count() int {
t.mutex.RLock()
defer t.mutex.RUnlock()
return len(t.entries)
}
func (t *Table) handleNotifications(ctx context.Context, notifications chan arp.MACEntry) {
for {
select {
case <-ctx.Done():
return
case entry := <-notifications:
logger.Debug(ctx, "arp notification", logger.F("entry", entry))
ip := entry.IP().String()
if entry.Online {
t.add(ip, entry.MAC.String())
} else {
t.delete(ip)
}
}
}
}
func (t *Table) add(ip string, mac string) {
t.mutex.Lock()
defer t.mutex.Unlock()
t.entries[ip] = mac
}
func (t *Table) delete(ip string) {
t.mutex.Lock()
defer t.mutex.Unlock()
mac, exists := t.entries[ip]
delete(t.entries, ip)
if exists && t.onOffline != nil {
t.onOffline(mac)
}
}
func (t *Table) getMACAddress(nic string) (net.HardwareAddr, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, errors.WithStack(err)
}
for _, iface := range ifaces {
if iface.Name == nic {
return iface.HardwareAddr, nil
}
}
return nil, errors.WithStack(ErrIfaceNotFound)
}
func (t *Table) OnOffline(fn func(id string)) {
t.onOffline = fn
}
func NewTable() *Table {
return &Table{
entries: make(map[string]string),
}
}