From d31c9b19ce311f1b717c29df50a75503980efe81 Mon Sep 17 00:00:00 2001 From: kevgliss Date: Sun, 16 Oct 2016 03:56:13 -0700 Subject: [PATCH] Closes #412. Allows 'name' be a valid attribute to specify a role. (#457) --- lemur/roles/views.py | 4 +- lemur/schemas.py | 98 +++++++++++++++++++++---------------- lemur/tests/test_schemas.py | 58 ++++++++++++++++++++++ lemur/users/views.py | 8 ++- 4 files changed, 122 insertions(+), 46 deletions(-) create mode 100644 lemur/tests/test_schemas.py diff --git a/lemur/roles/views.py b/lemur/roles/views.py index 639efeb2..3d85b4d6 100644 --- a/lemur/roles/views.py +++ b/lemur/roles/views.py @@ -108,7 +108,9 @@ class RolesList(AuthenticatedResource): "description": "this is role3", "username": null, "password": null, - "users": [] + "users": [ + {'id': 1} + ] } **Example response**: diff --git a/lemur/schemas.py b/lemur/schemas.py index cf550f25..a876af59 100644 --- a/lemur/schemas.py +++ b/lemur/schemas.py @@ -24,23 +24,51 @@ from lemur.destinations.models import Destination from lemur.notifications.models import Notification -def fetch_object(model, field, value): - try: - return model.query.filter(getattr(model, field) == value).one() - except NoResultFound: - raise ValidationError('Unable to find {model} with {field}: {data}'.format(model=model, field=field, data=value)) +def get_object_attribute(data, many=False): + if many: + ids = [d.get('id') for d in data] + names = [d.get('name') for d in data] + + if None in ids: + if None in names: + raise ValidationError('Associated object require a name or id.') + else: + return 'name' + return 'id' + else: + if data.get('id'): + return 'id' + elif data.get('name'): + return 'name' + else: + raise ValidationError('Associated object require a name or id.') -def fetch_objects(model, field, values): - values = [v[field] for v in values] - items = model.query.filter(getattr(model, field).in_(values)).all() - found = [getattr(i, field) for i in items] - diff = set(values).symmetric_difference(set(found)) +def fetch_objects(model, data, many=False): + attr = get_object_attribute(data, many=many) - if diff: - raise ValidationError('Unable to locate {model} with {field} {diff}'.format(model=model, field=field, diff=",".join([list(diff)]))) + if many: + values = [v[attr] for v in data] + items = model.query.filter(getattr(model, attr).in_(values)).all() + found = [getattr(i, attr) for i in items] + diff = set(values).symmetric_difference(set(found)) - return items + if diff: + raise ValidationError('Unable to locate {model} with {attr} {diff}'.format( + model=model, + attr=attr, + diff=",".join(list(diff)))) + + return items + + else: + try: + return model.query.filter(getattr(model, attr) == data[attr]).one() + except NoResultFound: + raise ValidationError('Unable to find {model} with {attr}: {data}'.format( + model=model, + attr=attr, + data=data[attr])) class AssociatedAuthoritySchema(LemurInputSchema): @@ -49,68 +77,52 @@ class AssociatedAuthoritySchema(LemurInputSchema): @post_load def get_object(self, data, many=False): - if data.get('id'): - return fetch_object(Authority, 'id', data['id']) - - elif data.get('name'): - return fetch_object(Authority, 'name', data['name']) + return fetch_objects(Authority, data, many=many) class AssociatedRoleSchema(LemurInputSchema): - id = fields.Int(required=True) + id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): - if many: - return fetch_objects(Role, 'id', data) - else: - return fetch_object(Role, 'id', data['id']) + return fetch_objects(Role, data, many=many) class AssociatedDestinationSchema(LemurInputSchema): - id = fields.Int(required=True) + id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): - if many: - return fetch_objects(Destination, 'id', data) - else: - return fetch_object(Destination, 'id', data['id']) + return fetch_objects(Destination, data, many=many) class AssociatedNotificationSchema(LemurInputSchema): - id = fields.Int(required=True) + id = fields.Int() + name = fields.String() @post_load def get_object(self, data, many=False): - if many: - return fetch_objects(Notification, 'id', data) - else: - return fetch_object(Notification, 'id', data['id']) + return fetch_objects(Notification, data, many=many) class AssociatedCertificateSchema(LemurInputSchema): - id = fields.Int(required=True) + id = fields.Int() + name = fields.String() @post_load def get_object(self, data, many=False): - if many: - return fetch_objects(Certificate, 'id', data) - else: - return fetch_object(Certificate, 'id', data['id']) + return fetch_objects(Certificate, data, many=many) class AssociatedUserSchema(LemurInputSchema): - id = fields.Int(required=True) + id = fields.Int() + name = fields.String() @post_load def get_object(self, data, many=False): - if many: - return fetch_objects(User, 'id', data) - else: - return fetch_object(User, 'id', data['id']) + return fetch_objects(User, data, many=many) class PluginInputSchema(LemurInputSchema): diff --git a/lemur/tests/test_schemas.py b/lemur/tests/test_schemas.py new file mode 100644 index 00000000..e2a05213 --- /dev/null +++ b/lemur/tests/test_schemas.py @@ -0,0 +1,58 @@ +import pytest +from marshmallow.exceptions import ValidationError + +from lemur.tests.factories import RoleFactory + + +def test_get_object_attribute(): + from lemur.schemas import get_object_attribute + + with pytest.raises(ValidationError): + get_object_attribute({}) + + with pytest.raises(ValidationError): + get_object_attribute([{}], many=True) + + with pytest.raises(ValidationError): + get_object_attribute([{}, {'id': 1}], many=True) + + with pytest.raises(ValidationError): + get_object_attribute([{}, {'name': 'test'}], many=True) + + assert get_object_attribute({'name': 'test'}) == 'name' + assert get_object_attribute({'id': 1}) == 'id' + assert get_object_attribute([{'name': 'test'}], many=True) == 'name' + assert get_object_attribute([{'id': 1}], many=True) == 'id' + + +def test_fetch_objects(session): + from lemur.roles.models import Role + from lemur.schemas import fetch_objects + + role = RoleFactory() + role1 = RoleFactory() + session.commit() + + data = {'id': role.id} + found_role = fetch_objects(Role, data) + assert found_role == role + + data = {'name': role.name} + found_role = fetch_objects(Role, data) + assert found_role == role + + data = [{'id': role.id}, {'id': role1.id}] + found_roles = fetch_objects(Role, data, many=True) + assert found_roles == [role, role1] + + data = [{'name': role.name}, {'name': role1.name}] + found_roles = fetch_objects(Role, data, many=True) + assert found_roles == [role, role1] + + with pytest.raises(ValidationError): + data = [{'name': 'blah'}, {'name': role1.name}] + fetch_objects(Role, data, many=True) + + with pytest.raises(ValidationError): + data = {'name': 'nah'} + fetch_objects(Role, data) diff --git a/lemur/users/views.py b/lemur/users/views.py index a9a4695e..43445286 100644 --- a/lemur/users/views.py +++ b/lemur/users/views.py @@ -108,7 +108,9 @@ class UsersList(AuthenticatedResource): "username": "user3", "email": "user3@example.com", "active": true, - "roles": [] + "roles": [ + {'id': 1} - or - {'name': 'myRole'} + ] } **Example response**: @@ -199,7 +201,9 @@ class Users(AuthenticatedResource): "username": "user1", "email": "user1@example.com", "active": false, - "roles": [] + "roles": [ + {'id': 1} - or - {'name': 'myRole'} + ] } **Example response**: