move to GitHub.

This commit is contained in:
Nikolay Stupak
2019-05-24 16:13:15 +03:00
parent d761ad579a
commit 3bbac7bb74
28 changed files with 1840 additions and 336 deletions

View File

@ -1,8 +1,8 @@
/*
Copyright (C) JSC iCore - All Rights Reserved
Copyright (c) JSC iCore.
Unauthorized copying of this file, via any medium is strictly prohibited
Proprietary and confidential
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
*/
package ldapclient
@ -17,22 +17,44 @@ import (
"time"
"github.com/coocood/freecache"
"github.com/i-core/rlog"
"github.com/pkg/errors"
"go.uber.org/zap"
"gopkg.i-core.ru/logutil"
ldap "gopkg.in/ldap.v2"
)
var (
// errInvalidCredentials is an error that happens when a user's password is invalid.
errInvalidCredentials = fmt.Errorf("invalid credentials")
// errConnectionTimeout is an error that happens when no one LDAP endpoint responds.
errConnectionTimeout = fmt.Errorf("connection timeout")
// errMissedUsername is an error that happens
errMissedUsername = errors.New("username is missed")
// errUnknownUsername is an error that happens
errUnknownUsername = errors.New("unknown username")
)
type conn interface {
Bind(bindDN, password string) error
SearchUser(user string, attrs ...string) ([]map[string]interface{}, error)
SearchUserRoles(user string, attrs ...string) ([]map[string]interface{}, error)
Close()
}
type connector interface {
Connect(ctx context.Context, addr string) (conn, error)
}
// Config is a LDAP configuration.
type Config struct {
Endpoints []string `envconfig:"endpoints" required:"true" desc:"a LDAP's server URLs as \"<address>:<port>\""`
BaseDN string `envconfig:"basedn" required:"true" desc:"a LDAP base DN for searching users"`
BindDN string `envconfig:"binddn" desc:"a LDAP bind DN"`
BindPass string `envconfig:"bindpw" json:"-" desc:"a LDAP bind password"`
BaseDN string `envconfig:"basedn" required:"true" desc:"a LDAP base DN for searching users"`
AttrClaims map[string]string `envconfig:"attr_claims" default:"name:name,sn:family_name,givenName:given_name,mail:email" desc:"a mapping of LDAP attributes to OpenID connect claims"`
RoleBaseDN string `envconfig:"role_basedn" required:"true" desc:"a LDAP base DN for searching roles"`
RoleAttr string `envconfig:"role_attr" default:"description" desc:"a LDAP attribute for role's name"`
RoleClaim string `ignored:"true"` // is custom OIDC claim name for roles' list
AttrClaims map[string]string `envconfig:"attr_claims" default:"name:name,sn:family_name,givenName:given_name,mail:email" desc:"a mapping of LDAP attributes to OIDC claims"`
RoleAttr string `envconfig:"role_attr" default:"description" desc:"a LDAP group's attribute that contains a role's name"`
RoleClaim string `envconfig:"role_claim" default:"https://github.com/i-core/werther/claims/roles" desc:"a name of an OpenID Connect claim that contains user roles"`
CacheSize int `envconfig:"cache_size" default:"512" desc:"a user info cache's size in KiB"`
CacheTTL time.Duration `envconfig:"cache_ttl" default:"30m" desc:"a user info cache TTL"`
}
@ -40,17 +62,16 @@ type Config struct {
// Client is a LDAP client (compatible with Active Directory).
type Client struct {
Config
cache *freecache.Cache
connector connector
cache *freecache.Cache
}
// New creates a new LDAP client.
func New(cnf Config) *Client {
if cnf.RoleClaim == "" {
cnf.RoleClaim = "http://i-core.ru/claims/roles"
}
return &Client{
Config: cnf,
cache: freecache.NewCache(cnf.CacheSize * 1024),
Config: cnf,
connector: &ldapConnector{BaseDN: cnf.BaseDN, RoleBaseDN: cnf.RoleBaseDN},
cache: freecache.NewCache(cnf.CacheSize * 1024),
}
}
@ -64,10 +85,10 @@ func (cli *Client) Authenticate(ctx context.Context, username, password string)
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
cn, ok := <-cli.dialTCP(ctx)
cn, ok := <-cli.connect(ctx)
cancel()
if !ok {
return false, errors.New("connection timeout")
return false, errConnectionTimeout
}
defer cn.Close()
@ -81,7 +102,7 @@ func (cli *Client) Authenticate(ctx context.Context, username, password string)
}
if err := cn.Bind(details["dn"].(string), password); err != nil {
if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials {
if err == errInvalidCredentials {
return false, nil
}
return false, err
@ -89,85 +110,20 @@ func (cli *Client) Authenticate(ctx context.Context, username, password string)
// Clear the claims' cache because of possible re-authentication. We don't want stale claims after re-login.
if ok := cli.cache.Del([]byte(username)); ok {
log := logutil.FromContext(ctx)
log := rlog.FromContext(ctx)
log.Debug("Cleared user's OIDC claims in the cache")
}
return true, nil
}
func (cli *Client) dialTCP(ctx context.Context) <-chan *ldap.Conn {
var (
wg sync.WaitGroup
ch = make(chan *ldap.Conn)
)
wg.Add(len(cli.Endpoints))
for _, addr := range cli.Endpoints {
go func(addr string) {
defer wg.Done()
log := logutil.FromContext(ctx).Sugar()
d := net.Dialer{Timeout: ldap.DefaultTimeout}
tcpcn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
log.Debug("Failed to create a LDAP connection", "address", addr)
return
}
ldapcn := ldap.NewConn(tcpcn, false)
ldapcn.Start()
select {
case <-ctx.Done():
ldapcn.Close()
log.Debug("a LDAP connection is cancelled", "address", addr)
return
case ch <- ldapcn:
}
}(addr)
}
go func() {
wg.Wait()
close(ch)
}()
return ch
}
// findBasicUserDetails finds user's LDAP attributes that were specified. It returns nil if no such user.
func (cli *Client) findBasicUserDetails(cn *ldap.Conn, username string, attrs []string) (map[string]interface{}, error) {
if cli.BindDN != "" {
// We need to login to a LDAP server with a service account for retrieving user data.
if err := cn.Bind(cli.BindDN, cli.BindPass); err != nil {
return nil, err
}
}
query := fmt.Sprintf(
"(&(|(objectClass=organizationalPerson)(objectClass=inetOrgPerson))"+
"(|(uid=%[1]s)(mail=%[1]s)(userPrincipalName=%[1]s)(sAMAccountName=%[1]s)))", username)
entries, err := cli.searchEntries(cn, cli.BaseDN, query, attrs...)
if err != nil {
return nil, err
}
if len(entries) != 1 {
// We didn't find the user.
return nil, nil
}
var (
entry = entries[0]
details = make(map[string]interface{})
)
for _, attr := range attrs {
if v, ok := entry[attr]; ok {
details[attr] = v
}
}
return details, nil
}
// FindOIDCClaims finds all OIDC claims for a user.
func (cli *Client) FindOIDCClaims(ctx context.Context, username string) (map[string]interface{}, error) {
log := logutil.FromContext(ctx).Sugar()
if username == "" {
return nil, errMissedUsername
}
log := rlog.FromContext(ctx).Sugar()
// Retrieving from LDAP is slow. So, we try to get claims for the given username from the cache.
switch cdata, err := cli.cache.Get([]byte(username)); err {
@ -190,10 +146,10 @@ func (cli *Client) FindOIDCClaims(ctx context.Context, username string) (map[str
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
cn, ok := <-cli.dialTCP(ctx)
cn, ok := <-cli.connect(ctx)
cancel()
if !ok {
return nil, errors.New("connection timeout")
return nil, errConnectionTimeout
}
defer cn.Close()
@ -208,11 +164,11 @@ func (cli *Client) FindOIDCClaims(ctx context.Context, username string) (map[str
return nil, err
}
if details == nil {
return nil, errors.New("unknown username")
return nil, errUnknownUsername
}
log.Infow("Retrieved user's info from LDAP", "details", details)
// Transform the retrived attributes to corresponding claims.
// Transform the retrieved attributes to corresponding claims.
claims := make(map[string]interface{})
for attr, v := range details {
if claim, ok := cli.AttrClaims[attr]; ok {
@ -222,16 +178,15 @@ func (cli *Client) FindOIDCClaims(ctx context.Context, username string) (map[str
// User's roles is stored in LDAP as groups. We find all groups in a role's DN
// that include the user as a member.
query := fmt.Sprintf("(&(objectClass=group)(member=%s))", details["dn"])
entries, err := cli.searchEntries(cn, cli.RoleBaseDN, query, "dn", cli.RoleAttr)
entries, err := cn.SearchUserRoles(fmt.Sprintf("%s", details["dn"]), "dn", cli.RoleAttr)
if err != nil {
return nil, err
}
roles := make(map[string][]string)
roles := make(map[string]interface{})
for _, entry := range entries {
roleDN := entry["dn"].(string)
if roleDN == "" {
roleDN, ok := entry["dn"].(string)
if !ok || roleDN == "" {
log.Infow("No required LDAP attribute for a role", "ldapAttribute", "dn", "entry", entry)
continue
}
@ -248,14 +203,19 @@ func (cli *Client) FindOIDCClaims(ctx context.Context, username string) (map[str
}
// The DN without the role's base DN must contain a CN and OU
// where the CN is for uniqueness only, and the OU is an application id.
v := strings.Split(roleDN[:n-k-1], ",")
if len(v) != 2 {
path := strings.Split(roleDN[:n-k-1], ",")
if len(path) != 2 {
log.Infow("A role's DN without the role's base DN must contain two nodes only",
"roleBaseDN", cli.RoleBaseDN, "roleDN", roleDN)
continue
}
appID := v[1][len("OU="):]
roles[appID] = append(roles[appID], entry[cli.RoleAttr].(string))
appID := path[1][len("OU="):]
var appRoles []interface{}
if v := roles[appID]; v != nil {
appRoles = v.([]interface{})
}
roles[appID] = append(appRoles, entry[cli.RoleAttr])
}
claims[cli.RoleClaim] = roles
@ -271,11 +231,114 @@ func (cli *Client) FindOIDCClaims(ctx context.Context, username string) (map[str
return claims, nil
}
func (cli *Client) connect(ctx context.Context) <-chan conn {
var (
wg sync.WaitGroup
ch = make(chan conn)
)
wg.Add(len(cli.Endpoints))
for _, addr := range cli.Endpoints {
go func(addr string) {
defer wg.Done()
log := rlog.FromContext(ctx).Sugar()
cn, err := cli.connector.Connect(ctx, addr)
if err != nil {
log.Debug("Failed to create a LDAP connection", "address", addr)
return
}
select {
case <-ctx.Done():
cn.Close()
log.Debug("a LDAP connection is cancelled", "address", addr)
return
case ch <- cn:
}
}(addr)
}
go func() {
wg.Wait()
close(ch)
}()
return ch
}
// findBasicUserDetails finds user's LDAP attributes that were specified. It returns nil if no such user.
func (cli *Client) findBasicUserDetails(cn conn, username string, attrs []string) (map[string]interface{}, error) {
if cli.BindDN != "" {
// We need to login to a LDAP server with a service account for retrieving user data.
if err := cn.Bind(cli.BindDN, cli.BindPass); err != nil {
return nil, errors.Wrap(err, "failed to login to a LDAP woth a service account")
}
}
entries, err := cn.SearchUser(username, attrs...)
if err != nil {
return nil, err
}
if len(entries) != 1 {
// We didn't find the user.
return nil, nil
}
var (
entry = entries[0]
details = make(map[string]interface{})
)
for _, attr := range attrs {
if v, ok := entry[attr]; ok {
details[attr] = v
}
}
return details, nil
}
type ldapConnector struct {
BaseDN string
RoleBaseDN string
}
func (c *ldapConnector) Connect(ctx context.Context, addr string) (conn, error) {
d := net.Dialer{Timeout: ldap.DefaultTimeout}
tcpcn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
ldapcn := ldap.NewConn(tcpcn, false)
ldapcn.Start()
return &ldapConn{Conn: ldapcn, BaseDN: c.BaseDN, RoleBaseDN: c.RoleBaseDN}, nil
}
type ldapConn struct {
*ldap.Conn
BaseDN string
RoleBaseDN string
}
func (c *ldapConn) Bind(bindDN, password string) error {
err := c.Conn.Bind(bindDN, password)
if ldapErr, ok := err.(*ldap.Error); ok && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials {
return errInvalidCredentials
}
return err
}
func (c *ldapConn) SearchUser(user string, attrs ...string) ([]map[string]interface{}, error) {
query := fmt.Sprintf(
"(&(|(objectClass=organizationalPerson)(objectClass=inetOrgPerson))"+
"(|(uid=%[1]s)(mail=%[1]s)(userPrincipalName=%[1]s)(sAMAccountName=%[1]s)))", user)
return c.searchEntries(c.BaseDN, query, attrs)
}
func (c *ldapConn) SearchUserRoles(user string, attrs ...string) ([]map[string]interface{}, error) {
query := fmt.Sprintf("(&(objectClass=group)(member=%s))", user)
return c.searchEntries(c.RoleBaseDN, query, attrs)
}
// searchEntries executes a LDAP query, and returns a result as entries where each entry is mapping of LDAP attributes.
func (cli *Client) searchEntries(cn *ldap.Conn, baseDN, query string, attrs ...string) ([]map[string]interface{}, error) {
res, err := cn.Search(ldap.NewSearchRequest(
baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, query, attrs, nil,
))
func (c *ldapConn) searchEntries(baseDN, query string, attrs []string) ([]map[string]interface{}, error) {
req := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, query, attrs, nil)
res, err := c.Search(req)
if err != nil {
return nil, err
}

View File

@ -0,0 +1,588 @@
package ldapclient
import (
"context"
"fmt"
"reflect"
"strings"
"testing"
"time"
"github.com/pkg/errors"
)
var (
errBindUser = fmt.Errorf("bind user error")
errSearchUser = fmt.Errorf("search user error")
errSearchRoles = fmt.Errorf("search user roles error")
users = []map[string]interface{}{
{
"dn": "user1",
"pass": "user1",
"a": "valA",
"b": "valB",
"c": "valC",
},
{
"dn": "user2",
"pass": "user2",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"dn": "CN=role2,OU=app1,OU=test,DC=local", "test-roles-attr": "r2"},
},
},
{
"dn": "user3",
"pass": "user3",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"dn": "CN=role2,OU=app1,OU=test,DC=local", "test-roles-attr": "r2"},
{"dn": "CN=role3,OU=app2,OU=test,DC=local", "test-roles-attr": "r3"},
{"dn": "CN=role4,OU=app2,OU=test,DC=local", "test-roles-attr": "r4"},
},
},
{
"dn": "user4",
"pass": "user4",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"test-roles-attr": "r2"},
},
},
{
"dn": "user5",
"pass": "user5",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=app1,OU=test,DC=local", "test-roles-attr": "r1"},
{"dn": "CN=role2,OU=app1,OU=test,DC=local"},
},
},
{
"dn": "user6",
"pass": "user6",
"a": "valA",
"b": "valB",
"c": "valC",
"roles": []map[string]interface{}{
{"dn": "CN=role1,OU=test,DC=local", "test-roles-attr": "r1"},
},
},
{
"dn": "serviceUser",
"pass": "servicePass",
},
}
)
func TestAuthenticate(t *testing.T) {
testCases := []struct {
name string
connector *testConnector
bindDN string
bindPass string
user string
pass string
wantErr error
wantAuth bool
}{
{
name: "username is empty",
connector: newTestConnector("ep1", &testConn{users: users}),
},
{
name: "password is empty",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
},
{
name: "connection timeout",
connector: newTestConnector("ep1", fmt.Errorf("failed to connect to endpoint")),
user: "user1",
pass: "user1",
wantErr: errConnectionTimeout,
},
{
name: "search user error",
connector: newTestConnector("ep1", &testConn{userErr: errSearchUser}),
user: "user1",
pass: "user1",
wantErr: errSearchUser,
},
{
name: "user is not found",
connector: newTestConnector("ep1", &testConn{}),
user: "user1",
pass: "user1",
},
{
name: "authentication error",
connector: newTestConnector("ep1", &testConn{users: users, bindErr: errBindUser}),
user: "user1",
pass: "user1",
wantErr: errBindUser,
},
{
name: "invalid password",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
pass: "invalid",
},
{
name: "success auth",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
pass: "user1",
wantAuth: true,
},
{
name: "auth with invalid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "invalid",
user: "user1",
pass: "user1",
wantErr: errInvalidCredentials,
},
{
name: "auth with valid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "servicePass",
user: "user1",
pass: "user1",
wantAuth: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client := New(Config{Endpoints: tc.connector.Endpoints(), BindDN: tc.bindDN, BindPass: tc.bindPass})
client.connector = tc.connector
ok, err := client.Authenticate(context.Background(), tc.user, tc.pass)
if ok != tc.wantAuth {
t.Errorf("got auth: %t, want auth: %t", ok, tc.wantAuth)
}
if tc.wantErr != nil {
if err == nil {
t.Fatalf("\ngot no errors\nwant error:\n\t%s", tc.wantErr)
}
err = errors.Cause(err)
if err != tc.wantErr {
t.Fatalf("\ngot error:\n\t%s\nwant error:\n\t%s", err, tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
})
}
}
func TestAuthenticateWhenMultipleEndpointsFailed(t *testing.T) {
connector := newTestConnector("ep1", fmt.Errorf("error"), "ep2", fmt.Errorf("error"))
client := New(Config{Endpoints: connector.Endpoints()})
client.connector = connector
_, err := client.Authenticate(context.Background(), "user1", "user1")
if err == nil {
t.Fatalf("\ngot no errors\nwant error:\n\t%s", errConnectionTimeout)
}
err = errors.Cause(err)
if err != errConnectionTimeout {
t.Fatalf("\ngot error:\n\t%s\nwant error:\n\t%s", err, errConnectionTimeout)
}
}
func TestAuthenticateWhenOneEndpointFailedAndOneSuccess(t *testing.T) {
ep2 := &testConn{users: users}
connector := newTestConnector("ep1", fmt.Errorf("error"), "ep2", ep2)
client := New(Config{Endpoints: connector.Endpoints()})
client.connector = connector
ok, err := client.Authenticate(context.Background(), "user1", "user1")
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
if !ok {
t.Errorf("got auth: %t, want auth: true", ok)
}
if !ep2.authRequest {
t.Error("\ngot: endpoint \"ep2\" is not called, want: endpoint \"ep2\" is called")
}
}
func TestAuthenticateWhenMultipleEndpointsSuccess(t *testing.T) {
ep1 := &testConn{users: users}
ep2 := &testConn{users: users}
connector := newTestConnector("ep1", ep1, "ep2", ep2)
client := New(Config{Endpoints: connector.Endpoints()})
client.connector = connector
ok, err := client.Authenticate(context.Background(), "user1", "user1")
// Wait for closing all opened LDAP connections.
time.Sleep(100 * time.Millisecond)
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
if !ok {
t.Errorf("got auth: %t, want auth: true", ok)
}
switch {
case ep1.authRequest && ep2.authRequest:
t.Error("got: every endpoint is called, want: only one endpoint is called")
case !ep1.authRequest && !ep2.authRequest:
t.Error("got: no one endpoint is not called, want: only one endpoint is called")
}
var notClosed []string
if !ep1.closed {
notClosed = append(notClosed, "ep1")
}
if !ep2.closed {
notClosed = append(notClosed, "ep2")
}
if len(notClosed) > 0 {
t.Errorf("got: endpoints %s are not closed, want: all endpoints are closed", strings.Join(notClosed, ", "))
}
}
func TestFindOIDCClaims(t *testing.T) {
testCases := []struct {
name string
connector *testConnector
bindDN string
bindPass string
user string
attrClaims map[string]string
wantErr error
want map[string]interface{}
}{
{
name: "username is empty",
connector: newTestConnector("ep1", &testConn{users: users}),
wantErr: errMissedUsername,
},
{
name: "connection timeout",
connector: newTestConnector("ep1", fmt.Errorf("failed to connect to endpoint")),
user: "user1",
wantErr: errConnectionTimeout,
},
{
name: "search user error",
connector: newTestConnector("ep1", &testConn{userErr: errSearchUser}),
user: "user1",
wantErr: errSearchUser,
},
{
name: "user is not found",
connector: newTestConnector("ep1", &testConn{}),
user: "user1",
wantErr: errUnknownUsername,
},
{
name: "search roles error",
connector: newTestConnector("ep1", &testConn{users: users, rolesErr: errSearchRoles}),
user: "user1",
wantErr: errSearchRoles,
},
{
name: "extra attributes is filtered from claims",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "b": "claimB"},
want: map[string]interface{}{"name": "user1", "claimA": "valA", "claimB": "valB", "roles": nil},
},
{
name: "skip claim if no attribute",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "d": "claimD"},
want: map[string]interface{}{"name": "user1", "claimA": "valA", "roles": nil},
},
{
name: "claims with roles for one application",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user2",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "test-roles-claim": map[string][]string{"app1": {"r1", "r2"}}},
},
{
name: "claims with roles for multiple applications",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user3",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "test-roles-claim": map[string][]string{"app1": {"r1", "r2"}, "app2": {"r3", "r4"}}},
},
{
name: "skip role without DN",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user4",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "roles": map[string][]string{"app1": {"r1"}}},
},
{
name: "skip role without role attribute",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user5",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "roles": map[string][]string{"app1": {"r1"}}},
},
{
name: "skip invalid role without role base DN",
connector: newTestConnector("ep1", &testConn{users: users}),
user: "user6",
attrClaims: map[string]string{"dn": "name"},
want: map[string]interface{}{"name": "user1", "roles": map[string][]string{"app1": {"r1"}}},
},
{
name: "auth with invalid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "invalid",
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "b": "claimB"},
wantErr: errInvalidCredentials,
},
{
name: "auth with valid service account",
connector: newTestConnector("ep1", &testConn{users: users}),
bindDN: "serviceUser",
bindPass: "servicePass",
user: "user1",
attrClaims: map[string]string{"dn": "name", "a": "claimA", "b": "claimB"},
want: map[string]interface{}{"name": "user1", "claimA": "valA", "claimB": "valB", "roles": nil},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client := New(Config{
Endpoints: tc.connector.Endpoints(),
BindDN: tc.bindDN,
BindPass: tc.bindPass,
AttrClaims: tc.attrClaims,
RoleBaseDN: "OU=test,DC=local",
RoleClaim: "test-roles-claim",
RoleAttr: "test-roles-attr",
})
client.connector = tc.connector
got, err := client.FindOIDCClaims(context.Background(), tc.user)
if tc.wantErr != nil {
if err == nil {
t.Fatalf("\ngot no errors\nwant error:\n\t%s", tc.wantErr)
}
err = errors.Cause(err)
if err != tc.wantErr {
t.Fatalf("\ngot error:\n\t%s\nwant error:\n\t%s", err, tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("\ngot error:\n\t%s\nwant no errors", err)
}
if reflect.DeepEqual(got, tc.want) {
t.Errorf("\ngot claims:\n\t%v\nwant claims:\n\t%v", got, tc.want)
}
})
}
}
func TestClaimsCache(t *testing.T) {
ep := &testConn{users: users}
connector := newTestConnector("ep", ep)
client := New(Config{
Endpoints: connector.Endpoints(),
AttrClaims: map[string]string{"dn": "name", "a": "claimA", "d": "claimD"},
RoleBaseDN: "OU=test,DC=local",
RoleClaim: "test-roles-claim",
RoleAttr: "test-roles-attr",
})
client.connector = connector
ok, err := client.Authenticate(context.Background(), "user2", "user2")
if err != nil {
t.Fatalf("initial auth: unexpected error: %s", err)
}
if !ok {
t.Fatal("initial auth: got no auth, want auth")
}
claims1, err := client.FindOIDCClaims(context.Background(), "user2")
if err != nil {
t.Fatalf("claims request 1: unexpected error: %s", err)
}
if claims1 == nil {
t.Fatal("claims request 1: got no claims, want claims")
}
if !ep.claimsRequest {
t.Fatal("claims request 1: got claims from cache, want claims from ldap")
}
ep.claimsRequest = false
claims2, err := client.FindOIDCClaims(context.Background(), "user2")
if err != nil {
t.Fatalf("claims request 2: unexpected error: %s", err)
}
if claims2 == nil {
t.Fatal("claims request 2: got no claims, want claims")
}
if ep.claimsRequest {
t.Fatal("claims request 2: got claims from ldap, want claims from cache")
}
if !reflect.DeepEqual(claims1, claims2) {
t.Fatalf("claims request 2:\ngot claims:\n\t%v\nwant claims:\n\t%v", claims2, claims1)
}
ok, err = client.Authenticate(context.Background(), "user2", "user2")
if err != nil {
t.Fatalf("re-auth: unexpected error: %s", err)
}
if !ok {
t.Fatal("re-auth: got no auth, want auth")
}
claims3, err := client.FindOIDCClaims(context.Background(), "user2")
if err != nil {
t.Fatalf("claims request 3: unexpected error: %s", err)
}
if claims3 == nil {
t.Fatal("claims request 3: got no claims, want claims")
}
if !ep.claimsRequest {
t.Fatal("claims request 3: got claims from cache, want claims from ldap")
}
}
type testConnector struct {
conns map[string]interface{}
}
func newTestConnector(args ...interface{}) *testConnector {
if len(args)%2 != 0 {
panic("newTestConnector want args in format \"addr1, conn1, addr2, conn2, addr3, err3\"")
}
conns := make(map[string]interface{})
for i := 0; i < len(args)/2; i++ {
addr, ok := args[i*2].(string)
if !ok {
panic("newTestConnector want args in format \"addr1, conn1, addr2, conn2, addr3, err3\"")
}
switch arg := args[i*2+1].(type) {
case error, *testConn:
conns[addr] = arg
default:
panic("newTestConnector want args in format \"addr1, conn1, addr2, conn2, addr3, err3\"")
}
}
return &testConnector{conns: conns}
}
func (c *testConnector) Endpoints() []string {
var eps []string
for addr := range c.conns {
eps = append(eps, addr)
}
return eps
}
func (c *testConnector) Connect(ctx context.Context, addr string) (conn, error) {
switch v := c.conns[addr].(type) {
case error:
return nil, v
case *testConn:
return v, nil
default:
panic(fmt.Sprintf("Invalid config for endpoint %q", addr))
}
}
type testConn struct {
users []map[string]interface{}
bindErr error
userErr error
rolesErr error
authRequest bool
claimsRequest bool
closed bool
}
func (c *testConn) Bind(bindDN, password string) error {
c.authRequest = true
if c.bindErr != nil {
return c.bindErr
}
user := c.findUser(bindDN)
if user == nil {
return fmt.Errorf("user is not found")
}
if user["pass"] != password {
return errInvalidCredentials
}
return nil
}
func (c *testConn) SearchUser(bindDN string, attrs ...string) ([]map[string]interface{}, error) {
c.claimsRequest = true
if c.userErr != nil {
return nil, c.userErr
}
user := c.findUser(bindDN)
if user == nil {
return nil, nil
}
return []map[string]interface{}{user}, nil
}
func (c *testConn) SearchUserRoles(bindDN string, attrs ...string) ([]map[string]interface{}, error) {
if c.rolesErr != nil {
return nil, c.rolesErr
}
user := c.findUser(bindDN)
if user == nil {
return nil, fmt.Errorf("user is not found")
}
switch roles := user["roles"].(type) {
case nil:
return nil, nil
case []map[string]interface{}:
return roles, nil
default:
return nil, fmt.Errorf("invalid test roles")
}
}
func (c *testConn) findUser(bindDN string) map[string]interface{} {
for _, v := range c.users {
if v["dn"] == bindDN {
return v
}
}
return nil
}
func (c *testConn) Close() {
c.closed = true
}