diff --git a/lemur/common/utils.py b/lemur/common/utils.py index e69d71a8..7e3f48e8 100644 --- a/lemur/common/utils.py +++ b/lemur/common/utils.py @@ -9,6 +9,9 @@ import string import random +import sqlalchemy +from sqlalchemy import and_, func + from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa @@ -97,3 +100,57 @@ def validate_conf(app, required_vars): for var in required_vars: if not app.config.get(var): raise InvalidConfiguration("Required variable '{var}' is not set in Lemur's conf.".format(var=var)) + + +# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/WindowedRangeQuery +def column_windows(session, column, windowsize): + """Return a series of WHERE clauses against + a given column that break it into windows. + + Result is an iterable of tuples, consisting of + ((start, end), whereclause), where (start, end) are the ids. + + Requires a database that supports window functions, + i.e. Postgresql, SQL Server, Oracle. + + Enhance this yourself ! Add a "where" argument + so that windows of just a subset of rows can + be computed. + + """ + def int_for_range(start_id, end_id): + if end_id: + return and_( + column >= start_id, + column < end_id + ) + else: + return column >= start_id + + q = session.query( + column, + func.row_number().over(order_by=column).label('rownum') + ).from_self(column) + + if windowsize > 1: + q = q.filter(sqlalchemy.text("rownum %% %d=1" % windowsize)) + + intervals = [id for id, in q] + + while intervals: + start = intervals.pop(0) + if intervals: + end = intervals[0] + else: + end = None + yield int_for_range(start, end) + + +def windowed_query(q, column, windowsize): + """"Break a Query into windows on a given column.""" + + for whereclause in column_windows( + q.session, + column, windowsize): + for row in q.filter(whereclause).order_by(column): + yield row diff --git a/lemur/notifications/messaging.py b/lemur/notifications/messaging.py index 9aa61473..65841217 100644 --- a/lemur/notifications/messaging.py +++ b/lemur/notifications/messaging.py @@ -11,12 +11,13 @@ from itertools import groupby from collections import defaultdict -from sqlalchemy.orm import joinedload - import arrow +from datetime import timedelta from flask import current_app from lemur import database, metrics +from lemur.common.utils import windowed_query + from lemur.certificates.schemas import certificate_notification_output_schema from lemur.certificates.models import Certificate @@ -29,11 +30,21 @@ def get_certificates(): Finds all certificates that are eligible for notifications. :return: """ - return database.session_query(Certificate)\ - .options(joinedload('notifications'))\ - .filter(Certificate.notify == True)\ - .filter(Certificate.expired == False)\ - .filter(Certificate.notifications.any()).all() # noqa + now = arrow.utcnow() + max = now + timedelta(days=90) + + q = database.db.session.query(Certificate) \ + .filter(Certificate.not_after <= max) \ + .filter(Certificate.notify == True) \ + .filter(Certificate.expired == False) # noqa + + certs = [] + + for c in windowed_query(q, Certificate.id, 100): + if needs_notification(c): + certs.append(c) + + return certs def get_eligible_certificates(): @@ -151,6 +162,9 @@ def needs_notification(certificate): days = (certificate.not_after - now).days for notification in certificate.notifications: + if not notification.options: + return + interval = get_plugin_option('interval', notification.options) unit = get_plugin_option('unit', notification.options) diff --git a/lemur/notifications/service.py b/lemur/notifications/service.py index 05d9ac27..efbfd512 100644 --- a/lemur/notifications/service.py +++ b/lemur/notifications/service.py @@ -21,6 +21,7 @@ def create_default_expiration_notifications(name, recipients): already exist these will be returned instead of new notifications. :param name: + :param recipients: :return: """ if not recipients: diff --git a/lemur/schemas.py b/lemur/schemas.py index 082f4172..7a5e62cc 100644 --- a/lemur/schemas.py +++ b/lemur/schemas.py @@ -17,6 +17,7 @@ from lemur.common.schema import LemurSchema, LemurInputSchema, LemurOutputSchema from lemur.common.fields import KeyUsageExtension, ExtendedKeyUsageExtension, BasicConstraintsExtension, SubjectAlternativeNameExtension from lemur.plugins import plugins +from lemur.plugins.utils import get_plugin_option from lemur.roles.models import Role from lemur.users.models import User from lemur.authorities.models import Authority @@ -25,6 +26,25 @@ from lemur.destinations.models import Destination from lemur.notifications.models import Notification +def validate_options(options): + """ + Ensures that the plugin options are valid. + :param options: + :return: + """ + interval = get_plugin_option('interval', options) + unit = get_plugin_option('unit', options) + + if interval == 'month': + unit *= 30 + + elif interval == 'week': + unit *= 7 + + if unit > 90: + raise ValidationError('Notification cannot be more than 90 days into the future.') + + def get_object_attribute(data, many=False): if many: ids = [d.get('id') for d in data] @@ -127,7 +147,7 @@ class AssociatedUserSchema(LemurInputSchema): class PluginInputSchema(LemurInputSchema): - plugin_options = fields.List(fields.Dict()) + plugin_options = fields.List(fields.Dict(), validate=validate_options) slug = fields.String(required=True) title = fields.String() description = fields.String() diff --git a/lemur/tests/test_messaging.py b/lemur/tests/test_messaging.py index f45ebbfd..e7e8c16b 100644 --- a/lemur/tests/test_messaging.py +++ b/lemur/tests/test_messaging.py @@ -2,7 +2,7 @@ import pytest from freezegun import freeze_time from datetime import timedelta - +import arrow from moto import mock_ses @@ -25,7 +25,14 @@ def test_needs_notification(app, certificate, notification): def test_get_certificates(app, certificate, notification): from lemur.notifications.messaging import get_certificates + + certificate.not_after = arrow.utcnow() + timedelta(days=30) delta = certificate.not_after - timedelta(days=2) + + notification.options = [ + {'name': 'interval', 'value': 2}, {'name': 'unit', 'value': 'days'} + ] + with freeze_time(delta.datetime): # no notification certs = len(get_certificates()) @@ -41,7 +48,7 @@ def test_get_certificates(app, certificate, notification): delta = certificate.not_after + timedelta(days=2) with freeze_time(delta.datetime): certificate.notifications.append(notification) - assert len(get_certificates()) == 1 + assert len(get_certificates()) == 0 def test_get_eligible_certificates(app, certificate, notification): diff --git a/setup.py b/setup.py index 1be3aee8..893a3c68 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ install_requires = [ 'Flask-Mail==0.9.1', 'SQLAlchemy-Utils==0.32.12', 'requests==2.11.1', + 'ndg-httpsclient==0.4.2', 'psycopg2==2.6.2', 'arrow==0.10.0', 'six==1.10.0',