diff --git a/README.rst b/README.rst index d42bc810..16b14d79 100644 --- a/README.rst +++ b/README.rst @@ -22,7 +22,7 @@ Lemur Lemur manages TLS certificate creation. While not able to issue certificates itself, Lemur acts as a broker between CAs and environments providing a central portal for developers to issue TLS certificates with 'sane' defaults. -It works on CPython 3.5. We deploy on Ubuntu and develop on OS X. +It works on Python 3.7. We deploy on Ubuntu and develop on OS X. Project resources diff --git a/bower.json b/bower.json index f7d5500d..8a042a8d 100644 --- a/bower.json +++ b/bower.json @@ -11,12 +11,12 @@ "angular": "1.4.9", "json3": "~3.3", "es5-shim": "~4.5.0", - "bootstrap": "~3.3.6", "angular-bootstrap": "~1.1.1", "angular-animate": "~1.4.9", "restangular": "~1.5.1", "ng-table": "~0.8.3", "moment": "~2.11.1", + "bootstrap": "~3.4.1", "angular-loading-bar": "~0.8.0", "angular-moment": "~0.10.3", "moment-range": "~2.1.0", @@ -24,7 +24,7 @@ "angularjs-toaster": "~1.0.0", "angular-chart.js": "~0.8.8", "ngletteravatar": "~4.0.0", - "bootswatch": "~3.3.6", + "bootswatch": "3.4.1+1", "fontawesome": "~4.5.0", "satellizer": "~0.13.4", "angular-ui-router": "~0.2.15", diff --git a/docker/Dockerfile b/docker/Dockerfile index 5c80606f..d12c55ee 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,6 +3,8 @@ FROM alpine:3.8 ARG VERSION ENV VERSION master +ARG URLCONTEXT + ENV uid 1337 ENV gid 1337 ENV user lemur @@ -22,6 +24,7 @@ RUN addgroup -S ${group} -g ${gid} && \ gcc \ autoconf \ automake \ + libtool \ make \ nasm \ zlib-dev \ @@ -42,7 +45,7 @@ WORKDIR /opt/lemur RUN npm install --unsafe-perm && \ pip3 install -e . && \ node_modules/.bin/gulp build && \ - node_modules/.bin/gulp package --urlContextPath=$(urlContextPath) && \ + node_modules/.bin/gulp package --urlContextPath=${URLCONTEXT} && \ apk del build-dependencies COPY entrypoint / diff --git a/docker/Dockerfile-src b/docker/Dockerfile-src new file mode 100644 index 00000000..50d408b0 --- /dev/null +++ b/docker/Dockerfile-src @@ -0,0 +1,67 @@ +FROM alpine:3.8 + +ARG VERSION +ENV VERSION master + +ARG URLCONTEXT + +ENV uid 1337 +ENV gid 1337 +ENV user lemur +ENV group lemur + +RUN addgroup -S ${group} -g ${gid} && \ + adduser -D -S ${user} -G ${group} -u ${uid} && \ + apk --update add python3 libldap postgresql-client nginx supervisor curl tzdata openssl bash && \ + apk --update add --virtual build-dependencies \ + git \ + tar \ + curl \ + python3-dev \ + npm \ + bash \ + musl-dev \ + gcc \ + autoconf \ + automake \ + libtool \ + make \ + nasm \ + zlib-dev \ + postgresql-dev \ + libressl-dev \ + libffi-dev \ + cyrus-sasl-dev \ + openldap-dev && \ + pip3 install --upgrade pip && \ + pip3 install --upgrade setuptools && \ + mkdir -p /home/lemur/.lemur/ && \ + mkdir -p /run/nginx/ /etc/nginx/ssl/ + +COPY ./ /opt/lemur +WORKDIR /opt/lemur + +RUN chown -R $user:$group /opt/lemur/ /home/lemur/.lemur/ && \ + npm install --unsafe-perm && \ + pip3 install -e . && \ + node_modules/.bin/gulp build && \ + node_modules/.bin/gulp package --urlContextPath=${URLCONTEXT} && \ + apk del build-dependencies + +COPY docker/entrypoint / +COPY docker/src/lemur.conf.py /home/lemur/.lemur/lemur.conf.py +COPY docker/supervisor.conf / +COPY docker/nginx/default.conf /etc/nginx/conf.d/ +COPY docker/nginx/default-ssl.conf /etc/nginx/conf.d/ + +RUN chmod +x /entrypoint +WORKDIR / + +HEALTHCHECK --interval=12s --timeout=12s --start-period=30s \ + CMD curl --fail http://localhost:80/api/1/healthcheck | grep -q ok || exit 1 + +USER root + +ENTRYPOINT ["/entrypoint"] + +CMD ["/usr/bin/supervisord","-c","supervisor.conf"] diff --git a/docker/entrypoint b/docker/entrypoint index 2a3a84e3..3f25951a 100644 --- a/docker/entrypoint +++ b/docker/entrypoint @@ -36,7 +36,7 @@ fi # fi echo " # Running init" -su lemur -s /bin/bash -c "cd /opt/lemur/lemur; python3 /opt/lemur/lemur/manage.py init -p ${LEMUR_ADMIN_PASSWORD}" +su lemur -s /bin/bash -c "cd /opt/lemur/lemur; lemur init -p ${LEMUR_ADMIN_PASSWORD}" echo " # Done" # echo "Creating user" @@ -47,11 +47,13 @@ echo " # Done" cron_notify="${CRON_NOTIFY:-"0 22 * * *"}" cron_sync="${CRON_SYNC:-"*/15 * * * *"}" cron_revoked="${CRON_CHECK_REVOKED:-"0 22 * * *"}" +cron_reissue="${CRON_REISSUE:-"0 23 * * *"}" echo " # Populating crontab" -echo "${cron_notify} lemur python3 /opt/lemur/lemur/manage.py notify expirations" > /etc/crontabs/lemur_notify -echo "${cron_sync} lemur python3 /opt/lemur/lemur/manage.py source sync -s all" > /etc/crontabs/lemur_sync -echo "${cron_revoked} lemur python3 /opt/lemur/lemur/manage.py certificate check_revoked" > /etc/crontabs/lemur_revoked +echo "${cron_notify} lemur notify expirations" > /etc/crontabs/lemur +echo "${cron_sync} lemur source sync -s all" >> /etc/crontabs/lemur +echo "${cron_revoked} lemur certificate check_revoked" >> /etc/crontabs/lemur +echo "${cron_reissue} lemur certificate reissue -c" >> /etc/crontabs/lemur echo " # Done" exec "$@" diff --git a/docker/src/lemur.conf.py b/docker/src/lemur.conf.py index 0f294b28..3cc51792 100644 --- a/docker/src/lemur.conf.py +++ b/docker/src/lemur.conf.py @@ -16,12 +16,16 @@ LEMUR_WHITELISTED_DOMAINS = [] LEMUR_EMAIL = '' LEMUR_SECURITY_TEAM_EMAIL = [] +ALLOW_CERT_DELETION = os.environ.get('ALLOW_CERT_DELETION') == "True" -LEMUR_DEFAULT_COUNTRY = repr(os.environ.get('LEMUR_DEFAULT_COUNTRY','')) -LEMUR_DEFAULT_STATE = repr(os.environ.get('LEMUR_DEFAULT_STATE','')) -LEMUR_DEFAULT_LOCATION = repr(os.environ.get('LEMUR_DEFAULT_LOCATION','')) -LEMUR_DEFAULT_ORGANIZATION = repr(os.environ.get('LEMUR_DEFAULT_ORGANIZATION','')) -LEMUR_DEFAULT_ORGANIZATIONAL_UNIT = repr(os.environ.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT','')) +LEMUR_DEFAULT_COUNTRY = str(os.environ.get('LEMUR_DEFAULT_COUNTRY','')) +LEMUR_DEFAULT_STATE = str(os.environ.get('LEMUR_DEFAULT_STATE','')) +LEMUR_DEFAULT_LOCATION = str(os.environ.get('LEMUR_DEFAULT_LOCATION','')) +LEMUR_DEFAULT_ORGANIZATION = str(os.environ.get('LEMUR_DEFAULT_ORGANIZATION','')) +LEMUR_DEFAULT_ORGANIZATIONAL_UNIT = str(os.environ.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT','')) + +LEMUR_DEFAULT_ISSUER_PLUGIN = str(os.environ.get('LEMUR_DEFAULT_ISSUER_PLUGIN','')) +LEMUR_DEFAULT_AUTHORITY = str(os.environ.get('LEMUR_DEFAULT_AUTHORITY','')) ACTIVE_PROVIDERS = [] diff --git a/docker/supervisor.conf b/docker/supervisor.conf index fed01581..ec4b221d 100644 --- a/docker/supervisor.conf +++ b/docker/supervisor.conf @@ -7,7 +7,7 @@ pidfile = /tmp/supervisord.pid [program:lemur] environment=LEMUR_CONF=/home/lemur/.lemur/lemur.conf.py -command=/usr/bin/python3 manage.py start -b 0.0.0.0:8000 +command=lemur start -b 0.0.0.0:8000 user=lemur directory=/opt/lemur/lemur stdout_logfile=/dev/stdout @@ -24,6 +24,7 @@ stderr_logfile=/dev/stderr stderr_logfile_maxbytes=0 [program:cron] +environment=LEMUR_CONF=/home/lemur/.lemur/lemur.conf.py command=/usr/sbin/crond -f user=root stdout_logfile=/dev/stdout diff --git a/docs/production/index.rst b/docs/production/index.rst index cd044ca4..b91ed6bd 100644 --- a/docs/production/index.rst +++ b/docs/production/index.rst @@ -390,6 +390,10 @@ Here are the Celery configuration variables that should be set:: CELERY_IMPORTS = ('lemur.common.celery') CELERY_TIMEZONE = 'UTC' +Do not forget to import crontab module in your configuration file:: + + from celery.task.schedules import crontab + You must start a single Celery scheduler instance and one or more worker instances in order to handle incoming tasks. The scheduler can be started with:: diff --git a/lemur/auth/views.py b/lemur/auth/views.py index e7f87356..eaed419d 100644 --- a/lemur/auth/views.py +++ b/lemur/auth/views.py @@ -127,6 +127,10 @@ def retrieve_user(user_api_url, access_token): # retrieve information about the current user. r = requests.get(user_api_url, params=user_params, headers=headers) + # Some IDPs, like "Keycloak", require a POST instead of a GET + if r.status_code == 400: + r = requests.post(user_api_url, data=user_params, headers=headers) + profile = r.json() user = user_service.get_by_email(profile["email"]) @@ -434,7 +438,7 @@ class OAuth2(Resource): verify_cert=verify_cert, ) - jwks_url = current_app.config.get("PING_JWKS_URL") + jwks_url = current_app.config.get("OAUTH2_JWKS_URL") error_code = validate_id_token(id_token, args["clientId"], jwks_url) if error_code: return error_code diff --git a/lemur/authorizations/models.py b/lemur/authorizations/models.py index 04ac0508..0797a489 100644 --- a/lemur/authorizations/models.py +++ b/lemur/authorizations/models.py @@ -25,7 +25,7 @@ class Authorization(db.Model): return plugins.get(self.plugin_name) def __repr__(self): - return "Authorization(id={id})".format(label=self.id) + return "Authorization(id={id})".format(id=self.id) def __init__(self, account_number, domains, dns_provider_type, options=None): self.account_number = account_number diff --git a/lemur/certificates/cli.py b/lemur/certificates/cli.py index b57ff175..b883dee0 100644 --- a/lemur/certificates/cli.py +++ b/lemur/certificates/cli.py @@ -5,39 +5,36 @@ :license: Apache, see LICENSE for more details. .. moduleauthor:: Kevin Glisson """ -import sys import multiprocessing -from tabulate import tabulate -from sqlalchemy import or_ - +import sys from flask import current_app - -from flask_script import Manager from flask_principal import Identity, identity_changed - +from flask_script import Manager +from sqlalchemy import or_ +from tabulate import tabulate from lemur import database -from lemur.extensions import sentry -from lemur.extensions import metrics -from lemur.plugins.base import plugins -from lemur.constants import SUCCESS_METRIC_STATUS, FAILURE_METRIC_STATUS -from lemur.deployment import service as deployment_service -from lemur.endpoints import service as endpoint_service -from lemur.notifications.messaging import send_rotation_notification -from lemur.domains.models import Domain from lemur.authorities.models import Authority -from lemur.certificates.schemas import CertificateOutputSchema +from lemur.authorities.service import get as authorities_get_by_id from lemur.certificates.models import Certificate +from lemur.certificates.schemas import CertificateOutputSchema from lemur.certificates.service import ( reissue_certificate, get_certificate_primitives, get_all_pending_reissue, get_by_name, - get_all_certs, + get_all_valid_certs, get, + get_all_certs_attached_to_endpoint_without_autorotate, ) - from lemur.certificates.verify import verify_string +from lemur.constants import SUCCESS_METRIC_STATUS, FAILURE_METRIC_STATUS +from lemur.deployment import service as deployment_service +from lemur.domains.models import Domain +from lemur.endpoints import service as endpoint_service +from lemur.extensions import sentry, metrics +from lemur.notifications.messaging import send_rotation_notification +from lemur.plugins.base import plugins manager = Manager(usage="Handles all certificate related tasks.") @@ -213,6 +210,10 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c status = FAILURE_METRIC_STATUS + log_data = { + "function": f"{__name__}.{sys._getframe().f_code.co_name}", + } + try: old_cert = validate_certificate(old_certificate_name) new_cert = validate_certificate(new_certificate_name) @@ -222,26 +223,43 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c print( f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}" ) + log_data["message"] = "Rotating endpoint" + log_data["endpoint"] = endpoint.dnsname + log_data["certificate"] = new_cert.name request_rotation(endpoint, new_cert, message, commit) + current_app.logger.info(log_data) elif old_cert and new_cert: print(f"[+] Rotating all endpoints from {old_cert.name} to {new_cert.name}") + log_data["message"] = "Rotating all endpoints" + log_data["certificate"] = new_cert.name + log_data["certificate_old"] = old_cert.name + log_data["message"] = "Rotating endpoint from old to new cert" for endpoint in old_cert.endpoints: print(f"[+] Rotating {endpoint.name}") + log_data["endpoint"] = endpoint.dnsname request_rotation(endpoint, new_cert, message, commit) + current_app.logger.info(log_data) else: print("[+] Rotating all endpoints that have new certificates available") + log_data["message"] = "Rotating all endpoints that have new certificates available" for endpoint in endpoint_service.get_all_pending_rotation(): + log_data["endpoint"] = endpoint.dnsname if len(endpoint.certificate.replaced) == 1: print( f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}" ) + log_data["certificate"] = endpoint.certificate.replaced[0].name request_rotation( endpoint, endpoint.certificate.replaced[0], message, commit ) + current_app.logger.info(log_data) + else: + log_data["message"] = "Failed to rotate endpoint due to Multiple replacement certificates found" + print(log_data) metrics.send( "endpoint_rotation", "counter", @@ -289,6 +307,178 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c ) +def request_rotation_region(endpoint, new_cert, message, commit, log_data, region): + if region in endpoint.dnsname: + log_data["message"] = "Rotating endpoint in region" + request_rotation(endpoint, new_cert, message, commit) + else: + log_data["message"] = "Skipping rotation, region mismatch" + + print(log_data) + current_app.logger.info(log_data) + + +@manager.option( + "-e", + "--endpoint", + dest="endpoint_name", + help="Name of the endpoint you wish to rotate.", +) +@manager.option( + "-n", + "--new-certificate", + dest="new_certificate_name", + help="Name of the certificate you wish to rotate to.", +) +@manager.option( + "-o", + "--old-certificate", + dest="old_certificate_name", + help="Name of the certificate you wish to rotate.", +) +@manager.option( + "-a", + "--notify", + dest="message", + action="store_true", + help="Send a rotation notification to the certificates owner.", +) +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) +@manager.option( + "-r", + "--region", + dest="region", + required=True, + help="Region in which to rotate the endpoint.", +) +def rotate_region(endpoint_name, new_certificate_name, old_certificate_name, message, commit, region): + """ + Rotates an endpoint in a defined region it if it has not already been replaced. If it has + been replaced, will use the replacement certificate for the rotation. + :param old_certificate_name: Name of the certificate you wish to rotate. + :param new_certificate_name: Name of the certificate you wish to rotate to. + :param endpoint_name: Name of the endpoint you wish to rotate. + :param message: Send a rotation notification to the certificates owner. + :param commit: Persist changes. + :param region: Region in which to rotate the endpoint. + """ + if commit: + print("[!] Running in COMMIT mode.") + + print("[+] Starting endpoint rotation.") + status = FAILURE_METRIC_STATUS + + log_data = { + "function": f"{__name__}.{sys._getframe().f_code.co_name}", + "region": region, + } + + try: + old_cert = validate_certificate(old_certificate_name) + new_cert = validate_certificate(new_certificate_name) + endpoint = validate_endpoint(endpoint_name) + + if endpoint and new_cert: + log_data["endpoint"] = endpoint.dnsname + log_data["certificate"] = new_cert.name + request_rotation_region(endpoint, new_cert, message, commit, log_data, region) + + elif old_cert and new_cert: + log_data["certificate"] = new_cert.name + log_data["certificate_old"] = old_cert.name + log_data["message"] = "Rotating endpoint from old to new cert" + print(log_data) + current_app.logger.info(log_data) + for endpoint in old_cert.endpoints: + log_data["endpoint"] = endpoint.dnsname + request_rotation_region(endpoint, new_cert, message, commit, log_data, region) + + else: + log_data["message"] = "Rotating all endpoints that have new certificates available" + print(log_data) + current_app.logger.info(log_data) + all_pending_rotation_endpoints = endpoint_service.get_all_pending_rotation() + for endpoint in all_pending_rotation_endpoints: + log_data["endpoint"] = endpoint.dnsname + if region not in endpoint.dnsname: + log_data["message"] = "Skipping rotation, region mismatch" + print(log_data) + current_app.logger.info(log_data) + metrics.send( + "endpoint_rotation_region_skipped", + "counter", + 1, + metric_tags={ + "region": region, + "old_certificate_name": str(old_cert), + "new_certificate_name": str(endpoint.certificate.replaced[0].name), + "endpoint_name": str(endpoint.dnsname), + }, + ) + + if len(endpoint.certificate.replaced) == 1: + log_data["certificate"] = endpoint.certificate.replaced[0].name + log_data["message"] = "Rotating all endpoints in region" + print(log_data) + current_app.logger.info(log_data) + request_rotation(endpoint, endpoint.certificate.replaced[0], message, commit) + status = SUCCESS_METRIC_STATUS + else: + status = FAILURE_METRIC_STATUS + log_data["message"] = "Failed to rotate endpoint due to Multiple replacement certificates found" + print(log_data) + current_app.logger.info(log_data) + + metrics.send( + "endpoint_rotation_region", + "counter", + 1, + metric_tags={ + "status": FAILURE_METRIC_STATUS, + "old_certificate_name": str(old_cert), + "new_certificate_name": str(endpoint.certificate.replaced[0].name), + "endpoint_name": str(endpoint.dnsname), + "message": str(message), + "region": str(region), + }, + ) + status = SUCCESS_METRIC_STATUS + print("[+] Done!") + + except Exception as e: + sentry.captureException( + extra={ + "old_certificate_name": str(old_certificate_name), + "new_certificate_name": str(new_certificate_name), + "endpoint": str(endpoint_name), + "message": str(message), + "region": str(region), + } + ) + + metrics.send( + "endpoint_rotation_region_job", + "counter", + 1, + metric_tags={ + "status": status, + "old_certificate_name": str(old_certificate_name), + "new_certificate_name": str(new_certificate_name), + "endpoint_name": str(endpoint_name), + "message": str(message), + "endpoint": str(globals().get("endpoint")), + "region": str(region), + }, + ) + + @manager.option( "-o", "--old-certificate", @@ -467,7 +657,14 @@ def check_revoked(): encounters an issue with verification it marks the certificate status as `unknown`. """ - for cert in get_all_certs(): + + log_data = { + "function": f"{__name__}.{sys._getframe().f_code.co_name}", + "message": "Checking for revoked Certificates" + } + + certs = get_all_valid_certs(current_app.config.get("SUPPORTED_REVOCATION_AUTHORITY_PLUGINS", [])) + for cert in certs: try: if cert.chain: status = verify_string(cert.body, cert.chain) @@ -476,9 +673,65 @@ def check_revoked(): cert.status = "valid" if status else "revoked" + if cert.status == "revoked": + log_data["valid"] = cert.status + log_data["certificate_name"] = cert.name + log_data["certificate_id"] = cert.id + metrics.send( + "certificate_revoked", + "counter", + 1, + metric_tags={"status": log_data["valid"], + "certificate_name": log_data["certificate_name"], + "certificate_id": log_data["certificate_id"]}, + ) + current_app.logger.info(log_data) + except Exception as e: sentry.captureException() current_app.logger.exception(e) cert.status = "unknown" database.update(cert) + + +@manager.command +def automatically_enable_autorotate(): + """ + This function automatically enables auto-rotation for unexpired certificates that are + attached to an endpoint but do not have autorotate enabled. + + WARNING: This will overwrite the Auto-rotate toggle! + """ + log_data = { + "function": f"{__name__}.{sys._getframe().f_code.co_name}", + "message": "Enabling auto-rotate for certificate" + } + + permitted_authorities = current_app.config.get("ENABLE_AUTO_ROTATE_AUTHORITY", []) + + eligible_certs = get_all_certs_attached_to_endpoint_without_autorotate() + for cert in eligible_certs: + + if cert.authority_id not in permitted_authorities: + continue + + log_data["certificate"] = cert.name + log_data["certificate_id"] = cert.id + log_data["authority_id"] = cert.authority_id + log_data["authority_name"] = authorities_get_by_id(cert.authority_id).name + if cert.destinations: + log_data["destination_names"] = ', '.join([d.label for d in cert.destinations]) + else: + log_data["destination_names"] = "NONE" + current_app.logger.info(log_data) + metrics.send("automatically_enable_autorotate", + "counter", 1, + metric_tags={"certificate": log_data["certificate"], + "certificate_id": log_data["certificate_id"], + "authority_id": log_data["authority_id"], + "authority_name": log_data["authority_name"], + "destination_names": log_data["destination_names"] + }) + cert.rotation = True + database.update(cert) diff --git a/lemur/certificates/schemas.py b/lemur/certificates/schemas.py index 8f15542d..42e444bc 100644 --- a/lemur/certificates/schemas.py +++ b/lemur/certificates/schemas.py @@ -146,7 +146,8 @@ class CertificateInputSchema(CertificateCreationSchema): data["extensions"]["subAltNames"] = {"names": []} elif not data["extensions"]["subAltNames"].get("names"): data["extensions"]["subAltNames"]["names"] = [] - data["extensions"]["subAltNames"]["names"] += csr_sans + + data["extensions"]["subAltNames"]["names"] = csr_sans return missing.convert_validity_years(data) diff --git a/lemur/certificates/service.py b/lemur/certificates/service.py index a6bbba30..df73487d 100644 --- a/lemur/certificates/service.py +++ b/lemur/certificates/service.py @@ -20,6 +20,7 @@ from lemur.common.utils import generate_private_key, truthiness from lemur.destinations.models import Destination from lemur.domains.models import Domain from lemur.extensions import metrics, sentry, signals +from lemur.models import certificate_associations from lemur.notifications.models import Notification from lemur.pending_certificates.models import PendingCertificate from lemur.plugins.base import plugins @@ -102,6 +103,27 @@ def get_all_certs(): return Certificate.query.all() +def get_all_valid_certs(authority_plugin_name): + """ + Retrieves all valid (not expired) certificates within Lemur, for the given authority plugin names + ignored if no authority_plugin_name provided. + + Note that depending on the DB size retrieving all certificates might an expensive operation + + :return: + """ + if authority_plugin_name: + return ( + Certificate.query.outerjoin(Authority, Authority.id == Certificate.authority_id).filter( + Certificate.not_after > arrow.now().format("YYYY-MM-DD")).filter( + Authority.plugin_name.in_(authority_plugin_name)).all() + ) + else: + return ( + Certificate.query.filter(Certificate.not_after > arrow.now().format("YYYY-MM-DD")).all() + ) + + def get_all_pending_cleaning_expired(source): """ Retrieves all certificates that are available for cleaning. These are certificates which are expired and are not @@ -118,6 +140,21 @@ def get_all_pending_cleaning_expired(source): ) +def get_all_certs_attached_to_endpoint_without_autorotate(): + """ + Retrieves all certificates that are attached to an endpoint, but that do not have autorotate enabled. + + :return: list of certificates attached to an endpoint without autorotate + """ + return ( + Certificate.query.filter(Certificate.endpoints.any()) + .filter(Certificate.rotation == False) + .filter(Certificate.not_after >= arrow.now()) + .filter(not_(Certificate.replaced.any())) + .all() # noqa + ) + + def get_all_pending_cleaning_expiring_in_days(source, days_to_expire): """ Retrieves all certificates that are available for cleaning, not attached to endpoint, @@ -144,7 +181,9 @@ def get_all_pending_cleaning_issued_since_days(source, days_since_issuance): :param source: the source to search for certificates :return: list of pending certificates """ - not_in_use_window = arrow.now().shift(days=-days_since_issuance).format("YYYY-MM-DD") + not_in_use_window = ( + arrow.now().shift(days=-days_since_issuance).format("YYYY-MM-DD") + ) return ( Certificate.query.filter(Certificate.sources.any(id=source.id)) .filter(not_(Certificate.endpoints.any())) @@ -367,9 +406,11 @@ def render(args): show_expired = args.pop("showExpired") if show_expired != 1: - one_month_old = arrow.now()\ - .shift(months=current_app.config.get("HIDE_EXPIRED_CERTS_AFTER_MONTHS", -1))\ + one_month_old = ( + arrow.now() + .shift(months=current_app.config.get("HIDE_EXPIRED_CERTS_AFTER_MONTHS", -1)) .format("YYYY-MM-DD") + ) query = query.filter(Certificate.not_after > one_month_old) time_range = args.pop("time_range") @@ -415,8 +456,8 @@ def render(args): elif "cn" in terms: query = query.filter( or_( - Certificate.cn.ilike(term), - Certificate.domains.any(Domain.name.ilike(term)), + func.lower(Certificate.cn).like(term.lower()), + Certificate.id.in_(like_domain_query(term)), ) ) elif "id" in terms: @@ -424,9 +465,9 @@ def render(args): elif "name" in terms: query = query.filter( or_( - Certificate.name.ilike(term), - Certificate.domains.any(Domain.name.ilike(term)), - Certificate.cn.ilike(term), + func.lower(Certificate.name).like(term.lower()), + Certificate.id.in_(like_domain_query(term)), + func.lower(Certificate.cn).like(term.lower()), ) ) elif "fixedName" in terms: @@ -471,6 +512,14 @@ def render(args): return result +def like_domain_query(term): + domain_query = database.session_query(Domain.id) + domain_query = domain_query.filter(func.lower(Domain.name).like(term.lower())) + assoc_query = database.session_query(certificate_associations.c.certificate_id) + assoc_query = assoc_query.filter(certificate_associations.c.domain_id.in_(domain_query)) + return assoc_query + + def query_name(certificate_name, args): """ Helper function that queries for a certificate by name diff --git a/lemur/certificates/verify.py b/lemur/certificates/verify.py index 76c6b521..0fe379f8 100644 --- a/lemur/certificates/verify.py +++ b/lemur/certificates/verify.py @@ -8,6 +8,7 @@ import requests import subprocess from flask import current_app +from lemur.extensions import sentry from requests.exceptions import ConnectionError, InvalidSchema from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -152,10 +153,19 @@ def verify(cert_path, issuer_chain_path): # OCSP is our main source of truth, in a lot of cases CRLs # have been deprecated and are no longer updated - verify_result = ocsp_verify(cert, cert_path, issuer_chain_path) + verify_result = None + try: + verify_result = ocsp_verify(cert, cert_path, issuer_chain_path) + except Exception as e: + sentry.captureException() + current_app.logger.exception(e) if verify_result is None: - verify_result = crl_verify(cert, cert_path) + try: + verify_result = crl_verify(cert, cert_path) + except Exception as e: + sentry.captureException() + current_app.logger.exception(e) if verify_result is None: current_app.logger.debug("Failed to verify {}".format(cert.serial_number)) diff --git a/lemur/common/celery.py b/lemur/common/celery.py index 4af33d86..a490b13b 100644 --- a/lemur/common/celery.py +++ b/lemur/common/celery.py @@ -10,27 +10,27 @@ command: celery -A lemur.common.celery worker --loglevel=info -l DEBUG -B import copy import sys import time -from datetime import datetime, timezone, timedelta - from celery import Celery +from celery.app.task import Context from celery.exceptions import SoftTimeLimitExceeded +from celery.signals import task_failure, task_received, task_revoked, task_success +from datetime import datetime, timezone, timedelta from flask import current_app from lemur.authorities.service import get as get_authority +from lemur.certificates import cli as cli_certificate from lemur.common.redis import RedisHandler from lemur.destinations import service as destinations_service +from lemur.dns_providers import cli as cli_dns_providers +from lemur.endpoints import cli as cli_endpoints from lemur.extensions import metrics, sentry from lemur.factory import create_app +from lemur.notifications import cli as cli_notification from lemur.notifications.messaging import send_pending_failure_notification from lemur.pending_certificates import service as pending_certificate_service from lemur.plugins.base import plugins from lemur.sources.cli import clean, sync, validate_sources from lemur.sources.service import add_aws_destination_to_sources -from lemur.certificates import cli as cli_certificate -from lemur.dns_providers import cli as cli_dns_providers -from lemur.notifications import cli as cli_notification -from lemur.endpoints import cli as cli_endpoints - if current_app: flask_app = current_app @@ -67,7 +67,7 @@ def is_task_active(fun, task_id, args): from celery.task.control import inspect if not args: - args = '()' # empty args + args = "()" # empty args i = inspect() active_tasks = i.active() @@ -80,6 +80,37 @@ def is_task_active(fun, task_id, args): return False +def get_celery_request_tags(**kwargs): + request = kwargs.get("request") + sender_hostname = "unknown" + sender = kwargs.get("sender") + if sender: + try: + sender_hostname = sender.hostname + except AttributeError: + sender_hostname = vars(sender.request).get("origin", "unknown") + if request and not isinstance( + request, Context + ): # unlike others, task_revoked sends a Context for `request` + task_name = request.name + task_id = request.id + receiver_hostname = request.hostname + else: + task_name = sender.name + task_id = sender.request.id + receiver_hostname = sender.request.hostname + + tags = { + "task_name": task_name, + "task_id": task_id, + "sender_hostname": sender_hostname, + "receiver_hostname": receiver_hostname, + } + if kwargs.get("exception"): + tags["error"] = repr(kwargs["exception"]) + return tags + + @celery.task() def report_celery_last_success_metrics(): """ @@ -89,7 +120,6 @@ def report_celery_last_success_metrics(): report_celery_last_success_metrics should be ran periodically to emit metrics on when a task was last successful. Admins can then alert when tasks are not ran when intended. Admins should also alert when no metrics are emitted from this function. - """ function = f"{__name__}.{sys._getframe().f_code.co_name}" task_id = None @@ -108,15 +138,91 @@ def report_celery_last_success_metrics(): return current_time = int(time.time()) - schedule = current_app.config.get('CELERYBEAT_SCHEDULE') + schedule = current_app.config.get("CELERYBEAT_SCHEDULE") for _, t in schedule.items(): task = t.get("task") last_success = int(red.get(f"{task}.last_success") or 0) - metrics.send(f"{task}.time_since_last_success", 'gauge', current_time - last_success) + metrics.send( + f"{task}.time_since_last_success", "gauge", current_time - last_success + ) red.set( f"{function}.last_success", int(time.time()) ) # Alert if this metric is not seen - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + + +@task_received.connect +def report_number_pending_tasks(**kwargs): + """ + Report the number of pending tasks to our metrics broker every time a task is published. This metric can be used + for autoscaling workers. + https://docs.celeryproject.org/en/latest/userguide/signals.html#task-received + """ + with flask_app.app_context(): + metrics.send( + "celery.new_pending_task", + "TIMER", + 1, + metric_tags=get_celery_request_tags(**kwargs), + ) + + +@task_success.connect +def report_successful_task(**kwargs): + """ + Report a generic success metric as tasks to our metrics broker every time a task finished correctly. + This metric can be used for autoscaling workers. + https://docs.celeryproject.org/en/latest/userguide/signals.html#task-success + """ + with flask_app.app_context(): + tags = get_celery_request_tags(**kwargs) + red.set(f"{tags['task_name']}.last_success", int(time.time())) + metrics.send("celery.successful_task", "TIMER", 1, metric_tags=tags) + + +@task_failure.connect +def report_failed_task(**kwargs): + """ + Report a generic failure metric as tasks to our metrics broker every time a task fails. + This metric can be used for alerting. + https://docs.celeryproject.org/en/latest/userguide/signals.html#task-failure + """ + with flask_app.app_context(): + log_data = { + "function": f"{__name__}.{sys._getframe().f_code.co_name}", + "Message": "Celery Task Failure", + } + + # Add traceback if exception info is in the kwargs + einfo = kwargs.get("einfo") + if einfo: + log_data["traceback"] = einfo.traceback + + error_tags = get_celery_request_tags(**kwargs) + + log_data.update(error_tags) + current_app.logger.error(log_data) + metrics.send("celery.failed_task", "TIMER", 1, metric_tags=error_tags) + + +@task_revoked.connect +def report_revoked_task(**kwargs): + """ + Report a generic failure metric as tasks to our metrics broker every time a task is revoked. + This metric can be used for alerting. + https://docs.celeryproject.org/en/latest/userguide/signals.html#task-revoked + """ + with flask_app.app_context(): + log_data = { + "function": f"{__name__}.{sys._getframe().f_code.co_name}", + "Message": "Celery Task Revoked", + } + + error_tags = get_celery_request_tags(**kwargs) + + log_data.update(error_tags) + current_app.logger.error(log_data) + metrics.send("celery.revoked_task", "TIMER", 1, metric_tags=error_tags) @celery.task(soft_time_limit=600) @@ -217,15 +323,15 @@ def fetch_acme_cert(id): log_data["failed"] = failed log_data["wrong_issuer"] = wrong_issuer current_app.logger.debug(log_data) - metrics.send(f"{function}.resolved", 'gauge', new) - metrics.send(f"{function}.failed", 'gauge', failed) - metrics.send(f"{function}.wrong_issuer", 'gauge', wrong_issuer) + metrics.send(f"{function}.resolved", "gauge", new) + metrics.send(f"{function}.failed", "gauge", failed) + metrics.send(f"{function}.wrong_issuer", "gauge", wrong_issuer) print( "[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format( new=new, failed=failed, wrong_issuer=wrong_issuer ) ) - red.set(f'{function}.last_success', int(time.time())) + return log_data @celery.task() @@ -262,8 +368,8 @@ def fetch_all_pending_acme_certs(): current_app.logger.debug(log_data) fetch_acme_cert.delay(cert.id) - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task() @@ -296,8 +402,8 @@ def remove_old_acme_certs(): current_app.logger.debug(log_data) pending_certificate_service.delete(cert) - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task() @@ -328,11 +434,11 @@ def clean_all_sources(): current_app.logger.debug(log_data) clean_source.delay(source.label) - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data -@celery.task(soft_time_limit=600) +@celery.task(soft_time_limit=3600) def clean_source(source): """ This celery task will clean the specified source. This is a destructive operation that will delete unused @@ -366,6 +472,7 @@ def clean_source(source): current_app.logger.error(log_data) sentry.captureException() metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return log_data @celery.task() @@ -395,8 +502,8 @@ def sync_all_sources(): current_app.logger.debug(log_data) sync_source.delay(source.label) - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task(soft_time_limit=7200) @@ -428,19 +535,23 @@ def sync_source(source): current_app.logger.debug(log_data) try: sync([source]) - metrics.send(f"{function}.success", 'counter', 1, metric_tags={"source": source}) + metrics.send( + f"{function}.success", "counter", 1, metric_tags={"source": source} + ) except SoftTimeLimitExceeded: log_data["message"] = "Error syncing source: Time limit exceeded." current_app.logger.error(log_data) sentry.captureException() - metrics.send("sync_source_timeout", "counter", 1, metric_tags={"source": source}) + metrics.send( + "sync_source_timeout", "counter", 1, metric_tags={"source": source} + ) metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) return log_data["message"] = "Done syncing source" current_app.logger.debug(log_data) - metrics.send(f"{function}.success", 'counter', 1, metric_tags={"source": source}) - red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", "counter", 1, metric_tags={"source": source}) + return log_data @celery.task() @@ -477,8 +588,8 @@ def sync_source_destination(): log_data["message"] = "completed Syncing AWS destinations and sources" current_app.logger.debug(log_data) - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task(soft_time_limit=3600) @@ -515,12 +626,13 @@ def certificate_reissue(): log_data["message"] = "reissuance completed" current_app.logger.debug(log_data) - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task(soft_time_limit=3600) -def certificate_rotate(): +def certificate_rotate(**kwargs): + """ This celery task rotates certificates which are reissued but having endpoints attached to the replaced cert :return: @@ -530,11 +642,11 @@ def certificate_rotate(): if celery.current_task: task_id = celery.current_task.request.id + region = kwargs.get("region") log_data = { "function": function, "message": "rotating certificates", "task_id": task_id, - } if task_id and is_task_active(function, task_id, None): @@ -544,7 +656,11 @@ def certificate_rotate(): current_app.logger.debug(log_data) try: - cli_certificate.rotate(None, None, None, None, True) + if region: + log_data["region"] = region + cli_certificate.rotate_region(None, None, None, None, True, region) + else: + cli_certificate.rotate(None, None, None, None, True) except SoftTimeLimitExceeded: log_data["message"] = "Certificate rotate: Time limit exceeded." current_app.logger.error(log_data) @@ -554,8 +670,8 @@ def certificate_rotate(): log_data["message"] = "rotation completed" current_app.logger.debug(log_data) - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task(soft_time_limit=3600) @@ -590,8 +706,8 @@ def endpoints_expire(): metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) return - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task(soft_time_limit=600) @@ -626,8 +742,8 @@ def get_all_zones(): metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) return - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task(soft_time_limit=3600) @@ -662,8 +778,8 @@ def check_revoked(): metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) return - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data @celery.task(soft_time_limit=3600) @@ -690,7 +806,9 @@ def notify_expirations(): current_app.logger.debug(log_data) try: - cli_notification.expirations(current_app.config.get("EXCLUDE_CN_FROM_NOTIFICATION", [])) + cli_notification.expirations( + current_app.config.get("EXCLUDE_CN_FROM_NOTIFICATION", []) + ) except SoftTimeLimitExceeded: log_data["message"] = "Notify expiring Time limit exceeded." current_app.logger.error(log_data) @@ -698,5 +816,29 @@ def notify_expirations(): metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) return - red.set(f'{function}.last_success', int(time.time())) - metrics.send(f"{function}.success", 'counter', 1) + metrics.send(f"{function}.success", "counter", 1) + return log_data + + +@celery.task(soft_time_limit=3600) +def enable_autorotate_for_certs_attached_to_endpoint(): + """ + This celery task automatically enables autorotation for unexpired certificates that are + attached to an endpoint but do not have autorotate enabled. + :return: + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "task_id": task_id, + "message": "Enabling autorotate to eligible certificates", + } + current_app.logger.debug(log_data) + + cli_certificate.automatically_enable_autorotate() + metrics.send(f"{function}.success", "counter", 1) + return log_data diff --git a/lemur/common/defaults.py b/lemur/common/defaults.py index d563dbd0..b9c88e49 100644 --- a/lemur/common/defaults.py +++ b/lemur/common/defaults.py @@ -2,6 +2,7 @@ import re import unicodedata from cryptography import x509 +from cryptography.hazmat.primitives.serialization import Encoding from flask import current_app from lemur.common.utils import is_selfsigned @@ -71,12 +72,20 @@ def common_name(cert): :return: Common name or None """ try: - return cert.subject.get_attributes_for_oid(x509.OID_COMMON_NAME)[ - 0 - ].value.strip() + subject_oid = cert.subject.get_attributes_for_oid(x509.OID_COMMON_NAME) + if len(subject_oid) > 0: + return subject_oid[0].value.strip() + return None except Exception as e: sentry.captureException() - current_app.logger.error("Unable to get common name! {0}".format(e)) + current_app.logger.error( + { + "message": "Unable to get common name", + "error": e, + "public_key": cert.public_bytes(Encoding.PEM).decode("utf-8") + }, + exc_info=True + ) def organization(cert): diff --git a/lemur/common/validators.py b/lemur/common/validators.py index 2412e2d3..e1dfe3c1 100644 --- a/lemur/common/validators.py +++ b/lemur/common/validators.py @@ -99,8 +99,12 @@ def csr(data): raise ValidationError("CSR presented is not valid.") # Validate common name and SubjectAltNames - for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME): - common_name(name.value) + try: + for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME): + common_name(name.value) + except ValueError as err: + current_app.logger.info("Error parsing Subject from CSR: %s", err) + raise ValidationError("Invalid Subject value in supplied CSR") try: alt_names = request.extensions.get_extension_for_class( diff --git a/lemur/dns_providers/util.py b/lemur/dns_providers/util.py index cc8d9bb3..0fa84ac1 100644 --- a/lemur/dns_providers/util.py +++ b/lemur/dns_providers/util.py @@ -31,11 +31,11 @@ class DNSResolveError(DNSError): def is_valid_domain(domain): """Checks if a domain is syntactically valid and returns a bool""" - if len(domain) > 253: - return False if domain[-1] == ".": domain = domain[:-1] - fqdn_re = re.compile("(?=^.{1,254}$)(^(?:(?!\d+\.|-)[a-zA-Z0-9_\-]{1,63}(? 253: + return False + fqdn_re = re.compile("(?=^.{1,63}$)(^(?:[a-z0-9_](?:-*[a-z0-9_])+)$|^[a-z0-9]$)", re.IGNORECASE) return all(fqdn_re.match(d) for d in domain.split(".")) diff --git a/lemur/migrations/versions/8323a5ea723a_.py b/lemur/migrations/versions/8323a5ea723a_.py new file mode 100644 index 00000000..9505cdb1 --- /dev/null +++ b/lemur/migrations/versions/8323a5ea723a_.py @@ -0,0 +1,50 @@ +"""Add lowercase index for certificate name and cn and also for domain name + +Revision ID: 8323a5ea723a +Revises: b33c838cb669 +Create Date: 2020-01-10 10:51:44.776052 + +""" + +# revision identifiers, used by Alembic. +revision = '8323a5ea723a' +down_revision = 'b33c838cb669' + +from alembic import op +from sqlalchemy import text + +import sqlalchemy as sa + + +def upgrade(): + op.create_index( + "ix_certificates_cn_lower", + "certificates", + [text("lower(cn)")], + unique=False, + postgresql_ops={"lower(cn)": "gin_trgm_ops"}, + postgresql_using="gin", + ) + op.create_index( + "ix_certificates_name_lower", + "certificates", + [text("lower(name)")], + unique=False, + postgresql_ops={"lower(name)": "gin_trgm_ops"}, + postgresql_using="gin", + ) + op.create_index( + "ix_domains_name_lower", + "domains", + [text("lower(name)")], + unique=False, + postgresql_ops={"lower(name)": "gin_trgm_ops"}, + postgresql_using="gin", + ) + + + +def downgrade(): + op.drop_index("ix_certificates_cn_lower", table_name="certificates") + op.drop_index("ix_certificates_name_lower", table_name="certificates") + op.drop_index("ix_domains_name_lower", table_name="domains") diff --git a/lemur/migrations/versions/ee827d1e1974_.py b/lemur/migrations/versions/ee827d1e1974_.py index 56696fe3..649f1ed7 100644 --- a/lemur/migrations/versions/ee827d1e1974_.py +++ b/lemur/migrations/versions/ee827d1e1974_.py @@ -45,6 +45,6 @@ def upgrade(): def downgrade(): - op.drop_index("ix_domains_name", table_name="domains") + op.drop_index("ix_domains_name_gin", table_name="domains") op.drop_index("ix_certificates_name", table_name="certificates") op.drop_index("ix_certificates_cn", table_name="certificates") diff --git a/lemur/plugins/lemur_acme/powerdns.py b/lemur/plugins/lemur_acme/powerdns.py index a26faaac..a5d02353 100644 --- a/lemur/plugins/lemur_acme/powerdns.py +++ b/lemur/plugins/lemur_acme/powerdns.py @@ -1,11 +1,10 @@ -import time -import requests import json import sys +import time import lemur.common.utils as utils import lemur.dns_providers.util as dnsutil - +import requests from flask import current_app from lemur.extensions import metrics, sentry @@ -17,7 +16,9 @@ REQUIRED_VARIABLES = [ class Zone: - """ This class implements a PowerDNS zone in JSON. """ + """ + This class implements a PowerDNS zone in JSON. + """ def __init__(self, _data): self._data = _data @@ -39,7 +40,9 @@ class Zone: class Record: - """ This class implements a PowerDNS record. """ + """ + This class implements a PowerDNS record. + """ def __init__(self, _data): self._data = _data @@ -49,20 +52,30 @@ class Record: return self._data["name"] @property - def disabled(self): - return self._data["disabled"] + def type(self): + return self._data["type"] + + @property + def ttl(self): + return self._data["ttl"] @property def content(self): return self._data["content"] @property - def ttl(self): - return self._data["ttl"] + def disabled(self): + return self._data["disabled"] def get_zones(account_number): - """Retrieve authoritative zones from the PowerDNS API and return a list""" + """ + Retrieve authoritative zones from the PowerDNS API and return a list of zones + + :param account_number: + :raise: Exception + :return: list of Zone Objects + """ _check_conf() server_id = current_app.config.get("ACME_POWERDNS_SERVERID", "localhost") path = f"/api/v1/servers/{server_id}/zones" @@ -90,44 +103,41 @@ def get_zones(account_number): def create_txt_record(domain, token, account_number): - """ Create a TXT record for the given domain and token and return a change_id tuple """ + """ + Create a TXT record for the given domain and token and return a change_id tuple + + :param domain: FQDN + :param token: challenge value + :param account_number: + :return: tuple of domain/token + """ _check_conf() - zone_name = _get_zone_name(domain, account_number) - server_id = current_app.config.get("ACME_POWERDNS_SERVERID", "localhost") - zone_id = zone_name + "." - domain_id = domain + "." - path = f"/api/v1/servers/{server_id}/zones/{zone_id}" - payload = { - "rrsets": [ - { - "name": domain_id, - "type": "TXT", - "ttl": 300, - "changetype": "REPLACE", - "records": [ - { - "content": f"\"{token}\"", - "disabled": False - } - ], - "comments": [] - } - ] - } + function = sys._getframe().f_code.co_name log_data = { "function": function, "fqdn": domain, "token": token, } + + # Create new record + domain_id = domain + "." + records = [Record({'name': domain_id, 'content': f"\"{token}\"", 'disabled': False})] + + # Get current records + cur_records = _get_txt_records(domain) + for record in cur_records: + if record.content != token: + records.append(record) + try: - _patch(path, payload) - log_data["message"] = "TXT record successfully created" + _patch_txt_records(domain, account_number, records) + log_data["message"] = "TXT record(s) successfully created" current_app.logger.debug(log_data) except Exception as e: sentry.captureException() log_data["Exception"] = e - log_data["message"] = "Unable to create TXT record" + log_data["message"] = "Unable to create TXT record(s)" current_app.logger.debug(log_data) change_id = (domain, token) @@ -136,8 +146,11 @@ def create_txt_record(domain, token, account_number): def wait_for_dns_change(change_id, account_number=None): """ - Checks the authoritative DNS Server to see if changes have propagated to DNS - Retries and waits until successful. + Checks the authoritative DNS Server to see if changes have propagated. + + :param change_id: tuple of domain/token + :param account_number: + :return: """ _check_conf() domain, token = change_id @@ -171,53 +184,115 @@ def wait_for_dns_change(change_id, account_number=None): def delete_txt_record(change_id, account_number, domain, token): - """ Delete the TXT record for the given domain and token """ + """ + Delete the TXT record for the given domain and token + + :param change_id: tuple of domain/token + :param account_number: + :param domain: FQDN + :param token: challenge to delete + :return: + """ _check_conf() - zone_name = _get_zone_name(domain, account_number) - server_id = current_app.config.get("ACME_POWERDNS_SERVERID", "localhost") - zone_id = zone_name + "." - domain_id = domain + "." - path = f"/api/v1/servers/{server_id}/zones/{zone_id}" - payload = { - "rrsets": [ - { - "name": domain_id, - "type": "TXT", - "ttl": 300, - "changetype": "DELETE", - "records": [ - { - "content": f"\"{token}\"", - "disabled": False - } - ], - "comments": [] - } - ] - } + function = sys._getframe().f_code.co_name log_data = { "function": function, "fqdn": domain, - "token": token + "token": token, } - try: - _patch(path, payload) - log_data["message"] = "TXT record successfully deleted" - current_app.logger.debug(log_data) - except Exception as e: - sentry.captureException() - log_data["Exception"] = e - log_data["message"] = "Unable to delete TXT record" + + """ + Get existing TXT records matching the domain from DNS + The token to be deleted should already exist + There may be other records with different tokens as well + """ + cur_records = _get_txt_records(domain) + found = False + new_records = [] + for record in cur_records: + if record.content == f"\"{token}\"": + found = True + else: + new_records.append(record) + + # Since the matching token is not in DNS, there is nothing to delete + if not found: + log_data["message"] = "Unable to delete TXT record: Token not found in existing TXT records" current_app.logger.debug(log_data) + return + + # The record to delete has been found AND there are other tokens set on the same domain + # Since we only want to delete one token value from the RRSet, we need to use the Patch command to + # overwrite the current RRSet with the existing records. + elif new_records: + try: + _patch_txt_records(domain, account_number, new_records) + log_data["message"] = "TXT record successfully deleted" + current_app.logger.debug(log_data) + except Exception as e: + sentry.captureException() + log_data["Exception"] = e + log_data["message"] = "Unable to delete TXT record: patching exception" + current_app.logger.debug(log_data) + + # The record to delete has been found AND there are no other token values set on the same domain + # Use the Delete command to delete the whole RRSet. + else: + zone_name = _get_zone_name(domain, account_number) + server_id = current_app.config.get("ACME_POWERDNS_SERVERID", "localhost") + zone_id = zone_name + "." + domain_id = domain + "." + path = f"/api/v1/servers/{server_id}/zones/{zone_id}" + payload = { + "rrsets": [ + { + "name": domain_id, + "type": "TXT", + "ttl": 300, + "changetype": "DELETE", + "records": [ + { + "content": f"\"{token}\"", + "disabled": False + } + ], + "comments": [] + } + ] + } + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "fqdn": domain, + "token": token + } + try: + _patch(path, payload) + log_data["message"] = "TXT record successfully deleted" + current_app.logger.debug(log_data) + except Exception as e: + sentry.captureException() + log_data["Exception"] = e + log_data["message"] = "Unable to delete TXT record" + current_app.logger.debug(log_data) def _check_conf(): + """ + Verifies required configuration variables are set + + :return: + """ utils.validate_conf(current_app, REQUIRED_VARIABLES) def _generate_header(): - """Generate a PowerDNS API header and return it as a dictionary""" + """ + Generate a PowerDNS API header and return it as a dictionary + + :return: Dict of header parameters + """ api_key_name = current_app.config.get("ACME_POWERDNS_APIKEYNAME") api_key = current_app.config.get("ACME_POWERDNS_APIKEY") headers = {api_key_name: api_key} @@ -225,7 +300,13 @@ def _generate_header(): def _get_zone_name(domain, account_number): - """Get most specific matching zone for the given domain and return as a String""" + """ + Get most specific matching zone for the given domain and return as a String + + :param domain: FQDN + :param account_number: + :return: FQDN of domain + """ zones = get_zones(account_number) zone_name = "" for z in zones: @@ -243,8 +324,47 @@ def _get_zone_name(domain, account_number): return zone_name +def _get_txt_records(domain): + """ + Retrieve TXT records for a given domain and return list of Record Objects + + :param domain: FQDN + :return: list of Record objects + """ + server_id = current_app.config.get("ACME_POWERDNS_SERVERID", "localhost") + + path = f"/api/v1/servers/{server_id}/search-data?q={domain}&max=100&object_type=record" + function = sys._getframe().f_code.co_name + log_data = { + "function": function + } + try: + records = _get(path) + log_data["message"] = "Retrieved TXT Records Successfully" + current_app.logger.debug(log_data) + + except Exception as e: + sentry.captureException() + log_data["Exception"] = e + log_data["message"] = "Failed to Retrieve TXT Records" + current_app.logger.debug(log_data) + return [] + + txt_records = [] + for record in records: + cur_record = Record(record) + txt_records.append(cur_record) + return txt_records + + def _get(path, params=None): - """ Execute a GET request on the given URL (base_uri + path) and return response as JSON object """ + """ + Execute a GET request on the given URL (base_uri + path) and return response as JSON object + + :param path: Relative URL path + :param params: additional parameters + :return: json response + """ base_uri = current_app.config.get("ACME_POWERDNS_DOMAIN") verify_value = current_app.config.get("ACME_POWERDNS_VERIFY", True) resp = requests.get( @@ -257,8 +377,54 @@ def _get(path, params=None): return resp.json() +def _patch_txt_records(domain, account_number, records): + """ + Send Patch request to PowerDNS Server + + :param domain: FQDN + :param account_number: + :param records: List of Record objects + :return: + """ + domain_id = domain + "." + + # Create records + txt_records = [] + for record in records: + txt_records.append( + {'content': record.content, 'disabled': record.disabled} + ) + + # Create RRSet + payload = { + "rrsets": [ + { + "name": domain_id, + "type": "TXT", + "ttl": 300, + "changetype": "REPLACE", + "records": txt_records, + "comments": [] + } + ] + } + + # Create Txt Records + server_id = current_app.config.get("ACME_POWERDNS_SERVERID", "localhost") + zone_name = _get_zone_name(domain, account_number) + zone_id = zone_name + "." + path = f"/api/v1/servers/{server_id}/zones/{zone_id}" + _patch(path, payload) + + def _patch(path, payload): - """ Execute a Patch request on the given URL (base_uri + path) with given payload """ + """ + Execute a Patch request on the given URL (base_uri + path) with given payload + + :param path: + :param payload: + :return: + """ base_uri = current_app.config.get("ACME_POWERDNS_DOMAIN") verify_value = current_app.config.get("ACME_POWERDNS_VERIFY", True) resp = requests.patch( diff --git a/lemur/plugins/lemur_acme/route53.py b/lemur/plugins/lemur_acme/route53.py index 55da5161..aaccb57e 100644 --- a/lemur/plugins/lemur_acme/route53.py +++ b/lemur/plugins/lemur_acme/route53.py @@ -35,9 +35,10 @@ def get_zones(client=None): zones = [] for page in paginator.paginate(): for zone in page["HostedZones"]: - zones.append( - zone["Name"][:-1] - ) # We need [:-1] to strip out the trailing dot. + if not zone["Config"]["PrivateZone"]: + zones.append( + zone["Name"][:-1] + ) # We need [:-1] to strip out the trailing dot. return zones diff --git a/lemur/plugins/lemur_acme/tests/test_acme.py b/lemur/plugins/lemur_acme/tests/test_acme.py index b2c32eec..94949a74 100644 --- a/lemur/plugins/lemur_acme/tests/test_acme.py +++ b/lemur/plugins/lemur_acme/tests/test_acme.py @@ -1,11 +1,9 @@ import unittest +from unittest.mock import patch, Mock from cryptography.x509 import DNSName -from requests.models import Response - -from mock import MagicMock, Mock, patch - -from lemur.plugins.lemur_acme import plugin, ultradns +from lemur.plugins.lemur_acme import plugin +from mock import MagicMock class TestAcme(unittest.TestCase): @@ -57,7 +55,7 @@ class TestAcme(unittest.TestCase): @patch("lemur.plugins.lemur_acme.plugin.len", return_value=1) @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_dns_challenges") def test_start_dns_challenge( - self, mock_get_dns_challenges, mock_len, mock_app, mock_acme + self, mock_get_dns_challenges, mock_len, mock_app, mock_acme ): assert mock_len mock_order = Mock() @@ -88,7 +86,7 @@ class TestAcme(unittest.TestCase): @patch("lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change") @patch("time.sleep") def test_complete_dns_challenge_success( - self, mock_sleep, mock_wait_for_dns_change, mock_current_app, mock_acme + self, mock_sleep, mock_wait_for_dns_change, mock_current_app, mock_acme ): mock_dns_provider = Mock() mock_dns_provider.wait_for_dns_change = Mock(return_value=True) @@ -112,7 +110,7 @@ class TestAcme(unittest.TestCase): @patch("lemur.plugins.lemur_acme.plugin.current_app") @patch("lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change") def test_complete_dns_challenge_fail( - self, mock_wait_for_dns_change, mock_current_app, mock_acme + self, mock_wait_for_dns_change, mock_current_app, mock_acme ): mock_dns_provider = Mock() mock_dns_provider.wait_for_dns_change = Mock(return_value=True) @@ -140,12 +138,12 @@ class TestAcme(unittest.TestCase): @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_dns_challenges") @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_request_certificate( - self, - mock_current_app, - mock_get_dns_challenges, - mock_jose, - mock_crypto, - mock_acme, + self, + mock_current_app, + mock_get_dns_challenges, + mock_jose, + mock_crypto, + mock_acme, ): mock_cert_response = Mock() mock_cert_response.body = "123" @@ -182,7 +180,7 @@ class TestAcme(unittest.TestCase): assert result_client assert result_registration - @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch('lemur.plugins.lemur_acme.plugin.current_app') def test_get_domains_single(self, mock_current_app): options = {"common_name": "test.netflix.net"} result = self.acme.get_domains(options) @@ -288,14 +286,14 @@ class TestAcme(unittest.TestCase): @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") def test_get_ordered_certificate( - self, - mock_request_certificate, - mock_finalize_authorizations, - mock_get_authorizations, - mock_dns_provider_service, - mock_authorization_service, - mock_current_app, - mock_acme, + self, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_dns_provider_service, + mock_authorization_service, + mock_current_app, + mock_acme, ): mock_client = Mock() mock_acme.return_value = (mock_client, "") @@ -319,14 +317,14 @@ class TestAcme(unittest.TestCase): @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") def test_get_ordered_certificates( - self, - mock_request_certificate, - mock_finalize_authorizations, - mock_get_authorizations, - mock_dns_provider_service, - mock_authorization_service, - mock_current_app, - mock_acme, + self, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_dns_provider_service, + mock_authorization_service, + mock_current_app, + mock_acme, ): mock_client = Mock() mock_acme.return_value = (mock_client, "") @@ -388,121 +386,3 @@ class TestAcme(unittest.TestCase): mock_request_certificate.return_value = ("pem_certificate", "chain") result = provider.create_certificate(csr, issuer_options) assert result - - @patch("lemur.plugins.lemur_acme.ultradns.requests") - @patch("lemur.plugins.lemur_acme.ultradns.current_app") - def test_ultradns_get_token(self, mock_current_app, mock_requests): - # ret_val = json.dumps({"access_token": "access"}) - the_response = Response() - the_response._content = b'{"access_token": "access"}' - mock_requests.post = Mock(return_value=the_response) - mock_current_app.config.get = Mock(return_value="Test") - result = ultradns.get_ultradns_token() - self.assertTrue(len(result) > 0) - - @patch("lemur.plugins.lemur_acme.ultradns.current_app") - def test_ultradns_create_txt_record(self, mock_current_app): - domain = "_acme_challenge.test.example.com" - zone = "test.example.com" - token = "ABCDEFGHIJ" - account_number = "1234567890" - change_id = (domain, token) - ultradns.get_zone_name = Mock(return_value=zone) - mock_current_app.logger.debug = Mock() - ultradns._post = Mock() - log_data = { - "function": "create_txt_record", - "fqdn": domain, - "token": token, - "message": "TXT record created" - } - result = ultradns.create_txt_record(domain, token, account_number) - mock_current_app.logger.debug.assert_called_with(log_data) - self.assertEqual(result, change_id) - - @patch("lemur.plugins.lemur_acme.ultradns.current_app") - @patch("lemur.extensions.metrics") - def test_ultradns_delete_txt_record(self, mock_metrics, mock_current_app): - domain = "_acme_challenge.test.example.com" - zone = "test.example.com" - token = "ABCDEFGHIJ" - account_number = "1234567890" - change_id = (domain, token) - mock_current_app.logger.debug = Mock() - ultradns.get_zone_name = Mock(return_value=zone) - ultradns._post = Mock() - ultradns._get = Mock() - ultradns._get.return_value = {'zoneName': 'test.example.com.com', - 'rrSets': [{'ownerName': '_acme-challenge.test.example.com.', - 'rrtype': 'TXT (16)', 'ttl': 5, 'rdata': ['ABCDEFGHIJ']}], - 'queryInfo': {'sort': 'OWNER', 'reverse': False, 'limit': 100}, - 'resultInfo': {'totalCount': 1, 'offset': 0, 'returnedCount': 1}} - ultradns._delete = Mock() - mock_metrics.send = Mock() - ultradns.delete_txt_record(change_id, account_number, domain, token) - mock_current_app.logger.debug.assert_not_called() - mock_metrics.send.assert_not_called() - - @patch("lemur.plugins.lemur_acme.ultradns.current_app") - @patch("lemur.extensions.metrics") - def test_ultradns_wait_for_dns_change(self, mock_metrics, mock_current_app): - ultradns._has_dns_propagated = Mock(return_value=True) - nameserver = "1.1.1.1" - ultradns.get_authoritative_nameserver = Mock(return_value=nameserver) - mock_metrics.send = Mock() - domain = "_acme-challenge.test.example.com" - token = "ABCDEFGHIJ" - change_id = (domain, token) - mock_current_app.logger.debug = Mock() - ultradns.wait_for_dns_change(change_id) - # mock_metrics.send.assert_not_called() - log_data = { - "function": "wait_for_dns_change", - "fqdn": domain, - "status": True, - "message": "Record status on Public DNS" - } - mock_current_app.logger.debug.assert_called_with(log_data) - - def test_ultradns_get_zone_name(self): - zones = ['example.com', 'test.example.com'] - zone = "test.example.com" - domain = "_acme-challenge.test.example.com" - account_number = "1234567890" - ultradns.get_zones = Mock(return_value=zones) - result = ultradns.get_zone_name(domain, account_number) - self.assertEqual(result, zone) - - def test_ultradns_get_zones(self): - account_number = "1234567890" - path = "a/b/c" - zones = ['example.com', 'test.example.com'] - paginate_response = [{ - 'properties': { - 'name': 'example.com.', 'accountName': 'example', 'type': 'PRIMARY', - 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, - 'lastModifiedDateTime': '2017-06-14T06:45Z'}, - 'registrarInfo': { - 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', - 'example.ultradns.biz.', 'example.ultradns.org.']}}, - 'inherit': 'ALL'}, { - 'properties': { - 'name': 'test.example.com.', 'accountName': 'example', 'type': 'PRIMARY', - 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, - 'lastModifiedDateTime': '2017-06-14T06:45Z'}, - 'registrarInfo': { - 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', - 'example.ultradns.biz.', 'example.ultradns.org.']}}, - 'inherit': 'ALL'}, { - 'properties': { - 'name': 'example2.com.', 'accountName': 'example', 'type': 'SECONDARY', - 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, - 'lastModifiedDateTime': '2017-06-14T06:45Z'}, - 'registrarInfo': { - 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', - 'example.ultradns.biz.', 'example.ultradns.org.']}}, - 'inherit': 'ALL'}] - ultradns._paginate = Mock(path, "zones") - ultradns._paginate.side_effect = [[paginate_response]] - result = ultradns.get_zones(account_number) - self.assertEqual(result, zones) diff --git a/lemur/plugins/lemur_acme/tests/test_powerdns.py b/lemur/plugins/lemur_acme/tests/test_powerdns.py index c8b0a11e..37e4968e 100644 --- a/lemur/plugins/lemur_acme/tests/test_powerdns.py +++ b/lemur/plugins/lemur_acme/tests/test_powerdns.py @@ -1,5 +1,5 @@ import unittest -from mock import Mock, patch +from unittest.mock import patch, Mock from lemur.plugins.lemur_acme import plugin, powerdns @@ -48,13 +48,14 @@ class TestPowerdns(unittest.TestCase): self.assertEqual(result, zone) @patch("lemur.plugins.lemur_acme.powerdns.current_app") - def test_create_txt_record(self, mock_current_app): + def test_create_txt_record_write_only(self, mock_current_app): domain = "_acme_challenge.test.example.com" zone = "test.example.com" token = "ABCDEFGHIJ" account_number = "1234567890" change_id = (domain, token) powerdns._check_conf = Mock() + powerdns._get_txt_records = Mock(return_value=[]) powerdns._get_zone_name = Mock(return_value=zone) mock_current_app.logger.debug = Mock() mock_current_app.config.get = Mock(return_value="localhost") @@ -63,24 +64,74 @@ class TestPowerdns(unittest.TestCase): "function": "create_txt_record", "fqdn": domain, "token": token, - "message": "TXT record successfully created" + "message": "TXT record(s) successfully created" } result = powerdns.create_txt_record(domain, token, account_number) mock_current_app.logger.debug.assert_called_with(log_data) self.assertEqual(result, change_id) + @patch("lemur.plugins.lemur_acme.powerdns.current_app") + def test_create_txt_record_append(self, mock_current_app): + domain = "_acme_challenge.test.example.com" + zone = "test.example.com" + token = "ABCDEFGHIJ" + account_number = "1234567890" + change_id = (domain, token) + powerdns._check_conf = Mock() + cur_token = "123456" + cur_records = [powerdns.Record({'name': domain, 'content': f"\"{cur_token}\"", 'disabled': False})] + powerdns._get_txt_records = Mock(return_value=cur_records) + powerdns._get_zone_name = Mock(return_value=zone) + mock_current_app.logger.debug = Mock() + mock_current_app.config.get = Mock(return_value="localhost") + powerdns._patch = Mock() + log_data = { + "function": "create_txt_record", + "fqdn": domain, + "token": token, + "message": "TXT record(s) successfully created" + } + expected_path = "/api/v1/servers/localhost/zones/test.example.com." + expected_payload = { + "rrsets": [ + { + "name": domain + ".", + "type": "TXT", + "ttl": 300, + "changetype": "REPLACE", + "records": [ + { + "content": f"\"{token}\"", + "disabled": False + }, + { + "content": f"\"{cur_token}\"", + "disabled": False + } + ], + "comments": [] + } + ] + } + + result = powerdns.create_txt_record(domain, token, account_number) + mock_current_app.logger.debug.assert_called_with(log_data) + powerdns._patch.assert_called_with(expected_path, expected_payload) + self.assertEqual(result, change_id) + @patch("lemur.plugins.lemur_acme.powerdns.dnsutil") @patch("lemur.plugins.lemur_acme.powerdns.current_app") @patch("lemur.extensions.metrics") @patch("time.sleep") def test_wait_for_dns_change(self, mock_sleep, mock_metrics, mock_current_app, mock_dnsutil): domain = "_acme-challenge.test.example.com" - token = "ABCDEFG" + token1 = "ABCDEFG" + token2 = "HIJKLMN" zone_name = "test.example.com" nameserver = "1.1.1.1" - change_id = (domain, token) + change_id = (domain, token1) powerdns._check_conf = Mock() - mock_records = (token,) + mock_records = (token2, token1) mock_current_app.config.get = Mock(return_value=1) powerdns._get_zone_name = Mock(return_value=zone_name) mock_dnsutil.get_authoritative_nameserver = Mock(return_value=nameserver) @@ -114,7 +165,7 @@ class TestPowerdns(unittest.TestCase): "function": "delete_txt_record", "fqdn": domain, "token": token, - "message": "TXT record successfully deleted" + "message": "Unable to delete TXT record: Token not found in existing TXT records" } powerdns.delete_txt_record(change_id, account_number, domain, token) mock_current_app.logger.debug.assert_called_with(log_data) diff --git a/lemur/plugins/lemur_acme/tests/test_ultradns.py b/lemur/plugins/lemur_acme/tests/test_ultradns.py new file mode 100644 index 00000000..f1d61e68 --- /dev/null +++ b/lemur/plugins/lemur_acme/tests/test_ultradns.py @@ -0,0 +1,138 @@ +import unittest +from unittest.mock import patch, Mock + +from lemur.plugins.lemur_acme import plugin, ultradns +from requests.models import Response + + +class TestUltradns(unittest.TestCase): + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + def setUp(self, mock_dns_provider_service): + self.ACMEIssuerPlugin = plugin.ACMEIssuerPlugin() + self.acme = plugin.AcmeHandler() + mock_dns_provider = Mock() + mock_dns_provider.name = "cloudflare" + mock_dns_provider.credentials = "{}" + mock_dns_provider.provider_type = "cloudflare" + self.acme.dns_providers_for_domain = { + "www.test.com": [mock_dns_provider], + "test.fakedomain.net": [mock_dns_provider], + } + + @patch("lemur.plugins.lemur_acme.ultradns.requests") + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + def test_ultradns_get_token(self, mock_current_app, mock_requests): + # ret_val = json.dumps({"access_token": "access"}) + the_response = Response() + the_response._content = b'{"access_token": "access"}' + mock_requests.post = Mock(return_value=the_response) + mock_current_app.config.get = Mock(return_value="Test") + result = ultradns.get_ultradns_token() + self.assertTrue(len(result) > 0) + + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + def test_ultradns_create_txt_record(self, mock_current_app): + domain = "_acme_challenge.test.example.com" + zone = "test.example.com" + token = "ABCDEFGHIJ" + account_number = "1234567890" + change_id = (domain, token) + ultradns.get_zone_name = Mock(return_value=zone) + mock_current_app.logger.debug = Mock() + ultradns._post = Mock() + log_data = { + "function": "create_txt_record", + "fqdn": domain, + "token": token, + "message": "TXT record created" + } + result = ultradns.create_txt_record(domain, token, account_number) + mock_current_app.logger.debug.assert_called_with(log_data) + self.assertEqual(result, change_id) + + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + @patch("lemur.extensions.metrics") + def test_ultradns_delete_txt_record(self, mock_metrics, mock_current_app): + domain = "_acme_challenge.test.example.com" + zone = "test.example.com" + token = "ABCDEFGHIJ" + account_number = "1234567890" + change_id = (domain, token) + mock_current_app.logger.debug = Mock() + ultradns.get_zone_name = Mock(return_value=zone) + ultradns._post = Mock() + ultradns._get = Mock() + ultradns._get.return_value = {'zoneName': 'test.example.com.com', + 'rrSets': [{'ownerName': '_acme-challenge.test.example.com.', + 'rrtype': 'TXT (16)', 'ttl': 5, 'rdata': ['ABCDEFGHIJ']}], + 'queryInfo': {'sort': 'OWNER', 'reverse': False, 'limit': 100}, + 'resultInfo': {'totalCount': 1, 'offset': 0, 'returnedCount': 1}} + ultradns._delete = Mock() + mock_metrics.send = Mock() + ultradns.delete_txt_record(change_id, account_number, domain, token) + mock_current_app.logger.debug.assert_not_called() + mock_metrics.send.assert_not_called() + + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + @patch("lemur.extensions.metrics") + def test_ultradns_wait_for_dns_change(self, mock_metrics, mock_current_app): + ultradns._has_dns_propagated = Mock(return_value=True) + nameserver = "1.1.1.1" + ultradns.get_authoritative_nameserver = Mock(return_value=nameserver) + mock_metrics.send = Mock() + domain = "_acme-challenge.test.example.com" + token = "ABCDEFGHIJ" + change_id = (domain, token) + mock_current_app.logger.debug = Mock() + ultradns.wait_for_dns_change(change_id) + # mock_metrics.send.assert_not_called() + log_data = { + "function": "wait_for_dns_change", + "fqdn": domain, + "status": True, + "message": "Record status on Public DNS" + } + mock_current_app.logger.debug.assert_called_with(log_data) + + def test_ultradns_get_zone_name(self): + zones = ['example.com', 'test.example.com'] + zone = "test.example.com" + domain = "_acme-challenge.test.example.com" + account_number = "1234567890" + ultradns.get_zones = Mock(return_value=zones) + result = ultradns.get_zone_name(domain, account_number) + self.assertEqual(result, zone) + + def test_ultradns_get_zones(self): + account_number = "1234567890" + path = "a/b/c" + zones = ['example.com', 'test.example.com'] + paginate_response = [{ + 'properties': { + 'name': 'example.com.', 'accountName': 'example', 'type': 'PRIMARY', + 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, + 'lastModifiedDateTime': '2017-06-14T06:45Z'}, + 'registrarInfo': { + 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', + 'example.ultradns.biz.', 'example.ultradns.org.']}}, + 'inherit': 'ALL'}, { + 'properties': { + 'name': 'test.example.com.', 'accountName': 'example', 'type': 'PRIMARY', + 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, + 'lastModifiedDateTime': '2017-06-14T06:45Z'}, + 'registrarInfo': { + 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', + 'example.ultradns.biz.', 'example.ultradns.org.']}}, + 'inherit': 'ALL'}, { + 'properties': { + 'name': 'example2.com.', 'accountName': 'example', 'type': 'SECONDARY', + 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, + 'lastModifiedDateTime': '2017-06-14T06:45Z'}, + 'registrarInfo': { + 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', + 'example.ultradns.biz.', 'example.ultradns.org.']}}, + 'inherit': 'ALL'}] + ultradns._paginate = Mock(path, "zones") + ultradns._paginate.side_effect = [[paginate_response]] + result = ultradns.get_zones(account_number) + self.assertEqual(result, zones) diff --git a/lemur/plugins/lemur_aws/iam.py b/lemur/plugins/lemur_aws/iam.py index 13590ddd..8d80e020 100644 --- a/lemur/plugins/lemur_aws/iam.py +++ b/lemur/plugins/lemur_aws/iam.py @@ -24,6 +24,12 @@ def retry_throttled(exception): if exception.response["Error"]["Code"] == "NoSuchEntity": return False + # No need to retry deletion requests if there is a DeleteConflict error. + # This error indicates that the certificate is still attached to an entity + # and cannot be deleted. + if exception.response["Error"]["Code"] == "DeleteConflict": + return False + metrics.send("iam_retry", "counter", 1, metric_tags={"exception": str(exception)}) return True diff --git a/lemur/plugins/lemur_aws/plugin.py b/lemur/plugins/lemur_aws/plugin.py index 7bb7a3a2..8692348a 100644 --- a/lemur/plugins/lemur_aws/plugin.py +++ b/lemur/plugins/lemur_aws/plugin.py @@ -216,22 +216,24 @@ class AWSSourcePlugin(SourcePlugin): for region in regions: elbs = elb.get_all_elbs(account_number=account_number, region=region) - current_app.logger.info( - "Describing classic load balancers in {0}-{1}".format( - account_number, region - ) - ) + current_app.logger.info({ + "message": "Describing classic load balancers", + "account_number": account_number, + "region": region, + "number_of_load_balancers": len(elbs) + }) for e in elbs: endpoints.extend(get_elb_endpoints(account_number, region, e)) # fetch advanced ELBs elbs_v2 = elb.get_all_elbs_v2(account_number=account_number, region=region) - current_app.logger.info( - "Describing advanced load balancers in {0}-{1}".format( - account_number, region - ) - ) + current_app.logger.info({ + "message": "Describing advanced load balancers", + "account_number": account_number, + "region": region, + "number_of_load_balancers": len(elbs_v2) + }) for e in elbs_v2: endpoints.extend(get_elb_endpoints_v2(account_number, region, e)) diff --git a/lemur/plugins/lemur_cryptography/plugin.py b/lemur/plugins/lemur_cryptography/plugin.py index 005f36f9..1cf60fba 100644 --- a/lemur/plugins/lemur_cryptography/plugin.py +++ b/lemur/plugins/lemur_cryptography/plugin.py @@ -24,7 +24,12 @@ from lemur.certificates.service import create_csr def build_certificate_authority(options): options["certificate_authority"] = True csr, private_key = create_csr(**options) - cert_pem, chain_cert_pem = issue_certificate(csr, options, private_key) + + if options.get("parent"): + # Intermediate Cert Issuance + cert_pem, chain_cert_pem = issue_certificate(csr, options, None) + else: + cert_pem, chain_cert_pem = issue_certificate(csr, options, private_key) return cert_pem, private_key, chain_cert_pem diff --git a/lemur/plugins/lemur_cryptography/tests/test_cryptography.py b/lemur/plugins/lemur_cryptography/tests/test_cryptography.py index 7f1777fc..05012c03 100644 --- a/lemur/plugins/lemur_cryptography/tests/test_cryptography.py +++ b/lemur/plugins/lemur_cryptography/tests/test_cryptography.py @@ -25,6 +25,31 @@ def test_build_certificate_authority(): assert chain_cert_pem == "" +def test_build_intermediate_certificate_authority(authority): + from lemur.plugins.lemur_cryptography.plugin import build_certificate_authority + + options = { + "key_type": "RSA2048", + "country": "US", + "state": "CA", + "location": "Example place", + "organization": "Example, Inc.", + "organizational_unit": "Example Unit", + "common_name": "Example INTERMEDIATE", + "validity_start": arrow.get("2016-12-01").datetime, + "validity_end": arrow.get("2016-12-02").datetime, + "first_serial": 1, + "serial_number": 1, + "owner": "owner@example.com", + "parent": authority + } + cert_pem, private_key_pem, chain_cert_pem = build_certificate_authority(options) + + assert cert_pem + assert private_key_pem + assert chain_cert_pem == authority.authority_certificate.body + + def test_issue_certificate(authority): from lemur.tests.vectors import CSR_STR from lemur.plugins.lemur_cryptography.plugin import issue_certificate diff --git a/lemur/plugins/lemur_digicert/tests/test_digicert.py b/lemur/plugins/lemur_digicert/tests/test_digicert.py index 1e9ebca4..8bfd1dcf 100644 --- a/lemur/plugins/lemur_digicert/tests/test_digicert.py +++ b/lemur/plugins/lemur_digicert/tests/test_digicert.py @@ -1,4 +1,5 @@ import json +from unittest.mock import patch, Mock import arrow import pytest @@ -6,7 +7,6 @@ from cryptography import x509 from freezegun import freeze_time from lemur.plugins.lemur_digicert import plugin from lemur.tests.vectors import CSR_STR -from mock import Mock, patch def config_mock(*args): diff --git a/lemur/plugins/lemur_email/templates/expiration.html b/lemur/plugins/lemur_email/templates/expiration.html index f5185acd..16b59733 100644 --- a/lemur/plugins/lemur_email/templates/expiration.html +++ b/lemur/plugins/lemur_email/templates/expiration.html @@ -75,7 +75,8 @@ -
This is a Lemur certificate expiration notice. Please verify that the following certificates are no longer used. +
This is a Lemur certificate expiration notice. Please verify that the following certificates are no longer used, + and disable notifications via the Notify toggle in Lemur, if applicable. diff --git a/lemur/plugins/lemur_vault_dest/plugin.py b/lemur/plugins/lemur_vault_dest/plugin.py index 41b9c252..3c5301f7 100755 --- a/lemur/plugins/lemur_vault_dest/plugin.py +++ b/lemur/plugins/lemur_vault_dest/plugin.py @@ -14,7 +14,7 @@ import re import hvac from flask import current_app -from lemur.common.defaults import common_name +from lemur.common.defaults import common_name, country, state, location, organizational_unit, organization from lemur.common.utils import parse_certificate from lemur.plugins.bases import DestinationPlugin from lemur.plugins.bases import SourcePlugin @@ -58,7 +58,7 @@ class VaultSourcePlugin(SourcePlugin): "helpMessage": "Authentication method to use", }, { - "name": "tokenFile/VaultRole", + "name": "tokenFileOrVaultRole", "type": "str", "required": True, "validation": "^([a-zA-Z0-9/._-]+/?)+$", @@ -94,7 +94,7 @@ class VaultSourcePlugin(SourcePlugin): body = "" url = self.get_option("vaultUrl", options) auth_method = self.get_option("authenticationMethod", options) - auth_key = self.get_option("tokenFile/vaultRole", options) + auth_key = self.get_option("tokenFileOrVaultRole", options) mount = self.get_option("vaultMount", options) path = self.get_option("vaultPath", options) obj_name = self.get_option("objectName", options) @@ -185,7 +185,7 @@ class VaultDestinationPlugin(DestinationPlugin): "helpMessage": "Authentication method to use", }, { - "name": "tokenFile/VaultRole", + "name": "tokenFileOrVaultRole", "type": "str", "required": True, "validation": "^([a-zA-Z0-9/._-]+/?)+$", @@ -202,15 +202,15 @@ class VaultDestinationPlugin(DestinationPlugin): "name": "vaultPath", "type": "str", "required": True, - "validation": "^([a-zA-Z0-9._-]+/?)+$", - "helpMessage": "Must be a valid Vault secrets path", + "validation": "^(([a-zA-Z0-9._-]+|{(CN|OU|O|L|S|C)})+/?)+$", + "helpMessage": "Must be a valid Vault secrets path. Support vars: {CN|OU|O|L|S|C}", }, { "name": "objectName", "type": "str", "required": False, - "validation": "[0-9a-zA-Z.:_-]+", - "helpMessage": "Name to bundle certs under, if blank use cn", + "validation": "^([0-9a-zA-Z.:_-]+|{(CN|OU|O|L|S|C)})+$", + "helpMessage": "Name to bundle certs under, if blank use {CN}. Support vars: {CN|OU|O|L|S|C}", }, { "name": "bundleChain", @@ -241,11 +241,12 @@ class VaultDestinationPlugin(DestinationPlugin): :param cert_chain: :return: """ - cname = common_name(parse_certificate(body)) + cert = parse_certificate(body) + cname = common_name(cert) url = self.get_option("vaultUrl", options) auth_method = self.get_option("authenticationMethod", options) - auth_key = self.get_option("tokenFile/vaultRole", options) + auth_key = self.get_option("tokenFileOrVaultRole", options) mount = self.get_option("vaultMount", options) path = self.get_option("vaultPath", options) bundle = self.get_option("bundleChain", options) @@ -285,10 +286,27 @@ class VaultDestinationPlugin(DestinationPlugin): client.secrets.kv.default_kv_version = api_version - if obj_name: - path = "{0}/{1}".format(path, obj_name) - else: - path = "{0}/{1}".format(path, cname) + t_path = path.format( + CN=cname, + OU=organizational_unit(cert), + O=organization(cert), # noqa: E741 + L=location(cert), + S=state(cert), + C=country(cert) + ) + if not obj_name: + obj_name = '{CN}' + + f_obj_name = obj_name.format( + CN=cname, + OU=organizational_unit(cert), + O=organization(cert), # noqa: E741 + L=location(cert), + S=state(cert), + C=country(cert) + ) + + path = "{0}/{1}".format(t_path, f_obj_name) secret = get_secret(client, mount, path) secret["data"][cname] = {} diff --git a/lemur/sources/cli.py b/lemur/sources/cli.py index 0d537500..c415b567 100644 --- a/lemur/sources/cli.py +++ b/lemur/sources/cli.py @@ -58,6 +58,13 @@ def execute_clean(plugin, certificate, source): try: plugin.clean(certificate, source.options) certificate.sources.remove(source) + + # If we want to remove the source from the certificate, we also need to clear any equivalent destinations to + # prevent Lemur from re-uploading the certificate. + for destination in certificate.destinations: + if destination.label == source.label: + certificate.destinations.remove(destination) + certificate_service.database.update(certificate) return SUCCESS_METRIC_STATUS except Exception as e: diff --git a/lemur/sources/service.py b/lemur/sources/service.py index f4783313..fafa6f5a 100644 --- a/lemur/sources/service.py +++ b/lemur/sources/service.py @@ -123,15 +123,19 @@ def sync_endpoints(source): "acct": s.get_option("accountNumber", source.options)}) if not endpoint["certificate"]: - current_app.logger.error( - "Certificate Not Found. Name: {0} Endpoint: {1}".format( - certificate_name, endpoint["name"] - ) - ) + current_app.logger.error({ + "message": "Certificate Not Found", + "certificate_name": certificate_name, + "endpoint_name": endpoint["name"], + "dns_name": endpoint.get("dnsname"), + "account": s.get_option("accountNumber", source.options), + }) + metrics.send("endpoint.certificate.not.found", "counter", 1, metric_tags={"cert": certificate_name, "endpoint": endpoint["name"], - "acct": s.get_option("accountNumber", source.options)}) + "acct": s.get_option("accountNumber", source.options), + "dnsname": endpoint.get("dnsname")}) continue policy = endpoint.pop("policy") @@ -193,6 +197,11 @@ def sync_certificates(source, user): s = plugins.get(source.plugin_name) certificates = s.get_certificates(source.options) + # emitting the count of certificates on the source + metrics.send("sync_certificates_count", + "gauge", len(certificates), + metric_tags={"source": source.label}) + for certificate in certificates: exists, updated_by_hash = find_cert(certificate) diff --git a/lemur/tests/test_certificates.py b/lemur/tests/test_certificates.py index adafa605..41584cb3 100644 --- a/lemur/tests/test_certificates.py +++ b/lemur/tests/test_certificates.py @@ -9,7 +9,8 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend from marshmallow import ValidationError from freezegun import freeze_time -from mock import patch +# from mock import patch +from unittest.mock import patch from lemur.certificates.service import create_csr from lemur.certificates.views import * # noqa @@ -906,12 +907,12 @@ def test_certificate_get_body(client): assert response_body["serial"] == "211983098819107449768450703123665283596" assert response_body["serialHex"] == "9F7A75B39DAE4C3F9524C68B06DA6A0C" assert response_body["distinguishedName"] == ( - "CN=LemurTrust Unittests Class 1 CA 2018," - "O=LemurTrust Enterprises Ltd," - "OU=Unittesting Operations Center," - "C=EE," + "L=Earth," "ST=N/A," - "L=Earth" + "C=EE," + "OU=Unittesting Operations Center," + "O=LemurTrust Enterprises Ltd," + "CN=LemurTrust Unittests Class 1 CA 2018" ) diff --git a/lemur/tests/test_dns_providers.py b/lemur/tests/test_dns_providers.py index b8714a2d..83315be5 100644 --- a/lemur/tests/test_dns_providers.py +++ b/lemur/tests/test_dns_providers.py @@ -4,9 +4,20 @@ from lemur.dns_providers import util as dnsutil class TestDNSProvider(unittest.TestCase): def test_is_valid_domain(self): - self.assertTrue(dnsutil.is_valid_domain("example.com")) - self.assertTrue(dnsutil.is_valid_domain("foo.bar.org")) - self.assertTrue(dnsutil.is_valid_domain("_acme-chall.example.com")) - self.assertFalse(dnsutil.is_valid_domain("e/xample.com")) - self.assertFalse(dnsutil.is_valid_domain("exam\ple.com")) - self.assertFalse(dnsutil.is_valid_domain("*.example.com")) + self.assertTrue(dnsutil.is_valid_domain('example.com')) + self.assertTrue(dnsutil.is_valid_domain('foo.bar.org')) + self.assertTrue(dnsutil.is_valid_domain('exam--ple.io')) + self.assertTrue(dnsutil.is_valid_domain('a.example.com')) + self.assertTrue(dnsutil.is_valid_domain('example.io')) + self.assertTrue(dnsutil.is_valid_domain('example-of-under-63-character-domain-label-length-limit-1234567.com')) + self.assertFalse(dnsutil.is_valid_domain('example-of-over-63-character-domain-label-length-limit-123456789.com')) + self.assertTrue(dnsutil.is_valid_domain('_acme-chall.example.com')) + self.assertFalse(dnsutil.is_valid_domain('e/xample.com')) + self.assertFalse(dnsutil.is_valid_domain('exam\ple.com')) + self.assertFalse(dnsutil.is_valid_domain('= (20, 1): + install_requires = [str(ir.requirement) for ir in install_requires_g] + tests_require = [str(ir.requirement) for ir in tests_require_g] + docs_require = [str(ir.requirement) for ir in docs_require_g] + dev_requires = [str(ir.requirement) for ir in dev_requires_g] +else: + install_requires = [str(ir.req) for ir in install_requires_g] + tests_require = [str(ir.req) for ir in tests_require_g] + docs_require = [str(ir.req) for ir in docs_require_g] + dev_requires = [str(ir.req) for ir in dev_requires_g] class SmartInstall(install):