hydra-werther/internal/ldapclient/ldapclient_test.go

589 lines
16 KiB
Go

package ldapclient
import (
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"
"github.com/pkg/errors"
)
var (
errBindUser = fmt.Errorf("bind user error")
errSearchUser = fmt.Errorf("search user error")
errSearchRoles = fmt.Errorf("search user roles error")
users = []map[string]interface{}{
{
"dn": "user1",
"pass": "user1",
"a": "valA",
"b": "valB",
"c": "valC",
},
{
"dn": "user2",
"pass": "user2",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"dn": "CN=role2,OU=app1,OU=test,DC=local", "test-roles-attr": "r2"},
},
},
{
"dn": "user3",
"pass": "user3",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"dn": "CN=role2,OU=app1,OU=test,DC=local", "test-roles-attr": "r2"},
{"dn": "CN=role3,OU=app2,OU=test,DC=local", "test-roles-attr": "r3"},
{"dn": "CN=role4,OU=app2,OU=test,DC=local", "test-roles-attr": "r4"},
},
},
{
"dn": "user4",
"pass": "user4",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"test-roles-attr": "r2"},
},
},
{
"dn": "user5",
"pass": "user5",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"dn": "CN=role2,OU=app1,OU=test,DC=local"},
},
},
{
"dn": "user6",
"pass": "user6",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=test,DC=local", "test-roles-attr": "r1"},
},
},
{
"dn": "serviceUser",
"pass": "servicePass",
},
}
)
func TestAuthenticate(t *testing.T) {
testCases := []struct {
name string
connector *testConnector
bindDN string
bindPass string
user string
pass string
wantErr error
wantAuth bool
}{
{
name: "username is empty",
connector: newTestConnector("ep1", &testConn{users: users}),
},
{
name: "password is empty",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
},
{
name: "connection timeout",
connector: newTestConnector("ep1", fmt.Errorf("failed to connect to endpoint")),
user: "user1",
pass: "user1",
wantErr: errConnectionTimeout,
},
{
name: "search user error",
connector: newTestConnector("ep1", &testConn{userErr: errSearchUser}),
user: "user1",
pass: "user1",
wantErr: errSearchUser,
},
{
name: "user is not found",
connector: newTestConnector("ep1", &testConn{}),
user: "user1",
pass: "user1",
},
{
name: "authentication error",
connector: newTestConnector("ep1", &testConn{users: users, bindErr: errBindUser}),
user: "user1",
pass: "user1",
wantErr: errBindUser,
},
{
name: "invalid password",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
pass: "invalid",
},
{
name: "success auth",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
pass: "user1",
wantAuth: true,
},
{
name: "auth with invalid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "invalid",
user: "user1",
pass: "user1",
wantErr: errInvalidCredentials,
},
{
name: "auth with valid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "servicePass",
user: "user1",
pass: "user1",
wantAuth: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client := New(Config{Endpoints: tc.connector.Endpoints(), BindDN: tc.bindDN, BindPass: tc.bindPass})
client.connector = tc.connector
ok, err := client.Authenticate(context.Background(), tc.user, tc.pass)
if ok != tc.wantAuth {
t.Errorf("got auth: %t, want auth: %t", ok, tc.wantAuth)
}
if tc.wantErr != nil {
if err == nil {
t.Fatalf("\ngot no errors\nwant error:\n\t%s", tc.wantErr)
}
err = errors.Cause(err)
if err != tc.wantErr {
t.Fatalf("\ngot error:\n\t%s\nwant error:\n\t%s", err, tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
})
}
}
func TestAuthenticateWhenMultipleEndpointsFailed(t *testing.T) {
connector := newTestConnector("ep1", fmt.Errorf("error"), "ep2", fmt.Errorf("error"))
client := New(Config{Endpoints: connector.Endpoints()})
client.connector = connector
_, err := client.Authenticate(context.Background(), "user1", "user1")
if err == nil {
t.Fatalf("\ngot no errors\nwant error:\n\t%s", errConnectionTimeout)
}
err = errors.Cause(err)
if err != errConnectionTimeout {
t.Fatalf("\ngot error:\n\t%s\nwant error:\n\t%s", err, errConnectionTimeout)
}
}
func TestAuthenticateWhenOneEndpointFailedAndOneSuccess(t *testing.T) {
ep2 := &testConn{users: users}
connector := newTestConnector("ep1", fmt.Errorf("error"), "ep2", ep2)
client := New(Config{Endpoints: connector.Endpoints()})
client.connector = connector
ok, err := client.Authenticate(context.Background(), "user1", "user1")
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
if !ok {
t.Errorf("got auth: %t, want auth: true", ok)
}
if !ep2.authRequest {
t.Error("\ngot: endpoint \"ep2\" is not called, want: endpoint \"ep2\" is called")
}
}
func TestAuthenticateWhenMultipleEndpointsSuccess(t *testing.T) {
ep1 := &testConn{users: users}
ep2 := &testConn{users: users}
connector := newTestConnector("ep1", ep1, "ep2", ep2)
client := New(Config{Endpoints: connector.Endpoints()})
client.connector = connector
ok, err := client.Authenticate(context.Background(), "user1", "user1")
// Wait for closing all opened LDAP connections.
time.Sleep(100 * time.Millisecond)
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
if !ok {
t.Errorf("got auth: %t, want auth: true", ok)
}
switch {
case ep1.authRequest && ep2.authRequest:
t.Error("got: every endpoint is called, want: only one endpoint is called")
case !ep1.authRequest && !ep2.authRequest:
t.Error("got: no one endpoint is not called, want: only one endpoint is called")
}
var notClosed []string
if !ep1.closed {
notClosed = append(notClosed, "ep1")
}
if !ep2.closed {
notClosed = append(notClosed, "ep2")
}
if len(notClosed) > 0 {
t.Errorf("got: endpoints %s are not closed, want: all endpoints are closed", strings.Join(notClosed, ", "))
}
}
func TestFindOIDCClaims(t *testing.T) {
testCases := []struct {
name string
connector *testConnector
bindDN string
bindPass string
user string
attrClaims map[string]string
wantErr error
want map[string]interface{}
}{
{
name: "username is empty",
connector: newTestConnector("ep1", &testConn{users: users}),
wantErr: errMissedUsername,
},
{
name: "connection timeout",
connector: newTestConnector("ep1", fmt.Errorf("failed to connect to endpoint")),
user: "user1",
wantErr: errConnectionTimeout,
},
{
name: "search user error",
connector: newTestConnector("ep1", &testConn{userErr: errSearchUser}),
user: "user1",
wantErr: errSearchUser,
},
{
name: "user is not found",
connector: newTestConnector("ep1", &testConn{}),
user: "user1",
wantErr: errUnknownUsername,
},
{
name: "search roles error",
connector: newTestConnector("ep1", &testConn{users: users, rolesErr: errSearchRoles}),
user: "user1",
wantErr: errSearchRoles,
},
{
name: "extra attributes is filtered from claims",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "b": "claimB"},
want: map[string]interface{}{"name": "user1", "claimA": "valA", "claimB": "valB", "roles": nil},
},
{
name: "skip claim if no attribute",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "d": "claimD"},
want: map[string]interface{}{"name": "user1", "claimA": "valA", "roles": nil},
},
{
name: "claims with roles for one application",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user2",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "test-roles-claim": map[string][]string{"app1": {"r1", "r2"}}},
},
{
name: "claims with roles for multiple applications",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user3",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "test-roles-claim": map[string][]string{"app1": {"r1", "r2"}, "app2": {"r3", "r4"}}},
},
{
name: "skip role without DN",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user4",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "roles": map[string][]string{"app1": {"r1"}}},
},
{
name: "skip role without role attribute",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user5",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "roles": map[string][]string{"app1": {"r1"}}},
},
{
name: "skip invalid role without role base DN",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user6",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "roles": map[string][]string{"app1": {"r1"}}},
},
{
name: "auth with invalid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "invalid",
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "b": "claimB"},
wantErr: errInvalidCredentials,
},
{
name: "auth with valid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "servicePass",
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "b": "claimB"},
want: map[string]interface{}{"name": "user1", "claimA": "valA", "claimB": "valB", "roles": nil},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client := New(Config{
Endpoints: tc.connector.Endpoints(),
BindDN: tc.bindDN,
BindPass: tc.bindPass,
AttrClaims: tc.attrClaims,
RoleBaseDN: "OU=test,DC=local",
RoleClaim: "test-roles-claim",
RoleAttr: "test-roles-attr",
})
client.connector = tc.connector
got, err := client.FindOIDCClaims(context.Background(), tc.user)
if tc.wantErr != nil {
if err == nil {
t.Fatalf("\ngot no errors\nwant error:\n\t%s", tc.wantErr)
}
err = errors.Cause(err)
if err != tc.wantErr {
t.Fatalf("\ngot error:\n\t%s\nwant error:\n\t%s", err, tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
if reflect.DeepEqual(got, tc.want) {
t.Errorf("\ngot claims:\n\t%v\nwant claims:\n\t%v", got, tc.want)
}
})
}
}
func TestClaimsCache(t *testing.T) {
ep := &testConn{users: users}
connector := newTestConnector("ep", ep)
client := New(Config{
Endpoints: connector.Endpoints(),
AttrClaims: map[string]string{"dn": "name", "a": "claimA", "d": "claimD"},
RoleBaseDN: "OU=test,DC=local",
RoleClaim: "test-roles-claim",
RoleAttr: "test-roles-attr",
})
client.connector = connector
ok, err := client.Authenticate(context.Background(), "user2", "user2")
if err != nil {
t.Fatalf("initial auth: unexpected error: %s", err)
}
if !ok {
t.Fatal("initial auth: got no auth, want auth")
}
claims1, err := client.FindOIDCClaims(context.Background(), "user2")
if err != nil {
t.Fatalf("claims request 1: unexpected error: %s", err)
}
if claims1 == nil {
t.Fatal("claims request 1: got no claims, want claims")
}
if !ep.claimsRequest {
t.Fatal("claims request 1: got claims from cache, want claims from ldap")
}
ep.claimsRequest = false
claims2, err := client.FindOIDCClaims(context.Background(), "user2")
if err != nil {
t.Fatalf("claims request 2: unexpected error: %s", err)
}
if claims2 == nil {
t.Fatal("claims request 2: got no claims, want claims")
}
if ep.claimsRequest {
t.Fatal("claims request 2: got claims from ldap, want claims from cache")
}
if !reflect.DeepEqual(claims1, claims2) {
t.Fatalf("claims request 2:\ngot claims:\n\t%v\nwant claims:\n\t%v", claims2, claims1)
}
ok, err = client.Authenticate(context.Background(), "user2", "user2")
if err != nil {
t.Fatalf("re-auth: unexpected error: %s", err)
}
if !ok {
t.Fatal("re-auth: got no auth, want auth")
}
claims3, err := client.FindOIDCClaims(context.Background(), "user2")
if err != nil {
t.Fatalf("claims request 3: unexpected error: %s", err)
}
if claims3 == nil {
t.Fatal("claims request 3: got no claims, want claims")
}
if !ep.claimsRequest {
t.Fatal("claims request 3: got claims from cache, want claims from ldap")
}
}
type testConnector struct {
conns map[string]interface{}
}
func newTestConnector(args ...interface{}) *testConnector {
if len(args)%2 != 0 {
panic("newTestConnector want args in format \"addr1, conn1, addr2, conn2, addr3, err3\"")
}
conns := make(map[string]interface{})
for i := 0; i < len(args)/2; i++ {
addr, ok := args[i*2].(string)
if !ok {
panic("newTestConnector want args in format \"addr1, conn1, addr2, conn2, addr3, err3\"")
}
switch arg := args[i*2+1].(type) {
case error, *testConn:
conns[addr] = arg
default:
panic("newTestConnector want args in format \"addr1, conn1, addr2, conn2, addr3, err3\"")
}
}
return &testConnector{conns: conns}
}
func (c *testConnector) Endpoints() []string {
var eps []string
for addr := range c.conns {
eps = append(eps, addr)
}
return eps
}
func (c *testConnector) Connect(ctx context.Context, addr string) (conn, error) {
switch v := c.conns[addr].(type) {
case error:
return nil, v
case *testConn:
return v, nil
default:
panic(fmt.Sprintf("Invalid config for endpoint %q", addr))
}
}
type testConn struct {
users []map[string]interface{}
bindErr error
userErr error
rolesErr error
authRequest bool
claimsRequest bool
closed bool
}
func (c *testConn) Bind(bindDN, password string) error {
c.authRequest = true
if c.bindErr != nil {
return c.bindErr
}
user := c.findUser(bindDN)
if user == nil {
return fmt.Errorf("user is not found")
}
if user["pass"] != password {
return errInvalidCredentials
}
return nil
}
func (c *testConn) SearchUser(bindDN string, attrs ...string) ([]map[string]interface{}, error) {
c.claimsRequest = true
if c.userErr != nil {
return nil, c.userErr
}
user := c.findUser(bindDN)
if user == nil {
return nil, nil
}
return []map[string]interface{}{user}, nil
}
func (c *testConn) SearchUserRoles(bindDN string, attrs ...string) ([]map[string]interface{}, error) {
if c.rolesErr != nil {
return nil, c.rolesErr
}
user := c.findUser(bindDN)
if user == nil {
return nil, fmt.Errorf("user is not found")
}
switch roles := user["roles"].(type) {
case nil:
return nil, nil
case []map[string]interface{}:
return roles, nil
default:
return nil, fmt.Errorf("invalid test roles")
}
}
func (c *testConn) findUser(bindDN string) map[string]interface{} {
for _, v := range c.users {
if v["dn"] == bindDN {
return v
}
}
return nil
}
func (c *testConn) Close() {
c.closed = true
}