diff --git a/internal/proxy/director/layer/authn/network/authenticator.go b/internal/proxy/director/layer/authn/network/authenticator.go index f32cd7c..37e619f 100644 --- a/internal/proxy/director/layer/authn/network/authenticator.go +++ b/internal/proxy/director/layer/authn/network/authenticator.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/http" + "strings" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn" "forge.cadoles.com/cadoles/bouncer/internal/store" @@ -49,9 +50,15 @@ func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, lay } func (a *Authenticator) matchAnyAuthorizedCIDRs(ctx context.Context, remoteHostPort string, CIDRs []string) (bool, error) { - remoteHost, _, err := net.SplitHostPort(remoteHostPort) - if err != nil { - return false, errors.WithStack(err) + var remoteHost string + if strings.Contains(remoteHostPort, ":") { + var err error + remoteHost, _, err = net.SplitHostPort(remoteHostPort) + if err != nil { + return false, errors.WithStack(err) + } + } else { + remoteHost = remoteHostPort } remoteAddr := net.ParseIP(remoteHost) diff --git a/internal/proxy/director/layer/authn/network/authenticator_test.go b/internal/proxy/director/layer/authn/network/authenticator_test.go new file mode 100644 index 0000000..395668a --- /dev/null +++ b/internal/proxy/director/layer/authn/network/authenticator_test.go @@ -0,0 +1,60 @@ +package network + +import ( + "context" + "fmt" + "testing" + + "github.com/pkg/errors" +) + +func TestMatchAuthorizedCIDRs(t *testing.T) { + + type testCase struct { + RemoteHostPort string + AuthorizedCIDRs []string + ExpectedResult bool + ExpectedError error + } + + testCases := []testCase{ + { + RemoteHostPort: "192.168.1.15", + AuthorizedCIDRs: []string{ + "192.168.1.0/24", + }, + ExpectedResult: true, + }, + { + RemoteHostPort: "192.168.1.15:43349", + AuthorizedCIDRs: []string{ + "192.168.1.0/24", + }, + ExpectedResult: true, + }, + { + RemoteHostPort: "192.168.1.15:43349", + AuthorizedCIDRs: []string{ + "192.168.1.5/32", + }, + ExpectedResult: false, + }, + } + + auth := Authenticator{} + ctx := context.Background() + + for idx, tc := range testCases { + t.Run(fmt.Sprintf("Case #%d", idx), func(t *testing.T) { + result, err := auth.matchAnyAuthorizedCIDRs(ctx, tc.RemoteHostPort, tc.AuthorizedCIDRs) + + if g, e := result, tc.ExpectedResult; e != g { + t.Errorf("result: expected '%v', got '%v'", e, g) + } + + if e, g := tc.ExpectedError, err; !errors.Is(err, tc.ExpectedError) { + t.Errorf("err: expected '%v', got '%v'", e, g) + } + }) + } +}