diff --git a/lemur/common/schema.py b/lemur/common/schema.py index a46ccc70..f7f2caf4 100644 --- a/lemur/common/schema.py +++ b/lemur/common/schema.py @@ -12,8 +12,8 @@ from flask import request, current_app from sqlalchemy.orm.collections import InstrumentedList -from marshmallow import Schema, post_dump, pre_load, pre_dump from inflection import camelize, underscore +from marshmallow import Schema, post_dump, pre_load, pre_dump class LemurSchema(Schema): diff --git a/lemur/exceptions.py b/lemur/exceptions.py index 42543a59..ca807db9 100644 --- a/lemur/exceptions.py +++ b/lemur/exceptions.py @@ -36,6 +36,14 @@ class IntegrityError(LemurException): return repr(self.message) +class AssociatedObjectNotFound(LemurException): + def __init__(self, message): + self.message = message + + def __str__(self): + return repr(self.message) + + class InvalidListener(LemurException): def __str__(self): return repr("Invalid listener, ensure you select a certificate if you are using a secure protocol") diff --git a/lemur/notifications/service.py b/lemur/notifications/service.py index 7baf2c74..81d69625 100644 --- a/lemur/notifications/service.py +++ b/lemur/notifications/service.py @@ -244,27 +244,31 @@ def create_default_expiration_notifications(name, recipients): def create(label, plugin_name, options, description, certificates): """ - Creates a new destination, that can then be used as a destination for certificates. + Creates a new notification. - :param label: Notification common name + :param label: Notification label :param plugin_name: :param options: :param description: + :param certificates: :rtype : Notification :return: """ notification = Notification(label=label, options=options, plugin_name=plugin_name, description=description) - notification = database.update_list(notification, 'certificates', Certificate, certificates) + notification.certificates = certificates return database.create(notification) def update(notification_id, label, options, description, active, certificates): """ - Updates an existing destination. + Updates an existing notification. - :param label: Notification common name + :param notification_id: + :param label: Notification label :param options: :param description: + :param active: + :param certificates: :rtype : Notification :return: """ diff --git a/lemur/roles/service.py b/lemur/roles/service.py index b36a8665..2fac2bdb 100644 --- a/lemur/roles/service.py +++ b/lemur/roles/service.py @@ -27,7 +27,8 @@ def update(role_id, name, description, users): role = get(role_id) role.name = name role.description = description - role = database.update_list(role, 'users', User, users) + if users: + role.users = users database.update(role) return role @@ -44,10 +45,8 @@ def create(name, password=None, description=None, username=None, users=None): :return: """ role = Role(name=name, description=description, username=username, password=password) - if users: - role = database.update_list(role, 'users', User, users) - + role.users = users return database.create(role) diff --git a/lemur/schemas.py b/lemur/schemas.py index 0bea9335..cf550f25 100644 --- a/lemur/schemas.py +++ b/lemur/schemas.py @@ -7,17 +7,40 @@ .. moduleauthor:: Kevin Glisson """ -from marshmallow import fields, post_load, pre_load, post_dump, validates_schema +from sqlalchemy.orm.exc import NoResultFound + +from marshmallow import fields, post_load, pre_load, post_dump, validates_schema +from marshmallow.exceptions import ValidationError -from lemur.authorities.models import Authority -from lemur.certificates.models import Certificate from lemur.common import validators from lemur.common.schema import LemurSchema, LemurInputSchema, LemurOutputSchema -from lemur.destinations.models import Destination -from lemur.notifications.models import Notification + from lemur.plugins import plugins from lemur.roles.models import Role from lemur.users.models import User +from lemur.authorities.models import Authority +from lemur.certificates.models import Certificate +from lemur.destinations.models import Destination +from lemur.notifications.models import Notification + + +def fetch_object(model, field, value): + try: + return model.query.filter(getattr(model, field) == value).one() + except NoResultFound: + raise ValidationError('Unable to find {model} with {field}: {data}'.format(model=model, field=field, data=value)) + + +def fetch_objects(model, field, values): + values = [v[field] for v in values] + items = model.query.filter(getattr(model, field).in_(values)).all() + found = [getattr(i, field) for i in items] + diff = set(values).symmetric_difference(set(found)) + + if diff: + raise ValidationError('Unable to locate {model} with {field} {diff}'.format(model=model, field=field, diff=",".join([list(diff)]))) + + return items class AssociatedAuthoritySchema(LemurInputSchema): @@ -27,9 +50,10 @@ class AssociatedAuthoritySchema(LemurInputSchema): @post_load def get_object(self, data, many=False): if data.get('id'): - return Authority.query.filter(Authority.id == data['id']).one() + return fetch_object(Authority, 'id', data['id']) + elif data.get('name'): - return Authority.query.filter(Authority.name == data['name']).one() + return fetch_object(Authority, 'name', data['name']) class AssociatedRoleSchema(LemurInputSchema): @@ -39,10 +63,9 @@ class AssociatedRoleSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): if many: - ids = [d['id'] for d in data] - return Role.query.filter(Role.id.in_(ids)).all() + return fetch_objects(Role, 'id', data) else: - return Role.query.filter(Role.id == data['id']).one() + return fetch_object(Role, 'id', data['id']) class AssociatedDestinationSchema(LemurInputSchema): @@ -52,10 +75,9 @@ class AssociatedDestinationSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): if many: - ids = [d['id'] for d in data] - return Destination.query.filter(Destination.id.in_(ids)).all() + return fetch_objects(Destination, 'id', data) else: - return Destination.query.filter(Destination.id == data['id']).one() + return fetch_object(Destination, 'id', data['id']) class AssociatedNotificationSchema(LemurInputSchema): @@ -64,10 +86,9 @@ class AssociatedNotificationSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): if many: - ids = [d['id'] for d in data] - return Notification.query.filter(Notification.id.in_(ids)).all() + return fetch_objects(Notification, 'id', data) else: - return Notification.query.filter(Notification.id == data['id']).one() + return fetch_object(Notification, 'id', data['id']) class AssociatedCertificateSchema(LemurInputSchema): @@ -76,10 +97,9 @@ class AssociatedCertificateSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): if many: - ids = [d['id'] for d in data] - return Certificate.query.filter(Certificate.id.in_(ids)).all() + return fetch_objects(Certificate, 'id', data) else: - return Certificate.query.filter(Certificate.id == data['id']).one() + return fetch_object(Certificate, 'id', data['id']) class AssociatedUserSchema(LemurInputSchema): @@ -88,10 +108,9 @@ class AssociatedUserSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): if many: - ids = [d['id'] for d in data] - return User.query.filter(User.id.in_(ids)).all() + return fetch_objects(User, 'id', data) else: - return User.query.filter(User.id == data['id']).one() + return fetch_object(User, 'id', data['id']) class PluginInputSchema(LemurInputSchema): @@ -102,8 +121,11 @@ class PluginInputSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): - data['plugin_object'] = plugins.get(data['slug']) - return data + try: + data['plugin_object'] = plugins.get(data['slug']) + return data + except Exception: + raise ValidationError('Unable to find plugin: {0}'.format(data['slug'])) class PluginOutputSchema(LemurOutputSchema):