feat: captive portal middleware with ARP table based client identifier
This commit is contained in:
8
arp/error.go
Normal file
8
arp/error.go
Normal file
@ -0,0 +1,8 @@
|
||||
package arp
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNotFound = errors.New("not found")
|
||||
ErrIfaceNotFound = errors.New("iface not found")
|
||||
)
|
42
arp/identifier.go
Normal file
42
arp/identifier.go
Normal file
@ -0,0 +1,42 @@
|
||||
package arp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Identifier struct {
|
||||
arp *Table
|
||||
}
|
||||
|
||||
func (i *Identifier) Watch(ctx context.Context, config WatchConfig) error {
|
||||
return i.arp.Watch(ctx, config)
|
||||
}
|
||||
|
||||
func (i *Identifier) Identify(r *http.Request) (string, error) {
|
||||
ip := strings.SplitN(r.RemoteAddr, ":", 2)[0]
|
||||
|
||||
mac, err := i.arp.FindMACByIP(ip)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
return mac, nil
|
||||
}
|
||||
|
||||
func (i *Identifier) OnOffline(fn func(id string)) {
|
||||
i.arp.OnOffline(fn)
|
||||
}
|
||||
|
||||
func NewIdentifier() *Identifier {
|
||||
return &Identifier{
|
||||
arp: NewTable(),
|
||||
}
|
||||
}
|
153
arp/table.go
Normal file
153
arp/table.go
Normal file
@ -0,0 +1,153 @@
|
||||
package arp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/irai/arp"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
mutex sync.RWMutex
|
||||
onOffline func(id string)
|
||||
entries map[string]string
|
||||
}
|
||||
|
||||
type WatchConfig struct {
|
||||
RouterIP string
|
||||
RouterNetwork string
|
||||
HostIP string
|
||||
NIC string
|
||||
OfflineDeadline time.Duration
|
||||
ProbeInterval time.Duration
|
||||
PurgeDeadline time.Duration
|
||||
}
|
||||
|
||||
func (t *Table) Watch(ctx context.Context, config WatchConfig) error {
|
||||
ipRouter := net.ParseIP(config.RouterIP).To4()
|
||||
ipHost := net.ParseIP(config.HostIP).To4()
|
||||
|
||||
_, lanNetwork, err := net.ParseCIDR(config.RouterNetwork)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
hostMac, err := t.getMACAddress(config.NIC)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
h, err := arp.New(arp.Config{
|
||||
HostMAC: hostMac,
|
||||
HostIP: ipHost,
|
||||
HomeLAN: *lanNetwork,
|
||||
NIC: config.NIC,
|
||||
RouterIP: ipRouter,
|
||||
OfflineDeadline: config.OfflineDeadline,
|
||||
ProbeInterval: config.ProbeInterval,
|
||||
PurgeDeadline: config.PurgeDeadline,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := h.ListenAndServe(ctx); err != nil {
|
||||
panic(errors.WithStack(err))
|
||||
}
|
||||
}()
|
||||
|
||||
notifications := make(chan arp.MACEntry)
|
||||
h.AddNotificationChannel(notifications)
|
||||
|
||||
go t.handleNotifications(ctx, notifications)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Table) FindMACByIP(ip string) (string, error) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
mac, exists := t.entries[ip]
|
||||
if !exists {
|
||||
return "", errors.WithStack(ErrNotFound)
|
||||
}
|
||||
|
||||
return mac, nil
|
||||
}
|
||||
|
||||
func (t *Table) Count() int {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
return len(t.entries)
|
||||
}
|
||||
|
||||
func (t *Table) handleNotifications(ctx context.Context, notifications chan arp.MACEntry) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case entry := <-notifications:
|
||||
logger.Debug(ctx, "arp notification", logger.F("entry", entry))
|
||||
|
||||
ip := entry.IP().String()
|
||||
|
||||
if entry.Online {
|
||||
t.add(ip, entry.MAC.String())
|
||||
} else {
|
||||
t.delete(ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) add(ip string, mac string) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
t.entries[ip] = mac
|
||||
}
|
||||
|
||||
func (t *Table) delete(ip string) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
mac, exists := t.entries[ip]
|
||||
|
||||
delete(t.entries, ip)
|
||||
|
||||
if exists && t.onOffline != nil {
|
||||
t.onOffline(mac)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) getMACAddress(nic string) (net.HardwareAddr, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Name == nic {
|
||||
return iface.HardwareAddr, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.WithStack(ErrIfaceNotFound)
|
||||
}
|
||||
|
||||
func (t *Table) OnOffline(fn func(id string)) {
|
||||
t.onOffline = fn
|
||||
}
|
||||
|
||||
func NewTable() *Table {
|
||||
return &Table{
|
||||
entries: make(map[string]string),
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user