diff --git a/.gitignore b/.gitignore index 97af00ca..72e85f26 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,11 @@ package-lock.json /lemur/static/dist/ /lemur/static/app/vendor/ /wheelhouse +/lemur/lib +/lemur/bin +/lemur/lib64 +/lemur/include + docs/_build .editorconfig .idea diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3d19151..be4fee92 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,3 +8,17 @@ sha: v2.9.5 hooks: - id: jshint +- repo: https://github.com/ambv/black + rev: stable + hooks: + - id: black + language_version: python3.7 + +- repo: local + hooks: + - id: python-bandit-vulnerability-check + name: bandit + entry: bandit + args: ['--ini', 'tox.ini', '-r', 'consoleme'] + language: system + pass_filenames: false \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index b540937d..f1abf3f3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,5 @@ language: python -sudo: required -dist: trusty +dist: xenial node_js: - "6.2.0" @@ -10,8 +9,8 @@ addons: matrix: include: - - python: "3.5" - env: TOXENV=py35 + - python: "3.7" + env: TOXENV=py37 cache: directories: diff --git a/Dockerfile b/Dockerfile index 46efd50a..fc83a034 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,9 @@ -FROM python:3.5 +FROM python:3.7 RUN apt-get update -RUN apt-get install -y make python-software-properties curl +RUN apt-get install -y make software-properties-common curl RUN curl -sL https://deb.nodesource.com/setup_7.x | bash - RUN apt-get update -RUN apt-get install -y nodejs libldap2-dev libsasl2-dev libldap2-dev libssl-dev +RUN apt-get install -y npm libldap2-dev libsasl2-dev libldap2-dev libssl-dev RUN pip install -U setuptools RUN pip install coveralls bandit WORKDIR /app diff --git a/Makefile b/Makefile index 19a69236..069eb29b 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ endif @echo "" dev-docs: - pip install -r docs/requirements.txt + pip install -r requirements-docs.txt reset-db: @echo "--> Dropping existing 'lemur' database" @@ -46,7 +46,7 @@ reset-db: @echo "--> Enabling pg_trgm extension" psql lemur -c "create extension IF NOT EXISTS pg_trgm;" @echo "--> Applying migrations" - lemur db upgrade + cd lemur && lemur db upgrade setup-git: @echo "--> Installing git hooks" @@ -113,10 +113,10 @@ endif @echo "--> Updating Python requirements" pip install --upgrade pip pip install --upgrade pip-tools + pip-compile --output-file requirements.txt requirements.in -U --no-index pip-compile --output-file requirements-docs.txt requirements-docs.in -U --no-index pip-compile --output-file requirements-dev.txt requirements-dev.in -U --no-index pip-compile --output-file requirements-tests.txt requirements-tests.in -U --no-index - pip-compile --output-file requirements.txt requirements.in -U --no-index @echo "--> Done updating Python requirements" @echo "--> Removing python-ldap from requirements-docs.txt" grep -v "python-ldap" requirements-docs.txt > tempreqs && mv tempreqs requirements-docs.txt @@ -125,5 +125,9 @@ endif @echo "--> Done installing new dependencies" @echo "" +# Execute with make checkout-pr pr= +checkout-pr: + git fetch upstream pull/$(pr)/head:pr-$(pr) + .PHONY: develop dev-postgres dev-docs setup-git build clean update-submodules test testloop test-cli test-js test-python lint lint-python lint-js coverage publish release diff --git a/docker-compose.yml b/docker-compose.yml index 66f2f0b1..ee0d8396 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,10 +13,13 @@ services: VIRTUAL_ENV: 'true' postgres: - image: postgres:9.4 + image: postgres + restart: always environment: POSTGRES_USER: lemur POSTGRES_PASSWORD: lemur + ports: + - "5432:5432" redis: image: "redis:alpine" diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 00000000..f7d1caf7 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,64 @@ +FROM alpine:3.8 + +ARG VERSION +ENV VERSION master + +ENV uid 1337 +ENV gid 1337 +ENV user lemur +ENV group lemur + +COPY entrypoint / +COPY src/lemur.conf.py /home/lemur/.lemur/lemur.conf.py +COPY supervisor.conf / +COPY nginx/default.conf /etc/nginx/conf.d/ +COPY nginx/default-ssl.conf /etc/nginx/conf.d/ + +RUN addgroup -S ${group} -g ${gid} && \ + adduser -D -S ${user} -G ${group} -u ${uid} && \ + apk --update add python3 libldap postgresql-client nginx supervisor curl tzdata openssl bash && \ + apk --update add --virtual build-dependencies \ + git \ + tar \ + curl \ + python3-dev \ + npm \ + bash \ + musl-dev \ + gcc \ + autoconf \ + automake \ + make \ + nasm \ + zlib-dev \ + postgresql-dev \ + libressl-dev \ + libffi-dev \ + cyrus-sasl-dev \ + openldap-dev && \ + mkdir -p /opt/lemur /home/lemur/.lemur/ && \ + curl -sSL https://github.com/Netflix/lemur/archive/$VERSION.tar.gz | tar xz -C /opt/lemur --strip-components=1 && \ + pip3 install --upgrade pip && \ + pip3 install --upgrade setuptools && \ + chmod +x /entrypoint && \ + mkdir -p /run/nginx/ /etc/nginx/ssl/ && \ + chown -R $user:$group /opt/lemur/ /home/lemur/.lemur/ + +WORKDIR /opt/lemur + +RUN npm install --unsafe-perm && \ + pip3 install -e . && \ + node_modules/.bin/gulp build && \ + node_modules/.bin/gulp package --urlContextPath=$(urlContextPath) && \ + apk del build-dependencies + +WORKDIR / + +HEALTHCHECK --interval=12s --timeout=12s --start-period=30s \ + CMD curl --fail http://localhost:80/api/1/healthcheck | grep -q ok || exit 1 + +USER root + +ENTRYPOINT ["/entrypoint"] + +CMD ["/usr/bin/supervisord","-c","supervisor.conf"] diff --git a/docker/entrypoint b/docker/entrypoint new file mode 100644 index 00000000..6077167a --- /dev/null +++ b/docker/entrypoint @@ -0,0 +1,54 @@ +#!/bin/sh + +if [ -z "${POSTGRES_USER}" ] || [ -z "${POSTGRES_PASSWORD}" ] || [ -z "${POSTGRES_HOST}" ] || [ -z "${POSTGRES_DB}" ];then + echo "Database vars not set" + exit 1 +fi + +export POSTGRES_PORT="${POSTGRES_PORT:-5432}" + +echo 'export SQLALCHEMY_DATABASE_URI="postgresql://$POSTGRES_USER:$POSTGRES_PASSWORD@$POSTGRES_HOST:$POSTGRES_PORT/$POSTGRES_DB"' >> /etc/profile + +source /etc/profile + +PGPASSWORD=$POSTGRES_PASSWORD psql -h $POSTGRES_HOST -p $POSTGRES_PORT -U $POSTGRES_USER -d $POSTGRES_DB --command 'select 1;' + +echo " # Create Postgres trgm extension" +PGPASSWORD=$POSTGRES_PASSWORD psql -h $POSTGRES_HOST -p $POSTGRES_PORT -U $POSTGRES_USER -d $POSTGRES_DB --command 'CREATE EXTENSION pg_trgm;' +echo " # Done" + +if [ -z "${SKIP_SSL}" ]; then + if [ ! -f /etc/nginx/ssl/server.crt ] && [ ! -f /etc/nginx/ssl/server.key ]; then + openssl req -x509 -newkey rsa:4096 -nodes -keyout /etc/nginx/ssl/server.key -out /etc/nginx/ssl/server.crt -days 365 -subj "/C=US/ST=FAKE/L=FAKE/O=FAKE/OU=FAKE/CN=FAKE" + fi + mv /etc/nginx/conf.d/default-ssl.conf.a /etc/nginx/conf.d/default-ssl.conf + mv /etc/nginx/conf.d/default.conf /etc/nginx/conf.d/default.conf.a +fi + +# if [ ! -f /home/lemur/.lemur/lemur.conf.py ]; then +# echo "Creating config" +# https://github.com/Netflix/lemur/issues/2257 +# python3 /opt/lemur/lemur/manage.py create_config +# echo "Done" +# fi + +echo " # Running init" +su lemur -c "python3 /opt/lemur/lemur/manage.py init" +echo " # Done" + +# echo "Creating user" +# https://github.com/Netflix/lemur/issues/ +# echo "something that will create user" | python3 /opt/lemur/lemur/manage.py shell +# echo "Done" + +cron_notify="${CRON_NOTIFY:-"0 22 * * *"}" +cron_sync="${CRON_SYNC:-"*/15 * * * *"}" +cron_revoked="${CRON_CHECK_REVOKED:-"0 22 * * *"}" + +echo " # Populating crontab" +echo "${cron_notify} lemur python3 /opt/lemur/lemur/manage.py notify expirations" > /etc/crontabs/lemur_notify +echo "${cron_sync} lemur python3 /opt/lemur/lemur/manage.py source sync -s all" > /etc/crontabs/lemur_sync +echo "${cron_revoked} lemur python3 /opt/lemur/lemur/manage.py certificate check_revoked" > /etc/crontabs/lemur_revoked +echo " # Done" + +exec "$@" diff --git a/docker/nginx/default-ssl.conf b/docker/nginx/default-ssl.conf new file mode 100644 index 00000000..86c770df --- /dev/null +++ b/docker/nginx/default-ssl.conf @@ -0,0 +1,37 @@ +add_header X-Frame-Options DENY; +add_header X-Content-Type-Options nosniff; +add_header X-XSS-Protection "1; mode=block"; + +server { + listen 80; + server_name _; + return 301 https://$host$request_uri; +} + +server { + listen 443; + server_name _; + access_log /dev/stdout; + error_log /dev/stderr; + ssl_certificate /etc/nginx/ssl/server.crt; + ssl_certificate_key /etc/nginx/ssl/server.key; + ssl_protocols TLSv1 TLSv1.1 TLSv1.2; + ssl_ciphers HIGH:!aNULL:!MD5; + + location /api { + proxy_pass http://127.0.0.1:8000; + proxy_next_upstream error timeout invalid_header http_500 http_502 http_503 http_504; + proxy_redirect off; + proxy_buffering off; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } + + location / { + root /opt/lemur/lemur/static/dist; + include mime.types; + index index.html; + } + +} diff --git a/docker/nginx/default.conf b/docker/nginx/default.conf new file mode 100644 index 00000000..d71a93d3 --- /dev/null +++ b/docker/nginx/default.conf @@ -0,0 +1,26 @@ +add_header X-Frame-Options DENY; +add_header X-Content-Type-Options nosniff; +add_header X-XSS-Protection "1; mode=block"; + +server { + listen 80; + access_log /dev/stdout; + error_log /dev/stderr; + + location /api { + proxy_pass http://127.0.0.1:8000; + proxy_next_upstream error timeout invalid_header http_500 http_502 http_503 http_504; + proxy_redirect off; + proxy_buffering off; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } + + location / { + root /opt/lemur/lemur/static/dist; + include mime.types; + index index.html; + } + +} diff --git a/docker/src/lemur.conf.py b/docker/src/lemur.conf.py new file mode 100644 index 00000000..a5f7e8b6 --- /dev/null +++ b/docker/src/lemur.conf.py @@ -0,0 +1,31 @@ +import os +_basedir = os.path.abspath(os.path.dirname(__file__)) + +CORS = os.environ.get("CORS") == "True" +debug = os.environ.get("DEBUG") == "True" + +SECRET_KEY = repr(os.environ.get('SECRET_KEY','Hrs8kCDNPuT9vtshsSWzlrYW+d+PrAXvg/HwbRE6M3vzSJTTrA/ZEw==')) + +LEMUR_TOKEN_SECRET = repr(os.environ.get('LEMUR_TOKEN_SECRET','YVKT6nNHnWRWk28Lra1OPxMvHTqg1ZXvAcO7bkVNSbrEuDQPABM0VQ==')) +LEMUR_ENCRYPTION_KEYS = repr(os.environ.get('LEMUR_ENCRYPTION_KEYS','Ls-qg9j3EMFHyGB_NL0GcQLI6622n9pSyGM_Pu0GdCo=')) + +LEMUR_WHITELISTED_DOMAINS = [] + +LEMUR_EMAIL = '' +LEMUR_SECURITY_TEAM_EMAIL = [] + + +LEMUR_DEFAULT_COUNTRY = repr(os.environ.get('LEMUR_DEFAULT_COUNTRY','')) +LEMUR_DEFAULT_STATE = repr(os.environ.get('LEMUR_DEFAULT_STATE','')) +LEMUR_DEFAULT_LOCATION = repr(os.environ.get('LEMUR_DEFAULT_LOCATION','')) +LEMUR_DEFAULT_ORGANIZATION = repr(os.environ.get('LEMUR_DEFAULT_ORGANIZATION','')) +LEMUR_DEFAULT_ORGANIZATIONAL_UNIT = repr(os.environ.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT','')) + +ACTIVE_PROVIDERS = [] + +METRIC_PROVIDERS = [] + +LOG_LEVEL = str(os.environ.get('LOG_LEVEL','DEBUG')) +LOG_FILE = str(os.environ.get('LOG_FILE','/home/lemur/.lemur/lemur.log')) + +SQLALCHEMY_DATABASE_URI = os.environ.get('SQLALCHEMY_DATABASE_URI','postgresql://lemur:lemur@localhost:5432/lemur') diff --git a/docker/supervisor.conf b/docker/supervisor.conf new file mode 100644 index 00000000..fed01581 --- /dev/null +++ b/docker/supervisor.conf @@ -0,0 +1,32 @@ +[supervisord] +nodaemon=true +user=root +logfile=/dev/stdout +logfile_maxbytes=0 +pidfile = /tmp/supervisord.pid + +[program:lemur] +environment=LEMUR_CONF=/home/lemur/.lemur/lemur.conf.py +command=/usr/bin/python3 manage.py start -b 0.0.0.0:8000 +user=lemur +directory=/opt/lemur/lemur +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes = 0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 + +[program:nginx] +command=/usr/sbin/nginx -g "daemon off;" +user=root +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes = 0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 + +[program:cron] +command=/usr/sbin/crond -f +user=root +stdout_logfile=/dev/stdout +stdout_logfile_maxbytes = 0 +stderr_logfile=/dev/stderr +stderr_logfile_maxbytes=0 diff --git a/docs/administration.rst b/docs/administration.rst index 9d6c8d12..e292ae03 100644 --- a/docs/administration.rst +++ b/docs/administration.rst @@ -161,6 +161,13 @@ Specifying the `SQLALCHEMY_MAX_OVERFLOW` to 0 will enforce limit to not create c Dump all imported or generated CSR and certificate details to stdout using OpenSSL. (default: `False`) +.. data:: ALLOW_CERT_DELETION + :noindex: + + When set to True, certificates can be marked as deleted via the API and deleted certificates will not be displayed + in the UI. When set to False (the default), the certificate delete API will always return "405 method not allowed" + and deleted certificates will always be visible in the UI. (default: `False`) + Certificate Default Options --------------------------- @@ -313,7 +320,7 @@ LDAP support requires the pyldap python library, which also depends on the follo To configure the use of an LDAP server, a number of settings need to be configured in `lemur.conf.py`. Here is an example LDAP configuration stanza you can add to your config. Adjust to suit your environment of course. - + .. code-block:: python LDAP_AUTH = True @@ -586,8 +593,60 @@ If you are not using a metric provider you do not need to configure any of these Plugin Specific Options ----------------------- +Active Directory Certificate Services Plugin +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +.. data:: ADCS_SERVER + :noindex: + + FQDN of your ADCS Server + + +.. data:: ADCS_AUTH_METHOD + :noindex: + + The chosen authentication method. Either ‘basic’ (the default), ‘ntlm’ or ‘cert’ (SSL client certificate). The next 2 variables are interpreted differently for different methods. + + +.. data:: ADCS_USER + :noindex: + + The username (basic) or the path to the public cert (cert) of the user accessing PKI + + +.. data:: ADCS_PWD + :noindex: + + The passwd (basic) or the path to the private key (cert) of the user accessing PKI + + +.. data:: ADCS_TEMPLATE + :noindex: + + Template to be used for certificate issuing. Usually display name w/o spaces + + +.. data:: ADCS_START + :noindex: + +.. data:: ADCS_STOP + :noindex: + +.. data:: ADCS_ISSUING + :noindex: + + Contains the issuing cert of the CA + + +.. data:: ADCS_ROOT + :noindex: + + Contains the root cert of the CA + + Verisign Issuer Plugin -^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~ Authorities will each have their own configuration options. There is currently just one plugin bundled with Lemur, Verisign/Symantec. Additional plugins may define additional options. Refer to the plugin's own documentation @@ -683,7 +742,7 @@ The following configuration properties are required to use the Digicert issuer p CFSSL Issuer Plugin -^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~ The following configuration properties are required to use the CFSSL issuer plugin. @@ -702,9 +761,36 @@ The following configuration properties are required to use the CFSSL issuer plug This is the intermediate to be used for your CA chain +.. data:: CFSSL_KEY + :noindex: + + This is the hmac key to authenticate to the CFSSL service. (Optional) + + +Hashicorp Vault Source/Destination Plugin +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Lemur can import and export certificate data to and from a Hashicorp Vault secrets store. Lemur can connect to a different Vault service per source/destination. + +.. note:: This plugin does not supersede or overlap the 3rd party Vault Issuer plugin. + +.. note:: Vault does not have any configuration properties however it does read from a file on disk for a vault access token. The Lemur service account needs read access to this file. + +Vault Source +"""""""""""" + +The Vault Source Plugin will read from one Vault object location per source defined. There is expected to be one or more certificates defined in each object in Vault. + +Vault Destination +""""""""""""""""" + +A Vault destination can be one object in Vault or a directory where all certificates will be stored as their own object by CN. + +Vault Destination supports a regex filter to prevent certificates with SAN that do not match the regex filter from being deployed. This is an optional feature per destination defined. + AWS Source/Destination Plugin -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In order for Lemur to manage its own account and other accounts we must ensure it has the correct AWS permissions. @@ -1056,7 +1142,9 @@ Verisign/Symantec ----------------- :Authors: - Kevin Glisson + Kevin Glisson , + Curtis Castrapel , + Hossein Shafagh :Type: Issuer :Description: @@ -1082,6 +1170,8 @@ Acme :Authors: Kevin Glisson , + Curtis Castrapel , + Hossein Shafagh , Mikhail Khodorovskiy :Type: Issuer @@ -1093,7 +1183,9 @@ Atlas ----- :Authors: - Kevin Glisson + Kevin Glisson , + Curtis Castrapel , + Hossein Shafagh :Type: Metric :Description: @@ -1104,7 +1196,9 @@ Email ----- :Authors: - Kevin Glisson + Kevin Glisson , + Curtis Castrapel , + Hossein Shafagh :Type: Notification :Description: @@ -1126,7 +1220,9 @@ AWS ---- :Authors: - Kevin Glisson + Kevin Glisson , + Curtis Castrapel , + Hossein Shafagh :Type: Source :Description: @@ -1137,7 +1233,9 @@ AWS ---- :Authors: - Kevin Glisson + Kevin Glisson , + Curtis Castrapel , + Hossein Shafagh :Type: Destination :Description: @@ -1187,6 +1285,26 @@ CFSSL :Description: Basic support for generating certificates from the private certificate authority CFSSL +Vault +----- + +:Authors: + Christopher Jolley +:Type: + Source +:Description: + Source plugin imports certificates from Hashicorp Vault secret store. + +Vault +----- + +:Authors: + Christopher Jolley +:Type: + Destination +:Description: + Destination plugin to deploy certificates to Hashicorp Vault secret store. + 3rd Party Plugins ================= diff --git a/docs/conf.py b/docs/conf.py index d5b1698c..55bd20d2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,48 +18,45 @@ from unittest.mock import MagicMock # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # Mock packages that cannot be installed on rtd -on_rtd = os.environ.get('READTHEDOCS') == 'True' +on_rtd = os.environ.get("READTHEDOCS") == "True" if on_rtd: + class Mock(MagicMock): @classmethod def __getattr__(cls, name): return MagicMock() - MOCK_MODULES = ['ldap'] + MOCK_MODULES = ["ldap"] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinxcontrib.autohttp.flask', - 'sphinx.ext.todo', -] +extensions = ["sphinx.ext.autodoc", "sphinxcontrib.autohttp.flask", "sphinx.ext.todo"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'lemur' -copyright = u'2018, Netflix Inc.' +project = u"lemur" +copyright = u"2018, Netflix Inc." # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -68,191 +65,186 @@ copyright = u'2018, Netflix Inc.' base_dir = os.path.join(os.path.dirname(__file__), os.pardir) about = {} with open(os.path.join(base_dir, "lemur", "__about__.py")) as f: - exec(f.read(), about) + exec(f.read(), about) # nosec version = release = about["__version__"] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- # on_rtd is whether we are on readthedocs.org, this line of code grabbed from docs.readthedocs.org -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if not on_rtd: # only import and set the theme if we're building docs locally import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' + + html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'lemurdoc' +htmlhelp_basename = "lemurdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'lemur.tex', u'Lemur Documentation', - u'Kevin Glisson', 'manual'), + ("index", "lemur.tex", u"Lemur Documentation", u"Netflix Security", "manual") ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'Lemur', u'Lemur Documentation', - [u'Kevin Glisson'], 1) -] +man_pages = [("index", "Lemur", u"Lemur Documentation", [u"Netflix Security"], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -261,19 +253,25 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Lemur', u'Lemur Documentation', - u'Kevin Glisson', 'Lemur', 'SSL Certificate Management', - 'Miscellaneous'), + ( + "index", + "Lemur", + u"Lemur Documentation", + u"Netflix Security", + "Lemur", + "SSL Certificate Management", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/docs/developer/index.rst b/docs/developer/index.rst index 4c46566a..0033c3f4 100644 --- a/docs/developer/index.rst +++ b/docs/developer/index.rst @@ -22,12 +22,18 @@ Once you've got all that, the rest is simple: # If you have a fork, you'll want to clone it instead git clone git://github.com/netflix/lemur.git - # Create a python virtualenv - mkvirtualenv lemur + # Create and activate python virtualenv from within the lemur repo + python3 -m venv env + . env/bin/activate + + # Install doc requirements - # Make the magic happen make dev-docs + # Make the docs + cd docs + make html + Running ``make dev-docs`` will install the basic requirements to get Sphinx running. @@ -58,7 +64,7 @@ Once you've got all that, the rest is simple: git clone git://github.com/lemur/lemur.git # Create a python virtualenv - mkvirtualenv lemur + python3 -m venv env # Make the magic happen make @@ -135,7 +141,7 @@ The test suite consists of multiple parts, testing both the Python and JavaScrip make test -If you only need to run the Python tests, you can do so with ``make test-python``, as well as ``test-js`` for the JavaScript tests. +If you only need to run the Python tests, you can do so with ``make test-python``, as well as ``make test-js`` for the JavaScript tests. You'll notice that the test suite is structured based on where the code lives, and strongly encourages using the mock library to drive more accurate individual tests. diff --git a/docs/production/create_dns_provider.png b/docs/production/create_dns_provider.png new file mode 100644 index 00000000..71d5a0d3 Binary files /dev/null and b/docs/production/create_dns_provider.png differ diff --git a/docs/production/index.rst b/docs/production/index.rst index 42f6648a..cd044ca4 100644 --- a/docs/production/index.rst +++ b/docs/production/index.rst @@ -217,23 +217,23 @@ An example apache config:: # HSTS (mod_headers is required) (15768000 seconds = 6 months) Header always set Strict-Transport-Security "max-age=15768000" ... - + # Set the lemur DocumentRoot to static/dist DocumentRoot /www/lemur/lemur/static/dist - + # Uncomment to force http 1.0 connections to proxy # SetEnv force-proxy-request-1.0 1 - + #Don't keep proxy connections alive SetEnv proxy-nokeepalive 1 - + # Only need to do reverse proxy ProxyRequests Off - + # Proxy requests to the api to the lemur service (and sanitize redirects from it) ProxyPass "/api" "http://127.0.0.1:8000/api" ProxyPassReverse "/api" "http://127.0.0.1:8000/api" - + Also included in the configurations above are several best practices when it comes to deploying TLS. Things like enabling @@ -318,7 +318,7 @@ Periodic Tasks ============== Lemur contains a few tasks that are run and scheduled basis, currently the recommend way to run these tasks is to create -a cron job that runs the commands. +celery tasks or cron jobs that run these commands. There are currently three commands that could/should be run on a periodic basis: @@ -326,11 +326,124 @@ There are currently three commands that could/should be run on a periodic basis: - `check_revoked` - `sync` +If you are using LetsEncrypt, you must also run the following: + +- `fetch_all_pending_acme_certs` +- `remove_old_acme_certs` + How often you run these commands is largely up to the user. `notify` and `check_revoked` are typically run at least once a day. -`sync` is typically run every 15 minutes. +`sync` is typically run every 15 minutes. `fetch_all_pending_acme_certs` should be ran frequently (Every minute is fine). +`remove_old_acme_certs` can be ran more rarely, such as once every week. Example cron entries:: 0 22 * * * lemuruser export LEMUR_CONF=/Users/me/.lemur/lemur.conf.py; /www/lemur/bin/lemur notify expirations */15 * * * * lemuruser export LEMUR_CONF=/Users/me/.lemur/lemur.conf.py; /www/lemur/bin/lemur source sync -s all 0 22 * * * lemuruser export LEMUR_CONF=/Users/me/.lemur/lemur.conf.py; /www/lemur/bin/lemur certificate check_revoked + + +Example Celery configuration (To be placed in your configuration file):: + + CELERYBEAT_SCHEDULE = { + 'fetch_all_pending_acme_certs': { + 'task': 'lemur.common.celery.fetch_all_pending_acme_certs', + 'options': { + 'expires': 180 + }, + 'schedule': crontab(minute="*"), + }, + 'remove_old_acme_certs': { + 'task': 'lemur.common.celery.remove_old_acme_certs', + 'options': { + 'expires': 180 + }, + 'schedule': crontab(hour=7, minute=30, day_of_week=1), + }, + 'clean_all_sources': { + 'task': 'lemur.common.celery.clean_all_sources', + 'options': { + 'expires': 180 + }, + 'schedule': crontab(hour=1, minute=0, day_of_week=1), + }, + 'sync_all_sources': { + 'task': 'lemur.common.celery.sync_all_sources', + 'options': { + 'expires': 180 + }, + 'schedule': crontab(hour="*/3", minute=5), + }, + 'sync_source_destination': { + 'task': 'lemur.common.celery.sync_source_destination', + 'options': { + 'expires': 180 + }, + 'schedule': crontab(hour="*"), + } + } + +To enable celery support, you must also have configuration values that tell Celery which broker and backend to use. +Here are the Celery configuration variables that should be set:: + + CELERY_RESULT_BACKEND = 'redis://your_redis_url:6379' + CELERY_BROKER_URL = 'redis://your_redis_url:6379' + CELERY_IMPORTS = ('lemur.common.celery') + CELERY_TIMEZONE = 'UTC' + +You must start a single Celery scheduler instance and one or more worker instances in order to handle incoming tasks. +The scheduler can be started with:: + + LEMUR_CONF='/location/to/conf.py' /location/to/lemur/bin/celery -A lemur.common.celery beat + +And the worker can be started with desired options such as the following:: + + LEMUR_CONF='/location/to/conf.py' /location/to/lemur/bin/celery -A lemur.common.celery worker --concurrency 10 -E -n lemurworker1@%%h + +supervisor or systemd configurations should be created for these in production environments as appropriate. + +Add support for LetsEncrypt +=========================== + +LetsEncrypt is a free, limited-feature certificate authority that offers publicly trusted certificates that are valid +for 90 days. LetsEncrypt does not use organizational validation (OV), and instead relies on domain validation (DV). +LetsEncrypt requires that we prove ownership of a domain before we're able to issue a certificate for that domain, each +time we want a certificate. + +The most common methods to prove ownership are HTTP validation and DNS validation. Lemur supports DNS validation +through the creation of DNS TXT records. + +In a nutshell, when we send a certificate request to LetsEncrypt, they generate a random token and ask us to put that +token in a DNS text record to prove ownership of a domain. If a certificate request has multiple domains, we must +prove ownership of all of these domains through this method. The token is typically written to a TXT record at +-acme_challenge.domain.com. Once we create the appropriate TXT record(s), Lemur will try to validate propagation +before requesting that LetsEncrypt finalize the certificate request and send us the certificate. + +.. figure:: letsencrypt_flow.png + +To start issuing certificates through LetsEncrypt, you must enable Celery support within Lemur first. After doing so, +you need to create a LetsEncrypt authority. To do this, visit +Authorities -> Create. Set the applicable attributes and click "More Options". + +.. figure:: letsencrypt_authority_1.png + +You will need to set "Certificate" to LetsEncrypt's active chain of trust for the authority you want to use. To find +the active chain of trust at the time of writing, please visit `LetsEncrypt +`_. + +Under Acme_url, enter in the appropriate endpoint URL. Lemur supports LetsEncrypt's V2 API, and we recommend you to use +this. At the time of writing, the staging and production URLs for LetsEncrypt V2 are +https://acme-staging-v02.api.letsencrypt.org/directory and https://acme-v02.api.letsencrypt.org/directory. + +.. figure:: letsencrypt_authority_2.png + +After creating the authorities, we will need to create a DNS provider. Visit `Admin` -> `DNS Providers` and click +`Create`. Lemur comes with a few provider plugins built in, with different options. Create a DNS provider with the +appropriate choices. + +.. figure:: create_dns_provider.png + +By default, users will need to select the DNS provider that is authoritative over their domain in order for the +LetsEncrypt flow to function. However, Lemur will attempt to automatically determine the appropriate provider if +possible. To enable this functionality, periodically (or through Cron/Celery) run `lemur dns_providers get_all_zones`. +This command will traverse all DNS providers, determine which zones they control, and upload this list of zones to +Lemur's database (in the dns_providers table). Alternatively, you can manually input this data. diff --git a/docs/production/letsencrypt_authority_1.png b/docs/production/letsencrypt_authority_1.png new file mode 100644 index 00000000..5898b0e2 Binary files /dev/null and b/docs/production/letsencrypt_authority_1.png differ diff --git a/docs/production/letsencrypt_authority_2.png b/docs/production/letsencrypt_authority_2.png new file mode 100644 index 00000000..04947ca2 Binary files /dev/null and b/docs/production/letsencrypt_authority_2.png differ diff --git a/docs/production/letsencrypt_flow.png b/docs/production/letsencrypt_flow.png new file mode 100644 index 00000000..f35a1410 Binary files /dev/null and b/docs/production/letsencrypt_flow.png differ diff --git a/docs/quickstart/index.rst b/docs/quickstart/index.rst index 70ca1312..82bfc357 100644 --- a/docs/quickstart/index.rst +++ b/docs/quickstart/index.rst @@ -12,7 +12,7 @@ Dependencies Some basic prerequisites which you'll need in order to run Lemur: * A UNIX-based operating system (we test on Ubuntu, develop on OS X) -* Python 3.5 or greater +* Python 3.7 or greater * PostgreSQL 9.4 or greater * Nginx @@ -22,7 +22,7 @@ Some basic prerequisites which you'll need in order to run Lemur: Installing Build Dependencies ----------------------------- -If installing Lemur on a bare Ubuntu OS you will need to grab the following packages so that Lemur can correctly build it's dependencies: +If installing Lemur on a bare Ubuntu OS you will need to grab the following packages so that Lemur can correctly build its dependencies: .. code-block:: bash @@ -31,7 +31,7 @@ If installing Lemur on a bare Ubuntu OS you will need to grab the following pack .. note:: PostgreSQL is only required if your database is going to be on the same host as the webserver. npm is needed if you're installing Lemur from the source (e.g., from git). -.. note:: Installing node from a package manager may creat the nodejs bin at /usr/bin/nodejs instead of /usr/bin/node If that is the case run the following +.. note:: Installing node from a package manager may create the nodejs bin at /usr/bin/nodejs instead of /usr/bin/node If that is the case run the following sudo ln -s /user/bin/nodejs /usr/bin/node Now, install Python ``virtualenv`` package: @@ -117,7 +117,7 @@ Simply run: .. note:: This command will create a default configuration under ``~/.lemur/lemur.conf.py`` you can specify this location by passing the ``config_path`` parameter to the ``create_config`` command. -You can specify ``-c`` or ``--config`` to any Lemur command to specify the current environment you are working in. Lemur will also look under the environmental variable ``LEMUR_CONF`` should that be easier to setup in your environment. +You can specify ``-c`` or ``--config`` to any Lemur command to specify the current environment you are working in. Lemur will also look under the environmental variable ``LEMUR_CONF`` should that be easier to set up in your environment. Update your configuration @@ -144,7 +144,7 @@ Before Lemur will run you need to fill in a few required variables in the config LEMUR_DEFAULT_ORGANIZATION LEMUR_DEFAULT_ORGANIZATIONAL_UNIT -Setup Postgres +Set Up Postgres -------------- For production, a dedicated database is recommended, for this guide we will assume postgres has been installed and is on the same machine that Lemur is installed on. @@ -180,6 +180,13 @@ Lemur provides a helpful command that will initialize your database for you. It In addition to creating a new user, Lemur also creates a few default email notifications. These notifications are based on a few configuration options such as ``LEMUR_SECURITY_TEAM_EMAIL``. They basically guarantee that every certificate within Lemur will send one expiration notification to the security team. +Your database installation requires the pg_trgm extension. If you do not have this installed already, you can allow the script to install this for you by adding the SUPERUSER permission to the lemur database user. + +.. code-block:: bash + sudo -u postgres -i + psql + postgres=# ALTER USER lemur WITH SUPERUSER + Additional notifications can be created through the UI or API. See :ref:`Creating Notifications ` and :ref:`Command Line Interface ` for details. **Make note of the password used as this will be used during first login to the Lemur UI.** @@ -189,14 +196,20 @@ Additional notifications can be created through the UI or API. See :ref:`Creati cd /www/lemur/lemur lemur init +.. note:: If you added the SUPERUSER permission to the lemur database user above, it is recommended you revoke that permission now. + +.. code-block:: bash + sudo -u postgres -i + psql + postgres=# ALTER USER lemur WITH NOSUPERUSER + .. note:: It is recommended that once the ``lemur`` user is created that you create individual users for every day access. There is currently no way for a user to self enroll for Lemur access, they must have an administrator create an account for them or be enrolled automatically through SSO. This can be done through the CLI or UI. See :ref:`Creating Users ` and :ref:`Command Line Interface ` for details. - -Setup a Reverse Proxy +Set Up a Reverse Proxy --------------------- -By default, Lemur runs on port 8000. Even if you change this, under normal conditions you won't be able to bind to port 80. To get around this (and to avoid running Lemur as a privileged user, which you shouldn't), we need setup a simple web proxy. There are many different web servers you can use for this, we like and recommend Nginx. +By default, Lemur runs on port 8000. Even if you change this, under normal conditions you won't be able to bind to port 80. To get around this (and to avoid running Lemur as a privileged user, which you shouldn't), we need to set up a simple web proxy. There are many different web servers you can use for this, we like and recommend Nginx. Proxying with Nginx diff --git a/gulp/server.js b/gulp/server.js index 777100f6..6c61273e 100644 --- a/gulp/server.js +++ b/gulp/server.js @@ -6,31 +6,31 @@ var browserSync = require('browser-sync'); var httpProxy = require('http-proxy'); /* This configuration allow you to configure browser sync to proxy your backend */ -/* - var proxyTarget = 'http://localhost/context/'; // The location of your backend - var proxyApiPrefix = 'api'; // The element in the URL which differentiate between API request and static file request + + var proxyTarget = 'http://localhost:8000/'; // The location of your backend + var proxyApiPrefix = '/api/'; // The element in the URL which differentiate between API request and static file request var proxy = httpProxy.createProxyServer({ - target: proxyTarget + target: proxyTarget }); function proxyMiddleware(req, res, next) { - if (req.url.indexOf(proxyApiPrefix) !== -1) { - proxy.web(req, res); - } else { - next(); + if (req.url.indexOf(proxyApiPrefix) !== -1) { + proxy.web(req, res); + } else { + next(); + } } - } - */ function browserSyncInit(baseDir, files, browser) { browser = browser === undefined ? 'default' : browser; browserSync.instance = browserSync.init(files, { startPath: '/index.html', - server: { - baseDir: baseDir, - routes: { - '/bower_components': './bower_components' - } + server: { + middleware: [proxyMiddleware], + baseDir: baseDir, + routes: { + '/bower_components': './bower_components' + } }, browser: browser, ghostMode: false diff --git a/lemur/__about__.py b/lemur/__about__.py index d15b7dea..766d3668 100644 --- a/lemur/__about__.py +++ b/lemur/__about__.py @@ -1,12 +1,18 @@ from __future__ import absolute_import, division, print_function __all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", ] __title__ = "lemur" -__summary__ = ("Certificate management and orchestration service") +__summary__ = "Certificate management and orchestration service" __uri__ = "https://github.com/Netflix/lemur" __version__ = "0.7.0" diff --git a/lemur/__init__.py b/lemur/__init__.py index 1cdb3468..27deb4cd 100644 --- a/lemur/__init__.py +++ b/lemur/__init__.py @@ -5,7 +5,8 @@ :license: Apache, see LICENSE for more details. .. moduleauthor:: Kevin Glisson - +.. moduleauthor:: Curtis Castrapel +.. moduleauthor:: Hossein Shafagh """ import time @@ -32,14 +33,26 @@ from lemur.pending_certificates.views import mod as pending_certificates_bp from lemur.dns_providers.views import mod as dns_providers_bp from lemur.__about__ import ( - __author__, __copyright__, __email__, __license__, __summary__, __title__, - __uri__, __version__ + __author__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, ) __all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", ] LEMUR_BLUEPRINTS = ( @@ -62,8 +75,10 @@ LEMUR_BLUEPRINTS = ( ) -def create_app(config=None): - app = factory.create_app(app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config) +def create_app(config_path=None): + app = factory.create_app( + app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config_path + ) configure_hook(app) return app @@ -93,7 +108,7 @@ def configure_hook(app): @app.after_request def after_request(response): # Return early if we don't have the start time - if not hasattr(g, 'request_start_time'): + if not hasattr(g, "request_start_time"): return response # Get elapsed time in milliseconds @@ -102,12 +117,12 @@ def configure_hook(app): # Collect request/response tags tags = { - 'endpoint': request.endpoint, - 'request_method': request.method.lower(), - 'status_code': response.status_code + "endpoint": request.endpoint, + "request_method": request.method.lower(), + "status_code": response.status_code, } # Record our response time metric - metrics.send('response_time', 'TIMER', elapsed, metric_tags=tags) - metrics.send('status_code_{}'.format(response.status_code), 'counter', 1) + metrics.send("response_time", "TIMER", elapsed, metric_tags=tags) + metrics.send("status_code_{}".format(response.status_code), "counter", 1) return response diff --git a/lemur/api_keys/cli.py b/lemur/api_keys/cli.py index 2259d774..8aed0497 100644 --- a/lemur/api_keys/cli.py +++ b/lemur/api_keys/cli.py @@ -14,23 +14,32 @@ from datetime import datetime manager = Manager(usage="Handles all api key related tasks.") -@manager.option('-u', '--user-id', dest='uid', help='The User ID this access key belongs too.') -@manager.option('-n', '--name', dest='name', help='The name of this API Key.') -@manager.option('-t', '--ttl', dest='ttl', help='The TTL of this API Key. -1 for forever.') +@manager.option( + "-u", "--user-id", dest="uid", help="The User ID this access key belongs too." +) +@manager.option("-n", "--name", dest="name", help="The name of this API Key.") +@manager.option( + "-t", "--ttl", dest="ttl", help="The TTL of this API Key. -1 for forever." +) def create(uid, name, ttl): """ Create a new api key for a user. :return: """ print("[+] Creating a new api key.") - key = api_key_service.create(user_id=uid, name=name, - ttl=ttl, issued_at=int(datetime.utcnow().timestamp()), revoked=False) + key = api_key_service.create( + user_id=uid, + name=name, + ttl=ttl, + issued_at=int(datetime.utcnow().timestamp()), + revoked=False, + ) print("[+] Successfully created a new api key. Generating a JWT...") jwt = create_token(uid, key.id, key.ttl) print("[+] Your JWT is: {jwt}".format(jwt=jwt)) -@manager.option('-a', '--api-key-id', dest='aid', help='The API Key ID to revoke.') +@manager.option("-a", "--api-key-id", dest="aid", help="The API Key ID to revoke.") def revoke(aid): """ Revokes an api key for a user. diff --git a/lemur/api_keys/models.py b/lemur/api_keys/models.py index df77edb1..fbcc3e44 100644 --- a/lemur/api_keys/models.py +++ b/lemur/api_keys/models.py @@ -12,14 +12,19 @@ from lemur.database import db class ApiKey(db.Model): - __tablename__ = 'api_keys' + __tablename__ = "api_keys" id = Column(Integer, primary_key=True) name = Column(String) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) ttl = Column(BigInteger) issued_at = Column(BigInteger) revoked = Column(Boolean) def __repr__(self): return "ApiKey(name={name}, user_id={user_id}, ttl={ttl}, issued_at={iat}, revoked={revoked})".format( - user_id=self.user_id, name=self.name, ttl=self.ttl, iat=self.issued_at, revoked=self.revoked) + user_id=self.user_id, + name=self.name, + ttl=self.ttl, + iat=self.issued_at, + revoked=self.revoked, + ) diff --git a/lemur/api_keys/schemas.py b/lemur/api_keys/schemas.py index a3c11417..e690b859 100644 --- a/lemur/api_keys/schemas.py +++ b/lemur/api_keys/schemas.py @@ -13,12 +13,18 @@ from lemur.users.schemas import UserNestedOutputSchema, UserInputSchema def current_user_id(): - return {'id': g.current_user.id, 'email': g.current_user.email, 'username': g.current_user.username} + return { + "id": g.current_user.id, + "email": g.current_user.email, + "username": g.current_user.username, + } class ApiKeyInputSchema(LemurInputSchema): name = fields.String(required=False) - user = fields.Nested(UserInputSchema, missing=current_user_id, default=current_user_id) + user = fields.Nested( + UserInputSchema, missing=current_user_id, default=current_user_id + ) ttl = fields.Integer() diff --git a/lemur/api_keys/service.py b/lemur/api_keys/service.py index 5ddb8a3a..ea681a62 100644 --- a/lemur/api_keys/service.py +++ b/lemur/api_keys/service.py @@ -34,7 +34,7 @@ def revoke(aid): :return: """ api_key = get(aid) - setattr(api_key, 'revoked', False) + setattr(api_key, "revoked", False) return database.update(api_key) @@ -80,10 +80,10 @@ def render(args): :return: """ query = database.session_query(ApiKey) - user_id = args.pop('user_id', None) - aid = args.pop('id', None) - has_permission = args.pop('has_permission', False) - requesting_user_id = args.pop('requesting_user_id') + user_id = args.pop("user_id", None) + aid = args.pop("id", None) + has_permission = args.pop("has_permission", False) + requesting_user_id = args.pop("requesting_user_id") if user_id: query = query.filter(ApiKey.user_id == user_id) diff --git a/lemur/api_keys/views.py b/lemur/api_keys/views.py index b7af2944..ee09d3f7 100644 --- a/lemur/api_keys/views.py +++ b/lemur/api_keys/views.py @@ -19,10 +19,16 @@ from lemur.auth.permissions import ApiKeyCreatorPermission from lemur.common.schema import validate_schema from lemur.common.utils import paginated_parser -from lemur.api_keys.schemas import api_key_input_schema, api_key_revoke_schema, api_key_output_schema, \ - api_keys_output_schema, api_key_described_output_schema, user_api_key_input_schema +from lemur.api_keys.schemas import ( + api_key_input_schema, + api_key_revoke_schema, + api_key_output_schema, + api_keys_output_schema, + api_key_described_output_schema, + user_api_key_input_schema, +) -mod = Blueprint('api_keys', __name__) +mod = Blueprint("api_keys", __name__) api = Api(mod) @@ -81,8 +87,8 @@ class ApiKeyList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['has_permission'] = ApiKeyCreatorPermission().can() - args['requesting_user_id'] = g.current_user.id + args["has_permission"] = ApiKeyCreatorPermission().can() + args["requesting_user_id"] = g.current_user.id return service.render(args) @validate_schema(api_key_input_schema, api_key_output_schema) @@ -124,12 +130,26 @@ class ApiKeyList(AuthenticatedResource): :statuscode 403: unauthenticated """ if not ApiKeyCreatorPermission().can(): - if data['user']['id'] != g.current_user.id: - return dict(message="You are not authorized to create tokens for: {0}".format(data['user']['username'])), 403 + if data["user"]["id"] != g.current_user.id: + return ( + dict( + message="You are not authorized to create tokens for: {0}".format( + data["user"]["username"] + ) + ), + 403, + ) - access_token = service.create(name=data['name'], user_id=data['user']['id'], ttl=data['ttl'], - revoked=False, issued_at=int(datetime.utcnow().timestamp())) - return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)) + access_token = service.create( + name=data["name"], + user_id=data["user"]["id"], + ttl=data["ttl"], + revoked=False, + issued_at=int(datetime.utcnow().timestamp()), + ) + return dict( + jwt=create_token(access_token.user_id, access_token.id, access_token.ttl) + ) class ApiKeyUserList(AuthenticatedResource): @@ -186,9 +206,9 @@ class ApiKeyUserList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['has_permission'] = ApiKeyCreatorPermission().can() - args['requesting_user_id'] = g.current_user.id - args['user_id'] = user_id + args["has_permission"] = ApiKeyCreatorPermission().can() + args["requesting_user_id"] = g.current_user.id + args["user_id"] = user_id return service.render(args) @validate_schema(user_api_key_input_schema, api_key_output_schema) @@ -230,11 +250,25 @@ class ApiKeyUserList(AuthenticatedResource): """ if not ApiKeyCreatorPermission().can(): if user_id != g.current_user.id: - return dict(message="You are not authorized to create tokens for: {0}".format(user_id)), 403 + return ( + dict( + message="You are not authorized to create tokens for: {0}".format( + user_id + ) + ), + 403, + ) - access_token = service.create(name=data['name'], user_id=user_id, ttl=data['ttl'], - revoked=False, issued_at=int(datetime.utcnow().timestamp())) - return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)) + access_token = service.create( + name=data["name"], + user_id=user_id, + ttl=data["ttl"], + revoked=False, + issued_at=int(datetime.utcnow().timestamp()), + ) + return dict( + jwt=create_token(access_token.user_id, access_token.id, access_token.ttl) + ) class ApiKeys(AuthenticatedResource): @@ -329,7 +363,9 @@ class ApiKeys(AuthenticatedResource): if not ApiKeyCreatorPermission().can(): return dict(message="You are not authorized to update this token!"), 403 - service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl']) + service.update( + access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"] + ) return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl)) def delete(self, aid): @@ -371,7 +407,7 @@ class ApiKeys(AuthenticatedResource): return dict(message="You are not authorized to delete this token!"), 403 service.delete(access_key) - return {'result': True} + return {"result": True} class UserApiKeys(AuthenticatedResource): @@ -472,7 +508,9 @@ class UserApiKeys(AuthenticatedResource): if access_key.user_id != uid: return dict(message="You are not authorized to update this token!"), 403 - service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl']) + service.update( + access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"] + ) return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl)) def delete(self, uid, aid): @@ -517,7 +555,7 @@ class UserApiKeys(AuthenticatedResource): return dict(message="You are not authorized to delete this token!"), 403 service.delete(access_key) - return {'result': True} + return {"result": True} class ApiKeysDescribed(AuthenticatedResource): @@ -572,8 +610,12 @@ class ApiKeysDescribed(AuthenticatedResource): return access_key -api.add_resource(ApiKeyList, '/keys', endpoint='api_keys') -api.add_resource(ApiKeys, '/keys/', endpoint='api_key') -api.add_resource(ApiKeysDescribed, '/keys//described', endpoint='api_key_described') -api.add_resource(ApiKeyUserList, '/users//keys', endpoint='user_api_keys') -api.add_resource(UserApiKeys, '/users//keys/', endpoint='user_api_key') +api.add_resource(ApiKeyList, "/keys", endpoint="api_keys") +api.add_resource(ApiKeys, "/keys/", endpoint="api_key") +api.add_resource( + ApiKeysDescribed, "/keys//described", endpoint="api_key_described" +) +api.add_resource(ApiKeyUserList, "/users//keys", endpoint="user_api_keys") +api.add_resource( + UserApiKeys, "/users//keys/", endpoint="user_api_key" +) diff --git a/lemur/auth/ldap.py b/lemur/auth/ldap.py index 7eded060..ed87b76c 100644 --- a/lemur/auth/ldap.py +++ b/lemur/auth/ldap.py @@ -14,35 +14,41 @@ from lemur.roles import service as role_service from lemur.common.utils import validate_conf, get_psuedo_random_string -class LdapPrincipal(): +class LdapPrincipal: """ Provides methods for authenticating against an LDAP server. """ + def __init__(self, args): self._ldap_validate_conf() # setup ldap config - if not args['username']: + if not args["username"]: raise Exception("missing ldap username") - if not args['password']: + if not args["password"]: self.error_message = "missing ldap password" raise Exception("missing ldap password") - self.ldap_principal = args['username'] + self.ldap_principal = args["username"] self.ldap_email_domain = current_app.config.get("LDAP_EMAIL_DOMAIN", None) - if '@' not in self.ldap_principal: - self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain) - self.ldap_username = args['username'] - if '@' in self.ldap_username: - self.ldap_username = args['username'].split("@")[0] - self.ldap_password = args['password'] - self.ldap_server = current_app.config.get('LDAP_BIND_URI', None) + if "@" not in self.ldap_principal: + self.ldap_principal = "%s@%s" % ( + self.ldap_principal, + self.ldap_email_domain, + ) + self.ldap_username = args["username"] + if "@" in self.ldap_username: + self.ldap_username = args["username"].split("@")[0] + self.ldap_password = args["password"] + self.ldap_server = current_app.config.get("LDAP_BIND_URI", None) self.ldap_base_dn = current_app.config.get("LDAP_BASE_DN", None) self.ldap_use_tls = current_app.config.get("LDAP_USE_TLS", False) self.ldap_cacert_file = current_app.config.get("LDAP_CACERT_FILE", None) self.ldap_default_role = current_app.config.get("LEMUR_DEFAULT_ROLE", None) self.ldap_required_group = current_app.config.get("LDAP_REQUIRED_GROUP", None) self.ldap_groups_to_roles = current_app.config.get("LDAP_GROUPS_TO_ROLES", None) - self.ldap_is_active_directory = current_app.config.get("LDAP_IS_ACTIVE_DIRECTORY", False) - self.ldap_attrs = ['memberOf'] + self.ldap_is_active_directory = current_app.config.get( + "LDAP_IS_ACTIVE_DIRECTORY", False + ) + self.ldap_attrs = ["memberOf"] self.ldap_client = None self.ldap_groups = None @@ -60,8 +66,8 @@ class LdapPrincipal(): get_psuedo_random_string(), self.ldap_principal, True, - '', # thumbnailPhotoUrl - list(roles) + "", # thumbnailPhotoUrl + list(roles), ) else: # we add 'lemur' specific roles, so they do not get marked as removed @@ -76,7 +82,7 @@ class LdapPrincipal(): self.ldap_principal, user.active, user.profile_picture, - list(roles) + list(roles), ) return user @@ -99,15 +105,18 @@ class LdapPrincipal(): role = role_service.get_by_name(self.ldap_default_role) if role: if not role.third_party: - role = role.set_third_party(role.id, third_party_status=True) + role = role_service.set_third_party(role.id, third_party_status=True) roles.add(role) # update their 'roles' role = role_service.get_by_name(self.ldap_principal) if not role: - description = "auto generated role based on owner: {0}".format(self.ldap_principal) - role = role_service.create(self.ldap_principal, description=description, - third_party=True) + description = "auto generated role based on owner: {0}".format( + self.ldap_principal + ) + role = role_service.create( + self.ldap_principal, description=description, third_party=True + ) if not role.third_party: role = role_service.set_third_party(role.id, third_party_status=True) roles.add(role) @@ -118,9 +127,15 @@ class LdapPrincipal(): role = role_service.get_by_name(role_name) if role: if ldap_group_name in self.ldap_groups: - current_app.logger.debug("assigning role {0} to ldap user {1}".format(self.ldap_principal, role)) + current_app.logger.debug( + "assigning role {0} to ldap user {1}".format( + self.ldap_principal, role + ) + ) if not role.third_party: - role = role_service.set_third_party(role.id, third_party_status=True) + role = role_service.set_third_party( + role.id, third_party_status=True + ) roles.add(role) return roles @@ -132,7 +147,7 @@ class LdapPrincipal(): self._bind() roles = self._authorize() if not roles: - raise Exception('ldap authorization failed') + raise Exception("ldap authorization failed") return self._update_user(roles) def _bind(self): @@ -141,9 +156,12 @@ class LdapPrincipal(): list groups for a user. raise an exception on error. """ - if '@' not in self.ldap_principal: - self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain) - ldap_filter = 'userPrincipalName=%s' % self.ldap_principal + if "@" not in self.ldap_principal: + self.ldap_principal = "%s@%s" % ( + self.ldap_principal, + self.ldap_email_domain, + ) + ldap_filter = "userPrincipalName=%s" % self.ldap_principal # query ldap for auth try: @@ -159,37 +177,47 @@ class LdapPrincipal(): self.ldap_client.set_option(ldap.OPT_X_TLS_DEMAND, True) self.ldap_client.set_option(ldap.OPT_DEBUG_LEVEL, 255) if self.ldap_cacert_file: - self.ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file) + self.ldap_client.set_option( + ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file + ) self.ldap_client.simple_bind_s(self.ldap_principal, self.ldap_password) except ldap.INVALID_CREDENTIALS: self.ldap_client.unbind() - raise Exception('The supplied ldap credentials are invalid') + raise Exception("The supplied ldap credentials are invalid") except ldap.SERVER_DOWN: - raise Exception('ldap server unavailable') + raise Exception("ldap server unavailable") except ldap.LDAPError as e: raise Exception("ldap error: {0}".format(e)) if self.ldap_is_active_directory: # Lookup user DN, needed to search for group membership - userdn = self.ldap_client.search_s(self.ldap_base_dn, - ldap.SCOPE_SUBTREE, ldap_filter, - ['distinguishedName'])[0][1]['distinguishedName'][0] - userdn = userdn.decode('utf-8') + userdn = self.ldap_client.search_s( + self.ldap_base_dn, + ldap.SCOPE_SUBTREE, + ldap_filter, + ["distinguishedName"], + )[0][1]["distinguishedName"][0] + userdn = userdn.decode("utf-8") # Search all groups that have the userDN as a member - groupfilter = '(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))'.format(userdn) - lgroups = self.ldap_client.search_s(self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ['cn']) + groupfilter = "(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))".format( + userdn + ) + lgroups = self.ldap_client.search_s( + self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ["cn"] + ) # Create a list of group CN's from the result self.ldap_groups = [] for group in lgroups: (dn, values) = group - self.ldap_groups.append(values['cn'][0].decode('ascii')) + self.ldap_groups.append(values["cn"][0].decode("ascii")) else: - lgroups = self.ldap_client.search_s(self.ldap_base_dn, - ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs)[0][1]['memberOf'] + lgroups = self.ldap_client.search_s( + self.ldap_base_dn, ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs + )[0][1]["memberOf"] # lgroups is a list of utf-8 encoded strings # convert to a single string of groups to allow matching - self.ldap_groups = b''.join(lgroups).decode('ascii') + self.ldap_groups = b"".join(lgroups).decode("ascii") self.ldap_client.unbind() @@ -197,9 +225,5 @@ class LdapPrincipal(): """ Confirms required ldap config settings exist. """ - required_vars = [ - 'LDAP_BIND_URI', - 'LDAP_BASE_DN', - 'LDAP_EMAIL_DOMAIN', - ] + required_vars = ["LDAP_BIND_URI", "LDAP_BASE_DN", "LDAP_EMAIL_DOMAIN"] validate_conf(current_app, required_vars) diff --git a/lemur/auth/permissions.py b/lemur/auth/permissions.py index 68c48773..a5964880 100644 --- a/lemur/auth/permissions.py +++ b/lemur/auth/permissions.py @@ -9,24 +9,32 @@ from functools import partial from collections import namedtuple +from flask import current_app from flask_principal import Permission, RoleNeed # Permissions -operator_permission = Permission(RoleNeed('operator')) -admin_permission = Permission(RoleNeed('admin')) +operator_permission = Permission(RoleNeed("operator")) +admin_permission = Permission(RoleNeed("admin")) -CertificateOwner = namedtuple('certificate', ['method', 'value']) -CertificateOwnerNeed = partial(CertificateOwner, 'role') +CertificateOwner = namedtuple("certificate", ["method", "value"]) +CertificateOwnerNeed = partial(CertificateOwner, "role") class SensitiveDomainPermission(Permission): def __init__(self): - super(SensitiveDomainPermission, self).__init__(RoleNeed('admin')) + needs = [RoleNeed("admin")] + sensitive_domain_roles = current_app.config.get("SENSITIVE_DOMAIN_ROLES", []) + + if sensitive_domain_roles: + for role in sensitive_domain_roles: + needs.append(RoleNeed(role)) + + super(SensitiveDomainPermission, self).__init__(*needs) class CertificatePermission(Permission): def __init__(self, owner, roles): - needs = [RoleNeed('admin'), RoleNeed(owner), RoleNeed('creator')] + needs = [RoleNeed("admin"), RoleNeed(owner), RoleNeed("creator")] for r in roles: needs.append(CertificateOwnerNeed(str(r))) # Backwards compatibility with mixed-case role names @@ -38,29 +46,29 @@ class CertificatePermission(Permission): class ApiKeyCreatorPermission(Permission): def __init__(self): - super(ApiKeyCreatorPermission, self).__init__(RoleNeed('admin')) + super(ApiKeyCreatorPermission, self).__init__(RoleNeed("admin")) -RoleMember = namedtuple('role', ['method', 'value']) -RoleMemberNeed = partial(RoleMember, 'member') +RoleMember = namedtuple("role", ["method", "value"]) +RoleMemberNeed = partial(RoleMember, "member") class RoleMemberPermission(Permission): def __init__(self, role_id): - needs = [RoleNeed('admin'), RoleMemberNeed(role_id)] + needs = [RoleNeed("admin"), RoleMemberNeed(role_id)] super(RoleMemberPermission, self).__init__(*needs) -AuthorityCreator = namedtuple('authority', ['method', 'value']) -AuthorityCreatorNeed = partial(AuthorityCreator, 'authorityUse') +AuthorityCreator = namedtuple("authority", ["method", "value"]) +AuthorityCreatorNeed = partial(AuthorityCreator, "authorityUse") -AuthorityOwner = namedtuple('authority', ['method', 'value']) -AuthorityOwnerNeed = partial(AuthorityOwner, 'role') +AuthorityOwner = namedtuple("authority", ["method", "value"]) +AuthorityOwnerNeed = partial(AuthorityOwner, "role") class AuthorityPermission(Permission): def __init__(self, authority_id, roles): - needs = [RoleNeed('admin'), AuthorityCreatorNeed(str(authority_id))] + needs = [RoleNeed("admin"), AuthorityCreatorNeed(str(authority_id))] for r in roles: needs.append(AuthorityOwnerNeed(str(r))) diff --git a/lemur/auth/service.py b/lemur/auth/service.py index c862aa2e..0e1521b3 100644 --- a/lemur/auth/service.py +++ b/lemur/auth/service.py @@ -39,13 +39,13 @@ def get_rsa_public_key(n, e): :param e: :return: a RSA Public Key in PEM format """ - n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, 'utf-8'))), 16) - e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, 'utf-8'))), 16) + n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, "utf-8"))), 16) + e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, "utf-8"))), 16) pub = RSAPublicNumbers(e, n).public_key(default_backend()) return pub.public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) @@ -57,28 +57,27 @@ def create_token(user, aid=None, ttl=None): :param user: :return: """ - expiration_delta = timedelta(days=int(current_app.config.get('LEMUR_TOKEN_EXPIRATION', 1))) - payload = { - 'iat': datetime.utcnow(), - 'exp': datetime.utcnow() + expiration_delta - } + expiration_delta = timedelta( + days=int(current_app.config.get("LEMUR_TOKEN_EXPIRATION", 1)) + ) + payload = {"iat": datetime.utcnow(), "exp": datetime.utcnow() + expiration_delta} # Handle Just a User ID & User Object. if isinstance(user, int): - payload['sub'] = user + payload["sub"] = user else: - payload['sub'] = user.id + payload["sub"] = user.id if aid is not None: - payload['aid'] = aid + payload["aid"] = aid # Custom TTLs are only supported on Access Keys. if ttl is not None and aid is not None: # Tokens that are forever until revoked. if ttl == -1: - del payload['exp'] + del payload["exp"] else: - payload['exp'] = ttl - token = jwt.encode(payload, current_app.config['LEMUR_TOKEN_SECRET']) - return token.decode('unicode_escape') + payload["exp"] = ttl + token = jwt.encode(payload, current_app.config["LEMUR_TOKEN_SECRET"]) + return token.decode("unicode_escape") def login_required(f): @@ -88,49 +87,54 @@ def login_required(f): :param f: :return: """ + @wraps(f) def decorated_function(*args, **kwargs): - if not request.headers.get('Authorization'): - response = jsonify(message='Missing authorization header') + if not request.headers.get("Authorization"): + response = jsonify(message="Missing authorization header") response.status_code = 401 return response try: - token = request.headers.get('Authorization').split()[1] + token = request.headers.get("Authorization").split()[1] except Exception as e: - return dict(message='Token is invalid'), 403 + return dict(message="Token is invalid"), 403 try: - payload = jwt.decode(token, current_app.config['LEMUR_TOKEN_SECRET']) + payload = jwt.decode(token, current_app.config["LEMUR_TOKEN_SECRET"]) except jwt.DecodeError: - return dict(message='Token is invalid'), 403 + return dict(message="Token is invalid"), 403 except jwt.ExpiredSignatureError: - return dict(message='Token has expired'), 403 + return dict(message="Token has expired"), 403 except jwt.InvalidTokenError: - return dict(message='Token is invalid'), 403 + return dict(message="Token is invalid"), 403 - if 'aid' in payload: - access_key = api_key_service.get(payload['aid']) + if "aid" in payload: + access_key = api_key_service.get(payload["aid"]) if access_key.revoked: - return dict(message='Token has been revoked'), 403 + return dict(message="Token has been revoked"), 403 if access_key.ttl != -1: current_time = datetime.utcnow() - expired_time = datetime.fromtimestamp(access_key.issued_at + access_key.ttl) + expired_time = datetime.fromtimestamp( + access_key.issued_at + access_key.ttl + ) if current_time >= expired_time: - return dict(message='Token has expired'), 403 + return dict(message="Token has expired"), 403 - user = user_service.get(payload['sub']) + user = user_service.get(payload["sub"]) if not user.active: - return dict(message='User is not currently active'), 403 + return dict(message="User is not currently active"), 403 g.current_user = user if not g.current_user: - return dict(message='You are not logged in'), 403 + return dict(message="You are not logged in"), 403 # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), identity=Identity(g.current_user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(g.current_user.id) + ) return f(*args, **kwargs) @@ -144,18 +148,18 @@ def fetch_token_header(token): :param token: :return: :raise jwt.DecodeError: """ - token = token.encode('utf-8') + token = token.encode("utf-8") try: - signing_input, crypto_segment = token.rsplit(b'.', 1) - header_segment, payload_segment = signing_input.split(b'.', 1) + signing_input, crypto_segment = token.rsplit(b".", 1) + header_segment, payload_segment = signing_input.split(b".", 1) except ValueError: - raise jwt.DecodeError('Not enough segments') + raise jwt.DecodeError("Not enough segments") try: - return json.loads(jwt.utils.base64url_decode(header_segment).decode('utf-8')) + return json.loads(jwt.utils.base64url_decode(header_segment).decode("utf-8")) except TypeError as e: current_app.logger.exception(e) - raise jwt.DecodeError('Invalid header padding') + raise jwt.DecodeError("Invalid header padding") @identity_loaded.connect @@ -174,13 +178,13 @@ def on_identity_loaded(sender, identity): identity.provides.add(UserNeed(identity.id)) # identity with the roles that the user provides - if hasattr(user, 'roles'): + if hasattr(user, "roles"): for role in user.roles: identity.provides.add(RoleNeed(role.name)) identity.provides.add(RoleMemberNeed(role.id)) # apply ownership for authorities - if hasattr(user, 'authorities'): + if hasattr(user, "authorities"): for authority in user.authorities: identity.provides.add(AuthorityCreatorNeed(authority.id)) @@ -191,6 +195,7 @@ class AuthenticatedResource(Resource): """ Inherited by all resources that need to be protected by authentication. """ + method_decorators = [login_required] def __init__(self): diff --git a/lemur/auth/views.py b/lemur/auth/views.py index 7a1bb34c..e7f87356 100644 --- a/lemur/auth/views.py +++ b/lemur/auth/views.py @@ -24,11 +24,13 @@ from lemur.auth.service import create_token, fetch_token_header, get_rsa_public_ from lemur.auth import ldap -mod = Blueprint('auth', __name__) +mod = Blueprint("auth", __name__) api = Api(mod) -def exchange_for_access_token(code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True): +def exchange_for_access_token( + code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True +): """ Exchanges authorization code for access token. @@ -43,28 +45,32 @@ def exchange_for_access_token(code, redirect_uri, client_id, secret, access_toke """ # take the information we have received from the provider to create a new request params = { - 'grant_type': 'authorization_code', - 'scope': 'openid email profile address', - 'code': code, - 'redirect_uri': redirect_uri, - 'client_id': client_id + "grant_type": "authorization_code", + "scope": "openid email profile address", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, } # the secret and cliendId will be given to you when you signup for the provider - token = '{0}:{1}'.format(client_id, secret) + token = "{0}:{1}".format(client_id, secret) - basic = base64.b64encode(bytes(token, 'utf-8')) + basic = base64.b64encode(bytes(token, "utf-8")) headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'authorization': 'basic {0}'.format(basic.decode('utf-8')) + "Content-Type": "application/x-www-form-urlencoded", + "authorization": "basic {0}".format(basic.decode("utf-8")), } # exchange authorization code for access token. - r = requests.post(access_token_url, headers=headers, params=params, verify=verify_cert) + r = requests.post( + access_token_url, headers=headers, params=params, verify=verify_cert + ) if r.status_code == 400: - r = requests.post(access_token_url, headers=headers, data=params, verify=verify_cert) - id_token = r.json()['id_token'] - access_token = r.json()['access_token'] + r = requests.post( + access_token_url, headers=headers, data=params, verify=verify_cert + ) + id_token = r.json()["id_token"] + access_token = r.json()["access_token"] return id_token, access_token @@ -83,23 +89,25 @@ def validate_id_token(id_token, client_id, jwks_url): # retrieve the key material as specified by the token header r = requests.get(jwks_url) - for key in r.json()['keys']: - if key['kid'] == header_data['kid']: - secret = get_rsa_public_key(key['n'], key['e']) - algo = header_data['alg'] + for key in r.json()["keys"]: + if key["kid"] == header_data["kid"]: + secret = get_rsa_public_key(key["n"], key["e"]) + algo = header_data["alg"] break else: - return dict(message='Key not found'), 401 + return dict(message="Key not found"), 401 # validate your token based on the key it was signed with try: - jwt.decode(id_token, secret.decode('utf-8'), algorithms=[algo], audience=client_id) + jwt.decode( + id_token, secret.decode("utf-8"), algorithms=[algo], audience=client_id + ) except jwt.DecodeError: - return dict(message='Token is invalid'), 401 + return dict(message="Token is invalid"), 401 except jwt.ExpiredSignatureError: - return dict(message='Token has expired'), 401 + return dict(message="Token has expired"), 401 except jwt.InvalidTokenError: - return dict(message='Token is invalid'), 401 + return dict(message="Token is invalid"), 401 def retrieve_user(user_api_url, access_token): @@ -110,13 +118,18 @@ def retrieve_user(user_api_url, access_token): :param access_token: :return: """ - user_params = dict(access_token=access_token, schema='profile') + user_params = dict(access_token=access_token, schema="profile") + + headers = {} + + if current_app.config.get("PING_INCLUDE_BEARER_TOKEN"): + headers = {"Authorization": f"Bearer {access_token}"} # retrieve information about the current user. - r = requests.get(user_api_url, params=user_params) + r = requests.get(user_api_url, params=user_params, headers=headers) profile = r.json() - user = user_service.get_by_email(profile['email']) + user = user_service.get_by_email(profile["email"]) return user, profile @@ -129,28 +142,44 @@ def create_user_roles(profile): roles = [] # update their google 'roles' - for group in profile['googleGroups']: - role = role_service.get_by_name(group) - if not role: - role = role_service.create(group, description='This is a google group based role created by Lemur', third_party=True) - if not role.third_party: - role = role_service.set_third_party(role.id, third_party_status=True) - roles.append(role) + if "googleGroups" in profile: + for group in profile["googleGroups"]: + role = role_service.get_by_name(group) + if not role: + role = role_service.create( + group, + description="This is a google group based role created by Lemur", + third_party=True, + ) + if not role.third_party: + role = role_service.set_third_party(role.id, third_party_status=True) + roles.append(role) + else: + current_app.logger.warning( + "'googleGroups' not sent by identity provider, no specific roles will assigned to the user." + ) - role = role_service.get_by_name(profile['email']) + role = role_service.get_by_name(profile["email"]) if not role: - role = role_service.create(profile['email'], description='This is a user specific role', third_party=True) + role = role_service.create( + profile["email"], + description="This is a user specific role", + third_party=True, + ) if not role.third_party: role = role_service.set_third_party(role.id, third_party_status=True) roles.append(role) # every user is an operator (tied to a default role) - if current_app.config.get('LEMUR_DEFAULT_ROLE'): - default = role_service.get_by_name(current_app.config['LEMUR_DEFAULT_ROLE']) + if current_app.config.get("LEMUR_DEFAULT_ROLE"): + default = role_service.get_by_name(current_app.config["LEMUR_DEFAULT_ROLE"]) if not default: - default = role_service.create(current_app.config['LEMUR_DEFAULT_ROLE'], description='This is the default Lemur role.') + default = role_service.create( + current_app.config["LEMUR_DEFAULT_ROLE"], + description="This is the default Lemur role.", + ) if not default.third_party: role_service.set_third_party(default.id, third_party_status=True) roles.append(default) @@ -169,12 +198,12 @@ def update_user(user, profile, roles): # if we get an sso user create them an account if not user: user = user_service.create( - profile['email'], + profile["email"], get_psuedo_random_string(), - profile['email'], + profile["email"], True, - profile.get('thumbnailPhotoUrl'), - roles + profile.get("thumbnailPhotoUrl"), + roles, ) else: @@ -186,11 +215,11 @@ def update_user(user, profile, roles): # update any changes to the user user_service.update( user.id, - profile['email'], - profile['email'], + profile["email"], + profile["email"], True, - profile.get('thumbnailPhotoUrl'), # profile isn't google+ enabled - roles + profile.get("thumbnailPhotoUrl"), # profile isn't google+ enabled + roles, ) @@ -211,6 +240,7 @@ class Login(Resource): on your uses cases but. It is important to not that there is currently no build in method to revoke a users token \ and force re-authentication. """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(Login, self).__init__() @@ -251,23 +281,26 @@ class Login(Resource): :statuscode 401: invalid credentials :statuscode 200: no error """ - self.reqparse.add_argument('username', type=str, required=True, location='json') - self.reqparse.add_argument('password', type=str, required=True, location='json') + self.reqparse.add_argument("username", type=str, required=True, location="json") + self.reqparse.add_argument("password", type=str, required=True, location="json") args = self.reqparse.parse_args() - if '@' in args['username']: - user = user_service.get_by_email(args['username']) + if "@" in args["username"]: + user = user_service.get_by_email(args["username"]) else: - user = user_service.get_by_username(args['username']) + user = user_service.get_by_username(args["username"]) # default to local authentication - if user and user.check_password(args['password']) and user.active: + if user and user.check_password(args["password"]) and user.active: # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), - identity=Identity(user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) # try ldap login @@ -277,19 +310,29 @@ class Login(Resource): user = ldap_principal.authenticate() if user and user.active: # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), - identity=Identity(user.id)) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) + metrics.send( + "login", + "counter", + 1, + metric_tags={"status": SUCCESS_METRIC_STATUS}, + ) return dict(token=create_token(user)) except Exception as e: - current_app.logger.error("ldap error: {0}".format(e)) - ldap_message = 'ldap error: %s' % e - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message=ldap_message), 403 + current_app.logger.error("ldap error: {0}".format(e)) + ldap_message = "ldap error: %s" % e + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message=ldap_message), 403 # if not valid user - no certificates for you - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid"), 403 class Ping(Resource): @@ -302,49 +345,59 @@ class Ping(Resource): provider uses for its callbacks. 2. Add or change the Lemur AngularJS Configuration to point to your new provider """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(Ping, self).__init__() def get(self): - return 'Redirecting...' + return "Redirecting..." def post(self): - self.reqparse.add_argument('clientId', type=str, required=True, location='json') - self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') - self.reqparse.add_argument('code', type=str, required=True, location='json') + self.reqparse.add_argument("clientId", type=str, required=True, location="json") + self.reqparse.add_argument( + "redirectUri", type=str, required=True, location="json" + ) + self.reqparse.add_argument("code", type=str, required=True, location="json") args = self.reqparse.parse_args() # you can either discover these dynamically or simply configure them - access_token_url = current_app.config.get('PING_ACCESS_TOKEN_URL') - user_api_url = current_app.config.get('PING_USER_API_URL') + access_token_url = current_app.config.get("PING_ACCESS_TOKEN_URL") + user_api_url = current_app.config.get("PING_USER_API_URL") - secret = current_app.config.get('PING_SECRET') + secret = current_app.config.get("PING_SECRET") id_token, access_token = exchange_for_access_token( - args['code'], - args['redirectUri'], - args['clientId'], + args["code"], + args["redirectUri"], + args["clientId"], secret, - access_token_url=access_token_url + access_token_url=access_token_url, ) - jwks_url = current_app.config.get('PING_JWKS_URL') - validate_id_token(id_token, args['clientId'], jwks_url) - + jwks_url = current_app.config.get("PING_JWKS_URL") + error_code = validate_id_token(id_token, args["clientId"], jwks_url) + if error_code: + return error_code user, profile = retrieve_user(user_api_url, access_token) roles = create_user_roles(profile) update_user(user, profile, roles) if not user or not user.active: - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid"), 403 # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), identity=Identity(user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) @@ -354,46 +407,56 @@ class OAuth2(Resource): super(OAuth2, self).__init__() def get(self): - return 'Redirecting...' + return "Redirecting..." def post(self): - self.reqparse.add_argument('clientId', type=str, required=True, location='json') - self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') - self.reqparse.add_argument('code', type=str, required=True, location='json') + self.reqparse.add_argument("clientId", type=str, required=True, location="json") + self.reqparse.add_argument( + "redirectUri", type=str, required=True, location="json" + ) + self.reqparse.add_argument("code", type=str, required=True, location="json") args = self.reqparse.parse_args() # you can either discover these dynamically or simply configure them - access_token_url = current_app.config.get('OAUTH2_ACCESS_TOKEN_URL') - user_api_url = current_app.config.get('OAUTH2_USER_API_URL') - verify_cert = current_app.config.get('OAUTH2_VERIFY_CERT') + access_token_url = current_app.config.get("OAUTH2_ACCESS_TOKEN_URL") + user_api_url = current_app.config.get("OAUTH2_USER_API_URL") + verify_cert = current_app.config.get("OAUTH2_VERIFY_CERT") - secret = current_app.config.get('OAUTH2_SECRET') + secret = current_app.config.get("OAUTH2_SECRET") id_token, access_token = exchange_for_access_token( - args['code'], - args['redirectUri'], - args['clientId'], + args["code"], + args["redirectUri"], + args["clientId"], secret, access_token_url=access_token_url, - verify_cert=verify_cert + verify_cert=verify_cert, ) - jwks_url = current_app.config.get('PING_JWKS_URL') - validate_id_token(id_token, args['clientId'], jwks_url) + jwks_url = current_app.config.get("PING_JWKS_URL") + error_code = validate_id_token(id_token, args["clientId"], jwks_url) + if error_code: + return error_code user, profile = retrieve_user(user_api_url, access_token) roles = create_user_roles(profile) update_user(user, profile, roles) if not user.active: - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid"), 403 # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), identity=Identity(user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) @@ -404,44 +467,52 @@ class Google(Resource): super(Google, self).__init__() def post(self): - access_token_url = 'https://accounts.google.com/o/oauth2/token' - people_api_url = 'https://www.googleapis.com/plus/v1/people/me/openIdConnect' + access_token_url = "https://accounts.google.com/o/oauth2/token" + people_api_url = "https://www.googleapis.com/plus/v1/people/me/openIdConnect" - self.reqparse.add_argument('clientId', type=str, required=True, location='json') - self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') - self.reqparse.add_argument('code', type=str, required=True, location='json') + self.reqparse.add_argument("clientId", type=str, required=True, location="json") + self.reqparse.add_argument( + "redirectUri", type=str, required=True, location="json" + ) + self.reqparse.add_argument("code", type=str, required=True, location="json") args = self.reqparse.parse_args() # Step 1. Exchange authorization code for access token payload = { - 'client_id': args['clientId'], - 'grant_type': 'authorization_code', - 'redirect_uri': args['redirectUri'], - 'code': args['code'], - 'client_secret': current_app.config.get('GOOGLE_SECRET') + "client_id": args["clientId"], + "grant_type": "authorization_code", + "redirect_uri": args["redirectUri"], + "code": args["code"], + "client_secret": current_app.config.get("GOOGLE_SECRET"), } r = requests.post(access_token_url, data=payload) token = r.json() # Step 2. Retrieve information about the current user - headers = {'Authorization': 'Bearer {0}'.format(token['access_token'])} + headers = {"Authorization": "Bearer {0}".format(token["access_token"])} r = requests.get(people_api_url, headers=headers) profile = r.json() - user = user_service.get_by_email(profile['email']) + user = user_service.get_by_email(profile["email"]) if not (user and user.active): - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid.'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid."), 403 if user: - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) class Providers(Resource): @@ -452,47 +523,57 @@ class Providers(Resource): provider = provider.lower() if provider == "google": - active_providers.append({ - 'name': 'google', - 'clientId': current_app.config.get("GOOGLE_CLIENT_ID"), - 'url': api.url_for(Google) - }) + active_providers.append( + { + "name": "google", + "clientId": current_app.config.get("GOOGLE_CLIENT_ID"), + "url": api.url_for(Google), + } + ) elif provider == "ping": - active_providers.append({ - 'name': current_app.config.get("PING_NAME"), - 'url': current_app.config.get('PING_REDIRECT_URI'), - 'redirectUri': current_app.config.get("PING_REDIRECT_URI"), - 'clientId': current_app.config.get("PING_CLIENT_ID"), - 'responseType': 'code', - 'scope': ['openid', 'email', 'profile', 'address'], - 'scopeDelimiter': ' ', - 'authorizationEndpoint': current_app.config.get("PING_AUTH_ENDPOINT"), - 'requiredUrlParams': ['scope'], - 'type': '2.0' - }) + active_providers.append( + { + "name": current_app.config.get("PING_NAME"), + "url": current_app.config.get("PING_REDIRECT_URI"), + "redirectUri": current_app.config.get("PING_REDIRECT_URI"), + "clientId": current_app.config.get("PING_CLIENT_ID"), + "responseType": "code", + "scope": ["openid", "email", "profile", "address"], + "scopeDelimiter": " ", + "authorizationEndpoint": current_app.config.get( + "PING_AUTH_ENDPOINT" + ), + "requiredUrlParams": ["scope"], + "type": "2.0", + } + ) elif provider == "oauth2": - active_providers.append({ - 'name': current_app.config.get("OAUTH2_NAME"), - 'url': current_app.config.get('OAUTH2_REDIRECT_URI'), - 'redirectUri': current_app.config.get("OAUTH2_REDIRECT_URI"), - 'clientId': current_app.config.get("OAUTH2_CLIENT_ID"), - 'responseType': 'code', - 'scope': ['openid', 'email', 'profile', 'groups'], - 'scopeDelimiter': ' ', - 'authorizationEndpoint': current_app.config.get("OAUTH2_AUTH_ENDPOINT"), - 'requiredUrlParams': ['scope', 'state', 'nonce'], - 'state': 'STATE', - 'nonce': get_psuedo_random_string(), - 'type': '2.0' - }) + active_providers.append( + { + "name": current_app.config.get("OAUTH2_NAME"), + "url": current_app.config.get("OAUTH2_REDIRECT_URI"), + "redirectUri": current_app.config.get("OAUTH2_REDIRECT_URI"), + "clientId": current_app.config.get("OAUTH2_CLIENT_ID"), + "responseType": "code", + "scope": ["openid", "email", "profile", "groups"], + "scopeDelimiter": " ", + "authorizationEndpoint": current_app.config.get( + "OAUTH2_AUTH_ENDPOINT" + ), + "requiredUrlParams": ["scope", "state", "nonce"], + "state": "STATE", + "nonce": get_psuedo_random_string(), + "type": "2.0", + } + ) return active_providers -api.add_resource(Login, '/auth/login', endpoint='login') -api.add_resource(Ping, '/auth/ping', endpoint='ping') -api.add_resource(Google, '/auth/google', endpoint='google') -api.add_resource(OAuth2, '/auth/oauth2', endpoint='oauth2') -api.add_resource(Providers, '/auth/providers', endpoint='providers') +api.add_resource(Login, "/auth/login", endpoint="login") +api.add_resource(Ping, "/auth/ping", endpoint="ping") +api.add_resource(Google, "/auth/google", endpoint="google") +api.add_resource(OAuth2, "/auth/oauth2", endpoint="oauth2") +api.add_resource(Providers, "/auth/providers", endpoint="providers") diff --git a/lemur/authorities/models.py b/lemur/authorities/models.py index 6c5f790b..ccd1fab8 100644 --- a/lemur/authorities/models.py +++ b/lemur/authorities/models.py @@ -7,7 +7,17 @@ .. moduleauthor:: Kevin Glisson """ from sqlalchemy.orm import relationship -from sqlalchemy import Column, Integer, String, Text, func, ForeignKey, DateTime, PassiveDefault, Boolean +from sqlalchemy import ( + Column, + Integer, + String, + Text, + func, + ForeignKey, + DateTime, + PassiveDefault, + Boolean, +) from sqlalchemy.dialects.postgresql import JSON from lemur.database import db @@ -16,7 +26,7 @@ from lemur.models import roles_authorities class Authority(db.Model): - __tablename__ = 'authorities' + __tablename__ = "authorities" id = Column(Integer, primary_key=True) owner = Column(String(128), nullable=False) name = Column(String(128), unique=True) @@ -27,22 +37,44 @@ class Authority(db.Model): description = Column(Text) options = Column(JSON) date_created = Column(DateTime, PassiveDefault(func.now()), nullable=False) - roles = relationship('Role', secondary=roles_authorities, passive_deletes=True, backref=db.backref('authority'), lazy='dynamic') - user_id = Column(Integer, ForeignKey('users.id')) - authority_certificate = relationship("Certificate", backref='root_authority', uselist=False, foreign_keys='Certificate.root_authority_id') - certificates = relationship("Certificate", backref='authority', foreign_keys='Certificate.authority_id') + roles = relationship( + "Role", + secondary=roles_authorities, + passive_deletes=True, + backref=db.backref("authority"), + lazy="dynamic", + ) + user_id = Column(Integer, ForeignKey("users.id")) + authority_certificate = relationship( + "Certificate", + backref="root_authority", + uselist=False, + foreign_keys="Certificate.root_authority_id", + ) + certificates = relationship( + "Certificate", backref="authority", foreign_keys="Certificate.authority_id" + ) - authority_pending_certificate = relationship("PendingCertificate", backref='root_authority', uselist=False, foreign_keys='PendingCertificate.root_authority_id') - pending_certificates = relationship('PendingCertificate', backref='authority', foreign_keys='PendingCertificate.authority_id') + authority_pending_certificate = relationship( + "PendingCertificate", + backref="root_authority", + uselist=False, + foreign_keys="PendingCertificate.root_authority_id", + ) + pending_certificates = relationship( + "PendingCertificate", + backref="authority", + foreign_keys="PendingCertificate.authority_id", + ) def __init__(self, **kwargs): - self.owner = kwargs['owner'] - self.roles = kwargs.get('roles', []) - self.name = kwargs.get('name') - self.description = kwargs.get('description') - self.authority_certificate = kwargs['authority_certificate'] - self.plugin_name = kwargs['plugin']['slug'] - self.options = kwargs.get('options') + self.owner = kwargs["owner"] + self.roles = kwargs.get("roles", []) + self.name = kwargs.get("name") + self.description = kwargs.get("description") + self.authority_certificate = kwargs["authority_certificate"] + self.plugin_name = kwargs["plugin"]["slug"] + self.options = kwargs.get("options") @property def plugin(self): diff --git a/lemur/authorities/schemas.py b/lemur/authorities/schemas.py index d1f0adfc..c78aec94 100644 --- a/lemur/authorities/schemas.py +++ b/lemur/authorities/schemas.py @@ -11,7 +11,13 @@ from marshmallow import fields, validates_schema, pre_load from marshmallow import validate from marshmallow.exceptions import ValidationError -from lemur.schemas import PluginInputSchema, PluginOutputSchema, ExtensionSchema, AssociatedAuthoritySchema, AssociatedRoleSchema +from lemur.schemas import ( + PluginInputSchema, + PluginOutputSchema, + ExtensionSchema, + AssociatedAuthoritySchema, + AssociatedRoleSchema, +) from lemur.users.schemas import UserNestedOutputSchema from lemur.common.schema import LemurInputSchema, LemurOutputSchema from lemur.common import validators, missing @@ -30,21 +36,36 @@ class AuthorityInputSchema(LemurInputSchema): validity_years = fields.Integer() # certificate body fields - organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT')) - organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION')) - location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION')) - country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY')) - state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE')) + organizational_unit = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT") + ) + organization = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION") + ) + location = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION") + ) + country = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY") + ) + state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE")) plugin = fields.Nested(PluginInputSchema) # signing related options - type = fields.String(validate=validate.OneOf(['root', 'subca']), missing='root') + type = fields.String(validate=validate.OneOf(["root", "subca"]), missing="root") parent = fields.Nested(AssociatedAuthoritySchema) - signing_algorithm = fields.String(validate=validate.OneOf(['sha256WithRSA', 'sha1WithRSA']), missing='sha256WithRSA') - key_type = fields.String(validate=validate.OneOf(['RSA2048', 'RSA4096']), missing='RSA2048') + signing_algorithm = fields.String( + validate=validate.OneOf(["sha256WithRSA", "sha1WithRSA"]), + missing="sha256WithRSA", + ) + key_type = fields.String( + validate=validate.OneOf(["RSA2048", "RSA4096"]), missing="RSA2048" + ) key_name = fields.String() - sensitivity = fields.String(validate=validate.OneOf(['medium', 'high']), missing='medium') + sensitivity = fields.String( + validate=validate.OneOf(["medium", "high"]), missing="medium" + ) serial_number = fields.Integer() first_serial = fields.Integer(missing=1) @@ -58,9 +79,11 @@ class AuthorityInputSchema(LemurInputSchema): @validates_schema def validate_subca(self, data): - if data['type'] == 'subca': - if not data.get('parent'): - raise ValidationError("If generating a subca, parent 'authority' must be specified.") + if data["type"] == "subca": + if not data.get("parent"): + raise ValidationError( + "If generating a subca, parent 'authority' must be specified." + ) @pre_load def ensure_dates(self, data): diff --git a/lemur/authorities/service.py b/lemur/authorities/service.py index 024cb42a..c70c6fc5 100644 --- a/lemur/authorities/service.py +++ b/lemur/authorities/service.py @@ -15,6 +15,7 @@ from lemur import database from lemur.common.utils import truthiness from lemur.extensions import metrics from lemur.authorities.models import Authority +from lemur.certificates.models import Certificate from lemur.roles import service as role_service from lemur.certificates.service import upload @@ -42,7 +43,7 @@ def mint(**kwargs): """ Creates the authority based on the plugin provided. """ - issuer = kwargs['plugin']['plugin_object'] + issuer = kwargs["plugin"]["plugin_object"] values = issuer.create_authority(kwargs) # support older plugins @@ -52,7 +53,12 @@ def mint(**kwargs): elif len(values) == 4: body, private_key, chain, roles = values - roles = create_authority_roles(roles, kwargs['owner'], kwargs['plugin']['plugin_object'].title, kwargs['creator']) + roles = create_authority_roles( + roles, + kwargs["owner"], + kwargs["plugin"]["plugin_object"].title, + kwargs["creator"], + ) return body, private_key, chain, roles @@ -65,16 +71,17 @@ def create_authority_roles(roles, owner, plugin_title, creator): """ role_objs = [] for r in roles: - role = role_service.get_by_name(r['name']) + role = role_service.get_by_name(r["name"]) if not role: role = role_service.create( - r['name'], - password=r['password'], + r["name"], + password=r["password"], description="Auto generated role for {0}".format(plugin_title), - username=r['username']) + username=r["username"], + ) # the user creating the authority should be able to administer it - if role.username == 'admin': + if role.username == "admin": creator.roles.append(role) role_objs.append(role) @@ -83,8 +90,7 @@ def create_authority_roles(roles, owner, plugin_title, creator): owner_role = role_service.get_by_name(owner) if not owner_role: owner_role = role_service.create( - owner, - description="Auto generated role based on owner: {0}".format(owner) + owner, description="Auto generated role based on owner: {0}".format(owner) ) role_objs.append(owner_role) @@ -97,27 +103,29 @@ def create(**kwargs): """ body, private_key, chain, roles = mint(**kwargs) - kwargs['creator'].roles = list(set(list(kwargs['creator'].roles) + roles)) + kwargs["creator"].roles = list(set(list(kwargs["creator"].roles) + roles)) - kwargs['body'] = body - kwargs['private_key'] = private_key - kwargs['chain'] = chain + kwargs["body"] = body + kwargs["private_key"] = private_key + kwargs["chain"] = chain - if kwargs.get('roles'): - kwargs['roles'] += roles + if kwargs.get("roles"): + kwargs["roles"] += roles else: - kwargs['roles'] = roles + kwargs["roles"] = roles cert = upload(**kwargs) - kwargs['authority_certificate'] = cert - if kwargs.get('plugin', {}).get('plugin_options', []): - kwargs['options'] = json.dumps(kwargs['plugin']['plugin_options']) + kwargs["authority_certificate"] = cert + if kwargs.get("plugin", {}).get("plugin_options", []): + kwargs["options"] = json.dumps(kwargs["plugin"]["plugin_options"]) authority = Authority(**kwargs) authority = database.create(authority) - kwargs['creator'].authorities.append(authority) + kwargs["creator"].authorities.append(authority) - metrics.send('authority_created', 'counter', 1, metric_tags=dict(owner=authority.owner)) + metrics.send( + "authority_created", "counter", 1, metric_tags=dict(owner=authority.owner) + ) return authority @@ -149,7 +157,7 @@ def get_by_name(authority_name): :param authority_name: :return: """ - return database.get(Authority, authority_name, field='name') + return database.get(Authority, authority_name, field="name") def get_authority_role(ca_name, creator=None): @@ -172,24 +180,31 @@ def render(args): :return: """ query = database.session_query(Authority) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') - if 'active' in filt: + terms = filt.split(";") + if "active" in filt: query = query.filter(Authority.active == truthiness(terms[1])) - elif 'cn' in filt: - query = query.join(Authority.active == truthiness(terms[1])) + elif "cn" in filt: + term = "%{0}%".format(terms[1]) + sub_query = ( + database.session_query(Certificate.root_authority_id) + .filter(Certificate.cn.ilike(term)) + .subquery() + ) + + query = query.filter(Authority.id.in_(sub_query)) else: query = database.filter(query, Authority, terms) # we make sure that a user can only use an authority they either own are a member of - admins can see all - if not args['user'].is_admin: + if not args["user"].is_admin: authority_ids = [] - for authority in args['user'].authorities: + for authority in args["user"].authorities: authority_ids.append(authority.id) - for role in args['user'].roles: + for role in args["user"].roles: for authority in role.authorities: authority_ids.append(authority.id) query = query.filter(Authority.id.in_(authority_ids)) diff --git a/lemur/authorities/views.py b/lemur/authorities/views.py index b85c9b70..49bce63e 100644 --- a/lemur/authorities/views.py +++ b/lemur/authorities/views.py @@ -16,15 +16,21 @@ from lemur.auth.permissions import AuthorityPermission from lemur.certificates import service as certificate_service from lemur.authorities import service -from lemur.authorities.schemas import authority_input_schema, authority_output_schema, authorities_output_schema, authority_update_schema +from lemur.authorities.schemas import ( + authority_input_schema, + authority_output_schema, + authorities_output_schema, + authority_update_schema, +) -mod = Blueprint('authorities', __name__) +mod = Blueprint("authorities", __name__) api = Api(mod) class AuthoritiesList(AuthenticatedResource): """ Defines the 'authorities' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(AuthoritiesList, self).__init__() @@ -107,7 +113,7 @@ class AuthoritiesList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['user'] = g.current_user + args["user"] = g.current_user return service.render(args) @validate_schema(authority_input_schema, authority_output_schema) @@ -220,7 +226,7 @@ class AuthoritiesList(AuthenticatedResource): :statuscode 403: unauthenticated :statuscode 200: no error """ - data['creator'] = g.current_user + data["creator"] = g.current_user return service.create(**data) @@ -388,7 +394,7 @@ class Authorities(AuthenticatedResource): authority = service.get(authority_id) if not authority: - return dict(message='Not Found'), 404 + return dict(message="Not Found"), 404 # all the authority role members should be allowed roles = [x.name for x in authority.roles] @@ -397,10 +403,10 @@ class Authorities(AuthenticatedResource): if permission.can(): return service.update( authority_id, - owner=data['owner'], - description=data['description'], - active=data['active'], - roles=data['roles'] + owner=data["owner"], + description=data["description"], + active=data["active"], + roles=data["roles"], ) return dict(message="You are not authorized to update this authority."), 403 @@ -505,10 +511,21 @@ class AuthorityVisualizations(AuthenticatedResource): ]} """ authority = service.get(authority_id) - return dict(name=authority.name, children=[{"name": c.name} for c in authority.certificates]) + return dict( + name=authority.name, + children=[{"name": c.name} for c in authority.certificates], + ) -api.add_resource(AuthoritiesList, '/authorities', endpoint='authorities') -api.add_resource(Authorities, '/authorities/', endpoint='authority') -api.add_resource(AuthorityVisualizations, '/authorities//visualize', endpoint='authority_visualizations') -api.add_resource(CertificateAuthority, '/certificates//authority', endpoint='certificateAuthority') +api.add_resource(AuthoritiesList, "/authorities", endpoint="authorities") +api.add_resource(Authorities, "/authorities/", endpoint="authority") +api.add_resource( + AuthorityVisualizations, + "/authorities//visualize", + endpoint="authority_visualizations", +) +api.add_resource( + CertificateAuthority, + "/certificates//authority", + endpoint="certificateAuthority", +) diff --git a/lemur/authorizations/models.py b/lemur/authorizations/models.py index d30de7ed..04ac0508 100644 --- a/lemur/authorizations/models.py +++ b/lemur/authorizations/models.py @@ -13,7 +13,7 @@ from lemur.plugins.base import plugins class Authorization(db.Model): - __tablename__ = 'pending_dns_authorizations' + __tablename__ = "pending_dns_authorizations" id = Column(Integer, primary_key=True, autoincrement=True) account_number = Column(String(128)) domains = Column(JSONType) diff --git a/lemur/certificates/cli.py b/lemur/certificates/cli.py index c4a95187..b57ff175 100644 --- a/lemur/certificates/cli.py +++ b/lemur/certificates/cli.py @@ -34,7 +34,7 @@ from lemur.certificates.service import ( get_all_pending_reissue, get_by_name, get_all_certs, - get + get, ) from lemur.certificates.verify import verify_string @@ -56,11 +56,14 @@ def print_certificate_details(details): "\t[+] Authority: {authority_name}\n" "\t[+] Validity Start: {validity_start}\n" "\t[+] Validity End: {validity_end}\n".format( - common_name=details['commonName'], - sans=",".join(x['value'] for x in details['extensions']['subAltNames']['names']) or None, - authority_name=details['authority']['name'], - validity_start=details['validityStart'], - validity_end=details['validityEnd'] + common_name=details["commonName"], + sans=",".join( + x["value"] for x in details["extensions"]["subAltNames"]["names"] + ) + or None, + authority_name=details["authority"]["name"], + validity_start=details["validityStart"], + validity_end=details["validityEnd"], ) ) @@ -120,13 +123,11 @@ def request_rotation(endpoint, certificate, message, commit): except Exception as e: print( "[!] Failed to rotate endpoint {0} to certificate {1} reason: {2}".format( - endpoint.name, - certificate.name, - e + endpoint.name, certificate.name, e ) ) - metrics.send('endpoint_rotation', 'counter', 1, metric_tags={'status': status}) + metrics.send("endpoint_rotation", "counter", 1, metric_tags={"status": status}) def request_reissue(certificate, commit): @@ -153,22 +154,53 @@ def request_reissue(certificate, commit): status = SUCCESS_METRIC_STATUS except Exception as e: - sentry.captureException() - current_app.logger.exception("Error reissuing certificate.", exc_info=True) - print( - "[!] Failed to reissue certificates. Reason: {}".format( - e - ) + sentry.captureException(extra={"certificate_name": str(certificate.name)}) + current_app.logger.exception( + f"Error reissuing certificate: {certificate.name}", exc_info=True ) + print(f"[!] Failed to reissue certificate: {certificate.name}. Reason: {e}") - metrics.send('certificate_reissue', 'counter', 1, metric_tags={'status': status}) + metrics.send( + "certificate_reissue", + "counter", + 1, + metric_tags={"status": status, "certificate": certificate.name}, + ) -@manager.option('-e', '--endpoint', dest='endpoint_name', help='Name of the endpoint you wish to rotate.') -@manager.option('-n', '--new-certificate', dest='new_certificate_name', help='Name of the certificate you wish to rotate to.') -@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to rotate.') -@manager.option('-a', '--notify', dest='message', action='store_true', help='Send a rotation notification to the certificates owner.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-e", + "--endpoint", + dest="endpoint_name", + help="Name of the endpoint you wish to rotate.", +) +@manager.option( + "-n", + "--new-certificate", + dest="new_certificate_name", + help="Name of the certificate you wish to rotate to.", +) +@manager.option( + "-o", + "--old-certificate", + dest="old_certificate_name", + help="Name of the certificate you wish to rotate.", +) +@manager.option( + "-a", + "--notify", + dest="message", + action="store_true", + help="Send a rotation notification to the certificates owner.", +) +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, commit): """ Rotates an endpoint and reissues it if it has not already been replaced. If it has @@ -187,39 +219,90 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c endpoint = validate_endpoint(endpoint_name) if endpoint and new_cert: - print("[+] Rotating endpoint: {0} to certificate {1}".format(endpoint.name, new_cert.name)) + print( + f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}" + ) request_rotation(endpoint, new_cert, message, commit) elif old_cert and new_cert: - print("[+] Rotating all endpoints from {0} to {1}".format(old_cert.name, new_cert.name)) + print(f"[+] Rotating all endpoints from {old_cert.name} to {new_cert.name}") for endpoint in old_cert.endpoints: - print("[+] Rotating {0}".format(endpoint.name)) + print(f"[+] Rotating {endpoint.name}") request_rotation(endpoint, new_cert, message, commit) else: print("[+] Rotating all endpoints that have new certificates available") for endpoint in endpoint_service.get_all_pending_rotation(): if len(endpoint.certificate.replaced) == 1: - print("[+] Rotating {0} to {1}".format(endpoint.name, endpoint.certificate.replaced[0].name)) - request_rotation(endpoint, endpoint.certificate.replaced[0], message, commit) + print( + f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}" + ) + request_rotation( + endpoint, endpoint.certificate.replaced[0], message, commit + ) else: - metrics.send('endpoint_rotation', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - print("[!] Failed to rotate endpoint {0} reason: Multiple replacement certificates found.".format( - endpoint.name - )) + metrics.send( + "endpoint_rotation", + "counter", + 1, + metric_tags={ + "status": FAILURE_METRIC_STATUS, + "old_certificate_name": str(old_cert), + "new_certificate_name": str( + endpoint.certificate.replaced[0].name + ), + "endpoint_name": str(endpoint.name), + "message": str(message), + }, + ) + print( + f"[!] Failed to rotate endpoint {endpoint.name} reason: " + "Multiple replacement certificates found." + ) status = SUCCESS_METRIC_STATUS print("[+] Done!") except Exception as e: - sentry.captureException() + sentry.captureException( + extra={ + "old_certificate_name": str(old_certificate_name), + "new_certificate_name": str(new_certificate_name), + "endpoint_name": str(endpoint_name), + "message": str(message), + } + ) - metrics.send('endpoint_rotation_job', 'counter', 1, metric_tags={'status': status}) + metrics.send( + "endpoint_rotation_job", + "counter", + 1, + metric_tags={ + "status": status, + "old_certificate_name": str(old_certificate_name), + "new_certificate_name": str(new_certificate_name), + "endpoint_name": str(endpoint_name), + "message": str(message), + "endpoint": str(globals().get("endpoint")), + }, + ) -@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to reissue.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-o", + "--old-certificate", + dest="old_certificate_name", + help="Name of the certificate you wish to reissue.", +) +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def reissue(old_certificate_name, commit): """ Reissues certificate with the same parameters as it was originally issued with. @@ -247,76 +330,94 @@ def reissue(old_certificate_name, commit): except Exception as e: sentry.captureException() current_app.logger.exception("Error reissuing certificate.", exc_info=True) - print( - "[!] Failed to reissue certificates. Reason: {}".format( - e - ) - ) + print("[!] Failed to reissue certificates. Reason: {}".format(e)) - metrics.send('certificate_reissue_job', 'counter', 1, metric_tags={'status': status}) + metrics.send( + "certificate_reissue_job", "counter", 1, metric_tags={"status": status} + ) -@manager.option('-f', '--fqdns', dest='fqdns', help='FQDNs to query. Multiple fqdns specified via comma.') -@manager.option('-i', '--issuer', dest='issuer', help='Issuer to query for.') -@manager.option('-o', '--owner', dest='owner', help='Owner to query for.') -@manager.option('-e', '--expired', dest='expired', type=bool, default=False, help='Include expired certificates.') +@manager.option( + "-f", + "--fqdns", + dest="fqdns", + help="FQDNs to query. Multiple fqdns specified via comma.", +) +@manager.option("-i", "--issuer", dest="issuer", help="Issuer to query for.") +@manager.option("-o", "--owner", dest="owner", help="Owner to query for.") +@manager.option( + "-e", + "--expired", + dest="expired", + type=bool, + default=False, + help="Include expired certificates.", +) def query(fqdns, issuer, owner, expired): """Prints certificates that match the query params.""" table = [] q = database.session_query(Certificate) if issuer: - sub_query = database.session_query(Authority.id) \ - .filter(Authority.name.ilike('%{0}%'.format(issuer))) \ + sub_query = ( + database.session_query(Authority.id) + .filter(Authority.name.ilike("%{0}%".format(issuer))) .subquery() + ) q = q.filter( or_( - Certificate.issuer.ilike('%{0}%'.format(issuer)), - Certificate.authority_id.in_(sub_query) + Certificate.issuer.ilike("%{0}%".format(issuer)), + Certificate.authority_id.in_(sub_query), ) ) if owner: - q = q.filter(Certificate.owner.ilike('%{0}%'.format(owner))) + q = q.filter(Certificate.owner.ilike("%{0}%".format(owner))) if not expired: q = q.filter(Certificate.expired == False) # noqa if fqdns: - for f in fqdns.split(','): + for f in fqdns.split(","): q = q.filter( or_( - Certificate.cn.ilike('%{0}%'.format(f)), - Certificate.domains.any(Domain.name.ilike('%{0}%'.format(f))) + Certificate.cn.ilike("%{0}%".format(f)), + Certificate.domains.any(Domain.name.ilike("%{0}%".format(f))), ) ) for c in q.all(): table.append([c.id, c.name, c.owner, c.issuer]) - print(tabulate(table, headers=['Id', 'Name', 'Owner', 'Issuer'], tablefmt='csv')) + print(tabulate(table, headers=["Id", "Name", "Owner", "Issuer"], tablefmt="csv")) def worker(data, commit, reason): - parts = [x for x in data.split(' ') if x] + parts = [x for x in data.split(" ") if x] try: cert = get(int(parts[0].strip())) plugin = plugins.get(cert.authority.plugin_name) - print('[+] Revoking certificate. Id: {0} Name: {1}'.format(cert.id, cert.name)) + print("[+] Revoking certificate. Id: {0} Name: {1}".format(cert.id, cert.name)) if commit: plugin.revoke_certificate(cert, reason) - metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "certificate_revoke", + "counter", + 1, + metric_tags={"status": SUCCESS_METRIC_STATUS}, + ) except Exception as e: sentry.captureException() - metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - print( - "[!] Failed to revoke certificates. Reason: {}".format( - e - ) + metrics.send( + "certificate_revoke", + "counter", + 1, + metric_tags={"status": FAILURE_METRIC_STATUS}, ) + print("[!] Failed to revoke certificates. Reason: {}".format(e)) @manager.command @@ -325,13 +426,22 @@ def clear_pending(): Function clears all pending certificates. :return: """ - v = plugins.get('verisign-issuer') + v = plugins.get("verisign-issuer") v.clear_pending_certificates() -@manager.option('-p', '--path', dest='path', help='Absolute file path to a Lemur query csv.') -@manager.option('-r', '--reason', dest='reason', help='Reason to revoke certificate.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-p", "--path", dest="path", help="Absolute file path to a Lemur query csv." +) +@manager.option("-r", "--reason", dest="reason", help="Reason to revoke certificate.") +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def revoke(path, reason, commit): """ Revokes given certificate. @@ -341,7 +451,7 @@ def revoke(path, reason, commit): print("[+] Starting certificate revocation.") - with open(path, 'r') as f: + with open(path, "r") as f: args = [[x, commit, reason] for x in f.readlines()[2:]] with multiprocessing.Pool(processes=3) as pool: @@ -364,11 +474,11 @@ def check_revoked(): else: status = verify_string(cert.body, "") - cert.status = 'valid' if status else 'revoked' + cert.status = "valid" if status else "revoked" except Exception as e: sentry.captureException() current_app.logger.exception(e) - cert.status = 'unknown' + cert.status = "unknown" database.update(cert) diff --git a/lemur/certificates/hooks.py b/lemur/certificates/hooks.py index 16f6c3b0..93409bb4 100644 --- a/lemur/certificates/hooks.py +++ b/lemur/certificates/hooks.py @@ -12,21 +12,30 @@ import subprocess from flask import current_app -from lemur.certificates.service import csr_created, csr_imported, certificate_issued, certificate_imported +from lemur.certificates.service import ( + csr_created, + csr_imported, + certificate_issued, + certificate_imported, +) def csr_dump_handler(sender, csr, **kwargs): try: - subprocess.run(['openssl', 'req', '-text', '-noout', '-reqopt', 'no_sigdump,no_pubkey'], - input=csr.encode('utf8')) + subprocess.run( + ["openssl", "req", "-text", "-noout", "-reqopt", "no_sigdump,no_pubkey"], + input=csr.encode("utf8"), + ) except Exception as err: current_app.logger.warning("Error inspecting CSR: %s", err) def cert_dump_handler(sender, certificate, **kwargs): try: - subprocess.run(['openssl', 'x509', '-text', '-noout', '-certopt', 'no_sigdump,no_pubkey'], - input=certificate.body.encode('utf8')) + subprocess.run( + ["openssl", "x509", "-text", "-noout", "-certopt", "no_sigdump,no_pubkey"], + input=certificate.body.encode("utf8"), + ) except Exception as err: current_app.logger.warning("Error inspecting certificate: %s", err) diff --git a/lemur/certificates/models.py b/lemur/certificates/models.py index e2ac2cba..0a76cd6b 100644 --- a/lemur/certificates/models.py +++ b/lemur/certificates/models.py @@ -12,32 +12,49 @@ from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import rsa from flask import current_app from idna.core import InvalidCodepoint -from sqlalchemy import event, Integer, ForeignKey, String, PassiveDefault, func, Column, Text, Boolean, Index +from sqlalchemy import ( + event, + Integer, + ForeignKey, + String, + PassiveDefault, + func, + Column, + Text, + Boolean, + Index, +) from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship from sqlalchemy.sql.expression import case, extract from sqlalchemy_utils.types.arrow import ArrowType from werkzeug.utils import cached_property -from lemur.common import defaults, utils +from lemur.common import defaults, utils, validators from lemur.constants import SUCCESS_METRIC_STATUS, FAILURE_METRIC_STATUS from lemur.database import db from lemur.domains.models import Domain from lemur.extensions import metrics from lemur.extensions import sentry -from lemur.models import certificate_associations, certificate_source_associations, \ - certificate_destination_associations, certificate_notification_associations, \ - certificate_replacement_associations, roles_certificates, pending_cert_replacement_associations +from lemur.models import ( + certificate_associations, + certificate_source_associations, + certificate_destination_associations, + certificate_notification_associations, + certificate_replacement_associations, + roles_certificates, + pending_cert_replacement_associations, +) from lemur.plugins.base import plugins from lemur.policies.models import RotationPolicy from lemur.utils import Vault def get_sequence(name): - if '-' not in name: + if "-" not in name: return name, None - parts = name.split('-') + parts = name.split("-") # see if we have an int at the end of our name try: @@ -49,22 +66,26 @@ def get_sequence(name): if len(parts[-1]) == 8: return name, None - root = '-'.join(parts[:-1]) + root = "-".join(parts[:-1]) return root, seq def get_or_increase_name(name, serial): - certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(name))).all() + certificates = Certificate.query.filter(Certificate.name == name).all() if not certificates: return name - serial_name = '{0}-{1}'.format(name, hex(int(serial))[2:].upper()) - certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(serial_name))).all() + serial_name = "{0}-{1}".format(name, hex(int(serial))[2:].upper()) + certificates = Certificate.query.filter(Certificate.name == serial_name).all() if not certificates: return serial_name + certificates = Certificate.query.filter( + Certificate.name.ilike("{0}%".format(serial_name)) + ).all() + ends = [0] root, end = get_sequence(serial_name) for cert in certificates: @@ -72,21 +93,29 @@ def get_or_increase_name(name, serial): if end: ends.append(end) - return '{0}-{1}'.format(root, max(ends) + 1) + return "{0}-{1}".format(root, max(ends) + 1) class Certificate(db.Model): - __tablename__ = 'certificates' + __tablename__ = "certificates" __table_args__ = ( - Index('ix_certificates_cn', "cn", - postgresql_ops={"cn": "gin_trgm_ops"}, - postgresql_using='gin'), - Index('ix_certificates_name', "name", - postgresql_ops={"name": "gin_trgm_ops"}, - postgresql_using='gin'), + Index( + "ix_certificates_cn", + "cn", + postgresql_ops={"cn": "gin_trgm_ops"}, + postgresql_using="gin", + ), + Index( + "ix_certificates_name", + "name", + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ), ) id = Column(Integer, primary_key=True) - ix = Index('ix_certificates_id_desc', id.desc(), postgresql_using='btree', unique=True) + ix = Index( + "ix_certificates_id_desc", id.desc(), postgresql_using="btree", unique=True + ) external_id = Column(String(128)) owner = Column(String(128), nullable=False) name = Column(String(256), unique=True) @@ -101,11 +130,15 @@ class Certificate(db.Model): issuer = Column(String(128)) serial = Column(String(128)) cn = Column(String(128)) - deleted = Column(Boolean, index=True) - dns_provider_id = Column(Integer(), ForeignKey('dns_providers.id', ondelete='CASCADE'), nullable=True) + deleted = Column(Boolean, index=True, default=False) + dns_provider_id = Column( + Integer(), ForeignKey("dns_providers.id", ondelete="CASCADE"), nullable=True + ) not_before = Column(ArrowType) not_after = Column(ArrowType) + not_after_ix = Index("ix_certificates_not_after", not_after.desc()) + date_created = Column(ArrowType, PassiveDefault(func.now()), nullable=False) signing_algorithm = Column(String(128)) @@ -114,34 +147,53 @@ class Certificate(db.Model): san = Column(String(1024)) # TODO this should be migrated to boolean rotation = Column(Boolean, default=False) - user_id = Column(Integer, ForeignKey('users.id')) - authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - root_authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - rotation_policy_id = Column(Integer, ForeignKey('rotation_policies.id')) + user_id = Column(Integer, ForeignKey("users.id")) + authority_id = Column(Integer, ForeignKey("authorities.id", ondelete="CASCADE")) + root_authority_id = Column( + Integer, ForeignKey("authorities.id", ondelete="CASCADE") + ) + rotation_policy_id = Column(Integer, ForeignKey("rotation_policies.id")) - notifications = relationship('Notification', secondary=certificate_notification_associations, backref='certificate') - destinations = relationship('Destination', secondary=certificate_destination_associations, backref='certificate') - sources = relationship('Source', secondary=certificate_source_associations, backref='certificate') - domains = relationship('Domain', secondary=certificate_associations, backref='certificate') - roles = relationship('Role', secondary=roles_certificates, backref='certificate') - replaces = relationship('Certificate', - secondary=certificate_replacement_associations, - primaryjoin=id == certificate_replacement_associations.c.certificate_id, # noqa - secondaryjoin=id == certificate_replacement_associations.c.replaced_certificate_id, # noqa - backref='replaced') + notifications = relationship( + "Notification", + secondary=certificate_notification_associations, + backref="certificate", + ) + destinations = relationship( + "Destination", + secondary=certificate_destination_associations, + backref="certificate", + ) + sources = relationship( + "Source", secondary=certificate_source_associations, backref="certificate" + ) + domains = relationship( + "Domain", secondary=certificate_associations, backref="certificate" + ) + roles = relationship("Role", secondary=roles_certificates, backref="certificate") + replaces = relationship( + "Certificate", + secondary=certificate_replacement_associations, + primaryjoin=id == certificate_replacement_associations.c.certificate_id, # noqa + secondaryjoin=id + == certificate_replacement_associations.c.replaced_certificate_id, # noqa + backref="replaced", + ) - replaced_by_pending = relationship('PendingCertificate', - secondary=pending_cert_replacement_associations, - backref='pending_replace', - viewonly=True) + replaced_by_pending = relationship( + "PendingCertificate", + secondary=pending_cert_replacement_associations, + backref="pending_replace", + viewonly=True, + ) - logs = relationship('Log', backref='certificate') - endpoints = relationship('Endpoint', backref='certificate') + logs = relationship("Log", backref="certificate") + endpoints = relationship("Endpoint", backref="certificate") rotation_policy = relationship("RotationPolicy") - sensitive_fields = ('private_key',) + sensitive_fields = ("private_key",) def __init__(self, **kwargs): - self.body = kwargs['body'].strip() + self.body = kwargs["body"].strip() cert = self.parsed_cert self.issuer = defaults.issuer(cert) @@ -152,40 +204,65 @@ class Certificate(db.Model): self.serial = defaults.serial(cert) # when destinations are appended they require a valid name. - if kwargs.get('name'): - self.name = get_or_increase_name(defaults.text_to_slug(kwargs['name']), self.serial) + if kwargs.get("name"): + self.name = get_or_increase_name( + defaults.text_to_slug(kwargs["name"]), self.serial + ) else: self.name = get_or_increase_name( - defaults.certificate_name(self.cn, self.issuer, self.not_before, self.not_after, self.san), self.serial) + defaults.certificate_name( + self.cn, self.issuer, self.not_before, self.not_after, self.san + ), + self.serial, + ) - self.owner = kwargs['owner'] + self.owner = kwargs["owner"] - if kwargs.get('private_key'): - self.private_key = kwargs['private_key'].strip() + if kwargs.get("private_key"): + self.private_key = kwargs["private_key"].strip() - if kwargs.get('chain'): - self.chain = kwargs['chain'].strip() + if kwargs.get("chain"): + self.chain = kwargs["chain"].strip() - if kwargs.get('csr'): - self.csr = kwargs['csr'].strip() + if kwargs.get("csr"): + self.csr = kwargs["csr"].strip() - self.notify = kwargs.get('notify', True) - self.destinations = kwargs.get('destinations', []) - self.notifications = kwargs.get('notifications', []) - self.description = kwargs.get('description') - self.roles = list(set(kwargs.get('roles', []))) - self.replaces = kwargs.get('replaces', []) - self.rotation = kwargs.get('rotation') - self.rotation_policy = kwargs.get('rotation_policy') + self.notify = kwargs.get("notify", True) + self.destinations = kwargs.get("destinations", []) + self.notifications = kwargs.get("notifications", []) + self.description = kwargs.get("description") + self.roles = list(set(kwargs.get("roles", []))) + self.replaces = kwargs.get("replaces", []) + self.rotation = kwargs.get("rotation") + self.rotation_policy = kwargs.get("rotation_policy") self.signing_algorithm = defaults.signing_algorithm(cert) self.bits = defaults.bitstrength(cert) - self.external_id = kwargs.get('external_id') - self.authority_id = kwargs.get('authority_id') - self.dns_provider_id = kwargs.get('dns_provider_id') + self.external_id = kwargs.get("external_id") + self.authority_id = kwargs.get("authority_id") + self.dns_provider_id = kwargs.get("dns_provider_id") for domain in defaults.domains(cert): self.domains.append(Domain(name=domain)) + # Check integrity before saving anything into the database. + # For user-facing API calls, validation should also be done in schema validators. + self.check_integrity() + + def check_integrity(self): + """ + Integrity checks: Does the cert have a valid chain and matching private key? + """ + if self.private_key: + validators.verify_private_key_match( + utils.parse_private_key(self.private_key), + self.parsed_cert, + error_class=AssertionError, + ) + + if self.chain: + chain = [self.parsed_cert] + utils.parse_cert_chain(self.chain) + validators.verify_cert_chain(chain, error_class=AssertionError) + @cached_property def parsed_cert(self): assert self.body, "Certificate body not set" @@ -215,10 +292,16 @@ class Certificate(db.Model): def location(self): return defaults.location(self.parsed_cert) + @property + def distinguished_name(self): + return self.parsed_cert.subject.rfc4514_string() + @property def key_type(self): if isinstance(self.parsed_cert.public_key(), rsa.RSAPublicKey): - return 'RSA{key_size}'.format(key_size=self.parsed_cert.public_key().key_size) + return "RSA{key_size}".format( + key_size=self.parsed_cert.public_key().key_size + ) @property def validity_remaining(self): @@ -243,26 +326,24 @@ class Certificate(db.Model): @expired.expression def expired(cls): - return case( - [ - (cls.not_after <= arrow.utcnow(), True) - ], - else_=False - ) + return case([(cls.not_after <= arrow.utcnow(), True)], else_=False) @hybrid_property def revoked(self): - if 'revoked' == self.status: + if "revoked" == self.status: return True @revoked.expression def revoked(cls): - return case( - [ - (cls.status == 'revoked', True) - ], - else_=False - ) + return case([(cls.status == "revoked", True)], else_=False) + + @hybrid_property + def has_private_key(self): + return self.private_key is not None + + @has_private_key.expression + def has_private_key(cls): + return case([(cls.private_key.is_(None), True)], else_=False) @hybrid_property def in_rotation_window(self): @@ -285,66 +366,65 @@ class Certificate(db.Model): :return: """ return case( - [ - (extract('day', cls.not_after - func.now()) <= RotationPolicy.days, True) - ], - else_=False + [(extract("day", cls.not_after - func.now()) <= RotationPolicy.days, True)], + else_=False, ) @property def extensions(self): # setup default values - return_extensions = { - 'sub_alt_names': {'names': []} - } + return_extensions = {"sub_alt_names": {"names": []}} try: for extension in self.parsed_cert.extensions: value = extension.value if isinstance(value, x509.BasicConstraints): - return_extensions['basic_constraints'] = value + return_extensions["basic_constraints"] = value elif isinstance(value, x509.SubjectAlternativeName): - return_extensions['sub_alt_names']['names'] = value + return_extensions["sub_alt_names"]["names"] = value elif isinstance(value, x509.ExtendedKeyUsage): - return_extensions['extended_key_usage'] = value + return_extensions["extended_key_usage"] = value elif isinstance(value, x509.KeyUsage): - return_extensions['key_usage'] = value + return_extensions["key_usage"] = value elif isinstance(value, x509.SubjectKeyIdentifier): - return_extensions['subject_key_identifier'] = {'include_ski': True} + return_extensions["subject_key_identifier"] = {"include_ski": True} elif isinstance(value, x509.AuthorityInformationAccess): - return_extensions['certificate_info_access'] = {'include_aia': True} + return_extensions["certificate_info_access"] = {"include_aia": True} elif isinstance(value, x509.AuthorityKeyIdentifier): - aki = { - 'use_key_identifier': False, - 'use_authority_cert': False - } + aki = {"use_key_identifier": False, "use_authority_cert": False} if value.key_identifier: - aki['use_key_identifier'] = True + aki["use_key_identifier"] = True if value.authority_cert_issuer: - aki['use_authority_cert'] = True + aki["use_authority_cert"] = True - return_extensions['authority_key_identifier'] = aki + return_extensions["authority_key_identifier"] = aki elif isinstance(value, x509.CRLDistributionPoints): - return_extensions['crl_distribution_points'] = {'include_crl_dp': value} + return_extensions["crl_distribution_points"] = { + "include_crl_dp": value + } # TODO: Not supporting custom OIDs yet. https://github.com/Netflix/lemur/issues/665 else: - current_app.logger.warning('Custom OIDs not yet supported for clone operation.') + current_app.logger.warning( + "Custom OIDs not yet supported for clone operation." + ) except InvalidCodepoint as e: sentry.captureException() - current_app.logger.warning('Unable to parse extensions due to underscore in dns name') + current_app.logger.warning( + "Unable to parse extensions due to underscore in dns name" + ) except ValueError as e: sentry.captureException() - current_app.logger.warning('Unable to parse') + current_app.logger.warning("Unable to parse") current_app.logger.exception(e) return return_extensions @@ -353,7 +433,7 @@ class Certificate(db.Model): return "Certificate(name={name})".format(name=self.name) -@event.listens_for(Certificate.destinations, 'append') +@event.listens_for(Certificate.destinations, "append") def update_destinations(target, value, initiator): """ Attempt to upload certificate to the new destination @@ -367,17 +447,31 @@ def update_destinations(target, value, initiator): status = FAILURE_METRIC_STATUS try: if target.private_key or not destination_plugin.requires_key: - destination_plugin.upload(target.name, target.body, target.private_key, target.chain, value.options) + destination_plugin.upload( + target.name, + target.body, + target.private_key, + target.chain, + value.options, + ) status = SUCCESS_METRIC_STATUS except Exception as e: sentry.captureException() raise - metrics.send('destination_upload', 'counter', 1, - metric_tags={'status': status, 'certificate': target.name, 'destination': value.label}) + metrics.send( + "destination_upload", + "counter", + 1, + metric_tags={ + "status": status, + "certificate": target.name, + "destination": value.label, + }, + ) -@event.listens_for(Certificate.replaces, 'append') +@event.listens_for(Certificate.replaces, "append") def update_replacement(target, value, initiator): """ When a certificate is marked as 'replaced' we should not notify. diff --git a/lemur/certificates/schemas.py b/lemur/certificates/schemas.py index bf18eac9..8f15542d 100644 --- a/lemur/certificates/schemas.py +++ b/lemur/certificates/schemas.py @@ -6,11 +6,14 @@ .. moduleauthor:: Kevin Glisson """ from flask import current_app +from flask_restful import inputs +from flask_restful.reqparse import RequestParser from marshmallow import fields, validate, validates_schema, post_load, pre_load from marshmallow.exceptions import ValidationError from lemur.authorities.schemas import AuthorityNestedOutputSchema -from lemur.common import validators, missing +from lemur.certificates import utils as cert_utils +from lemur.common import missing, utils, validators from lemur.common.fields import ArrowDateTime, Hex from lemur.common.schema import LemurInputSchema, LemurOutputSchema from lemur.constants import CERTIFICATE_KEY_TYPES @@ -38,22 +41,26 @@ from lemur.users.schemas import UserNestedOutputSchema class CertificateSchema(LemurInputSchema): owner = fields.Email(required=True) - description = fields.String(missing='', allow_none=True) + description = fields.String(missing="", allow_none=True) class CertificateCreationSchema(CertificateSchema): @post_load def default_notification(self, data): - if not data['notifications']: - data['notifications'] += notification_service.create_default_expiration_notifications( - "DEFAULT_{0}".format(data['owner'].split('@')[0].upper()), - [data['owner']], + if not data["notifications"]: + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + "DEFAULT_{0}".format(data["owner"].split("@")[0].upper()), + [data["owner"]], ) - data['notifications'] += notification_service.create_default_expiration_notifications( - 'DEFAULT_SECURITY', - current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL'), - current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL_INTERVALS', None) + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + "DEFAULT_SECURITY", + current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL"), + current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL_INTERVALS", None), ) return data @@ -70,34 +77,56 @@ class CertificateInputSchema(CertificateCreationSchema): destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True) notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True) replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) - replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated + replacements = fields.Nested( + AssociatedCertificateSchema, missing=[], many=True + ) # deprecated roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True) - dns_provider = fields.Nested(AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False) + dns_provider = fields.Nested( + AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False + ) csr = fields.String(allow_none=True, validate=validators.csr) key_type = fields.String( - validate=validate.OneOf(CERTIFICATE_KEY_TYPES), - missing='RSA2048') + validate=validate.OneOf(CERTIFICATE_KEY_TYPES), missing="RSA2048" + ) notify = fields.Boolean(default=True) rotation = fields.Boolean() - rotation_policy = fields.Nested(AssociatedRotationPolicySchema, missing={'name': 'default'}, allow_none=True, - default={'name': 'default'}) + rotation_policy = fields.Nested( + AssociatedRotationPolicySchema, + missing={"name": "default"}, + allow_none=True, + default={"name": "default"}, + ) # certificate body fields - organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT')) - organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION')) - location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION')) - country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY')) - state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE')) + organizational_unit = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT") + ) + organization = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION") + ) + location = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION") + ) + country = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY") + ) + state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE")) extensions = fields.Nested(ExtensionSchema) @validates_schema def validate_authority(self, data): - if not data['authority'].active: - raise ValidationError("The authority is inactive.", ['authority']) + if 'authority' not in data: + raise ValidationError("Missing Authority.") + + if isinstance(data["authority"], str): + raise ValidationError("Authority not found.") + + if not data["authority"].active: + raise ValidationError("The authority is inactive.", ["authority"]) @validates_schema def validate_dates(self, data): @@ -105,8 +134,19 @@ class CertificateInputSchema(CertificateCreationSchema): @pre_load def load_data(self, data): - if data.get('replacements'): - data['replaces'] = data['replacements'] # TODO remove when field is deprecated + if data.get("replacements"): + data["replaces"] = data[ + "replacements" + ] # TODO remove when field is deprecated + if data.get("csr"): + csr_sans = cert_utils.get_sans_from_csr(data["csr"]) + if not data.get("extensions"): + data["extensions"] = {"subAltNames": {"names": []}} + elif not data["extensions"].get("subAltNames"): + data["extensions"]["subAltNames"] = {"names": []} + elif not data["extensions"]["subAltNames"].get("names"): + data["extensions"]["subAltNames"]["names"] = [] + data["extensions"]["subAltNames"]["names"] += csr_sans return missing.convert_validity_years(data) @@ -119,13 +159,17 @@ class CertificateEditInputSchema(CertificateSchema): destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True) notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True) replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) - replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated + replacements = fields.Nested( + AssociatedCertificateSchema, missing=[], many=True + ) # deprecated roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True) @pre_load def load_data(self, data): - if data.get('replacements'): - data['replaces'] = data['replacements'] # TODO remove when field is deprecated + if data.get("replacements"): + data["replaces"] = data[ + "replacements" + ] # TODO remove when field is deprecated return data @post_load @@ -136,10 +180,15 @@ class CertificateEditInputSchema(CertificateSchema): :param data: :return: """ - if data['owner']: - notification_name = "DEFAULT_{0}".format(data['owner'].split('@')[0].upper()) - data['notifications'] += notification_service.create_default_expiration_notifications(notification_name, - [data['owner']]) + if data["owner"]: + notification_name = "DEFAULT_{0}".format( + data["owner"].split("@")[0].upper() + ) + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + notification_name, [data["owner"]] + ) return data @@ -165,13 +214,13 @@ class CertificateNestedOutputSchema(LemurOutputSchema): # Note aliasing is the first step in deprecating these fields. cn = fields.String() # deprecated - common_name = fields.String(attribute='cn') + common_name = fields.String(attribute="cn") not_after = fields.DateTime() # deprecated - validity_end = ArrowDateTime(attribute='not_after') + validity_end = ArrowDateTime(attribute="not_after") not_before = fields.DateTime() # deprecated - validity_start = ArrowDateTime(attribute='not_before') + validity_start = ArrowDateTime(attribute="not_before") issuer = fields.Nested(AuthorityNestedOutputSchema) @@ -202,21 +251,23 @@ class CertificateOutputSchema(LemurOutputSchema): # Note aliasing is the first step in deprecating these fields. notify = fields.Boolean() - active = fields.Boolean(attribute='notify') + active = fields.Boolean(attribute="notify") + has_private_key = fields.Boolean() cn = fields.String() - common_name = fields.String(attribute='cn') + common_name = fields.String(attribute="cn") + distinguished_name = fields.String() not_after = fields.DateTime() - validity_end = ArrowDateTime(attribute='not_after') + validity_end = ArrowDateTime(attribute="not_after") not_before = fields.DateTime() - validity_start = ArrowDateTime(attribute='not_before') + validity_start = ArrowDateTime(attribute="not_before") owner = fields.Email() san = fields.Boolean() serial = fields.String() - serial_hex = Hex(attribute='serial') + serial_hex = Hex(attribute="serial") signing_algorithm = fields.String() status = fields.String() @@ -233,19 +284,31 @@ class CertificateOutputSchema(LemurOutputSchema): dns_provider = fields.Nested(DnsProvidersNestedOutputSchema) roles = fields.Nested(RoleNestedOutputSchema, many=True) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) - replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') + replaced_by = fields.Nested( + CertificateNestedOutputSchema, many=True, attribute="replaced" + ) rotation_policy = fields.Nested(RotationPolicyNestedOutputSchema) +class CertificateShortOutputSchema(LemurOutputSchema): + id = fields.Integer() + name = fields.String() + owner = fields.Email() + notify = fields.Boolean() + authority = fields.Nested(AuthorityNestedOutputSchema) + issuer = fields.String() + cn = fields.String() + + class CertificateUploadInputSchema(CertificateCreationSchema): name = fields.String() authority = fields.Nested(AssociatedAuthoritySchema, required=False) notify = fields.Boolean(missing=True) external_id = fields.String(missing=None, allow_none=True) - private_key = fields.String(validate=validators.private_key) - body = fields.String(required=True, validate=validators.public_certificate) - chain = fields.String(validate=validators.public_certificate, missing=None, - allow_none=True) # TODO this could be multiple certificates + private_key = fields.String() + body = fields.String(required=True) + chain = fields.String(missing=None, allow_none=True) + csr = fields.String(required=False, allow_none=True, validate=validators.csr) destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True) notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True) @@ -254,9 +317,44 @@ class CertificateUploadInputSchema(CertificateCreationSchema): @validates_schema def keys(self, data): - if data.get('destinations'): - if not data.get('private_key'): - raise ValidationError('Destinations require private key.') + if data.get("destinations"): + if not data.get("private_key"): + raise ValidationError("Destinations require private key.") + + @validates_schema + def validate_cert_private_key_chain(self, data): + cert = None + key = None + if data.get("body"): + try: + cert = utils.parse_certificate(data["body"]) + except ValueError: + raise ValidationError( + "Public certificate presented is not valid.", field_names=["body"] + ) + + if data.get("private_key"): + try: + key = utils.parse_private_key(data["private_key"]) + except ValueError: + raise ValidationError( + "Private key presented is not valid.", field_names=["private_key"] + ) + + if cert and key: + # Throws ValidationError + validators.verify_private_key_match(key, cert) + + if data.get("chain"): + try: + chain = utils.parse_cert_chain(data["chain"]) + except ValueError: + raise ValidationError( + "Invalid certificate in certificate chain.", field_names=["chain"] + ) + + # Throws ValidationError + validators.verify_cert_chain([cert] + chain) class CertificateExportInputSchema(LemurInputSchema): @@ -269,8 +367,10 @@ class CertificateNotificationOutputSchema(LemurOutputSchema): name = fields.String() owner = fields.Email() user = fields.Nested(UserNestedOutputSchema) - validity_end = ArrowDateTime(attribute='not_after') - replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') + validity_end = ArrowDateTime(attribute="not_after") + replaced_by = fields.Nested( + CertificateNestedOutputSchema, many=True, attribute="replaced" + ) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) @@ -278,9 +378,22 @@ class CertificateRevokeSchema(LemurInputSchema): comments = fields.String() +certificates_list_request_parser = RequestParser() +certificates_list_request_parser.add_argument("short", type=inputs.boolean, default=False, location="args") + + +def certificates_list_output_schema_factory(): + args = certificates_list_request_parser.parse_args() + if args["short"]: + return certificates_short_output_schema + else: + return certificates_output_schema + + certificate_input_schema = CertificateInputSchema() certificate_output_schema = CertificateOutputSchema() certificates_output_schema = CertificateOutputSchema(many=True) +certificates_short_output_schema = CertificateShortOutputSchema(many=True) certificate_upload_input_schema = CertificateUploadInputSchema() certificate_export_input_schema = CertificateExportInputSchema() certificate_edit_input_schema = CertificateEditInputSchema() diff --git a/lemur/certificates/service.py b/lemur/certificates/service.py index c9a2fa24..0e91b563 100644 --- a/lemur/certificates/service.py +++ b/lemur/certificates/service.py @@ -20,17 +20,20 @@ from lemur.common.utils import generate_private_key, truthiness from lemur.destinations.models import Destination from lemur.domains.models import Domain from lemur.extensions import metrics, sentry, signals -from lemur.models import certificate_associations from lemur.notifications.models import Notification from lemur.pending_certificates.models import PendingCertificate from lemur.plugins.base import plugins from lemur.roles import service as role_service from lemur.roles.models import Role -csr_created = signals.signal('csr_created', "CSR generated") -csr_imported = signals.signal('csr_imported', "CSR imported from external source") -certificate_issued = signals.signal('certificate_issued', "Authority issued a certificate") -certificate_imported = signals.signal('certificate_imported', "Certificate imported from external source") +csr_created = signals.signal("csr_created", "CSR generated") +csr_imported = signals.signal("csr_imported", "CSR imported from external source") +certificate_issued = signals.signal( + "certificate_issued", "Authority issued a certificate" +) +certificate_imported = signals.signal( + "certificate_imported", "Certificate imported from external source" +) def get(cert_id): @@ -50,12 +53,12 @@ def get_by_name(name): :param name: :return: """ - return database.get(Certificate, name, field='name') + return database.get(Certificate, name, field="name") def get_by_serial(serial): """ - Retrieves certificate by it's Serial. + Retrieves certificate(s) by serial number. :param serial: :return: """ @@ -65,6 +68,22 @@ def get_by_serial(serial): return Certificate.query.filter(Certificate.serial == serial).all() +def get_by_attributes(conditions): + """ + Retrieves certificate(s) by conditions given in a hash of given key=>value pairs. + :param serial: + :return: + """ + # Ensure that each of the given conditions corresponds to actual columns + # if not, silently remove it + for attr in conditions.keys(): + if attr not in Certificate.__table__.columns: + conditions.pop(attr) + + query = database.session_query(Certificate) + return database.find_all(query, Certificate, conditions).all() + + def delete(cert_id): """ Delete's a certificate. @@ -90,8 +109,12 @@ def get_all_pending_cleaning(source): :param source: :return: """ - return Certificate.query.filter(Certificate.sources.any(id=source.id)) \ - .filter(not_(Certificate.endpoints.any())).filter(Certificate.expired).all() + return ( + Certificate.query.filter(Certificate.sources.any(id=source.id)) + .filter(not_(Certificate.endpoints.any())) + .filter(Certificate.expired) + .all() + ) def get_all_pending_reissue(): @@ -104,9 +127,12 @@ def get_all_pending_reissue(): :return: """ - return Certificate.query.filter(Certificate.rotation == True) \ - .filter(not_(Certificate.replaced.any())) \ - .filter(Certificate.in_rotation_window == True).all() # noqa + return ( + Certificate.query.filter(Certificate.rotation == True) + .filter(not_(Certificate.replaced.any())) + .filter(Certificate.in_rotation_window == True) + .all() + ) # noqa def find_duplicates(cert): @@ -118,10 +144,12 @@ def find_duplicates(cert): :param cert: :return: """ - if cert['chain']: - return Certificate.query.filter_by(body=cert['body'].strip(), chain=cert['chain'].strip()).all() + if cert["chain"]: + return Certificate.query.filter_by( + body=cert["body"].strip(), chain=cert["chain"].strip() + ).all() else: - return Certificate.query.filter_by(body=cert['body'].strip(), chain=None).all() + return Certificate.query.filter_by(body=cert["body"].strip(), chain=None).all() def export(cert, export_plugin): @@ -133,8 +161,10 @@ def export(cert, export_plugin): :param cert: :return: """ - plugin = plugins.get(export_plugin['slug']) - return plugin.export(cert.body, cert.chain, cert.private_key, export_plugin['pluginOptions']) + plugin = plugins.get(export_plugin["slug"]) + return plugin.export( + cert.body, cert.chain, cert.private_key, export_plugin["pluginOptions"] + ) def update(cert_id, **kwargs): @@ -153,17 +183,19 @@ def update(cert_id, **kwargs): def create_certificate_roles(**kwargs): # create an role for the owner and assign it - owner_role = role_service.get_by_name(kwargs['owner']) + owner_role = role_service.get_by_name(kwargs["owner"]) if not owner_role: owner_role = role_service.create( - kwargs['owner'], - description="Auto generated role based on owner: {0}".format(kwargs['owner']) + kwargs["owner"], + description="Auto generated role based on owner: {0}".format( + kwargs["owner"] + ), ) # ensure that the authority's owner is also associated with the certificate - if kwargs.get('authority'): - authority_owner_role = role_service.get_by_name(kwargs['authority'].owner) + if kwargs.get("authority"): + authority_owner_role = role_service.get_by_name(kwargs["authority"].owner) return [owner_role, authority_owner_role] return [owner_role] @@ -175,16 +207,16 @@ def mint(**kwargs): Support for multiple authorities is handled by individual plugins. """ - authority = kwargs['authority'] + authority = kwargs["authority"] issuer = plugins.get(authority.plugin_name) # allow the CSR to be specified by the user - if not kwargs.get('csr'): + if not kwargs.get("csr"): csr, private_key = create_csr(**kwargs) csr_created.send(authority=authority, csr=csr) else: - csr = str(kwargs.get('csr')) + csr = str(kwargs.get("csr")) private_key = None csr_imported.send(authority=authority, csr=csr) @@ -205,8 +237,8 @@ def import_certificate(**kwargs): :param kwargs: """ - if not kwargs.get('owner'): - kwargs['owner'] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL')[0] + if not kwargs.get("owner"): + kwargs["owner"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")[0] return upload(**kwargs) @@ -217,21 +249,16 @@ def upload(**kwargs): """ roles = create_certificate_roles(**kwargs) - if kwargs.get('roles'): - kwargs['roles'] += roles + if kwargs.get("roles"): + kwargs["roles"] += roles else: - kwargs['roles'] = roles - - if kwargs.get('private_key'): - private_key = kwargs['private_key'] - if not isinstance(private_key, bytes): - kwargs['private_key'] = private_key.encode('utf-8') + kwargs["roles"] = roles cert = Certificate(**kwargs) - cert.authority = kwargs.get('authority') + cert.authority = kwargs.get("authority") cert = database.create(cert) - kwargs['creator'].certificates.append(cert) + kwargs["creator"].certificates.append(cert) cert = database.update(cert) certificate_imported.send(certificate=cert, authority=cert.authority) @@ -248,39 +275,45 @@ def create(**kwargs): current_app.logger.error("Exception minting certificate", exc_info=True) sentry.captureException() raise - kwargs['body'] = cert_body - kwargs['private_key'] = private_key - kwargs['chain'] = cert_chain - kwargs['external_id'] = external_id - kwargs['csr'] = csr + kwargs["body"] = cert_body + kwargs["private_key"] = private_key + kwargs["chain"] = cert_chain + kwargs["external_id"] = external_id + kwargs["csr"] = csr roles = create_certificate_roles(**kwargs) - if kwargs.get('roles'): - kwargs['roles'] += roles + if kwargs.get("roles"): + kwargs["roles"] += roles else: - kwargs['roles'] = roles + kwargs["roles"] = roles if cert_body: cert = Certificate(**kwargs) - kwargs['creator'].certificates.append(cert) + kwargs["creator"].certificates.append(cert) else: cert = PendingCertificate(**kwargs) - kwargs['creator'].pending_certificates.append(cert) + kwargs["creator"].pending_certificates.append(cert) - cert.authority = kwargs['authority'] + cert.authority = kwargs["authority"] database.commit() if isinstance(cert, Certificate): certificate_issued.send(certificate=cert, authority=cert.authority) - metrics.send('certificate_issued', 'counter', 1, metric_tags=dict(owner=cert.owner, issuer=cert.issuer)) + metrics.send( + "certificate_issued", + "counter", + 1, + metric_tags=dict(owner=cert.owner, issuer=cert.issuer), + ) if isinstance(cert, PendingCertificate): # We need to refresh the pending certificate to avoid "Instance is not bound to a Session; " # "attribute refresh operation cannot proceed" pending_cert = database.session_query(PendingCertificate).get(cert.id) from lemur.common.celery import fetch_acme_cert + if not current_app.config.get("ACME_DISABLE_AUTORESOLVE", False): fetch_acme_cert.apply_async((pending_cert.id,), countdown=5) @@ -296,85 +329,150 @@ def render(args): """ query = database.session_query(Certificate) - time_range = args.pop('time_range') - destination_id = args.pop('destination_id') - notification_id = args.pop('notification_id', None) - show = args.pop('show') + show_expired = args.pop("showExpired") + if show_expired != 1: + one_month_old = arrow.now()\ + .shift(months=current_app.config.get("HIDE_EXPIRED_CERTS_AFTER_MONTHS", -1))\ + .format("YYYY-MM-DD") + query = query.filter(Certificate.not_after > one_month_old) + + time_range = args.pop("time_range") + + destination_id = args.pop("destination_id") + notification_id = args.pop("notification_id", None) + show = args.pop("show") # owner = args.pop('owner') # creator = args.pop('creator') # TODO we should enabling filtering by owner - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') - term = '%{0}%'.format(terms[1]) + terms = filt.split(";") + term = "%{0}%".format(terms[1]) # Exact matches for quotes. Only applies to name, issuer, and cn if terms[1].startswith('"') and terms[1].endswith('"'): term = terms[1][1:-1] - if 'issuer' in terms: + if "issuer" in terms: # we can't rely on issuer being correct in the cert directly so we combine queries - sub_query = database.session_query(Authority.id) \ - .filter(Authority.name.ilike(term)) \ + sub_query = ( + database.session_query(Authority.id) + .filter(Authority.name.ilike(term)) .subquery() + ) query = query.filter( or_( Certificate.issuer.ilike(term), - Certificate.authority_id.in_(sub_query) + Certificate.authority_id.in_(sub_query), ) ) - elif 'destination' in terms: - query = query.filter(Certificate.destinations.any(Destination.id == terms[1])) - elif 'notify' in filt: + elif "destination" in terms: + query = query.filter( + Certificate.destinations.any(Destination.id == terms[1]) + ) + elif "notify" in filt: query = query.filter(Certificate.notify == truthiness(terms[1])) - elif 'active' in filt: + elif "active" in filt: query = query.filter(Certificate.active == truthiness(terms[1])) - elif 'cn' in terms: + elif "cn" in terms: query = query.filter( or_( Certificate.cn.ilike(term), - Certificate.domains.any(Domain.name.ilike(term)) + Certificate.domains.any(Domain.name.ilike(term)), ) ) - elif 'id' in terms: + elif "id" in terms: query = query.filter(Certificate.id == cast(terms[1], Integer)) - elif 'name' in terms: - query = query.outerjoin(certificate_associations).outerjoin(Domain).filter( + elif "name" in terms: + query = query.filter( or_( Certificate.name.ilike(term), - Domain.name.ilike(term), + Certificate.domains.any(Domain.name.ilike(term)), Certificate.cn.ilike(term), ) - ).group_by(Certificate.id) + ) + elif "fixedName" in terms: + # only what matches the fixed name directly if a fixedname is provided + query = query.filter(Certificate.name == terms[1]) else: query = database.filter(query, Certificate, terms) if show: - sub_query = database.session_query(Role.name).filter(Role.user_id == args['user'].id).subquery() + sub_query = ( + database.session_query(Role.name) + .filter(Role.user_id == args["user"].id) + .subquery() + ) query = query.filter( or_( - Certificate.user_id == args['user'].id, - Certificate.owner.in_(sub_query) + Certificate.user_id == args["user"].id, Certificate.owner.in_(sub_query) ) ) if destination_id: - query = query.filter(Certificate.destinations.any(Destination.id == destination_id)) + query = query.filter( + Certificate.destinations.any(Destination.id == destination_id) + ) if notification_id: - query = query.filter(Certificate.notifications.any(Notification.id == notification_id)) + query = query.filter( + Certificate.notifications.any(Notification.id == notification_id) + ) if time_range: - to = arrow.now().replace(weeks=+time_range).format('YYYY-MM-DD') - now = arrow.now().format('YYYY-MM-DD') - query = query.filter(Certificate.not_after <= to).filter(Certificate.not_after >= now) + to = arrow.now().shift(weeks=+time_range).format("YYYY-MM-DD") + now = arrow.now().format("YYYY-MM-DD") + query = query.filter(Certificate.not_after <= to).filter( + Certificate.not_after >= now + ) + + if current_app.config.get("ALLOW_CERT_DELETION", False): + query = query.filter(Certificate.deleted == False) # noqa result = database.sort_and_page(query, Certificate, args) return result +def query_name(certificate_name, args): + """ + Helper function that queries for a certificate by name + + :param args: + :return: + """ + query = database.session_query(Certificate) + query = query.filter(Certificate.name == certificate_name) + result = database.sort_and_page(query, Certificate, args) + return result + + +def query_common_name(common_name, args): + """ + Helper function that queries for not expired certificates by common name (and owner) + + :param common_name: + :param args: + :return: + """ + owner = args.pop("owner") + if not owner: + owner = "%" + + # only not expired certificates + current_time = arrow.utcnow() + + result = ( + Certificate.query.filter(Certificate.cn.ilike(common_name)) + .filter(Certificate.owner.ilike(owner)) + .filter(Certificate.not_after >= current_time.format("YYYY-MM-DD")) + .all() + ) + + return result + + def create_csr(**csr_config): """ Given a list of domains create the appropriate csr @@ -382,65 +480,77 @@ def create_csr(**csr_config): :param csr_config: """ - private_key = generate_private_key(csr_config.get('key_type')) + private_key = generate_private_key(csr_config.get("key_type")) builder = x509.CertificateSigningRequestBuilder() - name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config['common_name'])] - if current_app.config.get('LEMUR_OWNER_EMAIL_IN_SUBJECT', True): - name_list.append(x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config['owner'])) - if 'organization' in csr_config and csr_config['organization'].strip(): - name_list.append(x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config['organization'])) - if 'organizational_unit' in csr_config and csr_config['organizational_unit'].strip(): - name_list.append(x509.NameAttribute(x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config['organizational_unit'])) - if 'country' in csr_config and csr_config['country'].strip(): - name_list.append(x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config['country'])) - if 'state' in csr_config and csr_config['state'].strip(): - name_list.append(x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config['state'])) - if 'location' in csr_config and csr_config['location'].strip(): - name_list.append(x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config['location'])) + name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config["common_name"])] + if current_app.config.get("LEMUR_OWNER_EMAIL_IN_SUBJECT", True): + name_list.append( + x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config["owner"]) + ) + if "organization" in csr_config and csr_config["organization"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config["organization"]) + ) + if ( + "organizational_unit" in csr_config + and csr_config["organizational_unit"].strip() + ): + name_list.append( + x509.NameAttribute( + x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config["organizational_unit"] + ) + ) + if "country" in csr_config and csr_config["country"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config["country"]) + ) + if "state" in csr_config and csr_config["state"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config["state"]) + ) + if "location" in csr_config and csr_config["location"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config["location"]) + ) builder = builder.subject_name(x509.Name(name_list)) - extensions = csr_config.get('extensions', {}) - critical_extensions = ['basic_constraints', 'sub_alt_names', 'key_usage'] - noncritical_extensions = ['extended_key_usage'] + extensions = csr_config.get("extensions", {}) + critical_extensions = ["basic_constraints", "sub_alt_names", "key_usage"] + noncritical_extensions = ["extended_key_usage"] for k, v in extensions.items(): if v: if k in critical_extensions: - current_app.logger.debug('Adding Critical Extension: {0} {1}'.format(k, v)) - if k == 'sub_alt_names': - if v['names']: - builder = builder.add_extension(v['names'], critical=True) + current_app.logger.debug( + "Adding Critical Extension: {0} {1}".format(k, v) + ) + if k == "sub_alt_names": + if v["names"]: + builder = builder.add_extension(v["names"], critical=True) else: builder = builder.add_extension(v, critical=True) if k in noncritical_extensions: - current_app.logger.debug('Adding Extension: {0} {1}'.format(k, v)) + current_app.logger.debug("Adding Extension: {0} {1}".format(k, v)) builder = builder.add_extension(v, critical=False) - ski = extensions.get('subject_key_identifier', {}) - if ski.get('include_ski', False): + ski = extensions.get("subject_key_identifier", {}) + if ski.get("include_ski", False): builder = builder.add_extension( x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), - critical=False + critical=False, ) - request = builder.sign( - private_key, hashes.SHA256(), default_backend() - ) + request = builder.sign(private_key, hashes.SHA256(), default_backend()) # serialize our private key and CSR private_key = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, # would like to use PKCS8 but AWS ELBs don't like it - encryption_algorithm=serialization.NoEncryption() - ) + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") - if isinstance(private_key, bytes): - private_key = private_key.decode('utf-8') - - csr = request.public_bytes( - encoding=serialization.Encoding.PEM - ).decode('utf-8') + csr = request.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") return csr, private_key @@ -452,16 +562,19 @@ def stats(**kwargs): :param kwargs: :return: """ - if kwargs.get('metric') == 'not_after': + if kwargs.get("metric") == "not_after": start = arrow.utcnow() - end = start.replace(weeks=+32) - items = database.db.session.query(Certificate.issuer, func.count(Certificate.id)) \ - .group_by(Certificate.issuer) \ - .filter(Certificate.not_after <= end.format('YYYY-MM-DD')) \ - .filter(Certificate.not_after >= start.format('YYYY-MM-DD')).all() + end = start.shift(weeks=+32) + items = ( + database.db.session.query(Certificate.issuer, func.count(Certificate.id)) + .group_by(Certificate.issuer) + .filter(Certificate.not_after <= end.format("YYYY-MM-DD")) + .filter(Certificate.not_after >= start.format("YYYY-MM-DD")) + .all() + ) else: - attr = getattr(Certificate, kwargs.get('metric')) + attr = getattr(Certificate, kwargs.get("metric")) query = database.db.session.query(attr, func.count(attr)) items = query.group_by(attr).all() @@ -472,7 +585,7 @@ def stats(**kwargs): keys.append(key) values.append(count) - return {'labels': keys, 'values': values} + return {"labels": keys, "values": values} def get_account_number(arn): @@ -519,22 +632,24 @@ def get_certificate_primitives(certificate): certificate via `create`. """ start, end = calculate_reissue_range(certificate.not_before, certificate.not_after) - ser = CertificateInputSchema().load(CertificateOutputSchema().dump(certificate).data) + ser = CertificateInputSchema().load( + CertificateOutputSchema().dump(certificate).data + ) assert not ser.errors, "Error re-serializing certificate: %s" % ser.errors data = ser.data # we can't quite tell if we are using a custom name, as this is an automated process (typically) # we will rely on the Lemur generated name - data.pop('name', None) + data.pop("name", None) # TODO this can be removed once we migrate away from cn - data['cn'] = data['common_name'] + data["cn"] = data["common_name"] # needed until we move off not_* - data['not_before'] = start - data['not_after'] = end - data['validity_start'] = start - data['validity_end'] = end + data["not_before"] = start + data["not_after"] = end + data["validity_start"] = start + data["validity_end"] = end return data @@ -552,13 +667,13 @@ def reissue_certificate(certificate, replace=None, user=None): # We do not want to re-use the CSR when creating a certificate because this defeats the purpose of rotation. del primitives["csr"] if not user: - primitives['creator'] = certificate.user + primitives["creator"] = certificate.user else: - primitives['creator'] = user + primitives["creator"] = user if replace: - primitives['replaces'] = [certificate] + primitives["replaces"] = [certificate] new_cert = create(**primitives) diff --git a/lemur/certificates/utils.py b/lemur/certificates/utils.py new file mode 100644 index 00000000..4e6cc4f1 --- /dev/null +++ b/lemur/certificates/utils.py @@ -0,0 +1,41 @@ +""" +Utils to parse certificate data. + +.. module: lemur.certificates.hooks + :platform: Unix + :copyright: (c) 2019 by Javier Ramos, see AUTHORS for more + :license: Apache, see LICENSE for more details. + +.. moduleauthor:: Javier Ramos +""" + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from marshmallow.exceptions import ValidationError + + +def get_sans_from_csr(data): + """ + Fetches SubjectAlternativeNames from CSR. + Works with any kind of SubjectAlternativeName + :param data: PEM-encoded string with CSR + :return: List of LemurAPI-compatible subAltNames + """ + sub_alt_names = [] + try: + request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend()) + except Exception: + raise ValidationError("CSR presented is not valid.") + + try: + alt_names = request.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) + for alt_name in alt_names.value: + sub_alt_names.append( + {"nameType": type(alt_name).__name__, "value": alt_name.value} + ) + except x509.ExtensionNotFound: + pass + + return sub_alt_names diff --git a/lemur/certificates/verify.py b/lemur/certificates/verify.py index d42e306c..76c6b521 100644 --- a/lemur/certificates/verify.py +++ b/lemur/certificates/verify.py @@ -29,31 +29,45 @@ def ocsp_verify(cert, cert_path, issuer_chain_path): :param issuer_chain_path: :return bool: True if certificate is valid, False otherwise """ - command = ['openssl', 'x509', '-noout', '-ocsp_uri', '-in', cert_path] + command = ["openssl", "x509", "-noout", "-ocsp_uri", "-in", cert_path] p1 = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) url, err = p1.communicate() if not url: - current_app.logger.debug("No OCSP URL in certificate {}".format(cert.serial_number)) + current_app.logger.debug( + "No OCSP URL in certificate {}".format(cert.serial_number) + ) return None - p2 = subprocess.Popen(['openssl', 'ocsp', '-issuer', issuer_chain_path, - '-cert', cert_path, "-url", url.strip()], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + p2 = subprocess.Popen( + [ + "openssl", + "ocsp", + "-issuer", + issuer_chain_path, + "-cert", + cert_path, + "-url", + url.strip(), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) message, err = p2.communicate() - p_message = message.decode('utf-8') + p_message = message.decode("utf-8") - if 'error' in p_message or 'Error' in p_message: + if "error" in p_message or "Error" in p_message: raise Exception("Got error when parsing OCSP url") - elif 'revoked' in p_message: - current_app.logger.debug("OCSP reports certificate revoked: {}".format(cert.serial_number)) + elif "revoked" in p_message: + current_app.logger.debug( + "OCSP reports certificate revoked: {}".format(cert.serial_number) + ) return False - elif 'good' not in p_message: + elif "good" not in p_message: raise Exception("Did not receive a valid response") return True @@ -73,7 +87,9 @@ def crl_verify(cert, cert_path): x509.OID_CRL_DISTRIBUTION_POINTS ).value except x509.ExtensionNotFound: - current_app.logger.debug("No CRLDP extension in certificate {}".format(cert.serial_number)) + current_app.logger.debug( + "No CRLDP extension in certificate {}".format(cert.serial_number) + ) return None for p in distribution_points: @@ -92,8 +108,9 @@ def crl_verify(cert, cert_path): except ConnectionError: raise Exception("Unable to retrieve CRL: {0}".format(point)) - crl_cache[point] = x509.load_der_x509_crl(response.content, - backend=default_backend()) + crl_cache[point] = x509.load_der_x509_crl( + response.content, backend=default_backend() + ) else: current_app.logger.debug("CRL point is cached {}".format(point)) @@ -110,8 +127,9 @@ def crl_verify(cert, cert_path): except x509.ExtensionNotFound: pass - current_app.logger.debug("CRL reports certificate " - "revoked: {}".format(cert.serial_number)) + current_app.logger.debug( + "CRL reports certificate " "revoked: {}".format(cert.serial_number) + ) return False return True @@ -125,7 +143,7 @@ def verify(cert_path, issuer_chain_path): :param issuer_chain_path: :return: True if valid, False otherwise """ - with open(cert_path, 'rt') as c: + with open(cert_path, "rt") as c: try: cert = parse_certificate(c.read()) except ValueError as e: @@ -154,10 +172,10 @@ def verify_string(cert_string, issuer_string): :return: True if valid, False otherwise """ with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: + with open(cert_tmp, "w") as f: f.write(cert_string) with mktempfile() as issuer_tmp: - with open(issuer_tmp, 'w') as f: + with open(issuer_tmp, "w") as f: f.write(issuer_string) status = verify(cert_tmp, issuer_tmp) return status diff --git a/lemur/certificates/views.py b/lemur/certificates/views.py index 54c60924..51f7f615 100644 --- a/lemur/certificates/views.py +++ b/lemur/certificates/views.py @@ -8,7 +8,7 @@ import base64 from builtins import str -from flask import Blueprint, make_response, jsonify, g +from flask import Blueprint, make_response, jsonify, g, current_app from flask_restful import reqparse, Api, inputs from lemur.common.schema import validate_schema @@ -26,17 +26,224 @@ from lemur.certificates.schemas import ( certificate_upload_input_schema, certificates_output_schema, certificate_export_input_schema, - certificate_edit_input_schema + certificate_edit_input_schema, + certificates_list_output_schema_factory, ) from lemur.roles import service as role_service from lemur.logs import service as log_service -mod = Blueprint('certificates', __name__) +mod = Blueprint("certificates", __name__) api = Api(mod) +class CertificatesListValid(AuthenticatedResource): + """ Defines the 'certificates/valid' endpoint """ + + def __init__(self): + self.reqparse = reqparse.RequestParser() + super(CertificatesListValid, self).__init__() + + @validate_schema(None, certificates_output_schema) + def get(self): + """ + .. http:get:: /certificates/valid/ + + The current list of not-expired certificates for a given common name, and owner + + **Example request**: + + .. sourcecode:: http + GET /certificates/valid?filter=cn;*.test.example.net&owner=joe@example.com + HTTP/1.1 + Host: example.com + Accept: application/json, text/javascript + + **Example response**: + + .. sourcecode:: http + + HTTP/1.1 200 OK + Vary: Accept + Content-Type: text/javascript + + { + "items": [{ + "status": null, + "cn": "*.test.example.net", + "chain": "", + "csr": "-----BEGIN CERTIFICATE REQUEST-----" + "authority": { + "active": true, + "owner": "secure@example.com", + "id": 1, + "description": "verisign test authority", + "name": "verisign" + }, + "owner": "joe@example.com", + "serial": "82311058732025924142789179368889309156", + "id": 2288, + "issuer": "SymantecCorporation", + "dateCreated": "2016-06-03T06:09:42.133769+00:00", + "notBefore": "2016-06-03T00:00:00+00:00", + "notAfter": "2018-01-12T23:59:59+00:00", + "destinations": [], + "bits": 2048, + "body": "-----BEGIN CERTIFICATE-----...", + "description": null, + "deleted": null, + "notifications": [{ + "id": 1 + }], + "signingAlgorithm": "sha256", + "user": { + "username": "jane", + "active": true, + "email": "jane@example.com", + "id": 2 + }, + "active": true, + "domains": [{ + "sensitive": false, + "id": 1090, + "name": "*.test.example.net" + }], + "replaces": [], + "replaced": [], + "name": "WILDCARD.test.example.net-SymantecCorporation-20160603-20180112", + "roles": [{ + "id": 464, + "description": "This is a google group based role created by Lemur", + "name": "joe@example.com" + }], + "san": null + }], + "total": 1 + } + + :reqheader Authorization: OAuth token to authenticate + :statuscode 200: no error + :statuscode 403: unauthenticated + + """ + parser = paginated_parser.copy() + args = parser.parse_args() + args["user"] = g.user + common_name = args["filter"].split(";")[1] + return service.query_common_name(common_name, args) + + +class CertificatesNameQuery(AuthenticatedResource): + """ Defines the 'certificates/name' endpoint """ + + def __init__(self): + self.reqparse = reqparse.RequestParser() + super(CertificatesNameQuery, self).__init__() + + @validate_schema(None, certificates_output_schema) + def get(self, certificate_name): + """ + .. http:get:: /certificates/name/ + + The current list of certificates + + **Example request**: + + .. sourcecode:: http + + GET /certificates/name/WILDCARD.test.example.net-SymantecCorporation-20160603-20180112 HTTP/1.1 + Host: example.com + Accept: application/json, text/javascript + + **Example response**: + + .. sourcecode:: http + + HTTP/1.1 200 OK + Vary: Accept + Content-Type: text/javascript + + { + "items": [{ + "status": null, + "cn": "*.test.example.net", + "chain": "", + "csr": "-----BEGIN CERTIFICATE REQUEST-----" + "authority": { + "active": true, + "owner": "secure@example.com", + "id": 1, + "description": "verisign test authority", + "name": "verisign" + }, + "owner": "joe@example.com", + "serial": "82311058732025924142789179368889309156", + "id": 2288, + "issuer": "SymantecCorporation", + "dateCreated": "2016-06-03T06:09:42.133769+00:00", + "notBefore": "2016-06-03T00:00:00+00:00", + "notAfter": "2018-01-12T23:59:59+00:00", + "destinations": [], + "bits": 2048, + "body": "-----BEGIN CERTIFICATE-----...", + "description": null, + "deleted": null, + "notifications": [{ + "id": 1 + }], + "signingAlgorithm": "sha256", + "user": { + "username": "jane", + "active": true, + "email": "jane@example.com", + "id": 2 + }, + "active": true, + "domains": [{ + "sensitive": false, + "id": 1090, + "name": "*.test.example.net" + }], + "replaces": [], + "replaced": [], + "name": "WILDCARD.test.example.net-SymantecCorporation-20160603-20180112", + "roles": [{ + "id": 464, + "description": "This is a google group based role created by Lemur", + "name": "joe@example.com" + }], + "san": null + }], + "total": 1 + } + + :query sortBy: field to sort on + :query sortDir: asc or desc + :query page: int. default is 1 + :query filter: key value pair format is k;v + :query count: count number. default is 10 + :reqheader Authorization: OAuth token to authenticate + :statuscode 200: no error + :statuscode 403: unauthenticated + + """ + parser = paginated_parser.copy() + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") + + args = parser.parse_args() + args["user"] = g.user + return service.query_name(certificate_name, args) + + class CertificatesList(AuthenticatedResource): """ Defines the 'certificates' endpoint """ @@ -44,7 +251,7 @@ class CertificatesList(AuthenticatedResource): self.reqparse = reqparse.RequestParser() super(CertificatesList, self).__init__() - @validate_schema(None, certificates_output_schema) + @validate_schema(None, certificates_list_output_schema_factory) def get(self): """ .. http:get:: /certificates @@ -132,16 +339,19 @@ class CertificatesList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('timeRange', type=int, dest='time_range', location='args') - parser.add_argument('owner', type=inputs.boolean, location='args') - parser.add_argument('id', type=str, location='args') - parser.add_argument('active', type=inputs.boolean, location='args') - parser.add_argument('destinationId', type=int, dest="destination_id", location='args') - parser.add_argument('creator', type=str, location='args') - parser.add_argument('show', type=str, location='args') + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") + parser.add_argument("showExpired", type=int, location="args") args = parser.parse_args() - args['user'] = g.user + args["user"] = g.user return service.render(args) @validate_schema(certificate_input_schema, certificate_output_schema) @@ -259,24 +469,31 @@ class CertificatesList(AuthenticatedResource): :statuscode 403: unauthenticated """ - role = role_service.get_by_name(data['authority'].owner) + role = role_service.get_by_name(data["authority"].owner) # all the authority role members should be allowed - roles = [x.name for x in data['authority'].roles] + roles = [x.name for x in data["authority"].roles] # allow "owner" roles by team DL roles.append(role) - authority_permission = AuthorityPermission(data['authority'].id, roles) + authority_permission = AuthorityPermission(data["authority"].id, roles) if authority_permission.can(): - data['creator'] = g.user + data["creator"] = g.user cert = service.create(**data) if isinstance(cert, Certificate): # only log if created, not pending - log_service.create(g.user, 'create_cert', certificate=cert) + log_service.create(g.user, "create_cert", certificate=cert) return cert - return dict(message="You are not authorized to use the authority: {0}".format(data['authority'].name)), 403 + return ( + dict( + message="You are not authorized to use the authority: {0}".format( + data["authority"].name + ) + ), + 403, + ) class CertificatesUpload(AuthenticatedResource): @@ -306,6 +523,7 @@ class CertificatesUpload(AuthenticatedResource): "body": "-----BEGIN CERTIFICATE-----...", "chain": "-----BEGIN CERTIFICATE-----...", "privateKey": "-----BEGIN RSA PRIVATE KEY-----..." + "csr": "-----BEGIN CERTIFICATE REQUEST-----..." "destinations": [], "notifications": [], "replacements": [], @@ -378,12 +596,14 @@ class CertificatesUpload(AuthenticatedResource): :statuscode 200: no error """ - data['creator'] = g.user - if data.get('destinations'): - if data.get('private_key'): + data["creator"] = g.user + if data.get("destinations"): + if data.get("private_key"): return service.upload(**data) else: - raise Exception("Private key must be provided in order to upload certificate to AWS") + raise Exception( + "Private key must be provided in order to upload certificate to AWS" + ) return service.upload(**data) @@ -395,10 +615,12 @@ class CertificatesStats(AuthenticatedResource): super(CertificatesStats, self).__init__() def get(self): - self.reqparse.add_argument('metric', type=str, location='args') - self.reqparse.add_argument('range', default=32, type=int, location='args') - self.reqparse.add_argument('destinationId', dest='destination_id', location='args') - self.reqparse.add_argument('active', type=str, default='true', location='args') + self.reqparse.add_argument("metric", type=str, location="args") + self.reqparse.add_argument("range", default=32, type=int, location="args") + self.reqparse.add_argument( + "destinationId", dest="destination_id", location="args" + ) + self.reqparse.add_argument("active", type=str, default="true", location="args") args = self.reqparse.parse_args() @@ -450,12 +672,12 @@ class CertificatePrivateKey(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to view this key'), 403 + return dict(message="You are not authorized to view this key"), 403 - log_service.create(g.current_user, 'key_view', certificate=cert) + log_service.create(g.current_user, "key_view", certificate=cert) response = make_response(jsonify(key=cert.private_key), 200) - response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store' - response.headers['pragma'] = 'no-cache' + response.headers["cache-control"] = "private, max-age=0, no-cache, no-store" + response.headers["pragma"] = "no-cache" return response @@ -645,21 +867,79 @@ class Certificates(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to update this certificate'), 403 + return ( + dict(message="You are not authorized to update this certificate"), + 403, + ) - for destination in data['destinations']: + for destination in data["destinations"]: if destination.plugin.requires_key: if not cert.private_key: - return dict( - message='Unable to add destination: {0}. Certificate does not have required private key.'.format( - destination.label - ) - ), 400 + return ( + dict( + message="Unable to add destination: {0}. Certificate does not have required private key.".format( + destination.label + ) + ), + 400, + ) cert = service.update(certificate_id, **data) - log_service.create(g.current_user, 'update_cert', certificate=cert) + log_service.create(g.current_user, "update_cert", certificate=cert) return cert + def delete(self, certificate_id, data=None): + """ + .. http:delete:: /certificates/1 + + Delete a certificate + + **Example request**: + + .. sourcecode:: http + + DELETE /certificates/1 HTTP/1.1 + Host: example.com + + **Example response**: + + .. sourcecode:: http + + HTTP/1.1 204 OK + + :reqheader Authorization: OAuth token to authenticate + :statuscode 204: no error + :statuscode 403: unauthenticated + :statuscode 404: certificate not found + :statuscode 405: certificate deletion is disabled + + """ + if not current_app.config.get("ALLOW_CERT_DELETION", False): + return dict(message="Certificate deletion is disabled"), 405 + + cert = service.get(certificate_id) + + if not cert: + return dict(message="Cannot find specified certificate"), 404 + + if cert.deleted: + return dict(message="Certificate is already deleted"), 412 + + # allow creators + if g.current_user != cert.user: + owner_role = role_service.get_by_name(cert.owner) + permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) + + if not permission.can(): + return ( + dict(message="You are not authorized to delete this certificate"), + 403, + ) + + service.update(certificate_id, deleted=True) + log_service.create(g.current_user, "delete_cert", certificate=cert) + return "Certificate deleted", 204 + class NotificationCertificatesList(AuthenticatedResource): """ Defines the 'certificates' endpoint """ @@ -758,17 +1038,19 @@ class NotificationCertificatesList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('timeRange', type=int, dest='time_range', location='args') - parser.add_argument('owner', type=inputs.boolean, location='args') - parser.add_argument('id', type=str, location='args') - parser.add_argument('active', type=inputs.boolean, location='args') - parser.add_argument('destinationId', type=int, dest="destination_id", location='args') - parser.add_argument('creator', type=str, location='args') - parser.add_argument('show', type=str, location='args') + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") args = parser.parse_args() - args['notification_id'] = notification_id - args['user'] = g.current_user + args["notification_id"] = notification_id + args["user"] = g.current_user return service.render(args) @@ -941,30 +1223,48 @@ class CertificateExport(AuthenticatedResource): if not cert: return dict(message="Cannot find specified certificate"), 404 - plugin = data['plugin']['plugin_object'] + plugin = data["plugin"]["plugin_object"] if plugin.requires_key: if not cert.private_key: - return dict( - message='Unable to export certificate, plugin: {0} requires a private key but no key was found.'.format( - plugin.slug)), 400 + return ( + dict( + message="Unable to export certificate, plugin: {0} requires a private key but no key was found.".format( + plugin.slug + ) + ), + 400, + ) else: # allow creators if g.current_user != cert.user: owner_role = role_service.get_by_name(cert.owner) - permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) + permission = CertificatePermission( + owner_role, [x.name for x in cert.roles] + ) if not permission.can(): - return dict(message='You are not authorized to export this certificate.'), 403 + return ( + dict( + message="You are not authorized to export this certificate." + ), + 403, + ) - options = data['plugin']['plugin_options'] + options = data["plugin"]["plugin_options"] - log_service.create(g.current_user, 'key_view', certificate=cert) - extension, passphrase, data = plugin.export(cert.body, cert.chain, cert.private_key, options) + log_service.create(g.current_user, "key_view", certificate=cert) + extension, passphrase, data = plugin.export( + cert.body, cert.chain, cert.private_key, options + ) # we take a hit in message size when b64 encoding - return dict(extension=extension, passphrase=passphrase, data=base64.b64encode(data).decode('utf-8')) + return dict( + extension=extension, + passphrase=passphrase, + data=base64.b64encode(data).decode("utf-8"), + ) class CertificateRevoke(AuthenticatedResource): @@ -1015,28 +1315,66 @@ class CertificateRevoke(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to revoke this certificate.'), 403 + return ( + dict(message="You are not authorized to revoke this certificate."), + 403, + ) if not cert.external_id: - return dict(message='Cannot revoke certificate. No external id found.'), 400 + return dict(message="Cannot revoke certificate. No external id found."), 400 if cert.endpoints: - return dict(message='Cannot revoke certificate. Endpoints are deployed with the given certificate.'), 403 + return ( + dict( + message="Cannot revoke certificate. Endpoints are deployed with the given certificate." + ), + 403, + ) plugin = plugins.get(cert.authority.plugin_name) plugin.revoke_certificate(cert, data) - log_service.create(g.current_user, 'revoke_cert', certificate=cert) + log_service.create(g.current_user, "revoke_cert", certificate=cert) return dict(id=cert.id) -api.add_resource(CertificateRevoke, '/certificates//revoke', endpoint='revokeCertificate') -api.add_resource(CertificatesList, '/certificates', endpoint='certificates') -api.add_resource(Certificates, '/certificates/', endpoint='certificate') -api.add_resource(CertificatesStats, '/certificates/stats', endpoint='certificateStats') -api.add_resource(CertificatesUpload, '/certificates/upload', endpoint='certificateUpload') -api.add_resource(CertificatePrivateKey, '/certificates//key', endpoint='privateKeyCertificates') -api.add_resource(CertificateExport, '/certificates//export', endpoint='exportCertificate') -api.add_resource(NotificationCertificatesList, '/notifications//certificates', - endpoint='notificationCertificates') -api.add_resource(CertificatesReplacementsList, '/certificates//replacements', - endpoint='replacements') +api.add_resource( + CertificateRevoke, + "/certificates//revoke", + endpoint="revokeCertificate", +) +api.add_resource( + CertificatesNameQuery, + "/certificates/name/", + endpoint="certificatesNameQuery", +) +api.add_resource(CertificatesList, "/certificates", endpoint="certificates") +api.add_resource( + CertificatesListValid, "/certificates/valid", endpoint="certificatesListValid" +) +api.add_resource( + Certificates, "/certificates/", endpoint="certificate" +) +api.add_resource(CertificatesStats, "/certificates/stats", endpoint="certificateStats") +api.add_resource( + CertificatesUpload, "/certificates/upload", endpoint="certificateUpload" +) +api.add_resource( + CertificatePrivateKey, + "/certificates//key", + endpoint="privateKeyCertificates", +) +api.add_resource( + CertificateExport, + "/certificates//export", + endpoint="exportCertificate", +) +api.add_resource( + NotificationCertificatesList, + "/notifications//certificates", + endpoint="notificationCertificates", +) +api.add_resource( + CertificatesReplacementsList, + "/certificates//replacements", + endpoint="replacements", +) diff --git a/lemur/common/celery.py b/lemur/common/celery.py index f2a2f826..4af33d86 100644 --- a/lemur/common/celery.py +++ b/lemur/common/celery.py @@ -9,27 +9,43 @@ command: celery -A lemur.common.celery worker --loglevel=info -l DEBUG -B """ import copy import sys +import time from datetime import datetime, timezone, timedelta from celery import Celery +from celery.exceptions import SoftTimeLimitExceeded from flask import current_app from lemur.authorities.service import get as get_authority +from lemur.common.redis import RedisHandler +from lemur.destinations import service as destinations_service +from lemur.extensions import metrics, sentry from lemur.factory import create_app from lemur.notifications.messaging import send_pending_failure_notification from lemur.pending_certificates import service as pending_certificate_service from lemur.plugins.base import plugins from lemur.sources.cli import clean, sync, validate_sources +from lemur.sources.service import add_aws_destination_to_sources +from lemur.certificates import cli as cli_certificate +from lemur.dns_providers import cli as cli_dns_providers +from lemur.notifications import cli as cli_notification +from lemur.endpoints import cli as cli_endpoints + if current_app: flask_app = current_app else: flask_app = create_app() +red = RedisHandler().redis() + def make_celery(app): - celery = Celery(app.import_name, backend=app.config.get('CELERY_RESULT_BACKEND'), - broker=app.config.get('CELERY_BROKER_URL')) + celery = Celery( + app.import_name, + backend=app.config.get("CELERY_RESULT_BACKEND"), + broker=app.config.get("CELERY_BROKER_URL"), + ) celery.conf.update(app.config) TaskBase = celery.Task @@ -47,7 +63,63 @@ def make_celery(app): celery = make_celery(flask_app) +def is_task_active(fun, task_id, args): + from celery.task.control import inspect + + if not args: + args = '()' # empty args + + i = inspect() + active_tasks = i.active() + for _, tasks in active_tasks.items(): + for task in tasks: + if task.get("id") == task_id: + continue + if task.get("name") == fun and task.get("args") == str(args): + return True + return False + + @celery.task() +def report_celery_last_success_metrics(): + """ + For each celery task, this will determine the number of seconds since it has last been successful. + + Celery tasks should be emitting redis stats with a deterministic key (In our case, `f"{task}.last_success"`. + report_celery_last_success_metrics should be ran periodically to emit metrics on when a task was last successful. + Admins can then alert when tasks are not ran when intended. Admins should also alert when no metrics are emitted + from this function. + + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "recurrent task", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_time = int(time.time()) + schedule = current_app.config.get('CELERYBEAT_SCHEDULE') + for _, t in schedule.items(): + task = t.get("task") + last_success = int(red.get(f"{task}.last_success") or 0) + metrics.send(f"{task}.time_since_last_success", 'gauge', current_time - last_success) + red.set( + f"{function}.last_success", int(time.time()) + ) # Alert if this metric is not seen + metrics.send(f"{function}.success", 'counter', 1) + + +@celery.task(soft_time_limit=600) def fetch_acme_cert(id): """ Attempt to get the full certificate for the pending certificate listed. @@ -55,11 +127,25 @@ def fetch_acme_cert(id): Args: id: an id of a PendingCertificate """ + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + function = f"{__name__}.{sys._getframe().f_code.co_name}" log_data = { - "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name), - "message": "Resolving pending certificate {}".format(id) + "function": function, + "message": "Resolving pending certificate {}".format(id), + "task_id": task_id, + "id": id, } + current_app.logger.debug(log_data) + + if task_id and is_task_active(log_data["function"], task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + pending_certs = pending_certificate_service.get_pending_certs([id]) new = 0 failed = 0 @@ -69,7 +155,7 @@ def fetch_acme_cert(id): # We only care about certs using the acme-issuer plugin for cert in pending_certs: cert_authority = get_authority(cert.authority_id) - if cert_authority.plugin_name == 'acme-issuer': + if cert_authority.plugin_name == "acme-issuer": acme_certs.append(cert) else: wrong_issuer += 1 @@ -82,20 +168,22 @@ def fetch_acme_cert(id): # It's necessary to reload the pending cert due to detached instance: http://sqlalche.me/e/bhk3 pending_cert = pending_certificate_service.get(cert.get("pending_cert").id) if not pending_cert: - log_data["message"] = "Pending certificate doesn't exist anymore. Was it resolved by another process?" + log_data[ + "message" + ] = "Pending certificate doesn't exist anymore. Was it resolved by another process?" current_app.logger.error(log_data) continue if real_cert: # If a real certificate was returned from issuer, then create it in Lemur and mark # the pending certificate as resolved - final_cert = pending_certificate_service.create_certificate(pending_cert, real_cert, pending_cert.user) - pending_certificate_service.update( - cert.get("pending_cert").id, - resolved=True + final_cert = pending_certificate_service.create_certificate( + pending_cert, real_cert, pending_cert.user ) pending_certificate_service.update( - cert.get("pending_cert").id, - resolved_cert_id=final_cert.id + cert.get("pending_cert").id, resolved_cert_id=final_cert.id + ) + pending_certificate_service.update( + cert.get("pending_cert").id, resolved=True ) # add metrics to metrics extension new += 1 @@ -109,17 +197,17 @@ def fetch_acme_cert(id): if pending_cert.number_attempts > 4: error_log["message"] = "Deleting pending certificate" - send_pending_failure_notification(pending_cert, notify_owner=pending_cert.notify) + send_pending_failure_notification( + pending_cert, notify_owner=pending_cert.notify + ) # Mark the pending cert as resolved pending_certificate_service.update( - cert.get("pending_cert").id, - resolved=True + cert.get("pending_cert").id, resolved=True ) else: pending_certificate_service.increment_attempt(pending_cert) pending_certificate_service.update( - cert.get("pending_cert").id, - status=str(cert.get("last_error")) + cert.get("pending_cert").id, status=str(cert.get("last_error")) ) # Add failed pending cert task back to queue fetch_acme_cert.delay(id) @@ -129,31 +217,44 @@ def fetch_acme_cert(id): log_data["failed"] = failed log_data["wrong_issuer"] = wrong_issuer current_app.logger.debug(log_data) + metrics.send(f"{function}.resolved", 'gauge', new) + metrics.send(f"{function}.failed", 'gauge', failed) + metrics.send(f"{function}.wrong_issuer", 'gauge', wrong_issuer) print( "[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format( - new=new, - failed=failed, - wrong_issuer=wrong_issuer + new=new, failed=failed, wrong_issuer=wrong_issuer ) ) + red.set(f'{function}.last_success', int(time.time())) @celery.task() def fetch_all_pending_acme_certs(): """Instantiate celery workers to resolve all pending Acme certificates""" - pending_certs = pending_certificate_service.get_unresolved_pending_certs() + + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id log_data = { - "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name), - "message": "Starting job." + "function": function, + "message": "Starting job.", + "task_id": task_id, } + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + current_app.logger.debug(log_data) + pending_certs = pending_certificate_service.get_unresolved_pending_certs() # We only care about certs using the acme-issuer plugin for cert in pending_certs: cert_authority = get_authority(cert.authority_id) - if cert_authority.plugin_name == 'acme-issuer': + if cert_authority.plugin_name == "acme-issuer": if datetime.now(timezone.utc) - cert.last_updated > timedelta(minutes=5): log_data["message"] = "Triggering job for cert {}".format(cert.name) log_data["cert_name"] = cert.name @@ -161,23 +262,42 @@ def fetch_all_pending_acme_certs(): current_app.logger.debug(log_data) fetch_acme_cert.delay(cert.id) + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) + @celery.task() def remove_old_acme_certs(): """Prune old pending acme certificates from the database""" + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + log_data = { - "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name) + "function": function, + "message": "Starting job.", + "task_id": task_id, } - pending_certs = pending_certificate_service.get_pending_certs('all') + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + pending_certs = pending_certificate_service.get_pending_certs("all") # Delete pending certs more than a week old for cert in pending_certs: if datetime.now(timezone.utc) - cert.last_updated > timedelta(days=7): - log_data['pending_cert_id'] = cert.id - log_data['pending_cert_name'] = cert.name - log_data['message'] = "Deleting pending certificate" + log_data["pending_cert_id"] = cert.id + log_data["pending_cert_name"] = cert.name + log_data["message"] = "Deleting pending certificate" current_app.logger.debug(log_data) - pending_certificate_service.delete(cert.id) + pending_certificate_service.delete(cert) + + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) @celery.task() @@ -186,13 +306,33 @@ def clean_all_sources(): This function will clean unused certificates from sources. This is a destructive operation and should only be ran periodically. This function triggers one celery task per source. """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "Creating celery task to clean source", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + sources = validate_sources("all") for source in sources: - current_app.logger.debug("Creating celery task to clean source {}".format(source.label)) + log_data["source"] = source.label + current_app.logger.debug(log_data) clean_source.delay(source.label) + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) -@celery.task() + +@celery.task(soft_time_limit=600) def clean_source(source): """ This celery task will clean the specified source. This is a destructive operation that will delete unused @@ -201,8 +341,31 @@ def clean_source(source): :param source: :return: """ - current_app.logger.debug("Cleaning source {}".format(source)) - clean([source], True) + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "Cleaning source", + "source": source, + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, (source,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + clean([source], True) + except SoftTimeLimitExceeded: + log_data["message"] = "Clean source: Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) @celery.task() @@ -210,13 +373,33 @@ def sync_all_sources(): """ This function will sync certificates from all sources. This function triggers one celery task per source. """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "creating celery task to sync source", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + sources = validate_sources("all") for source in sources: - current_app.logger.debug("Creating celery task to sync source {}".format(source.label)) + log_data["source"] = source.label + current_app.logger.debug(log_data) sync_source.delay(source.label) + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) -@celery.task() + +@celery.task(soft_time_limit=7200) def sync_source(source): """ This celery task will sync the specified source. @@ -224,5 +407,296 @@ def sync_source(source): :param source: :return: """ - current_app.logger.debug("Syncing source {}".format(source)) - sync([source]) + + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "Syncing source", + "source": source, + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, (source,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + sync([source]) + metrics.send(f"{function}.success", 'counter', 1, metric_tags={"source": source}) + except SoftTimeLimitExceeded: + log_data["message"] = "Error syncing source: Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("sync_source_timeout", "counter", 1, metric_tags={"source": source}) + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return + + log_data["message"] = "Done syncing source" + current_app.logger.debug(log_data) + metrics.send(f"{function}.success", 'counter', 1, metric_tags={"source": source}) + red.set(f'{function}.last_success', int(time.time())) + + +@celery.task() +def sync_source_destination(): + """ + This celery task will sync destination and source, to make sure all new destinations are also present as source. + Some destinations do not qualify as sources, and hence should be excluded from being added as sources + We identify qualified destinations based on the sync_as_source attributed of the plugin. + The destination sync_as_source_name reveals the name of the suitable source-plugin. + We rely on account numbers to avoid duplicates. + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "syncing AWS destinations and sources", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + for dst in destinations_service.get_all(): + if add_aws_destination_to_sources(dst): + log_data["message"] = "new source added" + log_data["source"] = dst.label + current_app.logger.debug(log_data) + + log_data["message"] = "completed Syncing AWS destinations and sources" + current_app.logger.debug(log_data) + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) + + +@celery.task(soft_time_limit=3600) +def certificate_reissue(): + """ + This celery task reissues certificates which are pending reissue + :return: + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "reissuing certificates", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + cli_certificate.reissue(None, True) + except SoftTimeLimitExceeded: + log_data["message"] = "Certificate reissue: Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return + + log_data["message"] = "reissuance completed" + current_app.logger.debug(log_data) + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) + + +@celery.task(soft_time_limit=3600) +def certificate_rotate(): + """ + This celery task rotates certificates which are reissued but having endpoints attached to the replaced cert + :return: + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "rotating certificates", + "task_id": task_id, + + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + cli_certificate.rotate(None, None, None, None, True) + except SoftTimeLimitExceeded: + log_data["message"] = "Certificate rotate: Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return + + log_data["message"] = "rotation completed" + current_app.logger.debug(log_data) + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) + + +@celery.task(soft_time_limit=3600) +def endpoints_expire(): + """ + This celery task removes all endpoints that have not been recently updated + :return: + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "endpoints expire", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + cli_endpoints.expire(2) # Time in hours + except SoftTimeLimitExceeded: + log_data["message"] = "endpoint expire: Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return + + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) + + +@celery.task(soft_time_limit=600) +def get_all_zones(): + """ + This celery syncs all zones from the available dns providers + :return: + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "refresh all zones from available DNS providers", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + cli_dns_providers.get_all_zones() + except SoftTimeLimitExceeded: + log_data["message"] = "get all zones: Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return + + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) + + +@celery.task(soft_time_limit=3600) +def check_revoked(): + """ + This celery task attempts to check if any certs are expired + :return: + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "check if any certificates are revoked revoked", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + cli_certificate.check_revoked() + except SoftTimeLimitExceeded: + log_data["message"] = "Checking revoked: Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return + + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) + + +@celery.task(soft_time_limit=3600) +def notify_expirations(): + """ + This celery task notifies about expiring certs + :return: + """ + function = f"{__name__}.{sys._getframe().f_code.co_name}" + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + log_data = { + "function": function, + "message": "notify for cert expiration", + "task_id": task_id, + } + + if task_id and is_task_active(function, task_id, None): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) + try: + cli_notification.expirations(current_app.config.get("EXCLUDE_CN_FROM_NOTIFICATION", [])) + except SoftTimeLimitExceeded: + log_data["message"] = "Notify expiring Time limit exceeded." + current_app.logger.error(log_data) + sentry.captureException() + metrics.send("celery.timeout", "counter", 1, metric_tags={"function": function}) + return + + red.set(f'{function}.last_success', int(time.time())) + metrics.send(f"{function}.success", 'counter', 1) diff --git a/lemur/common/defaults.py b/lemur/common/defaults.py index e9bbc6e6..d563dbd0 100644 --- a/lemur/common/defaults.py +++ b/lemur/common/defaults.py @@ -3,22 +3,29 @@ import unicodedata from cryptography import x509 from flask import current_app + +from lemur.common.utils import is_selfsigned from lemur.extensions import sentry from lemur.constants import SAN_NAMING_TEMPLATE, DEFAULT_NAMING_TEMPLATE -def text_to_slug(value): - """Normalize a string to a "slug" value, stripping character accents and removing non-alphanum characters.""" +def text_to_slug(value, joiner="-"): + """ + Normalize a string to a "slug" value, stripping character accents and removing non-alphanum characters. + A series of non-alphanumeric characters is replaced with the joiner character. + """ # Strip all character accents: decompose Unicode characters and then drop combining chars. - value = ''.join(c for c in unicodedata.normalize('NFKD', value) if not unicodedata.combining(c)) + value = "".join( + c for c in unicodedata.normalize("NFKD", value) if not unicodedata.combining(c) + ) - # Replace all remaining non-alphanumeric characters with '-'. Multiple characters get collapsed into a single dash. - # Except, keep 'xn--' used in IDNA domain names as is. - value = re.sub(r'[^A-Za-z0-9.]+(?' is returned. + If issuer cannot be determined, '' is returned. + + :param cert: Parsed certificate object + :return: Issuer slug """ - delchars = ''.join(c for c in map(chr, range(256)) if not c.isalnum()) - try: - # Try organization name or fall back to CN - issuer = (cert.issuer.get_attributes_for_oid(x509.OID_COMMON_NAME) or - cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)) - issuer = str(issuer[0].value) - for c in delchars: - issuer = issuer.replace(c, "") - return issuer - except Exception as e: - sentry.captureException() - current_app.logger.error("Unable to get issuer! {0}".format(e)) - return "Unknown" + # If certificate is self-signed, we return a special value -- there really is no distinct "issuer" for it + if is_selfsigned(cert): + return "" + + # Try Common Name or fall back to Organization name + attrs = cert.issuer.get_attributes_for_oid( + x509.OID_COMMON_NAME + ) or cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME) + if not attrs: + current_app.logger.error( + "Unable to get issuer! Cert serial {:x}".format(cert.serial_number) + ) + return "" + + return text_to_slug(attrs[0].value, "") def not_before(cert): diff --git a/lemur/common/fields.py b/lemur/common/fields.py index 5ab0c6f0..15631832 100644 --- a/lemur/common/fields.py +++ b/lemur/common/fields.py @@ -25,6 +25,7 @@ class Hex(Field): """ A hex formatted string. """ + def _serialize(self, value, attr, obj): if value: value = hex(int(value))[2:].upper() @@ -48,25 +49,25 @@ class ArrowDateTime(Field): """ DATEFORMAT_SERIALIZATION_FUNCS = { - 'iso': utils.isoformat, - 'iso8601': utils.isoformat, - 'rfc': utils.rfcformat, - 'rfc822': utils.rfcformat, + "iso": utils.isoformat, + "iso8601": utils.isoformat, + "rfc": utils.rfcformat, + "rfc822": utils.rfcformat, } DATEFORMAT_DESERIALIZATION_FUNCS = { - 'iso': utils.from_iso, - 'iso8601': utils.from_iso, - 'rfc': utils.from_rfc, - 'rfc822': utils.from_rfc, + "iso": utils.from_iso, + "iso8601": utils.from_iso, + "rfc": utils.from_rfc, + "rfc822": utils.from_rfc, } - DEFAULT_FORMAT = 'iso' + DEFAULT_FORMAT = "iso" localtime = False default_error_messages = { - 'invalid': 'Not a valid datetime.', - 'format': '"{input}" cannot be formatted as a datetime.', + "invalid": "Not a valid datetime.", + "format": '"{input}" cannot be formatted as a datetime.', } def __init__(self, format=None, **kwargs): @@ -89,34 +90,36 @@ class ArrowDateTime(Field): try: return format_func(value, localtime=self.localtime) except (AttributeError, ValueError) as err: - self.fail('format', input=value) + self.fail("format", input=value) else: return value.strftime(self.dateformat) def _deserialize(self, value, attr, data): if not value: # Falsy values, e.g. '', None, [] are not valid - raise self.fail('invalid') + raise self.fail("invalid") self.dateformat = self.dateformat or self.DEFAULT_FORMAT func = self.DATEFORMAT_DESERIALIZATION_FUNCS.get(self.dateformat) if func: try: return arrow.get(func(value)) except (TypeError, AttributeError, ValueError): - raise self.fail('invalid') + raise self.fail("invalid") elif self.dateformat: try: return dt.datetime.strptime(value, self.dateformat) except (TypeError, AttributeError, ValueError): - raise self.fail('invalid') + raise self.fail("invalid") elif utils.dateutil_available: try: return arrow.get(utils.from_datestring(value)) except TypeError: - raise self.fail('invalid') + raise self.fail("invalid") else: - warnings.warn('It is recommended that you install python-dateutil ' - 'for improved datetime deserialization.') - raise self.fail('invalid') + warnings.warn( + "It is recommended that you install python-dateutil " + "for improved datetime deserialization." + ) + raise self.fail("invalid") class KeyUsageExtension(Field): @@ -131,73 +134,75 @@ class KeyUsageExtension(Field): def _serialize(self, value, attr, obj): return { - 'useDigitalSignature': value.digital_signature, - 'useNonRepudiation': value.content_commitment, - 'useKeyEncipherment': value.key_encipherment, - 'useDataEncipherment': value.data_encipherment, - 'useKeyAgreement': value.key_agreement, - 'useKeyCertSign': value.key_cert_sign, - 'useCRLSign': value.crl_sign, - 'useEncipherOnly': value._encipher_only, - 'useDecipherOnly': value._decipher_only + "useDigitalSignature": value.digital_signature, + "useNonRepudiation": value.content_commitment, + "useKeyEncipherment": value.key_encipherment, + "useDataEncipherment": value.data_encipherment, + "useKeyAgreement": value.key_agreement, + "useKeyCertSign": value.key_cert_sign, + "useCRLSign": value.crl_sign, + "useEncipherOnly": value._encipher_only, + "useDecipherOnly": value._decipher_only, } def _deserialize(self, value, attr, data): keyusages = { - 'digital_signature': False, - 'content_commitment': False, - 'key_encipherment': False, - 'data_encipherment': False, - 'key_agreement': False, - 'key_cert_sign': False, - 'crl_sign': False, - 'encipher_only': False, - 'decipher_only': False + "digital_signature": False, + "content_commitment": False, + "key_encipherment": False, + "data_encipherment": False, + "key_agreement": False, + "key_cert_sign": False, + "crl_sign": False, + "encipher_only": False, + "decipher_only": False, } for k, v in value.items(): - if k == 'useDigitalSignature': - keyusages['digital_signature'] = v + if k == "useDigitalSignature": + keyusages["digital_signature"] = v - elif k == 'useNonRepudiation': - keyusages['content_commitment'] = v + elif k == "useNonRepudiation": + keyusages["content_commitment"] = v - elif k == 'useKeyEncipherment': - keyusages['key_encipherment'] = v + elif k == "useKeyEncipherment": + keyusages["key_encipherment"] = v - elif k == 'useDataEncipherment': - keyusages['data_encipherment'] = v + elif k == "useDataEncipherment": + keyusages["data_encipherment"] = v - elif k == 'useKeyCertSign': - keyusages['key_cert_sign'] = v + elif k == "useKeyCertSign": + keyusages["key_cert_sign"] = v - elif k == 'useCRLSign': - keyusages['crl_sign'] = v + elif k == "useCRLSign": + keyusages["crl_sign"] = v - elif k == 'useKeyAgreement': - keyusages['key_agreement'] = v + elif k == "useKeyAgreement": + keyusages["key_agreement"] = v - elif k == 'useEncipherOnly' and v: - keyusages['encipher_only'] = True - keyusages['key_agreement'] = True + elif k == "useEncipherOnly" and v: + keyusages["encipher_only"] = True + keyusages["key_agreement"] = True - elif k == 'useDecipherOnly' and v: - keyusages['decipher_only'] = True - keyusages['key_agreement'] = True + elif k == "useDecipherOnly" and v: + keyusages["decipher_only"] = True + keyusages["key_agreement"] = True - if keyusages['encipher_only'] and keyusages['decipher_only']: - raise ValidationError('A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages.') + if keyusages["encipher_only"] and keyusages["decipher_only"]: + raise ValidationError( + "A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages." + ) return x509.KeyUsage( - digital_signature=keyusages['digital_signature'], - content_commitment=keyusages['content_commitment'], - key_encipherment=keyusages['key_encipherment'], - data_encipherment=keyusages['data_encipherment'], - key_agreement=keyusages['key_agreement'], - key_cert_sign=keyusages['key_cert_sign'], - crl_sign=keyusages['crl_sign'], - encipher_only=keyusages['encipher_only'], - decipher_only=keyusages['decipher_only'] + digital_signature=keyusages["digital_signature"], + content_commitment=keyusages["content_commitment"], + key_encipherment=keyusages["key_encipherment"], + data_encipherment=keyusages["data_encipherment"], + key_agreement=keyusages["key_agreement"], + key_cert_sign=keyusages["key_cert_sign"], + crl_sign=keyusages["crl_sign"], + encipher_only=keyusages["encipher_only"], + decipher_only=keyusages["decipher_only"], ) @@ -216,69 +221,77 @@ class ExtendedKeyUsageExtension(Field): usage_list = {} for usage in usages: if usage == x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH: - usage_list['useClientAuthentication'] = True + usage_list["useClientAuthentication"] = True elif usage == x509.oid.ExtendedKeyUsageOID.SERVER_AUTH: - usage_list['useServerAuthentication'] = True + usage_list["useServerAuthentication"] = True elif usage == x509.oid.ExtendedKeyUsageOID.CODE_SIGNING: - usage_list['useCodeSigning'] = True + usage_list["useCodeSigning"] = True elif usage == x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION: - usage_list['useEmailProtection'] = True + usage_list["useEmailProtection"] = True elif usage == x509.oid.ExtendedKeyUsageOID.TIME_STAMPING: - usage_list['useTimestamping'] = True + usage_list["useTimestamping"] = True elif usage == x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING: - usage_list['useOCSPSigning'] = True + usage_list["useOCSPSigning"] = True - elif usage.dotted_string == '1.3.6.1.5.5.7.3.14': - usage_list['useEapOverLAN'] = True + elif usage.dotted_string == "1.3.6.1.5.5.7.3.14": + usage_list["useEapOverLAN"] = True - elif usage.dotted_string == '1.3.6.1.5.5.7.3.13': - usage_list['useEapOverPPP'] = True + elif usage.dotted_string == "1.3.6.1.5.5.7.3.13": + usage_list["useEapOverPPP"] = True - elif usage.dotted_string == '1.3.6.1.4.1.311.20.2.2': - usage_list['useSmartCardLogon'] = True + elif usage.dotted_string == "1.3.6.1.4.1.311.20.2.2": + usage_list["useSmartCardLogon"] = True else: - current_app.logger.warning('Unable to serialize ExtendedKeyUsage with OID: {usage}'.format(usage=usage.dotted_string)) + current_app.logger.warning( + "Unable to serialize ExtendedKeyUsage with OID: {usage}".format( + usage=usage.dotted_string + ) + ) return usage_list def _deserialize(self, value, attr, data): usage_oids = [] for k, v in value.items(): - if k == 'useClientAuthentication' and v: + if k == "useClientAuthentication" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH) - elif k == 'useServerAuthentication' and v: + elif k == "useServerAuthentication" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.SERVER_AUTH) - elif k == 'useCodeSigning' and v: + elif k == "useCodeSigning" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.CODE_SIGNING) - elif k == 'useEmailProtection' and v: + elif k == "useEmailProtection" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION) - elif k == 'useTimestamping' and v: + elif k == "useTimestamping" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.TIME_STAMPING) - elif k == 'useOCSPSigning' and v: + elif k == "useOCSPSigning" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING) - elif k == 'useEapOverLAN' and v: + elif k == "useEapOverLAN" and v: usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.14")) - elif k == 'useEapOverPPP' and v: + elif k == "useEapOverPPP" and v: usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.13")) - elif k == 'useSmartCardLogon' and v: + elif k == "useSmartCardLogon" and v: usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.4.1.311.20.2.2")) else: - current_app.logger.warning('Unable to deserialize ExtendedKeyUsage with name: {key}'.format(key=k)) + current_app.logger.warning( + "Unable to deserialize ExtendedKeyUsage with name: {key}".format( + key=k + ) + ) return x509.ExtendedKeyUsage(usage_oids) @@ -294,15 +307,17 @@ class BasicConstraintsExtension(Field): """ def _serialize(self, value, attr, obj): - return {'ca': value.ca, 'path_length': value.path_length} + return {"ca": value.ca, "path_length": value.path_length} def _deserialize(self, value, attr, data): - ca = value.get('ca', False) - path_length = value.get('path_length', None) + ca = value.get("ca", False) + path_length = value.get("path_length", None) if ca: if not isinstance(path_length, (type(None), int)): - raise ValidationError('A CA certificate path_length (for BasicConstraints) must be None or an integer.') + raise ValidationError( + "A CA certificate path_length (for BasicConstraints) must be None or an integer." + ) return x509.BasicConstraints(ca=True, path_length=path_length) else: return x509.BasicConstraints(ca=False, path_length=None) @@ -317,6 +332,7 @@ class SubjectAlternativeNameExtension(Field): :param kwargs: The same keyword arguments that :class:`Field` receives. """ + def _serialize(self, value, attr, obj): general_names = [] name_type = None @@ -326,53 +342,59 @@ class SubjectAlternativeNameExtension(Field): value = name.value if isinstance(name, x509.DNSName): - name_type = 'DNSName' + name_type = "DNSName" elif isinstance(name, x509.IPAddress): if isinstance(value, ipaddress.IPv4Network): - name_type = 'IPNetwork' + name_type = "IPNetwork" else: - name_type = 'IPAddress' + name_type = "IPAddress" value = str(value) elif isinstance(name, x509.UniformResourceIdentifier): - name_type = 'uniformResourceIdentifier' + name_type = "uniformResourceIdentifier" elif isinstance(name, x509.DirectoryName): - name_type = 'directoryName' + name_type = "directoryName" elif isinstance(name, x509.RFC822Name): - name_type = 'rfc822Name' + name_type = "rfc822Name" elif isinstance(name, x509.RegisteredID): - name_type = 'registeredID' + name_type = "registeredID" value = value.dotted_string else: - current_app.logger.warning('Unknown SubAltName type: {name}'.format(name=name)) + current_app.logger.warning( + "Unknown SubAltName type: {name}".format(name=name) + ) continue - general_names.append({'nameType': name_type, 'value': value}) + general_names.append({"nameType": name_type, "value": value}) return general_names def _deserialize(self, value, attr, data): general_names = [] for name in value: - if name['nameType'] == 'DNSName': - validators.sensitive_domain(name['value']) - general_names.append(x509.DNSName(name['value'])) + if name["nameType"] == "DNSName": + validators.sensitive_domain(name["value"]) + general_names.append(x509.DNSName(name["value"])) - elif name['nameType'] == 'IPAddress': - general_names.append(x509.IPAddress(ipaddress.ip_address(name['value']))) + elif name["nameType"] == "IPAddress": + general_names.append( + x509.IPAddress(ipaddress.ip_address(name["value"])) + ) - elif name['nameType'] == 'IPNetwork': - general_names.append(x509.IPAddress(ipaddress.ip_network(name['value']))) + elif name["nameType"] == "IPNetwork": + general_names.append( + x509.IPAddress(ipaddress.ip_network(name["value"])) + ) - elif name['nameType'] == 'uniformResourceIdentifier': - general_names.append(x509.UniformResourceIdentifier(name['value'])) + elif name["nameType"] == "uniformResourceIdentifier": + general_names.append(x509.UniformResourceIdentifier(name["value"])) - elif name['nameType'] == 'directoryName': + elif name["nameType"] == "directoryName": # TODO: Need to parse a string in name['value'] like: # 'CN=Common Name, O=Org Name, OU=OrgUnit Name, C=US, ST=ST, L=City/emailAddress=person@example.com' # or @@ -390,26 +412,32 @@ class SubjectAlternativeNameExtension(Field): # general_names.append(x509.DirectoryName(x509.Name(BLAH)))) pass - elif name['nameType'] == 'rfc822Name': - general_names.append(x509.RFC822Name(name['value'])) + elif name["nameType"] == "rfc822Name": + general_names.append(x509.RFC822Name(name["value"])) - elif name['nameType'] == 'registeredID': - general_names.append(x509.RegisteredID(x509.ObjectIdentifier(name['value']))) + elif name["nameType"] == "registeredID": + general_names.append( + x509.RegisteredID(x509.ObjectIdentifier(name["value"])) + ) - elif name['nameType'] == 'otherName': + elif name["nameType"] == "otherName": # This has two inputs (type and value), so it doesn't fit the mold of the rest of these GeneralName entities. # general_names.append(x509.OtherName(name['type'], bytes(name['value']), 'utf-8')) pass - elif name['nameType'] == 'x400Address': + elif name["nameType"] == "x400Address": # The Python Cryptography library doesn't support x400Address types (yet?) pass - elif name['nameType'] == 'EDIPartyName': + elif name["nameType"] == "EDIPartyName": # The Python Cryptography library doesn't support EDIPartyName types (yet?) pass else: - current_app.logger.warning('Unable to deserialize SubAltName with type: {name_type}'.format(name_type=name['nameType'])) + current_app.logger.warning( + "Unable to deserialize SubAltName with type: {name_type}".format( + name_type=name["nameType"] + ) + ) return x509.SubjectAlternativeName(general_names) diff --git a/lemur/common/health.py b/lemur/common/health.py index 69df3f0c..7e0a17ff 100644 --- a/lemur/common/health.py +++ b/lemur/common/health.py @@ -10,20 +10,20 @@ from flask import Blueprint from lemur.database import db from lemur.extensions import sentry -mod = Blueprint('healthCheck', __name__) +mod = Blueprint("healthCheck", __name__) -@mod.route('/healthcheck') +@mod.route("/healthcheck") def health(): try: if healthcheck(db): - return 'ok' + return "ok" except Exception: sentry.captureException() - return 'db check failed' + return "db check failed" def healthcheck(db): with db.engine.connect() as connection: - connection.execute('SELECT 1;') + connection.execute("SELECT 1;") return True diff --git a/lemur/common/managers.py b/lemur/common/managers.py index 9f30f216..6ce2608f 100644 --- a/lemur/common/managers.py +++ b/lemur/common/managers.py @@ -52,7 +52,7 @@ class InstanceManager(object): results = [] for cls_path in class_list: - module_name, class_name = cls_path.rsplit('.', 1) + module_name, class_name = cls_path.rsplit(".", 1) try: module = __import__(module_name, {}, {}, class_name) cls = getattr(module, class_name) @@ -62,10 +62,14 @@ class InstanceManager(object): results.append(cls) except InvalidConfiguration as e: - current_app.logger.warning("Plugin '{0}' may not work correctly. {1}".format(class_name, e)) + current_app.logger.warning( + "Plugin '{0}' may not work correctly. {1}".format(class_name, e) + ) except Exception as e: - current_app.logger.exception("Unable to import {0}. Reason: {1}".format(cls_path, e)) + current_app.logger.exception( + "Unable to import {0}. Reason: {1}".format(cls_path, e) + ) continue self.cache = results diff --git a/lemur/common/missing.py b/lemur/common/missing.py index a4bbba77..f991d2e3 100644 --- a/lemur/common/missing.py +++ b/lemur/common/missing.py @@ -11,14 +11,15 @@ def convert_validity_years(data): :param data: :return: """ - if data.get('validity_years'): + if data.get("validity_years"): now = arrow.utcnow() - data['validity_start'] = now.isoformat() + data["validity_start"] = now.isoformat() - end = now.replace(years=+int(data['validity_years'])) - if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True): + end = now.shift(years=+int(data["validity_years"])) + + if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True): if is_weekend(end): - end = end.replace(days=-2) + end = end.shift(days=-2) - data['validity_end'] = end.isoformat() + data["validity_end"] = end.isoformat() return data diff --git a/lemur/common/redis.py b/lemur/common/redis.py new file mode 100644 index 00000000..ca15734f --- /dev/null +++ b/lemur/common/redis.py @@ -0,0 +1,52 @@ +""" +Helper Class for Redis + +""" +import redis +import sys +from flask import current_app +from lemur.extensions import sentry +from lemur.factory import create_app + +if current_app: + flask_app = current_app +else: + flask_app = create_app() + + +class RedisHandler: + def __init__(self, host=flask_app.config.get('REDIS_HOST', 'localhost'), + port=flask_app.config.get('REDIS_PORT', 6379), + db=flask_app.config.get('REDIS_DB', 0)): + self.host = host + self.port = port + self.db = db + + def redis(self, db=0): + # The decode_responses flag here directs the client to convert the responses from Redis into Python strings + # using the default encoding utf-8. This is client specific. + function = f"{__name__}.{sys._getframe().f_code.co_name}" + try: + red = redis.StrictRedis(host=self.host, port=self.port, db=self.db, encoding="utf-8", decode_responses=True) + red.set("test", 0) + except redis.ConnectionError: + log_data = { + "function": function, + "message": "Redis Connection error", + "host": self.host, + "port": self.port + } + current_app.logger.error(log_data) + sentry.captureException() + return red + + +def redis_get(key, default=None): + red = RedisHandler().redis() + try: + v = red.get(key) + except redis.exceptions.ConnectionError: + v = None + if not v: + return default + return v diff --git a/lemur/common/schema.py b/lemur/common/schema.py index ee765dc4..ee1db464 100644 --- a/lemur/common/schema.py +++ b/lemur/common/schema.py @@ -22,27 +22,26 @@ class LemurSchema(Schema): """ Base schema from which all grouper schema's inherit """ + __envelope__ = True def under(self, data, many=None): items = [] if many: for i in data: - items.append( - {underscore(key): value for key, value in i.items()} - ) + items.append({underscore(key): value for key, value in i.items()}) return items - return { - underscore(key): value - for key, value in data.items() - } + return {underscore(key): value for key, value in data.items()} def camel(self, data, many=None): items = [] if many: for i in data: items.append( - {camelize(key, uppercase_first_letter=False): value for key, value in i.items()} + { + camelize(key, uppercase_first_letter=False): value + for key, value in i.items() + } ) return items return { @@ -52,16 +51,16 @@ class LemurSchema(Schema): def wrap_with_envelope(self, data, many): if many: - if 'total' in self.context.keys(): - return dict(total=self.context['total'], items=data) + if "total" in self.context.keys(): + return dict(total=self.context["total"], items=data) return data class LemurInputSchema(LemurSchema): @pre_load(pass_many=True) def preprocess(self, data, many): - if isinstance(data, dict) and data.get('owner'): - data['owner'] = data['owner'].lower() + if isinstance(data, dict) and data.get("owner"): + data["owner"] = data["owner"].lower() return self.under(data, many=many) @@ -74,17 +73,17 @@ class LemurOutputSchema(LemurSchema): def unwrap_envelope(self, data, many): if many: - if data['items']: + if data["items"]: if isinstance(data, InstrumentedList) or isinstance(data, list): - self.context['total'] = len(data) + self.context["total"] = len(data) return data else: - self.context['total'] = data['total'] + self.context["total"] = data["total"] else: - self.context['total'] = 0 - data = {'items': []} + self.context["total"] = 0 + data = {"items": []} - return data['items'] + return data["items"] return data @@ -110,11 +109,11 @@ def format_errors(messages): def wrap_errors(messages): - errors = dict(message='Validation Error.') - if messages.get('_schema'): - errors['reasons'] = {'Schema': {'rule': messages['_schema']}} + errors = dict(message="Validation Error.") + if messages.get("_schema"): + errors["reasons"] = {"Schema": {"rule": messages["_schema"]}} else: - errors['reasons'] = format_errors(messages) + errors["reasons"] = format_errors(messages) return errors @@ -123,19 +122,19 @@ def unwrap_pagination(data, output_schema): return data if isinstance(data, dict): - if 'total' in data.keys(): - if data.get('total') == 0: + if "total" in data.keys(): + if data.get("total") == 0: return data - marshaled_data = {'total': data['total']} - marshaled_data['items'] = output_schema.dump(data['items'], many=True).data + marshaled_data = {"total": data["total"]} + marshaled_data["items"] = output_schema.dump(data["items"], many=True).data return marshaled_data return output_schema.dump(data).data elif isinstance(data, list): - marshaled_data = {'total': len(data)} - marshaled_data['items'] = output_schema.dump(data, many=True).data + marshaled_data = {"total": len(data)} + marshaled_data["items"] = output_schema.dump(data, many=True).data return marshaled_data return output_schema.dump(data).data @@ -155,7 +154,7 @@ def validate_schema(input_schema, output_schema): if errors: return wrap_errors(errors), 400 - kwargs['data'] = data + kwargs["data"] = data try: resp = f(*args, **kwargs) @@ -170,7 +169,13 @@ def validate_schema(input_schema, output_schema): if not resp: return dict(message="No data found"), 404 - return unwrap_pagination(resp, output_schema), 200 + if callable(output_schema): + output_schema_to_use = output_schema() + else: + output_schema_to_use = output_schema + + return unwrap_pagination(resp, output_schema_to_use), 200 return decorated_function + return decorator diff --git a/lemur/common/utils.py b/lemur/common/utils.py index 7ea9d7f2..c33722b2 100644 --- a/lemur/common/utils.py +++ b/lemur/common/utils.py @@ -7,12 +7,16 @@ .. moduleauthor:: Kevin Glisson """ import random +import re import string import sqlalchemy from cryptography import x509 +from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.asymmetric import rsa, ec +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa, ec, padding +from cryptography.hazmat.primitives.serialization import load_pem_private_key from flask_restful.reqparse import RequestParser from sqlalchemy import and_, func @@ -21,21 +25,22 @@ from lemur.exceptions import InvalidConfiguration paginated_parser = RequestParser() -paginated_parser.add_argument('count', type=int, default=10, location='args') -paginated_parser.add_argument('page', type=int, default=1, location='args') -paginated_parser.add_argument('sortDir', type=str, dest='sort_dir', location='args') -paginated_parser.add_argument('sortBy', type=str, dest='sort_by', location='args') -paginated_parser.add_argument('filter', type=str, location='args') +paginated_parser.add_argument("count", type=int, default=10, location="args") +paginated_parser.add_argument("page", type=int, default=1, location="args") +paginated_parser.add_argument("sortDir", type=str, dest="sort_dir", location="args") +paginated_parser.add_argument("sortBy", type=str, dest="sort_by", location="args") +paginated_parser.add_argument("filter", type=str, location="args") +paginated_parser.add_argument("owner", type=str, location="args") def get_psuedo_random_string(): """ Create a random and strongish challenge. """ - challenge = ''.join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa - challenge += ''.join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa - challenge += ''.join(random.choice(string.ascii_lowercase) for x in range(6)) - challenge += ''.join(random.choice(string.digits) for x in range(6)) # noqa + challenge = "".join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa + challenge += "".join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa + challenge += "".join(random.choice(string.ascii_lowercase) for x in range(6)) + challenge += "".join(random.choice(string.digits) for x in range(6)) # noqa return challenge @@ -46,10 +51,46 @@ def parse_certificate(body): :param body: :return: """ - if isinstance(body, str): - body = body.encode('utf-8') + assert isinstance(body, str) - return x509.load_pem_x509_certificate(body, default_backend()) + return x509.load_pem_x509_certificate(body.encode("utf-8"), default_backend()) + + +def parse_private_key(private_key): + """ + Parses a PEM-format private key (RSA, DSA, ECDSA or any other supported algorithm). + + Raises ValueError for an invalid string. Raises AssertionError when passed value is not str-type. + + :param private_key: String containing PEM private key + """ + assert isinstance(private_key, str) + + return load_pem_private_key( + private_key.encode("utf8"), password=None, backend=default_backend() + ) + + +def split_pem(data): + """ + Split a string of several PEM payloads to a list of strings. + + :param data: String + :return: List of strings + """ + return re.split("\n(?=-----BEGIN )", data) + + +def parse_cert_chain(pem_chain): + """ + Helper function to split and parse a series of PEM certificates. + + :param pem_chain: string + :return: List of parsed certificates + """ + if pem_chain is None: + return [] + return [parse_certificate(cert) for cert in split_pem(pem_chain) if cert] def parse_csr(csr): @@ -59,17 +100,17 @@ def parse_csr(csr): :param csr: :return: """ - if isinstance(csr, str): - csr = csr.encode('utf-8') + assert isinstance(csr, str) - return x509.load_pem_x509_csr(csr, default_backend()) + return x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend()) def get_authority_key(body): """Returns the authority key for a given certificate in hex format""" parsed_cert = parse_certificate(body) authority_key = parsed_cert.extensions.get_extension_for_class( - x509.AuthorityKeyIdentifier).value.key_identifier + x509.AuthorityKeyIdentifier + ).value.key_identifier return authority_key.hex() @@ -89,20 +130,17 @@ def generate_private_key(key_type): _CURVE_TYPES = { "ECCPRIME192V1": ec.SECP192R1(), "ECCPRIME256V1": ec.SECP256R1(), - "ECCSECP192R1": ec.SECP192R1(), "ECCSECP224R1": ec.SECP224R1(), "ECCSECP256R1": ec.SECP256R1(), "ECCSECP384R1": ec.SECP384R1(), "ECCSECP521R1": ec.SECP521R1(), "ECCSECP256K1": ec.SECP256K1(), - "ECCSECT163K1": ec.SECT163K1(), "ECCSECT233K1": ec.SECT233K1(), "ECCSECT283K1": ec.SECT283K1(), "ECCSECT409K1": ec.SECT409K1(), "ECCSECT571K1": ec.SECT571K1(), - "ECCSECT163R2": ec.SECT163R2(), "ECCSECT233R1": ec.SECT233R1(), "ECCSECT283R1": ec.SECT283R1(), @@ -111,25 +149,74 @@ def generate_private_key(key_type): } if key_type not in CERTIFICATE_KEY_TYPES: - raise Exception("Invalid key type: {key_type}. Supported key types: {choices}".format( - key_type=key_type, - choices=",".join(CERTIFICATE_KEY_TYPES) - )) + raise Exception( + "Invalid key type: {key_type}. Supported key types: {choices}".format( + key_type=key_type, choices=",".join(CERTIFICATE_KEY_TYPES) + ) + ) - if 'RSA' in key_type: + if "RSA" in key_type: key_size = int(key_type[3:]) return rsa.generate_private_key( - public_exponent=65537, - key_size=key_size, - backend=default_backend() + public_exponent=65537, key_size=key_size, backend=default_backend() ) - elif 'ECC' in key_type: + elif "ECC" in key_type: return ec.generate_private_key( - curve=_CURVE_TYPES[key_type], - backend=default_backend() + curve=_CURVE_TYPES[key_type], backend=default_backend() ) +def check_cert_signature(cert, issuer_public_key): + """ + Check a certificate's signature against an issuer public key. + Before EC validation, make sure we support the algorithm, otherwise raise UnsupportedAlgorithm + On success, returns None; on failure, raises UnsupportedAlgorithm or InvalidSignature. + """ + if isinstance(issuer_public_key, rsa.RSAPublicKey): + # RSA requires padding, just to make life difficult for us poor developers :( + if cert.signature_algorithm_oid == x509.SignatureAlgorithmOID.RSASSA_PSS: + # In 2005, IETF devised a more secure padding scheme to replace PKCS #1 v1.5. To make sure that + # nobody can easily support or use it, they mandated lots of complicated parameters, unlike any + # other X.509 signature scheme. + # https://tools.ietf.org/html/rfc4056 + raise UnsupportedAlgorithm("RSASSA-PSS not supported") + else: + padder = padding.PKCS1v15() + issuer_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + padder, + cert.signature_hash_algorithm, + ) + elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance( + ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA + ): + issuer_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + ec.ECDSA(cert.signature_hash_algorithm), + ) + else: + raise UnsupportedAlgorithm( + "Unsupported Algorithm '{var}'.".format( + var=cert.signature_algorithm_oid._name + ) + ) + + +def is_selfsigned(cert): + """ + Returns True if the certificate is self-signed. + Returns False for failed verification or unsupported signing algorithm. + """ + try: + check_cert_signature(cert, cert.public_key()) + # If verification was successful, it's self-signed. + return True + except InvalidSignature: + return False + + def is_weekend(date): """ Determines if a given date is on a weekend. @@ -150,7 +237,9 @@ def validate_conf(app, required_vars): """ for var in required_vars: if var not in app.config: - raise InvalidConfiguration("Required variable '{var}' is not set in Lemur's conf.".format(var=var)) + raise InvalidConfiguration( + "Required variable '{var}' is not set in Lemur's conf.".format(var=var) + ) # https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/WindowedRangeQuery @@ -169,18 +258,15 @@ def column_windows(session, column, windowsize): be computed. """ + def int_for_range(start_id, end_id): if end_id: - return and_( - column >= start_id, - column < end_id - ) + return and_(column >= start_id, column < end_id) else: return column >= start_id q = session.query( - column, - func.row_number().over(order_by=column).label('rownum') + column, func.row_number().over(order_by=column).label("rownum") ).from_self(column) if windowsize > 1: @@ -200,9 +286,7 @@ def column_windows(session, column, windowsize): def windowed_query(q, column, windowsize): """"Break a Query into windows on a given column.""" - for whereclause in column_windows( - q.session, - column, windowsize): + for whereclause in column_windows(q.session, column, windowsize): for row in q.filter(whereclause).order_by(column): yield row @@ -210,4 +294,16 @@ def windowed_query(q, column, windowsize): def truthiness(s): """If input string resembles something truthy then return True, else False.""" - return s.lower() in ('true', 'yes', 'on', 't', '1') + return s.lower() in ("true", "yes", "on", "t", "1") + + +def find_matching_certificates_by_hash(cert, matching_certs): + """Given a Cryptography-formatted certificate cert, and Lemur-formatted certificates (matching_certs), + determine if any of the certificate hashes match and return the matches.""" + matching = [] + for c in matching_certs: + if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint( + hashes.SHA256() + ): + matching.append(c) + return matching diff --git a/lemur/common/validators.py b/lemur/common/validators.py index 47a94a30..2412e2d3 100644 --- a/lemur/common/validators.py +++ b/lemur/common/validators.py @@ -1,45 +1,14 @@ import re from cryptography import x509 +from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization from cryptography.x509 import NameOID from flask import current_app from marshmallow.exceptions import ValidationError from lemur.auth.permissions import SensitiveDomainPermission -from lemur.common.utils import parse_certificate, is_weekend -from lemur.domains import service as domain_service - - -def public_certificate(body): - """ - Determines if specified string is valid public certificate. - - :param body: - :return: - """ - try: - parse_certificate(body) - except Exception as e: - current_app.logger.exception(e) - raise ValidationError('Public certificate presented is not valid.') - - -def private_key(key): - """ - User to validate that a given string is a RSA private key - - :param key: - :return: :raise ValueError: - """ - try: - if isinstance(key, bytes): - serialization.load_pem_private_key(key, None, backend=default_backend()) - else: - serialization.load_pem_private_key(key.encode('utf-8'), None, backend=default_backend()) - except Exception: - raise ValidationError('Private key presented is not valid.') +from lemur.common.utils import check_cert_signature, is_weekend def common_name(value): @@ -47,7 +16,7 @@ def common_name(value): # Common name could be a domain name, or a human-readable name of the subject (often used in CA names or client # certificates). As a simple heuristic, we assume that human-readable names always include a space. # However, to avoid confusion for humans, we also don't count spaces at the beginning or end of the string. - if ' ' not in value.strip(): + if " " not in value.strip(): return sensitive_domain(value) @@ -61,14 +30,21 @@ def sensitive_domain(domain): # User has permission, no need to check anything return - whitelist = current_app.config.get('LEMUR_WHITELISTED_DOMAINS', []) + whitelist = current_app.config.get("LEMUR_WHITELISTED_DOMAINS", []) if whitelist and not any(re.match(pattern, domain) for pattern in whitelist): - raise ValidationError('Domain {0} does not match whitelisted domain patterns. ' - 'Contact an administrator to issue the certificate.'.format(domain)) + raise ValidationError( + "Domain {0} does not match whitelisted domain patterns. " + "Contact an administrator to issue the certificate.".format(domain) + ) - if any(d.sensitive for d in domain_service.get_by_name(domain)): - raise ValidationError('Domain {0} has been marked as sensitive. ' - 'Contact an administrator to issue the certificate.'.format(domain)) + # Avoid circular import. + from lemur.domains import service as domain_service + + if domain_service.is_domain_sensitive(domain): + raise ValidationError( + "Domain {0} has been marked as sensitive. " + "Contact an administrator to issue the certificate.".format(domain) + ) def encoding(oid_encoding): @@ -77,9 +53,13 @@ def encoding(oid_encoding): :param oid_encoding: :return: """ - valid_types = ['b64asn1', 'string', 'ia5string'] + valid_types = ["b64asn1", "string", "ia5string"] if oid_encoding.lower() not in [o_type.lower() for o_type in valid_types]: - raise ValidationError('Invalid Oid Encoding: {0} choose from {1}'.format(oid_encoding, ",".join(valid_types))) + raise ValidationError( + "Invalid Oid Encoding: {0} choose from {1}".format( + oid_encoding, ",".join(valid_types) + ) + ) def sub_alt_type(alt_type): @@ -88,10 +68,23 @@ def sub_alt_type(alt_type): :param alt_type: :return: """ - valid_types = ['DNSName', 'IPAddress', 'uniFormResourceIdentifier', 'directoryName', 'rfc822Name', 'registrationID', - 'otherName', 'x400Address', 'EDIPartyName'] + valid_types = [ + "DNSName", + "IPAddress", + "uniFormResourceIdentifier", + "directoryName", + "rfc822Name", + "registrationID", + "otherName", + "x400Address", + "EDIPartyName", + ] if alt_type.lower() not in [a_type.lower() for a_type in valid_types]: - raise ValidationError('Invalid SubAltName Type: {0} choose from {1}'.format(type, ",".join(valid_types))) + raise ValidationError( + "Invalid SubAltName Type: {0} choose from {1}".format( + type, ",".join(valid_types) + ) + ) def csr(data): @@ -101,16 +94,18 @@ def csr(data): :return: """ try: - request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend()) + request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend()) except Exception: - raise ValidationError('CSR presented is not valid.') + raise ValidationError("CSR presented is not valid.") # Validate common name and SubjectAltNames for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME): common_name(name.value) try: - alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName) + alt_names = request.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) for name in alt_names.value.get_values_for_type(x509.DNSName): sensitive_domain(name) @@ -119,25 +114,87 @@ def csr(data): def dates(data): - if not data.get('validity_start') and data.get('validity_end'): - raise ValidationError('If validity start is specified so must validity end.') + if not data.get("validity_start") and data.get("validity_end"): + raise ValidationError("If validity start is specified so must validity end.") - if not data.get('validity_end') and data.get('validity_start'): - raise ValidationError('If validity end is specified so must validity start.') + if not data.get("validity_end") and data.get("validity_start"): + raise ValidationError("If validity end is specified so must validity start.") - if data.get('validity_start') and data.get('validity_end'): - if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True): - if is_weekend(data.get('validity_end')): - raise ValidationError('Validity end must not land on a weekend.') + if data.get("validity_start") and data.get("validity_end"): + if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True): + if is_weekend(data.get("validity_end")): + raise ValidationError("Validity end must not land on a weekend.") - if not data['validity_start'] < data['validity_end']: - raise ValidationError('Validity start must be before validity end.') + if not data["validity_start"] < data["validity_end"]: + raise ValidationError("Validity start must be before validity end.") - if data.get('authority'): - if data.get('validity_start').date() < data['authority'].authority_certificate.not_before.date(): - raise ValidationError('Validity start must not be before {0}'.format(data['authority'].authority_certificate.not_before)) + if data.get("authority"): + if ( + data.get("validity_start").date() + < data["authority"].authority_certificate.not_before.date() + ): + raise ValidationError( + "Validity start must not be before {0}".format( + data["authority"].authority_certificate.not_before + ) + ) - if data.get('validity_end').date() > data['authority'].authority_certificate.not_after.date(): - raise ValidationError('Validity end must not be after {0}'.format(data['authority'].authority_certificate.not_after)) + if ( + data.get("validity_end").date() + > data["authority"].authority_certificate.not_after.date() + ): + raise ValidationError( + "Validity end must not be after {0}".format( + data["authority"].authority_certificate.not_after + ) + ) return data + + +def verify_private_key_match(key, cert, error_class=ValidationError): + """ + Checks that the supplied private key matches the certificate. + + :param cert: Parsed certificate + :param key: Parsed private key + :param error_class: Exception class to raise on error + """ + if key.public_key().public_numbers() != cert.public_key().public_numbers(): + raise error_class("Private key does not match certificate.") + + +def verify_cert_chain(certs, error_class=ValidationError): + """ + Verifies that the certificates in the chain are correct. + + We don't bother with full cert validation but just check that certs in the chain are signed by the next, to avoid + basic human errors -- such as pasting the wrong certificate. + + :param certs: List of parsed certificates, use parse_cert_chain() + :param error_class: Exception class to raise on error + """ + cert = certs[0] + for issuer in certs[1:]: + # Use the current cert's public key to verify the previous signature. + # "certificate validation is a complex problem that involves much more than just signature checks" + try: + check_cert_signature(cert, issuer.public_key()) + + except InvalidSignature: + # Avoid circular import. + from lemur.common import defaults + + raise error_class( + "Incorrect chain certificate(s) provided: '%s' is not signed by '%s'" + % ( + defaults.common_name(cert) or "Unknown", + defaults.common_name(issuer), + ) + ) + + except UnsupportedAlgorithm as err: + current_app.logger.warning("Skipping chain validation: %s", err) + + # Next loop will validate that *this issuer* cert is signed by the next chain cert. + cert = issuer diff --git a/lemur/constants.py b/lemur/constants.py index 060ecfed..cc1653cb 100644 --- a/lemur/constants.py +++ b/lemur/constants.py @@ -7,28 +7,28 @@ SAN_NAMING_TEMPLATE = "SAN-{subject}-{issuer}-{not_before}-{not_after}" DEFAULT_NAMING_TEMPLATE = "{subject}-{issuer}-{not_before}-{not_after}" NONSTANDARD_NAMING_TEMPLATE = "{issuer}-{not_before}-{not_after}" -SUCCESS_METRIC_STATUS = 'success' -FAILURE_METRIC_STATUS = 'failure' +SUCCESS_METRIC_STATUS = "success" +FAILURE_METRIC_STATUS = "failure" CERTIFICATE_KEY_TYPES = [ - 'RSA2048', - 'RSA4096', - 'ECCPRIME192V1', - 'ECCPRIME256V1', - 'ECCSECP192R1', - 'ECCSECP224R1', - 'ECCSECP256R1', - 'ECCSECP384R1', - 'ECCSECP521R1', - 'ECCSECP256K1', - 'ECCSECT163K1', - 'ECCSECT233K1', - 'ECCSECT283K1', - 'ECCSECT409K1', - 'ECCSECT571K1', - 'ECCSECT163R2', - 'ECCSECT233R1', - 'ECCSECT283R1', - 'ECCSECT409R1', - 'ECCSECT571R2' + "RSA2048", + "RSA4096", + "ECCPRIME192V1", + "ECCPRIME256V1", + "ECCSECP192R1", + "ECCSECP224R1", + "ECCSECP256R1", + "ECCSECP384R1", + "ECCSECP521R1", + "ECCSECP256K1", + "ECCSECT163K1", + "ECCSECT233K1", + "ECCSECT283K1", + "ECCSECT409K1", + "ECCSECT571K1", + "ECCSECT163R2", + "ECCSECT233R1", + "ECCSECT283R1", + "ECCSECT409R1", + "ECCSECT571R2", ] diff --git a/lemur/database.py b/lemur/database.py index 82fb0423..a9610325 100644 --- a/lemur/database.py +++ b/lemur/database.py @@ -43,7 +43,7 @@ def session_query(model): :param model: sqlalchemy model :return: query object for model """ - return model.query if hasattr(model, 'query') else db.session.query(model) + return model.query if hasattr(model, "query") else db.session.query(model) def create_query(model, kwargs): @@ -77,7 +77,7 @@ def add(model): def get_model_column(model, field): - if field in getattr(model, 'sensitive_fields', ()): + if field in getattr(model, "sensitive_fields", ()): raise AttrNotFound(field) column = model.__table__.columns._data.get(field, None) if column is None: @@ -100,7 +100,7 @@ def find_all(query, model, kwargs): kwargs = filter_none(kwargs) for attr, value in kwargs.items(): if not isinstance(value, list): - value = value.split(',') + value = value.split(",") conditions.append(get_model_column(model, attr).in_(value)) @@ -200,7 +200,7 @@ def filter(query, model, terms): :return: """ column = get_model_column(model, underscore(terms[0])) - return query.filter(column.ilike('%{}%'.format(terms[1]))) + return query.filter(column.ilike("%{}%".format(terms[1]))) def sort(query, model, field, direction): @@ -214,7 +214,7 @@ def sort(query, model, field, direction): :param direction: """ column = get_model_column(model, underscore(field)) - return query.order_by(column.desc() if direction == 'desc' else column.asc()) + return query.order_by(column.desc() if direction == "desc" else column.asc()) def paginate(query, page, count): @@ -247,10 +247,10 @@ def update_list(model, model_attr, item_model, items): for i in items: for item in getattr(model, model_attr): - if item.id == i['id']: + if item.id == i["id"]: break else: - getattr(model, model_attr).append(get(item_model, i['id'])) + getattr(model, model_attr).append(get(item_model, i["id"])) return model @@ -276,9 +276,9 @@ def get_count(q): disable_group_by = False if len(q._entities) > 1: # currently support only one entity - raise Exception('only one entity is supported for get_count, got: %s' % q) + raise Exception("only one entity is supported for get_count, got: %s" % q) entity = q._entities[0] - if hasattr(entity, 'column'): + if hasattr(entity, "column"): # _ColumnEntity has column attr - on case: query(Model.column)... col = entity.column if q._group_by and q._distinct: @@ -295,7 +295,11 @@ def get_count(q): count_func = func.count() if q._group_by and not disable_group_by: count_func = count_func.over(None) - count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None) + count_q = ( + q.options(lazyload("*")) + .statement.with_only_columns([count_func]) + .order_by(None) + ) if disable_group_by: count_q = count_q.group_by(None) count = q.session.execute(count_q).scalar() @@ -311,13 +315,13 @@ def sort_and_page(query, model, args): :param args: :return: """ - sort_by = args.pop('sort_by') - sort_dir = args.pop('sort_dir') - page = args.pop('page') - count = args.pop('count') + sort_by = args.pop("sort_by") + sort_dir = args.pop("sort_dir") + page = args.pop("page") + count = args.pop("count") - if args.get('user'): - user = args.pop('user') + if args.get("user"): + user = args.pop("user") query = find_all(query, model, args) diff --git a/lemur/default.conf.py b/lemur/default.conf.py index 217d8371..bd67bf7a 100644 --- a/lemur/default.conf.py +++ b/lemur/default.conf.py @@ -1,6 +1,7 @@ # This is just Python which means you can inherit and tweak settings import os + _basedir = os.path.abspath(os.path.dirname(__file__)) THREADS_PER_PAGE = 8 diff --git a/lemur/defaults/views.py b/lemur/defaults/views.py index 5a573829..b3741b15 100644 --- a/lemur/defaults/views.py +++ b/lemur/defaults/views.py @@ -13,12 +13,13 @@ from lemur.auth.service import AuthenticatedResource from lemur.defaults.schemas import default_output_schema -mod = Blueprint('default', __name__) +mod = Blueprint("default", __name__) api = Api(mod) class LemurDefaults(AuthenticatedResource): """ Defines the 'defaults' endpoint """ + def __init__(self): super(LemurDefaults) @@ -59,17 +60,21 @@ class LemurDefaults(AuthenticatedResource): :statuscode 403: unauthenticated """ - default_authority = get_by_name(current_app.config.get('LEMUR_DEFAULT_AUTHORITY')) + default_authority = get_by_name( + current_app.config.get("LEMUR_DEFAULT_AUTHORITY") + ) return dict( - country=current_app.config.get('LEMUR_DEFAULT_COUNTRY'), - state=current_app.config.get('LEMUR_DEFAULT_STATE'), - location=current_app.config.get('LEMUR_DEFAULT_LOCATION'), - organization=current_app.config.get('LEMUR_DEFAULT_ORGANIZATION'), - organizational_unit=current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT'), - issuer_plugin=current_app.config.get('LEMUR_DEFAULT_ISSUER_PLUGIN'), + country=current_app.config.get("LEMUR_DEFAULT_COUNTRY"), + state=current_app.config.get("LEMUR_DEFAULT_STATE"), + location=current_app.config.get("LEMUR_DEFAULT_LOCATION"), + organization=current_app.config.get("LEMUR_DEFAULT_ORGANIZATION"), + organizational_unit=current_app.config.get( + "LEMUR_DEFAULT_ORGANIZATIONAL_UNIT" + ), + issuer_plugin=current_app.config.get("LEMUR_DEFAULT_ISSUER_PLUGIN"), authority=default_authority, ) -api.add_resource(LemurDefaults, '/defaults', endpoint='default') +api.add_resource(LemurDefaults, "/defaults", endpoint="default") diff --git a/lemur/destinations/models.py b/lemur/destinations/models.py index 192a5f5d..a2575378 100644 --- a/lemur/destinations/models.py +++ b/lemur/destinations/models.py @@ -13,7 +13,7 @@ from lemur.plugins.base import plugins class Destination(db.Model): - __tablename__ = 'destinations' + __tablename__ = "destinations" id = Column(Integer, primary_key=True) label = Column(String(32)) options = Column(JSONType) diff --git a/lemur/destinations/schemas.py b/lemur/destinations/schemas.py index 279889b4..cc46ecd4 100644 --- a/lemur/destinations/schemas.py +++ b/lemur/destinations/schemas.py @@ -30,7 +30,7 @@ class DestinationOutputSchema(LemurOutputSchema): @post_dump def fill_object(self, data): if data: - data['plugin']['pluginOptions'] = data['options'] + data["plugin"]["pluginOptions"] = data["options"] return data diff --git a/lemur/destinations/service.py b/lemur/destinations/service.py index ed6fcb0f..92162f4b 100644 --- a/lemur/destinations/service.py +++ b/lemur/destinations/service.py @@ -6,11 +6,13 @@ .. moduleauthor:: Kevin Glisson """ from sqlalchemy import func +from flask import current_app from lemur import database from lemur.models import certificate_destination_associations from lemur.destinations.models import Destination from lemur.certificates.models import Certificate +from lemur.sources.service import add_aws_destination_to_sources def create(label, plugin_name, options, description=None): @@ -24,10 +26,18 @@ def create(label, plugin_name, options, description=None): """ # remove any sub-plugin objects before try to save the json options for option in options: - if 'plugin' in option['type']: - del option['value']['plugin_object'] + if "plugin" in option["type"]: + del option["value"]["plugin_object"] + + destination = Destination( + label=label, options=options, plugin_name=plugin_name, description=description + ) + current_app.logger.info("Destination: %s created", label) + + # add the destination as source, to avoid new destinations that are not in source, as long as an AWS destination + if add_aws_destination_to_sources(destination): + current_app.logger.info("Source: %s created", label) - destination = Destination(label=label, options=options, plugin_name=plugin_name, description=description) return database.create(destination) @@ -77,7 +87,7 @@ def get_by_label(label): :param label: :return: """ - return database.get(Destination, label, field='label') + return database.get(Destination, label, field="label") def get_all(): @@ -91,17 +101,19 @@ def get_all(): def render(args): - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if certificate_id: - query = database.session_query(Destination).join(Certificate, Destination.certificate) + query = database.session_query(Destination).join( + Certificate, Destination.certificate + ) query = query.filter(Certificate.id == certificate_id) else: query = database.session_query(Destination) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Destination, terms) return database.sort_and_page(query, Destination, args) @@ -114,9 +126,15 @@ def stats(**kwargs): :param kwargs: :return: """ - items = database.db.session.query(Destination.label, func.count(certificate_destination_associations.c.certificate_id))\ - .join(certificate_destination_associations)\ - .group_by(Destination.label).all() + items = ( + database.db.session.query( + Destination.label, + func.count(certificate_destination_associations.c.certificate_id), + ) + .join(certificate_destination_associations) + .group_by(Destination.label) + .all() + ) keys = [] values = [] @@ -124,4 +142,4 @@ def stats(**kwargs): keys.append(key) values.append(count) - return {'labels': keys, 'values': values} + return {"labels": keys, "values": values} diff --git a/lemur/destinations/views.py b/lemur/destinations/views.py index 7084e8e9..0b0559fe 100644 --- a/lemur/destinations/views.py +++ b/lemur/destinations/views.py @@ -15,15 +15,20 @@ from lemur.auth.permissions import admin_permission from lemur.common.utils import paginated_parser from lemur.common.schema import validate_schema -from lemur.destinations.schemas import destinations_output_schema, destination_input_schema, destination_output_schema +from lemur.destinations.schemas import ( + destinations_output_schema, + destination_input_schema, + destination_output_schema, +) -mod = Blueprint('destinations', __name__) +mod = Blueprint("destinations", __name__) api = Api(mod) class DestinationsList(AuthenticatedResource): """ Defines the 'destinations' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(DestinationsList, self).__init__() @@ -176,7 +181,12 @@ class DestinationsList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.create(data['label'], data['plugin']['slug'], data['plugin']['plugin_options'], data['description']) + return service.create( + data["label"], + data["plugin"]["slug"], + data["plugin"]["plugin_options"], + data["description"], + ) class Destinations(AuthenticatedResource): @@ -325,16 +335,22 @@ class Destinations(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.update(destination_id, data['label'], data['plugin']['plugin_options'], data['description']) + return service.update( + destination_id, + data["label"], + data["plugin"]["plugin_options"], + data["description"], + ) @admin_permission.require(http_exception=403) def delete(self, destination_id): service.delete(destination_id) - return {'result': True} + return {"result": True} class CertificateDestinations(AuthenticatedResource): """ Defines the 'certificate/', endpoint='destination') -api.add_resource(CertificateDestinations, '/certificates//destinations', - endpoint='certificateDestinations') -api.add_resource(DestinationsStats, '/destinations/stats', endpoint='destinationStats') +api.add_resource(DestinationsList, "/destinations", endpoint="destinations") +api.add_resource( + Destinations, "/destinations/", endpoint="destination" +) +api.add_resource( + CertificateDestinations, + "/certificates//destinations", + endpoint="certificateDestinations", +) +api.add_resource(DestinationsStats, "/destinations/stats", endpoint="destinationStats") diff --git a/lemur/dns_providers/cli.py b/lemur/dns_providers/cli.py index 159bdaa0..72f9c874 100644 --- a/lemur/dns_providers/cli.py +++ b/lemur/dns_providers/cli.py @@ -5,7 +5,9 @@ from lemur.dns_providers.service import get_all_dns_providers, set_domains from lemur.extensions import metrics from lemur.plugins.base import plugins -manager = Manager(usage="Iterates through all DNS providers and sets DNS zones in the database.") +manager = Manager( + usage="Iterates through all DNS providers and sets DNS zones in the database." +) @manager.command @@ -27,5 +29,5 @@ def get_all_zones(): status = SUCCESS_METRIC_STATUS - metrics.send('get_all_zones', 'counter', 1, metric_tags={'status': status}) + metrics.send("get_all_zones", "counter", 1, metric_tags={"status": status}) print("[+] Done with dns provider zone lookup and configuration.") diff --git a/lemur/dns_providers/models.py b/lemur/dns_providers/models.py index 435a2398..eb8cdff9 100644 --- a/lemur/dns_providers/models.py +++ b/lemur/dns_providers/models.py @@ -9,22 +9,23 @@ from lemur.utils import Vault class DnsProvider(db.Model): - __tablename__ = 'dns_providers' - id = Column( - Integer(), - primary_key=True, - ) + __tablename__ = "dns_providers" + id = Column(Integer(), primary_key=True) name = Column(String(length=256), unique=True, nullable=True) description = Column(Text(), nullable=True) provider_type = Column(String(length=256), nullable=True) credentials = Column(Vault, nullable=True) api_endpoint = Column(String(length=256), nullable=True) - date_created = Column(ArrowType(), server_default=text('now()'), nullable=False) + date_created = Column(ArrowType(), server_default=text("now()"), nullable=False) status = Column(String(length=128), nullable=True) options = Column(JSON, nullable=True) domains = Column(JSON, nullable=True) - certificates = relationship("Certificate", backref='dns_provider', foreign_keys='Certificate.dns_provider_id', - lazy='dynamic') + certificates = relationship( + "Certificate", + backref="dns_provider", + foreign_keys="Certificate.dns_provider_id", + lazy="dynamic", + ) def __init__(self, name, description, provider_type, credentials): self.name = name diff --git a/lemur/dns_providers/service.py b/lemur/dns_providers/service.py index bf50bba1..29f98a5b 100644 --- a/lemur/dns_providers/service.py +++ b/lemur/dns_providers/service.py @@ -49,7 +49,9 @@ def get_friendly(dns_provider_id): } if dns_provider.provider_type == "route53": - dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get("account_id") + dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get( + "account_id" + ) return dns_provider_friendly @@ -64,40 +66,41 @@ def delete(dns_provider_id): def get_types(): provider_config = current_app.config.get( - 'ACME_DNS_PROVIDER_TYPES', - {"items": [ - { - 'name': 'route53', - 'requirements': [ - { - 'name': 'account_id', - 'type': 'int', - 'required': True, - 'helpMessage': 'AWS Account number' - }, - ] - }, - { - 'name': 'cloudflare', - 'requirements': [ - { - 'name': 'email', - 'type': 'str', - 'required': True, - 'helpMessage': 'Cloudflare Email' - }, - { - 'name': 'key', - 'type': 'str', - 'required': True, - 'helpMessage': 'Cloudflare Key' - }, - ] - }, - { - 'name': 'dyn', - }, - ]} + "ACME_DNS_PROVIDER_TYPES", + { + "items": [ + { + "name": "route53", + "requirements": [ + { + "name": "account_id", + "type": "int", + "required": True, + "helpMessage": "AWS Account number", + } + ], + }, + { + "name": "cloudflare", + "requirements": [ + { + "name": "email", + "type": "str", + "required": True, + "helpMessage": "Cloudflare Email", + }, + { + "name": "key", + "type": "str", + "required": True, + "helpMessage": "Cloudflare Key", + }, + ], + }, + {"name": "dyn"}, + {"name": "ultradns"}, + ] + }, ) if not provider_config: raise Exception("No DNS Provider configuration specified.") diff --git a/lemur/dns_providers/views.py b/lemur/dns_providers/views.py index 1f5b3164..d470aa2f 100644 --- a/lemur/dns_providers/views.py +++ b/lemur/dns_providers/views.py @@ -13,9 +13,12 @@ from lemur.auth.service import AuthenticatedResource from lemur.common.schema import validate_schema from lemur.common.utils import paginated_parser from lemur.dns_providers import service -from lemur.dns_providers.schemas import dns_provider_output_schema, dns_provider_input_schema +from lemur.dns_providers.schemas import ( + dns_provider_output_schema, + dns_provider_input_schema, +) -mod = Blueprint('dns_providers', __name__) +mod = Blueprint("dns_providers", __name__) api = Api(mod) @@ -71,12 +74,12 @@ class DnsProvidersList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('dns_provider_id', type=int, location='args') - parser.add_argument('name', type=str, location='args') - parser.add_argument('type', type=str, location='args') + parser.add_argument("dns_provider_id", type=int, location="args") + parser.add_argument("name", type=str, location="args") + parser.add_argument("type", type=str, location="args") args = parser.parse_args() - args['user'] = g.user + args["user"] = g.user return service.render(args) @validate_schema(dns_provider_input_schema, None) @@ -152,7 +155,7 @@ class DnsProviders(AuthenticatedResource): @admin_permission.require(http_exception=403) def delete(self, dns_provider_id): service.delete(dns_provider_id) - return {'result': True} + return {"result": True} class DnsProviderOptions(AuthenticatedResource): @@ -166,6 +169,10 @@ class DnsProviderOptions(AuthenticatedResource): return service.get_types() -api.add_resource(DnsProvidersList, '/dns_providers', endpoint='dns_providers') -api.add_resource(DnsProviders, '/dns_providers/', endpoint='dns_provider') -api.add_resource(DnsProviderOptions, '/dns_provider_options', endpoint='dns_provider_options') +api.add_resource(DnsProvidersList, "/dns_providers", endpoint="dns_providers") +api.add_resource( + DnsProviders, "/dns_providers/", endpoint="dns_provider" +) +api.add_resource( + DnsProviderOptions, "/dns_provider_options", endpoint="dns_provider_options" +) diff --git a/lemur/domains/models.py b/lemur/domains/models.py index 05fccd9c..791e74de 100644 --- a/lemur/domains/models.py +++ b/lemur/domains/models.py @@ -13,11 +13,14 @@ from lemur.database import db class Domain(db.Model): - __tablename__ = 'domains' + __tablename__ = "domains" __table_args__ = ( - Index('ix_domains_name_gin', "name", - postgresql_ops={"name": "gin_trgm_ops"}, - postgresql_using='gin'), + Index( + "ix_domains_name_gin", + "name", + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ), ) id = Column(Integer, primary_key=True) name = Column(String(256), index=True) diff --git a/lemur/domains/service.py b/lemur/domains/service.py index c9b8f759..1944d9db 100644 --- a/lemur/domains/service.py +++ b/lemur/domains/service.py @@ -6,10 +6,11 @@ .. moduleauthor:: Kevin Glisson """ -from lemur.domains.models import Domain -from lemur.certificates.models import Certificate +from sqlalchemy import and_ from lemur import database +from lemur.certificates.models import Certificate +from lemur.domains.models import Domain def get(domain_id): @@ -42,6 +43,20 @@ def get_by_name(name): return database.get_all(Domain, name, field="name").all() +def is_domain_sensitive(name): + """ + Return True if domain is marked sensitive + + :param name: + :return: + """ + query = database.session_query(Domain) + + query = query.filter(and_(Domain.sensitive, Domain.name == name)) + + return database.find_all(query, Domain, {}).all() + + def create(name, sensitive): """ Create a new domain @@ -77,11 +92,11 @@ def render(args): :return: """ query = database.session_query(Domain) - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Domain, terms) if certificate_id: diff --git a/lemur/domains/views.py b/lemur/domains/views.py index db73f5cd..a3e0cdff 100644 --- a/lemur/domains/views.py +++ b/lemur/domains/views.py @@ -17,14 +17,19 @@ from lemur.auth.permissions import SensitiveDomainPermission from lemur.common.schema import validate_schema from lemur.common.utils import paginated_parser -from lemur.domains.schemas import domain_input_schema, domain_output_schema, domains_output_schema +from lemur.domains.schemas import ( + domain_input_schema, + domain_output_schema, + domains_output_schema, +) -mod = Blueprint('domains', __name__) +mod = Blueprint("domains", __name__) api = Api(mod) class DomainsList(AuthenticatedResource): """ Defines the 'domains' endpoint """ + def __init__(self): super(DomainsList, self).__init__() @@ -123,7 +128,7 @@ class DomainsList(AuthenticatedResource): :statuscode 200: no error :statuscode 403: unauthenticated """ - return service.create(data['name'], data['sensitive']) + return service.create(data["name"], data["sensitive"]) class Domains(AuthenticatedResource): @@ -205,13 +210,14 @@ class Domains(AuthenticatedResource): :statuscode 403: unauthenticated """ if SensitiveDomainPermission().can(): - return service.update(domain_id, data['name'], data['sensitive']) + return service.update(domain_id, data["name"], data["sensitive"]) - return dict(message='You are not authorized to modify this domain'), 403 + return dict(message="You are not authorized to modify this domain"), 403 class CertificateDomains(AuthenticatedResource): """ Defines the 'domains' endpoint """ + def __init__(self): super(CertificateDomains, self).__init__() @@ -265,10 +271,14 @@ class CertificateDomains(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['certificate_id'] = certificate_id + args["certificate_id"] = certificate_id return service.render(args) -api.add_resource(DomainsList, '/domains', endpoint='domains') -api.add_resource(Domains, '/domains/', endpoint='domain') -api.add_resource(CertificateDomains, '/certificates//domains', endpoint='certificateDomains') +api.add_resource(DomainsList, "/domains", endpoint="domains") +api.add_resource(Domains, "/domains/", endpoint="domain") +api.add_resource( + CertificateDomains, + "/certificates//domains", + endpoint="certificateDomains", +) diff --git a/lemur/endpoints/cli.py b/lemur/endpoints/cli.py index 59496930..99f8c342 100644 --- a/lemur/endpoints/cli.py +++ b/lemur/endpoints/cli.py @@ -21,7 +21,14 @@ from lemur.endpoints.models import Endpoint manager = Manager(usage="Handles all endpoint related tasks.") -@manager.option('-ttl', '--time-to-live', type=int, dest='ttl', default=2, help='Time in hours, which endpoint has not been refreshed to remove the endpoint.') +@manager.option( + "-ttl", + "--time-to-live", + type=int, + dest="ttl", + default=2, + help="Time in hours, which endpoint has not been refreshed to remove the endpoint.", +) def expire(ttl): """ Removed all endpoints that have not been recently updated. @@ -31,12 +38,18 @@ def expire(ttl): try: now = arrow.utcnow() expiration = now - timedelta(hours=ttl) - endpoints = database.session_query(Endpoint).filter(cast(Endpoint.last_updated, ArrowType) <= expiration) + endpoints = database.session_query(Endpoint).filter( + cast(Endpoint.last_updated, ArrowType) <= expiration + ) for endpoint in endpoints: - print("[!] Expiring endpoint: {name} Last Updated: {last_updated}".format(name=endpoint.name, last_updated=endpoint.last_updated)) + print( + "[!] Expiring endpoint: {name} Last Updated: {last_updated}".format( + name=endpoint.name, last_updated=endpoint.last_updated + ) + ) database.delete(endpoint) - metrics.send('endpoint_expired', 'counter', 1) + metrics.send("endpoint_expired", "counter", 1) print("[+] Finished expiration.") except Exception as e: diff --git a/lemur/endpoints/models.py b/lemur/endpoints/models.py index b5823327..6e44fe71 100644 --- a/lemur/endpoints/models.py +++ b/lemur/endpoints/models.py @@ -20,15 +20,11 @@ from lemur.database import db from lemur.models import policies_ciphers -BAD_CIPHERS = [ - 'Protocol-SSLv3', - 'Protocol-SSLv2', - 'Protocol-TLSv1' -] +BAD_CIPHERS = ["Protocol-SSLv3", "Protocol-SSLv2", "Protocol-TLSv1"] class Cipher(db.Model): - __tablename__ = 'ciphers' + __tablename__ = "ciphers" id = Column(Integer, primary_key=True) name = Column(String(128), nullable=False) @@ -38,23 +34,18 @@ class Cipher(db.Model): @deprecated.expression def deprecated(cls): - return case( - [ - (cls.name in BAD_CIPHERS, True) - ], - else_=False - ) + return case([(cls.name in BAD_CIPHERS, True)], else_=False) class Policy(db.Model): - ___tablename__ = 'policies' + ___tablename__ = "policies" id = Column(Integer, primary_key=True) name = Column(String(128), nullable=True) - ciphers = relationship('Cipher', secondary=policies_ciphers, backref='policy') + ciphers = relationship("Cipher", secondary=policies_ciphers, backref="policy") class Endpoint(db.Model): - __tablename__ = 'endpoints' + __tablename__ = "endpoints" id = Column(Integer, primary_key=True) owner = Column(String(128)) name = Column(String(128)) @@ -62,16 +53,18 @@ class Endpoint(db.Model): type = Column(String(128)) active = Column(Boolean, default=True) port = Column(Integer) - policy_id = Column(Integer, ForeignKey('policy.id')) - policy = relationship('Policy', backref='endpoint') - certificate_id = Column(Integer, ForeignKey('certificates.id')) - source_id = Column(Integer, ForeignKey('sources.id')) + policy_id = Column(Integer, ForeignKey("policy.id")) + policy = relationship("Policy", backref="endpoint") + certificate_id = Column(Integer, ForeignKey("certificates.id")) + source_id = Column(Integer, ForeignKey("sources.id")) sensitive = Column(Boolean, default=False) - source = relationship('Source', back_populates='endpoints') + source = relationship("Source", back_populates="endpoints") last_updated = Column(ArrowType, default=arrow.utcnow, nullable=False) - date_created = Column(ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False) + date_created = Column( + ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False + ) - replaced = association_proxy('certificate', 'replaced') + replaced = association_proxy("certificate", "replaced") @property def issues(self): @@ -79,13 +72,30 @@ class Endpoint(db.Model): for cipher in self.policy.ciphers: if cipher.deprecated: - issues.append({'name': 'deprecated cipher', 'value': '{0} has been deprecated consider removing it.'.format(cipher.name)}) + issues.append( + { + "name": "deprecated cipher", + "value": "{0} has been deprecated consider removing it.".format( + cipher.name + ), + } + ) if self.certificate.expired: - issues.append({'name': 'expired certificate', 'value': 'There is an expired certificate attached to this endpoint consider replacing it.'}) + issues.append( + { + "name": "expired certificate", + "value": "There is an expired certificate attached to this endpoint consider replacing it.", + } + ) if self.certificate.revoked: - issues.append({'name': 'revoked', 'value': 'There is a revoked certificate attached to this endpoint consider replacing it.'}) + issues.append( + { + "name": "revoked", + "value": "There is a revoked certificate attached to this endpoint consider replacing it.", + } + ) return issues diff --git a/lemur/endpoints/service.py b/lemur/endpoints/service.py index d14174df..2a737858 100644 --- a/lemur/endpoints/service.py +++ b/lemur/endpoints/service.py @@ -46,7 +46,7 @@ def get_by_name(name): :param name: :return: """ - return database.get(Endpoint, name, field='name') + return database.get(Endpoint, name, field="name") def get_by_dnsname(dnsname): @@ -56,7 +56,7 @@ def get_by_dnsname(dnsname): :param dnsname: :return: """ - return database.get(Endpoint, dnsname, field='dnsname') + return database.get(Endpoint, dnsname, field="dnsname") def get_by_dnsname_and_port(dnsname, port): @@ -66,7 +66,11 @@ def get_by_dnsname_and_port(dnsname, port): :param port: :return: """ - return Endpoint.query.filter(Endpoint.dnsname == dnsname).filter(Endpoint.port == port).scalar() + return ( + Endpoint.query.filter(Endpoint.dnsname == dnsname) + .filter(Endpoint.port == port) + .scalar() + ) def get_by_source(source_label): @@ -95,12 +99,14 @@ def create(**kwargs): """ endpoint = Endpoint(**kwargs) database.create(endpoint) - metrics.send('endpoint_added', 'counter', 1, metric_tags={'source': endpoint.source.label}) + metrics.send( + "endpoint_added", "counter", 1, metric_tags={"source": endpoint.source.label} + ) return endpoint def get_or_create_policy(**kwargs): - policy = database.get(Policy, kwargs['name'], field='name') + policy = database.get(Policy, kwargs["name"], field="name") if not policy: policy = Policy(**kwargs) @@ -110,7 +116,7 @@ def get_or_create_policy(**kwargs): def get_or_create_cipher(**kwargs): - cipher = database.get(Cipher, kwargs['name'], field='name') + cipher = database.get(Cipher, kwargs["name"], field="name") if not cipher: cipher = Cipher(**kwargs) @@ -122,11 +128,13 @@ def get_or_create_cipher(**kwargs): def update(endpoint_id, **kwargs): endpoint = database.get(Endpoint, endpoint_id) - endpoint.policy = kwargs['policy'] - endpoint.certificate = kwargs['certificate'] - endpoint.source = kwargs['source'] + endpoint.policy = kwargs["policy"] + endpoint.certificate = kwargs["certificate"] + endpoint.source = kwargs["source"] endpoint.last_updated = arrow.utcnow() - metrics.send('endpoint_updated', 'counter', 1, metric_tags={'source': endpoint.source.label}) + metrics.send( + "endpoint_updated", "counter", 1, metric_tags={"source": endpoint.source.label} + ) database.update(endpoint) return endpoint @@ -138,19 +146,17 @@ def render(args): :return: """ query = database.session_query(Endpoint) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') - if 'active' in filt: # this is really weird but strcmp seems to not work here?? + terms = filt.split(";") + if "active" in filt: # this is really weird but strcmp seems to not work here?? query = query.filter(Endpoint.active == truthiness(terms[1])) - elif 'port' in filt: - if terms[1] != 'null': # ng-table adds 'null' if a number is removed + elif "port" in filt: + if terms[1] != "null": # ng-table adds 'null' if a number is removed query = query.filter(Endpoint.port == terms[1]) - elif 'ciphers' in filt: - query = query.filter( - Cipher.name == terms[1] - ) + elif "ciphers" in filt: + query = query.filter(Cipher.name == terms[1]) else: query = database.filter(query, Endpoint, terms) @@ -164,7 +170,7 @@ def stats(**kwargs): :param kwargs: :return: """ - attr = getattr(Endpoint, kwargs.get('metric')) + attr = getattr(Endpoint, kwargs.get("metric")) query = database.db.session.query(attr, func.count(attr)) items = query.group_by(attr).all() @@ -175,4 +181,4 @@ def stats(**kwargs): keys.append(key) values.append(count) - return {'labels': keys, 'values': values} + return {"labels": keys, "values": values} diff --git a/lemur/endpoints/views.py b/lemur/endpoints/views.py index 6509f056..9f469a6b 100644 --- a/lemur/endpoints/views.py +++ b/lemur/endpoints/views.py @@ -16,12 +16,13 @@ from lemur.endpoints import service from lemur.endpoints.schemas import endpoint_output_schema, endpoints_output_schema -mod = Blueprint('endpoints', __name__) +mod = Blueprint("endpoints", __name__) api = Api(mod) class EndpointsList(AuthenticatedResource): """ Defines the 'endpoints' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(EndpointsList, self).__init__() @@ -63,7 +64,7 @@ class EndpointsList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['user'] = g.current_user + args["user"] = g.current_user return service.render(args) @@ -103,5 +104,5 @@ class Endpoints(AuthenticatedResource): return service.get(endpoint_id) -api.add_resource(EndpointsList, '/endpoints', endpoint='endpoints') -api.add_resource(Endpoints, '/endpoints/', endpoint='endpoint') +api.add_resource(EndpointsList, "/endpoints", endpoint="endpoints") +api.add_resource(Endpoints, "/endpoints/", endpoint="endpoint") diff --git a/lemur/exceptions.py b/lemur/exceptions.py index d392fe5d..98e216bb 100644 --- a/lemur/exceptions.py +++ b/lemur/exceptions.py @@ -21,7 +21,9 @@ class DuplicateError(LemurException): class InvalidListener(LemurException): def __str__(self): - return repr("Invalid listener, ensure you select a certificate if you are using a secure protocol") + return repr( + "Invalid listener, ensure you select a certificate if you are using a secure protocol" + ) class AttrNotFound(LemurException): diff --git a/lemur/extensions.py b/lemur/extensions.py index a54df6c7..24c4c814 100644 --- a/lemur/extensions.py +++ b/lemur/extensions.py @@ -15,25 +15,33 @@ class SQLAlchemy(SA): db = SQLAlchemy() from flask_migrate import Migrate + migrate = Migrate() from flask_bcrypt import Bcrypt + bcrypt = Bcrypt() from flask_principal import Principal + principal = Principal(use_sessions=False) from flask_mail import Mail + smtp_mail = Mail() from lemur.metrics import Metrics + metrics = Metrics() from raven.contrib.flask import Sentry + sentry = Sentry() from blinker import Namespace + signals = Namespace() from flask_cors import CORS + cors = CORS() diff --git a/lemur/factory.py b/lemur/factory.py index c2719e9b..0563d873 100644 --- a/lemur/factory.py +++ b/lemur/factory.py @@ -13,20 +13,21 @@ import os import imp import errno import pkg_resources +import socket from logging import Formatter, StreamHandler from logging.handlers import RotatingFileHandler from flask import Flask +from flask_replicated import FlaskReplicated +import logmatic from lemur.certificates.hooks import activate_debug_dump from lemur.common.health import mod as health from lemur.extensions import db, migrate, principal, smtp_mail, metrics, sentry, cors -DEFAULT_BLUEPRINTS = ( - health, -) +DEFAULT_BLUEPRINTS = (health,) API_VERSION = 1 @@ -53,6 +54,7 @@ def create_app(app_name=None, blueprints=None, config=None): configure_blueprints(app, blueprints) configure_extensions(app) configure_logging(app) + configure_database(app) install_plugins(app) @app.teardown_appcontext @@ -71,16 +73,17 @@ def from_file(file_path, silent=False): :param file_path: :param silent: """ - d = imp.new_module('config') + d = imp.new_module("config") d.__file__ = file_path try: with open(file_path) as config_file: - exec(compile(config_file.read(), # nosec: config file safe - file_path, 'exec'), d.__dict__) + exec( # nosec: config file safe + compile(config_file.read(), file_path, "exec"), d.__dict__ + ) except IOError as e: if silent and e.errno in (errno.ENOENT, errno.EISDIR): return False - e.strerror = 'Unable to load configuration file (%s)' % e.strerror + e.strerror = "Unable to load configuration file (%s)" % e.strerror raise return d @@ -94,8 +97,8 @@ def configure_app(app, config=None): :return: """ # respect the config first - if config and config != 'None': - app.config['CONFIG_PATH'] = config + if config and config != "None": + app.config["CONFIG_PATH"] = config app.config.from_object(from_file(config)) else: try: @@ -103,12 +106,21 @@ def configure_app(app, config=None): except RuntimeError: # look in default paths if os.path.isfile(os.path.expanduser("~/.lemur/lemur.conf.py")): - app.config.from_object(from_file(os.path.expanduser("~/.lemur/lemur.conf.py"))) + app.config.from_object( + from_file(os.path.expanduser("~/.lemur/lemur.conf.py")) + ) else: - app.config.from_object(from_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'default.conf.py'))) + app.config.from_object( + from_file( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "default.conf.py", + ) + ) + ) # we don't use this - app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False def configure_extensions(app): @@ -125,9 +137,15 @@ def configure_extensions(app): metrics.init_app(app) sentry.init_app(app) - if app.config['CORS']: - app.config['CORS_HEADERS'] = 'Content-Type' - cors.init_app(app, resources=r'/api/*', headers='Content-Type', origin='*', supports_credentials=True) + if app.config["CORS"]: + app.config["CORS_HEADERS"] = "Content-Type" + cors.init_app( + app, + resources=r"/api/*", + headers="Content-Type", + origin="*", + supports_credentials=True, + ) def configure_blueprints(app, blueprints): @@ -142,28 +160,41 @@ def configure_blueprints(app, blueprints): app.register_blueprint(blueprint, url_prefix="/api/{0}".format(API_VERSION)) +def configure_database(app): + if app.config.get("SQLALCHEMY_ENABLE_FLASK_REPLICATED"): + FlaskReplicated(app) + + def configure_logging(app): """ Sets up application wide logging. :param app: """ - handler = RotatingFileHandler(app.config.get('LOG_FILE', 'lemur.log'), maxBytes=10000000, backupCount=100) + handler = RotatingFileHandler( + app.config.get("LOG_FILE", "lemur.log"), maxBytes=10000000, backupCount=100 + ) - handler.setFormatter(Formatter( - '%(asctime)s %(levelname)s: %(message)s ' - '[in %(pathname)s:%(lineno)d]' - )) + handler.setFormatter( + Formatter( + "%(asctime)s %(levelname)s: %(message)s " "[in %(pathname)s:%(lineno)d]" + ) + ) - handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) - app.logger.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) + if app.config.get("LOG_JSON", False): + handler.setFormatter( + logmatic.JsonFormatter(extra={"hostname": socket.gethostname()}) + ) + + handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG")) + app.logger.setLevel(app.config.get("LOG_LEVEL", "DEBUG")) app.logger.addHandler(handler) stream_handler = StreamHandler() - stream_handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) + stream_handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG")) app.logger.addHandler(stream_handler) - if app.config.get('DEBUG_DUMP', False): + if app.config.get("DEBUG_DUMP", False): activate_debug_dump() @@ -176,17 +207,21 @@ def install_plugins(app): """ from lemur.plugins import plugins from lemur.plugins.base import register + # entry_points={ # 'lemur.plugins': [ # 'verisign = lemur_verisign.plugin:VerisignPlugin' # ], # }, - for ep in pkg_resources.iter_entry_points('lemur.plugins'): + for ep in pkg_resources.iter_entry_points("lemur.plugins"): try: plugin = ep.load() except Exception: import traceback - app.logger.error("Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc())) + + app.logger.error( + "Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc()) + ) else: register(plugin) @@ -196,6 +231,9 @@ def install_plugins(app): try: plugins.get(slug) except KeyError: - raise Exception("Unable to location notification plugin: {slug}. Ensure that " - "LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin." - .format(slug=slug)) + raise Exception( + "Unable to location notification plugin: {slug}. Ensure that " + "LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin.".format( + slug=slug + ) + ) diff --git a/lemur/logs/models.py b/lemur/logs/models.py index d4239e59..07a2ded3 100644 --- a/lemur/logs/models.py +++ b/lemur/logs/models.py @@ -15,9 +15,19 @@ from lemur.database import db class Log(db.Model): - __tablename__ = 'logs' + __tablename__ = "logs" id = Column(Integer, primary_key=True) - certificate_id = Column(Integer, ForeignKey('certificates.id')) - log_type = Column(Enum('key_view', 'create_cert', 'update_cert', 'revoke_cert', name='log_type'), nullable=False) + certificate_id = Column(Integer, ForeignKey("certificates.id")) + log_type = Column( + Enum( + "key_view", + "create_cert", + "update_cert", + "revoke_cert", + "delete_cert", + name="log_type", + ), + nullable=False, + ) logged_at = Column(ArrowType(), PassiveDefault(func.now()), nullable=False) - user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) diff --git a/lemur/logs/service.py b/lemur/logs/service.py index 04355938..f4949911 100644 --- a/lemur/logs/service.py +++ b/lemur/logs/service.py @@ -24,7 +24,11 @@ def create(user, type, certificate=None): :param certificate: :return: """ - current_app.logger.info("[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format(type, user.email, certificate.name)) + current_app.logger.info( + "[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format( + type, user.email, certificate.name + ) + ) view = Log(user_id=user.id, log_type=type, certificate_id=certificate.id) database.add(view) database.commit() @@ -50,20 +54,22 @@ def render(args): """ query = database.session_query(Log) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') + terms = filt.split(";") - if 'certificate.name' in terms: - sub_query = database.session_query(Certificate.id)\ - .filter(Certificate.name.ilike('%{0}%'.format(terms[1]))) + if "certificate.name" in terms: + sub_query = database.session_query(Certificate.id).filter( + Certificate.name.ilike("%{0}%".format(terms[1])) + ) query = query.filter(Log.certificate_id.in_(sub_query)) - elif 'user.email' in terms: - sub_query = database.session_query(User.id)\ - .filter(User.email.ilike('%{0}%'.format(terms[1]))) + elif "user.email" in terms: + sub_query = database.session_query(User.id).filter( + User.email.ilike("%{0}%".format(terms[1])) + ) query = query.filter(Log.user_id.in_(sub_query)) diff --git a/lemur/logs/views.py b/lemur/logs/views.py index 1e0bd184..57c588ed 100644 --- a/lemur/logs/views.py +++ b/lemur/logs/views.py @@ -17,12 +17,13 @@ from lemur.logs.schemas import logs_output_schema from lemur.logs import service -mod = Blueprint('logs', __name__) +mod = Blueprint("logs", __name__) api = Api(mod) class LogsList(AuthenticatedResource): """ Defines the 'logs' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(LogsList, self).__init__() @@ -65,10 +66,10 @@ class LogsList(AuthenticatedResource): :statuscode 200: no error """ parser = paginated_parser.copy() - parser.add_argument('owner', type=str, location='args') - parser.add_argument('id', type=str, location='args') + parser.add_argument("owner", type=str, location="args") + parser.add_argument("id", type=str, location="args") args = parser.parse_args() return service.render(args) -api.add_resource(LogsList, '/logs', endpoint='logs') +api.add_resource(LogsList, "/logs", endpoint="logs") diff --git a/lemur/manage.py b/lemur/manage.py index b972e8a5..7dd3b3b4 100755 --- a/lemur/manage.py +++ b/lemur/manage.py @@ -1,4 +1,5 @@ -from __future__ import unicode_literals # at top of module +#!/usr/bin/env python +from __future__ import unicode_literals # at top of module import os import sys @@ -49,25 +50,27 @@ from lemur.policies.models import RotationPolicy # noqa from lemur.pending_certificates.models import PendingCertificate # noqa from lemur.dns_providers.models import DnsProvider # noqa +from sqlalchemy.sql import text + manager = Manager(create_app) -manager.add_option('-c', '--config', dest='config') +manager.add_option("-c", "--config", dest="config_path", required=False) migrate = Migrate(create_app) REQUIRED_VARIABLES = [ - 'LEMUR_SECURITY_TEAM_EMAIL', - 'LEMUR_DEFAULT_ORGANIZATIONAL_UNIT', - 'LEMUR_DEFAULT_ORGANIZATION', - 'LEMUR_DEFAULT_LOCATION', - 'LEMUR_DEFAULT_COUNTRY', - 'LEMUR_DEFAULT_STATE', - 'SQLALCHEMY_DATABASE_URI' + "LEMUR_SECURITY_TEAM_EMAIL", + "LEMUR_DEFAULT_ORGANIZATIONAL_UNIT", + "LEMUR_DEFAULT_ORGANIZATION", + "LEMUR_DEFAULT_LOCATION", + "LEMUR_DEFAULT_COUNTRY", + "LEMUR_DEFAULT_STATE", + "SQLALCHEMY_DATABASE_URI", ] KEY_LENGTH = 40 -DEFAULT_CONFIG_PATH = '~/.lemur/lemur.conf.py' -DEFAULT_SETTINGS = 'lemur.conf.server' -SETTINGS_ENVVAR = 'LEMUR_CONF' +DEFAULT_CONFIG_PATH = "~/.lemur/lemur.conf.py" +DEFAULT_SETTINGS = "lemur.conf.server" +SETTINGS_ENVVAR = "LEMUR_CONF" CONFIG_TEMPLATE = """ # This is just Python which means you can inherit and tweak settings @@ -142,8 +145,9 @@ SQLALCHEMY_DATABASE_URI = 'postgresql://lemur:lemur@localhost:5432/lemur' @MigrateCommand.command def create(): + database.db.engine.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) database.db.create_all() - stamp(revision='head') + stamp(revision="head") @MigrateCommand.command @@ -171,9 +175,9 @@ def generate_settings(): output = CONFIG_TEMPLATE.format( # we use Fernet.generate_key to make sure that the key length is # compatible with Fernet - encryption_key=Fernet.generate_key().decode('utf-8'), - secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'), - flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'), + encryption_key=Fernet.generate_key().decode("utf-8"), + secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"), + flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"), ) return output @@ -187,39 +191,44 @@ class InitializeApp(Command): Additionally a Lemur user will be created as a default user and be used when certificates are discovered by Lemur. """ - option_list = ( - Option('-p', '--password', dest='password'), - ) + + option_list = (Option("-p", "--password", dest="password"),) def run(self, password): create() user = user_service.get_by_username("lemur") - admin_role = role_service.get_by_name('admin') + admin_role = role_service.get_by_name("admin") if admin_role: sys.stdout.write("[-] Admin role already created, skipping...!\n") else: # we create an admin role - admin_role = role_service.create('admin', description='This is the Lemur administrator role.') + admin_role = role_service.create( + "admin", description="This is the Lemur administrator role." + ) sys.stdout.write("[+] Created 'admin' role\n") - operator_role = role_service.get_by_name('operator') + operator_role = role_service.get_by_name("operator") if operator_role: sys.stdout.write("[-] Operator role already created, skipping...!\n") else: # we create an operator role - operator_role = role_service.create('operator', description='This is the Lemur operator role.') + operator_role = role_service.create( + "operator", description="This is the Lemur operator role." + ) sys.stdout.write("[+] Created 'operator' role\n") - read_only_role = role_service.get_by_name('read-only') + read_only_role = role_service.get_by_name("read-only") if read_only_role: sys.stdout.write("[-] Read only role already created, skipping...!\n") else: # we create an read only role - read_only_role = role_service.create('read-only', description='This is the Lemur read only role.') + read_only_role = role_service.create( + "read-only", description="This is the Lemur read only role." + ) sys.stdout.write("[+] Created 'read-only' role\n") if not user: @@ -232,34 +241,54 @@ class InitializeApp(Command): sys.stderr.write("[!] Passwords do not match!\n") sys.exit(1) - user_service.create("lemur", password, 'lemur@nobody.com', True, None, [admin_role]) - sys.stdout.write("[+] Created the user 'lemur' and granted it the 'admin' role!\n") + user_service.create( + "lemur", password, "lemur@nobody.com", True, None, [admin_role] + ) + sys.stdout.write( + "[+] Created the user 'lemur' and granted it the 'admin' role!\n" + ) else: - sys.stdout.write("[-] Default user has already been created, skipping...!\n") + sys.stdout.write( + "[-] Default user has already been created, skipping...!\n" + ) - intervals = current_app.config.get("LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", []) + intervals = current_app.config.get( + "LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", [] + ) sys.stdout.write( "[!] Creating {num} notifications for {intervals} days as specified by LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS\n".format( - num=len(intervals), - intervals=",".join([str(x) for x in intervals]) + num=len(intervals), intervals=",".join([str(x) for x in intervals]) ) ) - recipients = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') + recipients = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL") sys.stdout.write("[+] Creating expiration email notifications!\n") - sys.stdout.write("[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format(recipients)) - notification_service.create_default_expiration_notifications("DEFAULT_SECURITY", recipients=recipients) + sys.stdout.write( + "[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format( + recipients + ) + ) + notification_service.create_default_expiration_notifications( + "DEFAULT_SECURITY", recipients=recipients + ) - _DEFAULT_ROTATION_INTERVAL = 'default' - default_rotation_interval = policy_service.get_by_name(_DEFAULT_ROTATION_INTERVAL) + _DEFAULT_ROTATION_INTERVAL = "default" + default_rotation_interval = policy_service.get_by_name( + _DEFAULT_ROTATION_INTERVAL + ) if default_rotation_interval: - sys.stdout.write("[-] Default rotation interval policy already created, skipping...!\n") + sys.stdout.write( + "[-] Default rotation interval policy already created, skipping...!\n" + ) else: days = current_app.config.get("LEMUR_DEFAULT_ROTATION_INTERVAL", 30) - sys.stdout.write("[+] Creating default certificate rotation policy of {days} days before issuance.\n".format( - days=days)) + sys.stdout.write( + "[+] Creating default certificate rotation policy of {days} days before issuance.\n".format( + days=days + ) + ) policy_service.create(days=days, name=_DEFAULT_ROTATION_INTERVAL) sys.stdout.write("[/] Done!\n") @@ -269,14 +298,16 @@ class CreateUser(Command): """ This command allows for the creation of a new user within Lemur. """ + option_list = ( - Option('-u', '--username', dest='username', required=True), - Option('-e', '--email', dest='email', required=True), - Option('-a', '--active', dest='active', default=True), - Option('-r', '--roles', dest='roles', action='append', default=[]) + Option("-u", "--username", dest="username", required=True), + Option("-e", "--email", dest="email", required=True), + Option("-a", "--active", dest="active", default=True), + Option("-r", "--roles", dest="roles", action="append", default=[]), + Option("-p", "--password", dest="password", default=None), ) - def run(self, username, email, active, roles): + def run(self, username, email, active, roles, password): role_objs = [] for r in roles: role_obj = role_service.get_by_name(r) @@ -286,14 +317,16 @@ class CreateUser(Command): sys.stderr.write("[!] Cannot find role {0}\n".format(r)) sys.exit(1) - password1 = prompt_pass("Password") - password2 = prompt_pass("Confirm Password") + if not password: + password1 = prompt_pass("Password") + password2 = prompt_pass("Confirm Password") + password = password1 - if password1 != password2: - sys.stderr.write("[!] Passwords do not match!\n") - sys.exit(1) + if password1 != password2: + sys.stderr.write("[!] Passwords do not match!\n") + sys.exit(1) - user_service.create(username, password1, email, active, None, role_objs) + user_service.create(username, password, email, active, None, role_objs) sys.stdout.write("[+] Created new user: {0}\n".format(username)) @@ -301,9 +334,8 @@ class ResetPassword(Command): """ This command allows you to reset a user's password. """ - option_list = ( - Option('-u', '--username', dest='username', required=True), - ) + + option_list = (Option("-u", "--username", dest="username", required=True),) def run(self, username): user = user_service.get_by_username(username) @@ -329,10 +361,11 @@ class CreateRole(Command): """ This command allows for the creation of a new role within Lemur """ + option_list = ( - Option('-n', '--name', dest='name', required=True), - Option('-u', '--users', dest='users', default=[]), - Option('-d', '--description', dest='description', required=True) + Option("-n", "--name", dest="name", required=True), + Option("-u", "--users", dest="users", default=[]), + Option("-d", "--description", dest="description", required=True), ) def run(self, name, users, description): @@ -363,7 +396,8 @@ class LemurServer(Command): Will start gunicorn with 4 workers bound to 127.0.0.0:8002 """ - description = 'Run the app within Gunicorn' + + description = "Run the app within Gunicorn" def get_options(self): settings = make_settings() @@ -371,8 +405,10 @@ class LemurServer(Command): for setting, klass in settings.items(): if klass.cli: if klass.action: - if klass.action == 'store_const': - options.append(Option(*klass.cli, const=klass.const, action=klass.action)) + if klass.action == "store_const": + options.append( + Option(*klass.cli, const=klass.const, action=klass.action) + ) else: options.append(Option(*klass.cli, action=klass.action)) else: @@ -388,7 +424,9 @@ class LemurServer(Command): # run startup tasks on an app like object validate_conf(current_app, REQUIRED_VARIABLES) - app.app_uri = 'lemur:create_app(config="{0}")'.format(current_app.config.get('CONFIG_PATH')) + app.app_uri = 'lemur:create_app(config_path="{0}")'.format( + current_app.config.get("CONFIG_PATH") + ) return app.run() @@ -408,7 +446,7 @@ def create_config(config_path=None): os.makedirs(dir) config = generate_settings() - with open(config_path, 'w') as f: + with open(config_path, "w") as f: f.write(config) sys.stdout.write("[+] Created a new configuration file {0}\n".format(config_path)) @@ -430,7 +468,7 @@ def lock(path=None): :param: path """ if not path: - path = os.path.expanduser('~/.lemur/keys') + path = os.path.expanduser("~/.lemur/keys") dest_dir = os.path.join(path, "encrypted") sys.stdout.write("[!] Generating a new key...\n") @@ -441,15 +479,17 @@ def lock(path=None): sys.stdout.write("[+] Creating encryption directory: {0}\n".format(dest_dir)) os.makedirs(dest_dir) - for root, dirs, files in os.walk(os.path.join(path, 'decrypted')): + for root, dirs, files in os.walk(os.path.join(path, "decrypted")): for f in files: source = os.path.join(root, f) dest = os.path.join(dest_dir, f + ".enc") - with open(source, 'rb') as in_file, open(dest, 'wb') as out_file: + with open(source, "rb") as in_file, open(dest, "wb") as out_file: f = Fernet(key) data = f.encrypt(in_file.read()) out_file.write(data) - sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source)) + sys.stdout.write( + "[+] Writing file: {0} Source: {1}\n".format(dest, source) + ) sys.stdout.write("[+] Keys have been encrypted with key {0}\n".format(key)) @@ -469,7 +509,7 @@ def unlock(path=None): key = prompt_pass("[!] Please enter the encryption password") if not path: - path = os.path.expanduser('~/.lemur/keys') + path = os.path.expanduser("~/.lemur/keys") dest_dir = os.path.join(path, "decrypted") source_dir = os.path.join(path, "encrypted") @@ -482,11 +522,13 @@ def unlock(path=None): for f in files: source = os.path.join(source_dir, f) dest = os.path.join(dest_dir, ".".join(f.split(".")[:-1])) - with open(source, 'rb') as in_file, open(dest, 'wb') as out_file: + with open(source, "rb") as in_file, open(dest, "wb") as out_file: f = Fernet(key) data = f.decrypt(in_file.read()) out_file.write(data) - sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source)) + sys.stdout.write( + "[+] Writing file: {0} Source: {1}\n".format(dest, source) + ) sys.stdout.write("[+] Keys have been unencrypted!\n") @@ -499,15 +541,16 @@ def publish_verisign_units(): :return: """ from lemur.plugins import plugins - v = plugins.get('verisign-issuer') + + v = plugins.get("verisign-issuer") units = v.get_available_units() metrics = {} for item in units: - if item['@type'] in metrics.keys(): - metrics[item['@type']] += int(item['@remaining']) + if item["@type"] in metrics.keys(): + metrics[item["@type"]] += int(item["@remaining"]) else: - metrics.update({item['@type']: int(item['@remaining'])}) + metrics.update({item["@type"]: int(item["@remaining"])}) for name, value in metrics.items(): metric = [ @@ -516,16 +559,16 @@ def publish_verisign_units(): "type": "GAUGE", "name": "Symantec {0} Unit Count".format(name), "tags": {}, - "value": value + "value": value, } ] - requests.post('http://localhost:8078/metrics', data=json.dumps(metric)) + requests.post("http://localhost:8078/metrics", data=json.dumps(metric)) def main(): manager.add_command("start", LemurServer()) - manager.add_command("runserver", Server(host='127.0.0.1', threaded=True)) + manager.add_command("runserver", Server(host="127.0.0.1", threaded=True)) manager.add_command("clean", Clean()) manager.add_command("show_urls", ShowUrls()) manager.add_command("db", MigrateCommand) diff --git a/lemur/metrics.py b/lemur/metrics.py index 381dc605..52f8c25b 100644 --- a/lemur/metrics.py +++ b/lemur/metrics.py @@ -11,6 +11,7 @@ class Metrics(object): """ :param app: The Flask application object. Defaults to None. """ + _providers = [] def __init__(self, app=None): @@ -22,11 +23,14 @@ class Metrics(object): :param app: The Flask application object. """ - self._providers = app.config.get('METRIC_PROVIDERS', []) + self._providers = app.config.get("METRIC_PROVIDERS", []) def send(self, metric_name, metric_type, metric_value, *args, **kwargs): for provider in self._providers: current_app.logger.debug( - "Sending metric '{metric}' to the {provider} provider.".format(metric=metric_name, provider=provider)) + "Sending metric '{metric}' to the {provider} provider.".format( + metric=metric_name, provider=provider + ) + ) p = plugins.get(provider) p.submit(metric_name, metric_type, metric_value, *args, **kwargs) diff --git a/lemur/migrations/env.py b/lemur/migrations/env.py index 63425041..008a9952 100644 --- a/lemur/migrations/env.py +++ b/lemur/migrations/env.py @@ -19,8 +19,11 @@ fileConfig(config.config_file_name) # from myapp import mymodel # target_metadata = mymodel.Base.metadata from flask import current_app -config.set_main_option('sqlalchemy.url', current_app.config.get('SQLALCHEMY_DATABASE_URI')) -target_metadata = current_app.extensions['migrate'].db.metadata + +config.set_main_option( + "sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI") +) +target_metadata = current_app.extensions["migrate"].db.metadata # other values from the config, defined by the needs of env.py, # can be acquired: @@ -54,14 +57,18 @@ def run_migrations_online(): and associate a connection with the context. """ - engine = engine_from_config(config.get_section(config.config_ini_section), - prefix='sqlalchemy.', - poolclass=pool.NullPool) + engine = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) connection = engine.connect() - context.configure(connection=connection, - target_metadata=target_metadata, - **current_app.extensions['migrate'].configure_args) + context.configure( + connection=connection, + target_metadata=target_metadata, + **current_app.extensions["migrate"].configure_args + ) try: with context.begin_transaction(): @@ -69,8 +76,8 @@ def run_migrations_online(): finally: connection.close() + if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() - diff --git a/lemur/migrations/versions/131ec6accff5_.py b/lemur/migrations/versions/131ec6accff5_.py index bddc5fe2..d5b42462 100644 --- a/lemur/migrations/versions/131ec6accff5_.py +++ b/lemur/migrations/versions/131ec6accff5_.py @@ -7,8 +7,8 @@ Create Date: 2016-12-07 17:29:42.049986 """ # revision identifiers, used by Alembic. -revision = '131ec6accff5' -down_revision = 'e3691fc396e9' +revision = "131ec6accff5" +down_revision = "e3691fc396e9" from alembic import op import sqlalchemy as sa @@ -16,13 +16,24 @@ import sqlalchemy as sa def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('certificates', sa.Column('rotation', sa.Boolean(), nullable=False, server_default=sa.false())) - op.add_column('endpoints', sa.Column('last_updated', sa.DateTime(), server_default=sa.text('now()'), nullable=False)) + op.add_column( + "certificates", + sa.Column("rotation", sa.Boolean(), nullable=False, server_default=sa.false()), + ) + op.add_column( + "endpoints", + sa.Column( + "last_updated", + sa.DateTime(), + server_default=sa.text("now()"), + nullable=False, + ), + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('endpoints', 'last_updated') - op.drop_column('certificates', 'rotation') + op.drop_column("endpoints", "last_updated") + op.drop_column("certificates", "rotation") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/1ae8e3104db8_.py b/lemur/migrations/versions/1ae8e3104db8_.py index 3cb3bb9e..9e19f0e7 100644 --- a/lemur/migrations/versions/1ae8e3104db8_.py +++ b/lemur/migrations/versions/1ae8e3104db8_.py @@ -7,15 +7,19 @@ Create Date: 2017-07-13 12:32:09.162800 """ # revision identifiers, used by Alembic. -revision = '1ae8e3104db8' -down_revision = 'a02a678ddc25' +revision = "1ae8e3104db8" +down_revision = "a02a678ddc25" from alembic import op def upgrade(): - op.sync_enum_values('public', 'log_type', ['key_view'], ['create_cert', 'key_view', 'update_cert']) + op.sync_enum_values( + "public", "log_type", ["key_view"], ["create_cert", "key_view", "update_cert"] + ) def downgrade(): - op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['key_view']) + op.sync_enum_values( + "public", "log_type", ["create_cert", "key_view", "update_cert"], ["key_view"] + ) diff --git a/lemur/migrations/versions/1db4f82bc780_.py b/lemur/migrations/versions/1db4f82bc780_.py index 2d917e2e..e6fb47f0 100644 --- a/lemur/migrations/versions/1db4f82bc780_.py +++ b/lemur/migrations/versions/1db4f82bc780_.py @@ -7,8 +7,8 @@ Create Date: 2018-08-03 12:56:44.565230 """ # revision identifiers, used by Alembic. -revision = '1db4f82bc780' -down_revision = '3adfdd6598df' +revision = "1db4f82bc780" +down_revision = "3adfdd6598df" import logging @@ -20,12 +20,14 @@ log = logging.getLogger(__name__) def upgrade(): connection = op.get_bind() - result = connection.execute("""\ + result = connection.execute( + """\ UPDATE certificates SET rotation_policy_id=(SELECT id FROM rotation_policies WHERE name='default') WHERE rotation_policy_id IS NULL RETURNING id - """) + """ + ) log.info("Filled rotation_policy for %d certificates" % result.rowcount) diff --git a/lemur/migrations/versions/29d8c8455c86_.py b/lemur/migrations/versions/29d8c8455c86_.py index f0b4749f..3a0e8717 100644 --- a/lemur/migrations/versions/29d8c8455c86_.py +++ b/lemur/migrations/versions/29d8c8455c86_.py @@ -7,8 +7,8 @@ Create Date: 2016-06-28 16:05:25.720213 """ # revision identifiers, used by Alembic. -revision = '29d8c8455c86' -down_revision = '3307381f3b88' +revision = "29d8c8455c86" +down_revision = "3307381f3b88" from alembic import op import sqlalchemy as sa @@ -17,46 +17,60 @@ from sqlalchemy.dialects import postgresql def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('ciphers', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=128), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "ciphers", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=128), nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('policy', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=128), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "policy", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=128), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('policies_ciphers', - sa.Column('cipher_id', sa.Integer(), nullable=True), - sa.Column('policy_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['cipher_id'], ['ciphers.id'], ), - sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], ) + op.create_table( + "policies_ciphers", + sa.Column("cipher_id", sa.Integer(), nullable=True), + sa.Column("policy_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["cipher_id"], ["ciphers.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]), ) - op.create_index('policies_ciphers_ix', 'policies_ciphers', ['cipher_id', 'policy_id'], unique=False) - op.create_table('endpoints', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('owner', sa.String(length=128), nullable=True), - sa.Column('name', sa.String(length=128), nullable=True), - sa.Column('dnsname', sa.String(length=256), nullable=True), - sa.Column('type', sa.String(length=128), nullable=True), - sa.Column('active', sa.Boolean(), nullable=True), - sa.Column('port', sa.Integer(), nullable=True), - sa.Column('date_created', sa.DateTime(), server_default=sa.text(u'now()'), nullable=False), - sa.Column('policy_id', sa.Integer(), nullable=True), - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), - sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + "policies_ciphers_ix", + "policies_ciphers", + ["cipher_id", "policy_id"], + unique=False, + ) + op.create_table( + "endpoints", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("owner", sa.String(length=128), nullable=True), + sa.Column("name", sa.String(length=128), nullable=True), + sa.Column("dnsname", sa.String(length=256), nullable=True), + sa.Column("type", sa.String(length=128), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.Column("port", sa.Integer(), nullable=True), + sa.Column( + "date_created", + sa.DateTime(), + server_default=sa.text(u"now()"), + nullable=False, + ), + sa.Column("policy_id", sa.Integer(), nullable=True), + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]), + sa.PrimaryKeyConstraint("id"), ) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_table('endpoints') - op.drop_index('policies_ciphers_ix', table_name='policies_ciphers') - op.drop_table('policies_ciphers') - op.drop_table('policy') - op.drop_table('ciphers') + op.drop_table("endpoints") + op.drop_index("policies_ciphers_ix", table_name="policies_ciphers") + op.drop_table("policies_ciphers") + op.drop_table("policy") + op.drop_table("ciphers") ### end Alembic commands ### diff --git a/lemur/migrations/versions/318b66568358_.py b/lemur/migrations/versions/318b66568358_.py new file mode 100644 index 00000000..8578cd78 --- /dev/null +++ b/lemur/migrations/versions/318b66568358_.py @@ -0,0 +1,23 @@ +""" Set 'deleted' flag from null to false on all certificates once + +Revision ID: 318b66568358 +Revises: 9f79024fe67b +Create Date: 2019-02-05 15:42:25.477587 + +""" + +# revision identifiers, used by Alembic. +revision = "318b66568358" +down_revision = "9f79024fe67b" + +from alembic import op + + +def upgrade(): + connection = op.get_bind() + # Delete duplicate entries + connection.execute("UPDATE certificates SET deleted = false WHERE deleted IS NULL") + + +def downgrade(): + pass diff --git a/lemur/migrations/versions/3307381f3b88_.py b/lemur/migrations/versions/3307381f3b88_.py index e4da96a6..2af0448b 100644 --- a/lemur/migrations/versions/3307381f3b88_.py +++ b/lemur/migrations/versions/3307381f3b88_.py @@ -12,8 +12,8 @@ Create Date: 2016-05-20 17:33:04.360687 """ # revision identifiers, used by Alembic. -revision = '3307381f3b88' -down_revision = '412b22cb656a' +revision = "3307381f3b88" +down_revision = "412b22cb656a" from alembic import op import sqlalchemy as sa @@ -23,109 +23,165 @@ from sqlalchemy.dialects import postgresql def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.alter_column('authorities', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) - op.drop_column('authorities', 'not_after') - op.drop_column('authorities', 'bits') - op.drop_column('authorities', 'cn') - op.drop_column('authorities', 'not_before') - op.add_column('certificates', sa.Column('root_authority_id', sa.Integer(), nullable=True)) - op.alter_column('certificates', 'body', - existing_type=sa.TEXT(), - nullable=False) - op.alter_column('certificates', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) - op.drop_constraint(u'certificates_authority_id_fkey', 'certificates', type_='foreignkey') - op.create_foreign_key(None, 'certificates', 'authorities', ['authority_id'], ['id'], ondelete='CASCADE') - op.create_foreign_key(None, 'certificates', 'authorities', ['root_authority_id'], ['id'], ondelete='CASCADE') + op.alter_column( + "authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) + op.drop_column("authorities", "not_after") + op.drop_column("authorities", "bits") + op.drop_column("authorities", "cn") + op.drop_column("authorities", "not_before") + op.add_column( + "certificates", sa.Column("root_authority_id", sa.Integer(), nullable=True) + ) + op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=False) + op.alter_column( + "certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) + op.drop_constraint( + u"certificates_authority_id_fkey", "certificates", type_="foreignkey" + ) + op.create_foreign_key( + None, + "certificates", + "authorities", + ["authority_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + None, + "certificates", + "authorities", + ["root_authority_id"], + ["id"], + ondelete="CASCADE", + ) ### end Alembic commands ### # link existing certificate to their authority certificates conn = op.get_bind() - for id, body, owner in conn.execute(text('select id, body, owner from authorities')): + for id, body, owner in conn.execute( + text("select id, body, owner from authorities") + ): if not owner: owner = "lemur@nobody" # look up certificate by body, if duplications are found, pick one - stmt = text('select id from certificates where body=:body') + stmt = text("select id from certificates where body=:body") stmt = stmt.bindparams(body=body) root_certificate = conn.execute(stmt).fetchone() if root_certificate: - stmt = text('update certificates set root_authority_id=:root_authority_id where id=:id') + stmt = text( + "update certificates set root_authority_id=:root_authority_id where id=:id" + ) stmt = stmt.bindparams(root_authority_id=id, id=root_certificate[0]) op.execute(stmt) # link owner roles to their authorities - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() if not owner_role: - stmt = text('insert into roles (name, description) values (:name, :description)') - stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.') + stmt = text( + "insert into roles (name, description) values (:name, :description)" + ) + stmt = stmt.bindparams( + name=owner, description="Lemur generated role or existing owner." + ) op.execute(stmt) - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() - stmt = text('select * from roles_authorities where role_id=:role_id and authority_id=:authority_id') + stmt = text( + "select * from roles_authorities where role_id=:role_id and authority_id=:authority_id" + ) stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id) exists = conn.execute(stmt).fetchone() if not exists: - stmt = text('insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)') + stmt = text( + "insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)" + ) stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id) op.execute(stmt) # link owner roles to their certificates - for id, owner in conn.execute(text('select id, owner from certificates')): + for id, owner in conn.execute(text("select id, owner from certificates")): if not owner: owner = "lemur@nobody" - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() if not owner_role: - stmt = text('insert into roles (name, description) values (:name, :description)') - stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.') + stmt = text( + "insert into roles (name, description) values (:name, :description)" + ) + stmt = stmt.bindparams( + name=owner, description="Lemur generated role or existing owner." + ) op.execute(stmt) # link owner roles to their authorities - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() - stmt = text('select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id') + stmt = text( + "select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id" + ) stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id) exists = conn.execute(stmt).fetchone() if not exists: - stmt = text('insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)') + stmt = text( + "insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)" + ) stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id) op.execute(stmt) def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'certificates', type_='foreignkey') - op.drop_constraint(None, 'certificates', type_='foreignkey') - op.create_foreign_key(u'certificates_authority_id_fkey', 'certificates', 'authorities', ['authority_id'], ['id']) - op.alter_column('certificates', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) - op.alter_column('certificates', 'body', - existing_type=sa.TEXT(), - nullable=True) - op.drop_column('certificates', 'root_authority_id') - op.add_column('authorities', sa.Column('not_before', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) - op.add_column('authorities', sa.Column('cn', sa.VARCHAR(length=128), autoincrement=False, nullable=True)) - op.add_column('authorities', sa.Column('bits', sa.INTEGER(), autoincrement=False, nullable=True)) - op.add_column('authorities', sa.Column('not_after', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) - op.alter_column('authorities', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) + op.drop_constraint(None, "certificates", type_="foreignkey") + op.drop_constraint(None, "certificates", type_="foreignkey") + op.create_foreign_key( + u"certificates_authority_id_fkey", + "certificates", + "authorities", + ["authority_id"], + ["id"], + ) + op.alter_column( + "certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) + op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=True) + op.drop_column("certificates", "root_authority_id") + op.add_column( + "authorities", + sa.Column( + "not_before", postgresql.TIMESTAMP(), autoincrement=False, nullable=True + ), + ) + op.add_column( + "authorities", + sa.Column("cn", sa.VARCHAR(length=128), autoincrement=False, nullable=True), + ) + op.add_column( + "authorities", + sa.Column("bits", sa.INTEGER(), autoincrement=False, nullable=True), + ) + op.add_column( + "authorities", + sa.Column( + "not_after", postgresql.TIMESTAMP(), autoincrement=False, nullable=True + ), + ) + op.alter_column( + "authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) ### end Alembic commands ### diff --git a/lemur/migrations/versions/33de094da890_.py b/lemur/migrations/versions/33de094da890_.py index 76624e96..718e908f 100644 --- a/lemur/migrations/versions/33de094da890_.py +++ b/lemur/migrations/versions/33de094da890_.py @@ -7,25 +7,31 @@ Create Date: 2015-11-30 15:40:19.827272 """ # revision identifiers, used by Alembic. -revision = '33de094da890' +revision = "33de094da890" down_revision = None from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql + def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('certificate_replacement_associations', - sa.Column('replaced_certificate_id', sa.Integer(), nullable=True), - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade') + op.create_table( + "certificate_replacement_associations", + sa.Column("replaced_certificate_id", sa.Integer(), nullable=True), + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["certificate_id"], ["certificates.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["replaced_certificate_id"], ["certificates.id"], ondelete="cascade" + ), ) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_table('certificate_replacement_associations') + op.drop_table("certificate_replacement_associations") ### end Alembic commands ### diff --git a/lemur/migrations/versions/3adfdd6598df_.py b/lemur/migrations/versions/3adfdd6598df_.py index 1f290153..7f587f49 100644 --- a/lemur/migrations/versions/3adfdd6598df_.py +++ b/lemur/migrations/versions/3adfdd6598df_.py @@ -7,8 +7,8 @@ Create Date: 2018-04-10 13:25:47.007556 """ # revision identifiers, used by Alembic. -revision = '3adfdd6598df' -down_revision = '556ceb3e3c3e' +revision = "3adfdd6598df" +down_revision = "556ceb3e3c3e" import sqlalchemy as sa from alembic import op @@ -22,84 +22,90 @@ def upgrade(): # create provider table print("Creating dns_providers table") op.create_table( - 'dns_providers', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=256), nullable=True), - sa.Column('description', sa.String(length=1024), nullable=True), - sa.Column('provider_type', sa.String(length=256), nullable=True), - sa.Column('credentials', Vault(), nullable=True), - sa.Column('api_endpoint', sa.String(length=256), nullable=True), - sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False), - sa.Column('status', sa.String(length=128), nullable=True), - sa.Column('options', JSON), - sa.Column('domains', sa.JSON(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name') + "dns_providers", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=256), nullable=True), + sa.Column("description", sa.String(length=1024), nullable=True), + sa.Column("provider_type", sa.String(length=256), nullable=True), + sa.Column("credentials", Vault(), nullable=True), + sa.Column("api_endpoint", sa.String(length=256), nullable=True), + sa.Column( + "date_created", ArrowType(), server_default=sa.text("now()"), nullable=False + ), + sa.Column("status", sa.String(length=128), nullable=True), + sa.Column("options", JSON), + sa.Column("domains", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) print("Adding dns_provider_id column to certificates") - op.add_column('certificates', sa.Column('dns_provider_id', sa.Integer(), nullable=True)) + op.add_column( + "certificates", sa.Column("dns_provider_id", sa.Integer(), nullable=True) + ) print("Adding dns_provider_id column to pending_certs") - op.add_column('pending_certs', sa.Column('dns_provider_id', sa.Integer(), nullable=True)) + op.add_column( + "pending_certs", sa.Column("dns_provider_id", sa.Integer(), nullable=True) + ) print("Adding options column to pending_certs") - op.add_column('pending_certs', sa.Column('options', JSON)) + op.add_column("pending_certs", sa.Column("options", JSON)) print("Creating pending_dns_authorizations table") op.create_table( - 'pending_dns_authorizations', - sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), - sa.Column('account_number', sa.String(length=128), nullable=True), - sa.Column('domains', JSON, nullable=True), - sa.Column('dns_provider_type', sa.String(length=128), nullable=True), - sa.Column('options', JSON, nullable=True), + "pending_dns_authorizations", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("account_number", sa.String(length=128), nullable=True), + sa.Column("domains", JSON, nullable=True), + sa.Column("dns_provider_type", sa.String(length=128), nullable=True), + sa.Column("options", JSON, nullable=True), ) print("Creating certificates_dns_providers_fk foreign key") - op.create_foreign_key('certificates_dns_providers_fk', 'certificates', 'dns_providers', ['dns_provider_id'], ['id'], - ondelete='cascade') + op.create_foreign_key( + "certificates_dns_providers_fk", + "certificates", + "dns_providers", + ["dns_provider_id"], + ["id"], + ondelete="cascade", + ) print("Altering column types in the api_keys table") - op.alter_column('api_keys', 'issued_at', - existing_type=sa.BIGINT(), - nullable=True) - op.alter_column('api_keys', 'revoked', - existing_type=sa.BOOLEAN(), - nullable=True) - op.alter_column('api_keys', 'ttl', - existing_type=sa.BIGINT(), - nullable=True) - op.alter_column('api_keys', 'user_id', - existing_type=sa.INTEGER(), - nullable=True) + op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=True) + op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=True) + op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=True) + op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=True) print("Creating dns_providers_id foreign key on pending_certs table") - op.create_foreign_key(None, 'pending_certs', 'dns_providers', ['dns_provider_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key( + None, + "pending_certs", + "dns_providers", + ["dns_provider_id"], + ["id"], + ondelete="CASCADE", + ) + def downgrade(): print("Removing dns_providers_id foreign key on pending_certs table") - op.drop_constraint(None, 'pending_certs', type_='foreignkey') + op.drop_constraint(None, "pending_certs", type_="foreignkey") print("Reverting column types in the api_keys table") - op.alter_column('api_keys', 'user_id', - existing_type=sa.INTEGER(), - nullable=False) - op.alter_column('api_keys', 'ttl', - existing_type=sa.BIGINT(), - nullable=False) - op.alter_column('api_keys', 'revoked', - existing_type=sa.BOOLEAN(), - nullable=False) - op.alter_column('api_keys', 'issued_at', - existing_type=sa.BIGINT(), - nullable=False) + op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=False) + op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=False) + op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=False) + op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=False) print("Reverting certificates_dns_providers_fk foreign key") - op.drop_constraint('certificates_dns_providers_fk', 'certificates', type_='foreignkey') + op.drop_constraint( + "certificates_dns_providers_fk", "certificates", type_="foreignkey" + ) print("Dropping pending_dns_authorizations table") - op.drop_table('pending_dns_authorizations') + op.drop_table("pending_dns_authorizations") print("Undoing modifications to pending_certs table") - op.drop_column('pending_certs', 'options') - op.drop_column('pending_certs', 'dns_provider_id') + op.drop_column("pending_certs", "options") + op.drop_column("pending_certs", "dns_provider_id") print("Undoing modifications to certificates table") - op.drop_column('certificates', 'dns_provider_id') + op.drop_column("certificates", "dns_provider_id") print("Deleting dns_providers table") - op.drop_table('dns_providers') + op.drop_table("dns_providers") diff --git a/lemur/migrations/versions/412b22cb656a_.py b/lemur/migrations/versions/412b22cb656a_.py index d95ec701..c24ddfba 100644 --- a/lemur/migrations/versions/412b22cb656a_.py +++ b/lemur/migrations/versions/412b22cb656a_.py @@ -7,8 +7,8 @@ Create Date: 2016-05-17 17:37:41.210232 """ # revision identifiers, used by Alembic. -revision = '412b22cb656a' -down_revision = '4c50b903d1ae' +revision = "412b22cb656a" +down_revision = "4c50b903d1ae" from alembic import op import sqlalchemy as sa @@ -17,47 +17,102 @@ from sqlalchemy.sql import text def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('roles_authorities', - sa.Column('authority_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ), - sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) + op.create_table( + "roles_authorities", + sa.Column("authority_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["authority_id"], ["authorities.id"]), + sa.ForeignKeyConstraint(["role_id"], ["roles.id"]), ) - op.create_index('roles_authorities_ix', 'roles_authorities', ['authority_id', 'role_id'], unique=True) - op.create_table('roles_certificates', - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), - sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) + op.create_index( + "roles_authorities_ix", + "roles_authorities", + ["authority_id", "role_id"], + unique=True, + ) + op.create_table( + "roles_certificates", + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]), + sa.ForeignKeyConstraint(["role_id"], ["roles.id"]), + ) + op.create_index( + "roles_certificates_ix", + "roles_certificates", + ["certificate_id", "role_id"], + unique=True, + ) + op.create_index( + "certificate_associations_ix", + "certificate_associations", + ["domain_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_destination_associations_ix", + "certificate_destination_associations", + ["destination_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_notification_associations_ix", + "certificate_notification_associations", + ["notification_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["certificate_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_source_associations_ix", + "certificate_source_associations", + ["source_id", "certificate_id"], + unique=True, + ) + op.create_index( + "roles_users_ix", "roles_users", ["user_id", "role_id"], unique=True ) - op.create_index('roles_certificates_ix', 'roles_certificates', ['certificate_id', 'role_id'], unique=True) - op.create_index('certificate_associations_ix', 'certificate_associations', ['domain_id', 'certificate_id'], unique=True) - op.create_index('certificate_destination_associations_ix', 'certificate_destination_associations', ['destination_id', 'certificate_id'], unique=True) - op.create_index('certificate_notification_associations_ix', 'certificate_notification_associations', ['notification_id', 'certificate_id'], unique=True) - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True) - op.create_index('certificate_source_associations_ix', 'certificate_source_associations', ['source_id', 'certificate_id'], unique=True) - op.create_index('roles_users_ix', 'roles_users', ['user_id', 'role_id'], unique=True) ### end Alembic commands ### # migrate existing authority_id relationship to many_to_many conn = op.get_bind() - for id, authority_id in conn.execute(text('select id, authority_id from roles where authority_id is not null')): - stmt = text('insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)') + for id, authority_id in conn.execute( + text("select id, authority_id from roles where authority_id is not null") + ): + stmt = text( + "insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)" + ) stmt = stmt.bindparams(role_id=id, authority_id=authority_id) op.execute(stmt) def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_index('roles_users_ix', table_name='roles_users') - op.drop_index('certificate_source_associations_ix', table_name='certificate_source_associations') - op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations') - op.drop_index('certificate_notification_associations_ix', table_name='certificate_notification_associations') - op.drop_index('certificate_destination_associations_ix', table_name='certificate_destination_associations') - op.drop_index('certificate_associations_ix', table_name='certificate_associations') - op.drop_index('roles_certificates_ix', table_name='roles_certificates') - op.drop_table('roles_certificates') - op.drop_index('roles_authorities_ix', table_name='roles_authorities') - op.drop_table('roles_authorities') + op.drop_index("roles_users_ix", table_name="roles_users") + op.drop_index( + "certificate_source_associations_ix", + table_name="certificate_source_associations", + ) + op.drop_index( + "certificate_replacement_associations_ix", + table_name="certificate_replacement_associations", + ) + op.drop_index( + "certificate_notification_associations_ix", + table_name="certificate_notification_associations", + ) + op.drop_index( + "certificate_destination_associations_ix", + table_name="certificate_destination_associations", + ) + op.drop_index("certificate_associations_ix", table_name="certificate_associations") + op.drop_index("roles_certificates_ix", table_name="roles_certificates") + op.drop_table("roles_certificates") + op.drop_index("roles_authorities_ix", table_name="roles_authorities") + op.drop_table("roles_authorities") ### end Alembic commands ### diff --git a/lemur/migrations/versions/449c3d5c7299_.py b/lemur/migrations/versions/449c3d5c7299_.py index 1dcb7ab5..f33548da 100644 --- a/lemur/migrations/versions/449c3d5c7299_.py +++ b/lemur/migrations/versions/449c3d5c7299_.py @@ -7,8 +7,8 @@ Create Date: 2018-02-24 22:51:35.369229 """ # revision identifiers, used by Alembic. -revision = '449c3d5c7299' -down_revision = '5770674184de' +revision = "449c3d5c7299" +down_revision = "5770674184de" from alembic import op from flask_sqlalchemy import SQLAlchemy @@ -21,6 +21,16 @@ COLUMNS = ["notification_id", "certificate_id"] def upgrade(): + connection = op.get_bind() + # Delete duplicate entries + connection.execute( + """\ + DELETE FROM certificate_notification_associations WHERE ctid NOT IN ( + -- Select the first tuple ID for each (notification_id, certificate_id) combination and keep that + SELECT min(ctid) FROM certificate_notification_associations GROUP BY notification_id, certificate_id + ) + """ + ) op.create_unique_constraint(CONSTRAINT_NAME, TABLE, COLUMNS) diff --git a/lemur/migrations/versions/4c50b903d1ae_.py b/lemur/migrations/versions/4c50b903d1ae_.py index 7b0515d4..93d4a312 100644 --- a/lemur/migrations/versions/4c50b903d1ae_.py +++ b/lemur/migrations/versions/4c50b903d1ae_.py @@ -7,20 +7,21 @@ Create Date: 2015-12-30 10:19:30.057791 """ # revision identifiers, used by Alembic. -revision = '4c50b903d1ae' -down_revision = '33de094da890' +revision = "4c50b903d1ae" +down_revision = "33de094da890" from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql + def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.add_column('domains', sa.Column('sensitive', sa.Boolean(), nullable=True)) + op.add_column("domains", sa.Column("sensitive", sa.Boolean(), nullable=True)) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_column('domains', 'sensitive') + op.drop_column("domains", "sensitive") ### end Alembic commands ### diff --git a/lemur/migrations/versions/556ceb3e3c3e_.py b/lemur/migrations/versions/556ceb3e3c3e_.py index 2916c0eb..60304138 100644 --- a/lemur/migrations/versions/556ceb3e3c3e_.py +++ b/lemur/migrations/versions/556ceb3e3c3e_.py @@ -7,8 +7,8 @@ Create Date: 2018-01-05 01:18:45.571595 """ # revision identifiers, used by Alembic. -revision = '556ceb3e3c3e' -down_revision = '449c3d5c7299' +revision = "556ceb3e3c3e" +down_revision = "449c3d5c7299" from alembic import op import sqlalchemy as sa @@ -16,84 +16,150 @@ from lemur.utils import Vault from sqlalchemy.dialects import postgresql from sqlalchemy_utils import ArrowType + def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('pending_certs', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('external_id', sa.String(length=128), nullable=True), - sa.Column('owner', sa.String(length=128), nullable=False), - sa.Column('name', sa.String(length=256), nullable=True), - sa.Column('description', sa.String(length=1024), nullable=True), - sa.Column('notify', sa.Boolean(), nullable=True), - sa.Column('number_attempts', sa.Integer(), nullable=True), - sa.Column('rename', sa.Boolean(), nullable=True), - sa.Column('cn', sa.String(length=128), nullable=True), - sa.Column('csr', sa.Text(), nullable=False), - sa.Column('chain', sa.Text(), nullable=True), - sa.Column('private_key', Vault(), nullable=True), - sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False), - sa.Column('status', sa.String(length=128), nullable=True), - sa.Column('rotation', sa.Boolean(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('authority_id', sa.Integer(), nullable=True), - sa.Column('root_authority_id', sa.Integer(), nullable=True), - sa.Column('rotation_policy_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['root_authority_id'], ['authorities.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['rotation_policy_id'], ['rotation_policies.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name') + op.create_table( + "pending_certs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("external_id", sa.String(length=128), nullable=True), + sa.Column("owner", sa.String(length=128), nullable=False), + sa.Column("name", sa.String(length=256), nullable=True), + sa.Column("description", sa.String(length=1024), nullable=True), + sa.Column("notify", sa.Boolean(), nullable=True), + sa.Column("number_attempts", sa.Integer(), nullable=True), + sa.Column("rename", sa.Boolean(), nullable=True), + sa.Column("cn", sa.String(length=128), nullable=True), + sa.Column("csr", sa.Text(), nullable=False), + sa.Column("chain", sa.Text(), nullable=True), + sa.Column("private_key", Vault(), nullable=True), + sa.Column( + "date_created", ArrowType(), server_default=sa.text("now()"), nullable=False + ), + sa.Column("status", sa.String(length=128), nullable=True), + sa.Column("rotation", sa.Boolean(), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("authority_id", sa.Integer(), nullable=True), + sa.Column("root_authority_id", sa.Integer(), nullable=True), + sa.Column("rotation_policy_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["authority_id"], ["authorities.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["root_authority_id"], ["authorities.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["rotation_policy_id"], ["rotation_policies.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) - op.create_table('pending_cert_destination_associations', - sa.Column('destination_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['destination_id'], ['destinations.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade') + op.create_table( + "pending_cert_destination_associations", + sa.Column("destination_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["destination_id"], ["destinations.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), ) - op.create_index('pending_cert_destination_associations_ix', 'pending_cert_destination_associations', ['destination_id', 'pending_cert_id'], unique=False) - op.create_table('pending_cert_notification_associations', - sa.Column('notification_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['notification_id'], ['notifications.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade') + op.create_index( + "pending_cert_destination_associations_ix", + "pending_cert_destination_associations", + ["destination_id", "pending_cert_id"], + unique=False, ) - op.create_index('pending_cert_notification_associations_ix', 'pending_cert_notification_associations', ['notification_id', 'pending_cert_id'], unique=False) - op.create_table('pending_cert_replacement_associations', - sa.Column('replaced_certificate_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade') + op.create_table( + "pending_cert_notification_associations", + sa.Column("notification_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["notification_id"], ["notifications.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), ) - op.create_index('pending_cert_replacement_associations_ix', 'pending_cert_replacement_associations', ['replaced_certificate_id', 'pending_cert_id'], unique=False) - op.create_table('pending_cert_role_associations', - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ), - sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) + op.create_index( + "pending_cert_notification_associations_ix", + "pending_cert_notification_associations", + ["notification_id", "pending_cert_id"], + unique=False, ) - op.create_index('pending_cert_role_associations_ix', 'pending_cert_role_associations', ['pending_cert_id', 'role_id'], unique=False) - op.create_table('pending_cert_source_associations', - sa.Column('source_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='cascade') + op.create_table( + "pending_cert_replacement_associations", + sa.Column("replaced_certificate_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["replaced_certificate_id"], ["certificates.id"], ondelete="cascade" + ), + ) + op.create_index( + "pending_cert_replacement_associations_ix", + "pending_cert_replacement_associations", + ["replaced_certificate_id", "pending_cert_id"], + unique=False, + ) + op.create_table( + "pending_cert_role_associations", + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["pending_cert_id"], ["pending_certs.id"]), + sa.ForeignKeyConstraint(["role_id"], ["roles.id"]), + ) + op.create_index( + "pending_cert_role_associations_ix", + "pending_cert_role_associations", + ["pending_cert_id", "role_id"], + unique=False, + ) + op.create_table( + "pending_cert_source_associations", + sa.Column("source_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="cascade"), + ) + op.create_index( + "pending_cert_source_associations_ix", + "pending_cert_source_associations", + ["source_id", "pending_cert_id"], + unique=False, ) - op.create_index('pending_cert_source_associations_ix', 'pending_cert_source_associations', ['source_id', 'pending_cert_id'], unique=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_index('pending_cert_source_associations_ix', table_name='pending_cert_source_associations') - op.drop_table('pending_cert_source_associations') - op.drop_index('pending_cert_role_associations_ix', table_name='pending_cert_role_associations') - op.drop_table('pending_cert_role_associations') - op.drop_index('pending_cert_replacement_associations_ix', table_name='pending_cert_replacement_associations') - op.drop_table('pending_cert_replacement_associations') - op.drop_index('pending_cert_notification_associations_ix', table_name='pending_cert_notification_associations') - op.drop_table('pending_cert_notification_associations') - op.drop_index('pending_cert_destination_associations_ix', table_name='pending_cert_destination_associations') - op.drop_table('pending_cert_destination_associations') - op.drop_table('pending_certs') + op.drop_index( + "pending_cert_source_associations_ix", + table_name="pending_cert_source_associations", + ) + op.drop_table("pending_cert_source_associations") + op.drop_index( + "pending_cert_role_associations_ix", table_name="pending_cert_role_associations" + ) + op.drop_table("pending_cert_role_associations") + op.drop_index( + "pending_cert_replacement_associations_ix", + table_name="pending_cert_replacement_associations", + ) + op.drop_table("pending_cert_replacement_associations") + op.drop_index( + "pending_cert_notification_associations_ix", + table_name="pending_cert_notification_associations", + ) + op.drop_table("pending_cert_notification_associations") + op.drop_index( + "pending_cert_destination_associations_ix", + table_name="pending_cert_destination_associations", + ) + op.drop_table("pending_cert_destination_associations") + op.drop_table("pending_certs") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/5770674184de_.py b/lemur/migrations/versions/5770674184de_.py index 88262a84..49d89367 100644 --- a/lemur/migrations/versions/5770674184de_.py +++ b/lemur/migrations/versions/5770674184de_.py @@ -7,8 +7,8 @@ Create Date: 2018-02-23 15:27:30.335435 """ # revision identifiers, used by Alembic. -revision = '5770674184de' -down_revision = 'ce547319f7be' +revision = "5770674184de" +down_revision = "ce547319f7be" from flask_sqlalchemy import SQLAlchemy from lemur.models import certificate_notification_associations @@ -32,7 +32,9 @@ def upgrade(): # If we've seen a pair already, delete the duplicates if seen.get("{}-{}".format(x.certificate_id, x.notification_id)): print("Deleting duplicate: {}".format(x)) - d = session.query(certificate_notification_associations).filter(certificate_notification_associations.c.id==x.id) + d = session.query(certificate_notification_associations).filter( + certificate_notification_associations.c.id == x.id + ) d.delete(synchronize_session=False) seen["{}-{}".format(x.certificate_id, x.notification_id)] = True db.session.commit() diff --git a/lemur/migrations/versions/5ae0ecefb01f_.py b/lemur/migrations/versions/5ae0ecefb01f_.py index a471c4bf..7b0d5ae0 100644 --- a/lemur/migrations/versions/5ae0ecefb01f_.py +++ b/lemur/migrations/versions/5ae0ecefb01f_.py @@ -7,8 +7,8 @@ Create Date: 2018-08-14 08:16:43.329316 """ # revision identifiers, used by Alembic. -revision = '5ae0ecefb01f' -down_revision = '1db4f82bc780' +revision = "5ae0ecefb01f" +down_revision = "1db4f82bc780" from alembic import op import sqlalchemy as sa @@ -16,17 +16,14 @@ import sqlalchemy as sa def upgrade(): op.alter_column( - table_name='pending_certs', - column_name='status', - nullable=True, - type_=sa.TEXT() + table_name="pending_certs", column_name="status", nullable=True, type_=sa.TEXT() ) def downgrade(): op.alter_column( - table_name='pending_certs', - column_name='status', + table_name="pending_certs", + column_name="status", nullable=True, - type_=sa.VARCHAR(128) + type_=sa.VARCHAR(128), ) diff --git a/lemur/migrations/versions/5bc47fa7cac4_.py b/lemur/migrations/versions/5bc47fa7cac4_.py index f4a145c8..f786c527 100644 --- a/lemur/migrations/versions/5bc47fa7cac4_.py +++ b/lemur/migrations/versions/5bc47fa7cac4_.py @@ -7,16 +7,18 @@ Create Date: 2017-12-08 14:19:11.903864 """ # revision identifiers, used by Alembic. -revision = '5bc47fa7cac4' -down_revision = 'c05a8998b371' +revision = "5bc47fa7cac4" +down_revision = "c05a8998b371" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('roles', sa.Column('third_party', sa.Boolean(), nullable=True, default=False)) + op.add_column( + "roles", sa.Column("third_party", sa.Boolean(), nullable=True, default=False) + ) def downgrade(): - op.drop_column('roles', 'third_party') + op.drop_column("roles", "third_party") diff --git a/lemur/migrations/versions/5e680529b666_.py b/lemur/migrations/versions/5e680529b666_.py index d59d996f..4cca4521 100644 --- a/lemur/migrations/versions/5e680529b666_.py +++ b/lemur/migrations/versions/5e680529b666_.py @@ -7,20 +7,20 @@ Create Date: 2017-01-26 05:05:25.168125 """ # revision identifiers, used by Alembic. -revision = '5e680529b666' -down_revision = '131ec6accff5' +revision = "5e680529b666" +down_revision = "131ec6accff5" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('endpoints', sa.Column('sensitive', sa.Boolean(), nullable=True)) - op.add_column('endpoints', sa.Column('source_id', sa.Integer(), nullable=True)) - op.create_foreign_key(None, 'endpoints', 'sources', ['source_id'], ['id']) + op.add_column("endpoints", sa.Column("sensitive", sa.Boolean(), nullable=True)) + op.add_column("endpoints", sa.Column("source_id", sa.Integer(), nullable=True)) + op.create_foreign_key(None, "endpoints", "sources", ["source_id"], ["id"]) def downgrade(): - op.drop_constraint(None, 'endpoints', type_='foreignkey') - op.drop_column('endpoints', 'source_id') - op.drop_column('endpoints', 'sensitive') + op.drop_constraint(None, "endpoints", type_="foreignkey") + op.drop_column("endpoints", "source_id") + op.drop_column("endpoints", "sensitive") diff --git a/lemur/migrations/versions/6006c79b6011_.py b/lemur/migrations/versions/6006c79b6011_.py index c41b1d25..86727716 100644 --- a/lemur/migrations/versions/6006c79b6011_.py +++ b/lemur/migrations/versions/6006c79b6011_.py @@ -7,15 +7,15 @@ Create Date: 2018-10-19 15:23:06.750510 """ # revision identifiers, used by Alembic. -revision = '6006c79b6011' -down_revision = '984178255c83' +revision = "6006c79b6011" +down_revision = "984178255c83" from alembic import op def upgrade(): - op.create_unique_constraint("uq_label", 'sources', ['label']) + op.create_unique_constraint("uq_label", "sources", ["label"]) def downgrade(): - op.drop_constraint("uq_label", 'sources', type_='unique') + op.drop_constraint("uq_label", "sources", type_="unique") diff --git a/lemur/migrations/versions/7ead443ba911_.py b/lemur/migrations/versions/7ead443ba911_.py index 62be01aa..10b8e576 100644 --- a/lemur/migrations/versions/7ead443ba911_.py +++ b/lemur/migrations/versions/7ead443ba911_.py @@ -7,15 +7,16 @@ Create Date: 2018-10-21 22:06:23.056906 """ # revision identifiers, used by Alembic. -revision = '7ead443ba911' -down_revision = '6006c79b6011' +revision = "7ead443ba911" +down_revision = "6006c79b6011" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('certificates', sa.Column('csr', sa.TEXT(), nullable=True)) + op.add_column("certificates", sa.Column("csr", sa.TEXT(), nullable=True)) + def downgrade(): - op.drop_column('certificates', 'csr') + op.drop_column("certificates", "csr") diff --git a/lemur/migrations/versions/7f71c0cea31a_.py b/lemur/migrations/versions/7f71c0cea31a_.py index 04bb02ea..5e90cbb1 100644 --- a/lemur/migrations/versions/7f71c0cea31a_.py +++ b/lemur/migrations/versions/7f71c0cea31a_.py @@ -9,8 +9,8 @@ Create Date: 2016-07-28 09:39:12.736506 """ # revision identifiers, used by Alembic. -revision = '7f71c0cea31a' -down_revision = '29d8c8455c86' +revision = "7f71c0cea31a" +down_revision = "29d8c8455c86" from alembic import op import sqlalchemy as sa @@ -19,17 +19,25 @@ from sqlalchemy.sql import text def upgrade(): conn = op.get_bind() - for name in conn.execute(text('select name from certificates group by name having count(*) > 1')): - for idx, id in enumerate(conn.execute(text("select id from certificates where certificates.name like :name order by id ASC").bindparams(name=name[0]))): + for name in conn.execute( + text("select name from certificates group by name having count(*) > 1") + ): + for idx, id in enumerate( + conn.execute( + text( + "select id from certificates where certificates.name like :name order by id ASC" + ).bindparams(name=name[0]) + ) + ): if not idx: continue - new_name = name[0] + '-' + str(idx) - stmt = text('update certificates set name=:name where id=:id') + new_name = name[0] + "-" + str(idx) + stmt = text("update certificates set name=:name where id=:id") stmt = stmt.bindparams(name=new_name, id=id[0]) op.execute(stmt) - op.create_unique_constraint(None, 'certificates', ['name']) + op.create_unique_constraint(None, "certificates", ["name"]) def downgrade(): - op.drop_constraint(None, 'certificates', type_='unique') + op.drop_constraint(None, "certificates", type_="unique") diff --git a/lemur/migrations/versions/8ae67285ff14_.py b/lemur/migrations/versions/8ae67285ff14_.py index f45be70d..e8f6a217 100644 --- a/lemur/migrations/versions/8ae67285ff14_.py +++ b/lemur/migrations/versions/8ae67285ff14_.py @@ -7,18 +7,28 @@ Create Date: 2017-05-10 11:56:13.999332 """ # revision identifiers, used by Alembic. -revision = '8ae67285ff14' -down_revision = '5e680529b666' +revision = "8ae67285ff14" +down_revision = "5e680529b666" from alembic import op import sqlalchemy as sa def upgrade(): - op.drop_index('certificate_replacement_associations_ix') - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True) + op.drop_index("certificate_replacement_associations_ix") + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["replaced_certificate_id", "certificate_id"], + unique=True, + ) def downgrade(): - op.drop_index('certificate_replacement_associations_ix') - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True) + op.drop_index("certificate_replacement_associations_ix") + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["certificate_id", "certificate_id"], + unique=True, + ) diff --git a/lemur/migrations/versions/932525b82f1a_.py b/lemur/migrations/versions/932525b82f1a_.py index 2ee95d07..8ff36d1c 100644 --- a/lemur/migrations/versions/932525b82f1a_.py +++ b/lemur/migrations/versions/932525b82f1a_.py @@ -7,15 +7,15 @@ Create Date: 2016-10-13 20:14:33.928029 """ # revision identifiers, used by Alembic. -revision = '932525b82f1a' -down_revision = '7f71c0cea31a' +revision = "932525b82f1a" +down_revision = "7f71c0cea31a" from alembic import op def upgrade(): - op.alter_column('certificates', 'active', new_column_name='notify') + op.alter_column("certificates", "active", new_column_name="notify") def downgrade(): - op.alter_column('certificates', 'notify', new_column_name='active') + op.alter_column("certificates", "notify", new_column_name="active") diff --git a/lemur/migrations/versions/9392b9f9a805_.py b/lemur/migrations/versions/9392b9f9a805_.py index d6ca734b..8ff09333 100644 --- a/lemur/migrations/versions/9392b9f9a805_.py +++ b/lemur/migrations/versions/9392b9f9a805_.py @@ -6,8 +6,8 @@ Create Date: 2018-09-17 08:33:37.087488 """ # revision identifiers, used by Alembic. -revision = '9392b9f9a805' -down_revision = '5ae0ecefb01f' +revision = "9392b9f9a805" +down_revision = "5ae0ecefb01f" from alembic import op from sqlalchemy_utils import ArrowType @@ -15,10 +15,17 @@ import sqlalchemy as sa def upgrade(): - op.add_column('pending_certs', sa.Column('last_updated', ArrowType, server_default=sa.text('now()'), onupdate=sa.text('now()'), - nullable=False)) + op.add_column( + "pending_certs", + sa.Column( + "last_updated", + ArrowType, + server_default=sa.text("now()"), + onupdate=sa.text("now()"), + nullable=False, + ), + ) def downgrade(): - op.drop_column('pending_certs', 'last_updated') - + op.drop_column("pending_certs", "last_updated") diff --git a/lemur/migrations/versions/984178255c83_.py b/lemur/migrations/versions/984178255c83_.py index 40d2ce31..88cab183 100644 --- a/lemur/migrations/versions/984178255c83_.py +++ b/lemur/migrations/versions/984178255c83_.py @@ -7,18 +7,20 @@ Create Date: 2018-10-11 20:49:12.704563 """ # revision identifiers, used by Alembic. -revision = '984178255c83' -down_revision = 'f2383bf08fbc' +revision = "984178255c83" +down_revision = "f2383bf08fbc" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('pending_certs', sa.Column('resolved', sa.Boolean(), nullable=True)) - op.add_column('pending_certs', sa.Column('resolved_cert_id', sa.Integer(), nullable=True)) + op.add_column("pending_certs", sa.Column("resolved", sa.Boolean(), nullable=True)) + op.add_column( + "pending_certs", sa.Column("resolved_cert_id", sa.Integer(), nullable=True) + ) def downgrade(): - op.drop_column('pending_certs', 'resolved_cert_id') - op.drop_column('pending_certs', 'resolved') + op.drop_column("pending_certs", "resolved_cert_id") + op.drop_column("pending_certs", "resolved") diff --git a/lemur/migrations/versions/9f79024fe67b_.py b/lemur/migrations/versions/9f79024fe67b_.py new file mode 100644 index 00000000..cb7db296 --- /dev/null +++ b/lemur/migrations/versions/9f79024fe67b_.py @@ -0,0 +1,32 @@ +""" Add delete_cert to log_type enum + +Revision ID: 9f79024fe67b +Revises: ee827d1e1974 +Create Date: 2019-01-03 15:36:59.181911 + +""" + +# revision identifiers, used by Alembic. +revision = "9f79024fe67b" +down_revision = "ee827d1e1974" + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"], + ) + + +def downgrade(): + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"], + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ) diff --git a/lemur/migrations/versions/a02a678ddc25_.py b/lemur/migrations/versions/a02a678ddc25_.py index 603bc06a..f8fa09bb 100644 --- a/lemur/migrations/versions/a02a678ddc25_.py +++ b/lemur/migrations/versions/a02a678ddc25_.py @@ -10,8 +10,8 @@ Create Date: 2017-07-12 11:45:49.257927 """ # revision identifiers, used by Alembic. -revision = 'a02a678ddc25' -down_revision = '8ae67285ff14' +revision = "a02a678ddc25" +down_revision = "8ae67285ff14" from alembic import op import sqlalchemy as sa @@ -20,25 +20,30 @@ from sqlalchemy.sql import text def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('rotation_policies', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('days', sa.Integer(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "rotation_policies", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("days", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column( + "certificates", sa.Column("rotation_policy_id", sa.Integer(), nullable=True) + ) + op.create_foreign_key( + None, "certificates", "rotation_policies", ["rotation_policy_id"], ["id"] ) - op.add_column('certificates', sa.Column('rotation_policy_id', sa.Integer(), nullable=True)) - op.create_foreign_key(None, 'certificates', 'rotation_policies', ['rotation_policy_id'], ['id']) conn = op.get_bind() - stmt = text('insert into rotation_policies (days, name) values (:days, :name)') - stmt = stmt.bindparams(days=30, name='default') + stmt = text("insert into rotation_policies (days, name) values (:days, :name)") + stmt = stmt.bindparams(days=30, name="default") conn.execute(stmt) - stmt = text('select id from rotation_policies where name=:name') - stmt = stmt.bindparams(name='default') + stmt = text("select id from rotation_policies where name=:name") + stmt = stmt.bindparams(name="default") rotation_policy_id = conn.execute(stmt).fetchone()[0] - stmt = text('update certificates set rotation_policy_id=:rotation_policy_id') + stmt = text("update certificates set rotation_policy_id=:rotation_policy_id") stmt = stmt.bindparams(rotation_policy_id=rotation_policy_id) conn.execute(stmt) # ### end Alembic commands ### @@ -46,9 +51,17 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'certificates', type_='foreignkey') - op.drop_column('certificates', 'rotation_policy_id') - op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations') - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True) - op.drop_table('rotation_policies') + op.drop_constraint(None, "certificates", type_="foreignkey") + op.drop_column("certificates", "rotation_policy_id") + op.drop_index( + "certificate_replacement_associations_ix", + table_name="certificate_replacement_associations", + ) + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["replaced_certificate_id", "certificate_id"], + unique=True, + ) + op.drop_table("rotation_policies") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/ac483cfeb230_.py b/lemur/migrations/versions/ac483cfeb230_.py index d28a2599..d1e2361d 100644 --- a/lemur/migrations/versions/ac483cfeb230_.py +++ b/lemur/migrations/versions/ac483cfeb230_.py @@ -7,8 +7,8 @@ Create Date: 2017-10-11 10:16:39.682591 """ # revision identifiers, used by Alembic. -revision = 'ac483cfeb230' -down_revision = 'b29e2c4bf8c9' +revision = "ac483cfeb230" +down_revision = "b29e2c4bf8c9" from alembic import op import sqlalchemy as sa @@ -16,12 +16,18 @@ from sqlalchemy.dialects import postgresql def upgrade(): - op.alter_column('certificates', 'name', - existing_type=sa.VARCHAR(length=128), - type_=sa.String(length=256)) + op.alter_column( + "certificates", + "name", + existing_type=sa.VARCHAR(length=128), + type_=sa.String(length=256), + ) def downgrade(): - op.alter_column('certificates', 'name', - existing_type=sa.VARCHAR(length=256), - type_=sa.String(length=128)) + op.alter_column( + "certificates", + "name", + existing_type=sa.VARCHAR(length=256), + type_=sa.String(length=128), + ) diff --git a/lemur/migrations/versions/b29e2c4bf8c9_.py b/lemur/migrations/versions/b29e2c4bf8c9_.py index 19835e09..6f9dc526 100644 --- a/lemur/migrations/versions/b29e2c4bf8c9_.py +++ b/lemur/migrations/versions/b29e2c4bf8c9_.py @@ -7,8 +7,8 @@ Create Date: 2017-09-26 10:50:35.740367 """ # revision identifiers, used by Alembic. -revision = 'b29e2c4bf8c9' -down_revision = '1ae8e3104db8' +revision = "b29e2c4bf8c9" +down_revision = "1ae8e3104db8" from alembic import op import sqlalchemy as sa @@ -16,13 +16,25 @@ import sqlalchemy as sa def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('certificates', sa.Column('external_id', sa.String(128), nullable=True)) - op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['create_cert', 'key_view', 'revoke_cert', 'update_cert']) + op.add_column( + "certificates", sa.Column("external_id", sa.String(128), nullable=True) + ) + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "key_view", "update_cert"], + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'key_view', 'update_cert']) - op.drop_column('certificates', 'external_id') + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ["create_cert", "key_view", "update_cert"], + ) + op.drop_column("certificates", "external_id") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/b33c838cb669_.py b/lemur/migrations/versions/b33c838cb669_.py new file mode 100644 index 00000000..eb04d4a1 --- /dev/null +++ b/lemur/migrations/versions/b33c838cb669_.py @@ -0,0 +1,26 @@ +"""adding index on the not_after field + +Revision ID: b33c838cb669 +Revises: 318b66568358 +Create Date: 2019-05-30 08:42:05.294109 + +""" + +# revision identifiers, used by Alembic. +revision = 'b33c838cb669' +down_revision = '318b66568358' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_index('ix_certificates_not_after', 'certificates', [sa.text('not_after DESC')], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_certificates_not_after', table_name='certificates') + # ### end Alembic commands ### diff --git a/lemur/migrations/versions/c05a8998b371_.py b/lemur/migrations/versions/c05a8998b371_.py index cf600043..a5c9abff 100644 --- a/lemur/migrations/versions/c05a8998b371_.py +++ b/lemur/migrations/versions/c05a8998b371_.py @@ -7,25 +7,27 @@ Create Date: 2017-11-10 14:51:28.975927 """ # revision identifiers, used by Alembic. -revision = 'c05a8998b371' -down_revision = 'ac483cfeb230' +revision = "c05a8998b371" +down_revision = "ac483cfeb230" from alembic import op import sqlalchemy as sa import sqlalchemy_utils + def upgrade(): - op.create_table('api_keys', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=128), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('ttl', sa.BigInteger(), nullable=False), - sa.Column('issued_at', sa.BigInteger(), nullable=False), - sa.Column('revoked', sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "api_keys", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=128), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("ttl", sa.BigInteger(), nullable=False), + sa.Column("issued_at", sa.BigInteger(), nullable=False), + sa.Column("revoked", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.PrimaryKeyConstraint("id"), ) def downgrade(): - op.drop_table('api_keys') + op.drop_table("api_keys") diff --git a/lemur/migrations/versions/c87cb989af04_.py b/lemur/migrations/versions/c87cb989af04_.py index 4959e727..69f53bf4 100644 --- a/lemur/migrations/versions/c87cb989af04_.py +++ b/lemur/migrations/versions/c87cb989af04_.py @@ -5,15 +5,15 @@ Create Date: 2018-10-11 09:44:57.099854 """ -revision = 'c87cb989af04' -down_revision = '9392b9f9a805' +revision = "c87cb989af04" +down_revision = "9392b9f9a805" from alembic import op def upgrade(): - op.create_index(op.f('ix_domains_name'), 'domains', ['name'], unique=False) + op.create_index(op.f("ix_domains_name"), "domains", ["name"], unique=False) def downgrade(): - op.drop_index(op.f('ix_domains_name'), table_name='domains') + op.drop_index(op.f("ix_domains_name"), table_name="domains") diff --git a/lemur/migrations/versions/ce547319f7be_.py b/lemur/migrations/versions/ce547319f7be_.py index 41ef1fa8..d139c6fb 100644 --- a/lemur/migrations/versions/ce547319f7be_.py +++ b/lemur/migrations/versions/ce547319f7be_.py @@ -7,8 +7,8 @@ Create Date: 2018-02-23 11:00:02.150561 """ # revision identifiers, used by Alembic. -revision = 'ce547319f7be' -down_revision = '5bc47fa7cac4' +revision = "ce547319f7be" +down_revision = "5bc47fa7cac4" import sqlalchemy as sa @@ -24,12 +24,12 @@ TABLE = "certificate_notification_associations" def upgrade(): print("Adding id column") op.add_column( - TABLE, - sa.Column('id', sa.Integer, primary_key=True, autoincrement=True) + TABLE, sa.Column("id", sa.Integer, primary_key=True, autoincrement=True) ) db.session.commit() db.session.flush() + def downgrade(): op.drop_column(TABLE, "id") db.session.commit() diff --git a/lemur/migrations/versions/e3691fc396e9_.py b/lemur/migrations/versions/e3691fc396e9_.py index 1c5c2f15..0007b804 100644 --- a/lemur/migrations/versions/e3691fc396e9_.py +++ b/lemur/migrations/versions/e3691fc396e9_.py @@ -7,29 +7,36 @@ Create Date: 2016-11-28 13:15:46.995219 """ # revision identifiers, used by Alembic. -revision = 'e3691fc396e9' -down_revision = '932525b82f1a' +revision = "e3691fc396e9" +down_revision = "932525b82f1a" from alembic import op import sqlalchemy as sa import sqlalchemy_utils + def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('logs', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.Column('log_type', sa.Enum('key_view', name='log_type'), nullable=False), - sa.Column('logged_at', sqlalchemy_utils.types.arrow.ArrowType(), server_default=sa.text('now()'), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "logs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.Column("log_type", sa.Enum("key_view", name="log_type"), nullable=False), + sa.Column( + "logged_at", + sqlalchemy_utils.types.arrow.ArrowType(), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.PrimaryKeyConstraint("id"), ) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_table('logs') + op.drop_table("logs") ### end Alembic commands ### diff --git a/lemur/migrations/versions/ee827d1e1974_.py b/lemur/migrations/versions/ee827d1e1974_.py index 62ac6222..56696fe3 100644 --- a/lemur/migrations/versions/ee827d1e1974_.py +++ b/lemur/migrations/versions/ee827d1e1974_.py @@ -7,25 +7,44 @@ Create Date: 2018-11-05 09:49:40.226368 """ # revision identifiers, used by Alembic. -revision = 'ee827d1e1974' -down_revision = '7ead443ba911' +revision = "ee827d1e1974" +down_revision = "7ead443ba911" from alembic import op from sqlalchemy.exc import ProgrammingError + def upgrade(): connection = op.get_bind() connection.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") - op.create_index('ix_certificates_cn', 'certificates', ['cn'], unique=False, postgresql_ops={'cn': 'gin_trgm_ops'}, - postgresql_using='gin') - op.create_index('ix_certificates_name', 'certificates', ['name'], unique=False, - postgresql_ops={'name': 'gin_trgm_ops'}, postgresql_using='gin') - op.create_index('ix_domains_name_gin', 'domains', ['name'], unique=False, postgresql_ops={'name': 'gin_trgm_ops'}, - postgresql_using='gin') + op.create_index( + "ix_certificates_cn", + "certificates", + ["cn"], + unique=False, + postgresql_ops={"cn": "gin_trgm_ops"}, + postgresql_using="gin", + ) + op.create_index( + "ix_certificates_name", + "certificates", + ["name"], + unique=False, + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ) + op.create_index( + "ix_domains_name_gin", + "domains", + ["name"], + unique=False, + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ) def downgrade(): - op.drop_index('ix_domains_name', table_name='domains') - op.drop_index('ix_certificates_name', table_name='certificates') - op.drop_index('ix_certificates_cn', table_name='certificates') + op.drop_index("ix_domains_name", table_name="domains") + op.drop_index("ix_certificates_name", table_name="certificates") + op.drop_index("ix_certificates_cn", table_name="certificates") diff --git a/lemur/migrations/versions/f2383bf08fbc_.py b/lemur/migrations/versions/f2383bf08fbc_.py index 1fa36960..a54aa5d2 100644 --- a/lemur/migrations/versions/f2383bf08fbc_.py +++ b/lemur/migrations/versions/f2383bf08fbc_.py @@ -7,17 +7,22 @@ Create Date: 2018-10-11 11:23:31.195471 """ -revision = 'f2383bf08fbc' -down_revision = 'c87cb989af04' +revision = "f2383bf08fbc" +down_revision = "c87cb989af04" import sqlalchemy as sa from alembic import op def upgrade(): - op.create_index('ix_certificates_id_desc', 'certificates', [sa.text('id DESC')], unique=True, - postgresql_using='btree') + op.create_index( + "ix_certificates_id_desc", + "certificates", + [sa.text("id DESC")], + unique=True, + postgresql_using="btree", + ) def downgrade(): - op.drop_index('ix_certificates_id_desc', table_name='certificates') + op.drop_index("ix_certificates_id_desc", table_name="certificates") diff --git a/lemur/models.py b/lemur/models.py index 69f82360..163d156f 100644 --- a/lemur/models.py +++ b/lemur/models.py @@ -12,121 +12,201 @@ from sqlalchemy import Column, Integer, ForeignKey, Index, UniqueConstraint from lemur.database import db -certificate_associations = db.Table('certificate_associations', - Column('domain_id', Integer, ForeignKey('domains.id')), - Column('certificate_id', Integer, ForeignKey('certificates.id')) - ) +certificate_associations = db.Table( + "certificate_associations", + Column("domain_id", Integer, ForeignKey("domains.id")), + Column("certificate_id", Integer, ForeignKey("certificates.id")), +) -Index('certificate_associations_ix', certificate_associations.c.domain_id, certificate_associations.c.certificate_id) +Index( + "certificate_associations_ix", + certificate_associations.c.domain_id, + certificate_associations.c.certificate_id, +) -certificate_destination_associations = db.Table('certificate_destination_associations', - Column('destination_id', Integer, - ForeignKey('destinations.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')) - ) +certificate_destination_associations = db.Table( + "certificate_destination_associations", + Column( + "destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade") + ), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), +) -Index('certificate_destination_associations_ix', certificate_destination_associations.c.destination_id, certificate_destination_associations.c.certificate_id) +Index( + "certificate_destination_associations_ix", + certificate_destination_associations.c.destination_id, + certificate_destination_associations.c.certificate_id, +) -certificate_source_associations = db.Table('certificate_source_associations', - Column('source_id', Integer, - ForeignKey('sources.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')) - ) +certificate_source_associations = db.Table( + "certificate_source_associations", + Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), +) -Index('certificate_source_associations_ix', certificate_source_associations.c.source_id, certificate_source_associations.c.certificate_id) +Index( + "certificate_source_associations_ix", + certificate_source_associations.c.source_id, + certificate_source_associations.c.certificate_id, +) -certificate_notification_associations = db.Table('certificate_notification_associations', - Column('notification_id', Integer, - ForeignKey('notifications.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')), - Column('id', Integer, primary_key=True, autoincrement=True), - UniqueConstraint('notification_id', 'certificate_id', name='uq_dest_not_ids') - ) +certificate_notification_associations = db.Table( + "certificate_notification_associations", + Column( + "notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade") + ), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), + Column("id", Integer, primary_key=True, autoincrement=True), + UniqueConstraint("notification_id", "certificate_id", name="uq_dest_not_ids"), +) -Index('certificate_notification_associations_ix', certificate_notification_associations.c.notification_id, certificate_notification_associations.c.certificate_id) +Index( + "certificate_notification_associations_ix", + certificate_notification_associations.c.notification_id, + certificate_notification_associations.c.certificate_id, +) -certificate_replacement_associations = db.Table('certificate_replacement_associations', - Column('replaced_certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')) - ) +certificate_replacement_associations = db.Table( + "certificate_replacement_associations", + Column( + "replaced_certificate_id", + Integer, + ForeignKey("certificates.id", ondelete="cascade"), + ), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), +) -Index('certificate_replacement_associations_ix', certificate_replacement_associations.c.replaced_certificate_id, certificate_replacement_associations.c.certificate_id, unique=True) +Index( + "certificate_replacement_associations_ix", + certificate_replacement_associations.c.replaced_certificate_id, + certificate_replacement_associations.c.certificate_id, + unique=True, +) -roles_authorities = db.Table('roles_authorities', - Column('authority_id', Integer, ForeignKey('authorities.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +roles_authorities = db.Table( + "roles_authorities", + Column("authority_id", Integer, ForeignKey("authorities.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('roles_authorities_ix', roles_authorities.c.authority_id, roles_authorities.c.role_id) +Index( + "roles_authorities_ix", + roles_authorities.c.authority_id, + roles_authorities.c.role_id, +) -roles_certificates = db.Table('roles_certificates', - Column('certificate_id', Integer, ForeignKey('certificates.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +roles_certificates = db.Table( + "roles_certificates", + Column("certificate_id", Integer, ForeignKey("certificates.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('roles_certificates_ix', roles_certificates.c.certificate_id, roles_certificates.c.role_id) +Index( + "roles_certificates_ix", + roles_certificates.c.certificate_id, + roles_certificates.c.role_id, +) -roles_users = db.Table('roles_users', - Column('user_id', Integer, ForeignKey('users.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +roles_users = db.Table( + "roles_users", + Column("user_id", Integer, ForeignKey("users.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('roles_users_ix', roles_users.c.user_id, roles_users.c.role_id) +Index("roles_users_ix", roles_users.c.user_id, roles_users.c.role_id) -policies_ciphers = db.Table('policies_ciphers', - Column('cipher_id', Integer, ForeignKey('ciphers.id')), - Column('policy_id', Integer, ForeignKey('policy.id'))) +policies_ciphers = db.Table( + "policies_ciphers", + Column("cipher_id", Integer, ForeignKey("ciphers.id")), + Column("policy_id", Integer, ForeignKey("policy.id")), +) -Index('policies_ciphers_ix', policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id) +Index("policies_ciphers_ix", policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id) -pending_cert_destination_associations = db.Table('pending_cert_destination_associations', - Column('destination_id', Integer, - ForeignKey('destinations.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_destination_associations = db.Table( + "pending_cert_destination_associations", + Column( + "destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade") + ), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_destination_associations_ix', pending_cert_destination_associations.c.destination_id, pending_cert_destination_associations.c.pending_cert_id) +Index( + "pending_cert_destination_associations_ix", + pending_cert_destination_associations.c.destination_id, + pending_cert_destination_associations.c.pending_cert_id, +) -pending_cert_notification_associations = db.Table('pending_cert_notification_associations', - Column('notification_id', Integer, - ForeignKey('notifications.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_notification_associations = db.Table( + "pending_cert_notification_associations", + Column( + "notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade") + ), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_notification_associations_ix', pending_cert_notification_associations.c.notification_id, pending_cert_notification_associations.c.pending_cert_id) +Index( + "pending_cert_notification_associations_ix", + pending_cert_notification_associations.c.notification_id, + pending_cert_notification_associations.c.pending_cert_id, +) -pending_cert_source_associations = db.Table('pending_cert_source_associations', - Column('source_id', Integer, - ForeignKey('sources.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_source_associations = db.Table( + "pending_cert_source_associations", + Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_source_associations_ix', pending_cert_source_associations.c.source_id, pending_cert_source_associations.c.pending_cert_id) +Index( + "pending_cert_source_associations_ix", + pending_cert_source_associations.c.source_id, + pending_cert_source_associations.c.pending_cert_id, +) -pending_cert_replacement_associations = db.Table('pending_cert_replacement_associations', - Column('replaced_certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_replacement_associations = db.Table( + "pending_cert_replacement_associations", + Column( + "replaced_certificate_id", + Integer, + ForeignKey("certificates.id", ondelete="cascade"), + ), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_replacement_associations_ix', pending_cert_replacement_associations.c.replaced_certificate_id, pending_cert_replacement_associations.c.pending_cert_id) +Index( + "pending_cert_replacement_associations_ix", + pending_cert_replacement_associations.c.replaced_certificate_id, + pending_cert_replacement_associations.c.pending_cert_id, +) -pending_cert_role_associations = db.Table('pending_cert_role_associations', - Column('pending_cert_id', Integer, ForeignKey('pending_certs.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +pending_cert_role_associations = db.Table( + "pending_cert_role_associations", + Column("pending_cert_id", Integer, ForeignKey("pending_certs.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('pending_cert_role_associations_ix', pending_cert_role_associations.c.pending_cert_id, pending_cert_role_associations.c.role_id) +Index( + "pending_cert_role_associations_ix", + pending_cert_role_associations.c.pending_cert_id, + pending_cert_role_associations.c.role_id, +) diff --git a/lemur/notifications/cli.py b/lemur/notifications/cli.py index e3bf431e..a2848117 100644 --- a/lemur/notifications/cli.py +++ b/lemur/notifications/cli.py @@ -14,7 +14,14 @@ from lemur.notifications.messaging import send_expiration_notifications manager = Manager(usage="Handles notification related tasks.") -@manager.option('-e', '--exclude', dest='exclude', action='append', default=[], help='Common name matching of certificates that should be excluded from notification') +@manager.option( + "-e", + "--exclude", + dest="exclude", + action="append", + default=[], + help="Common name matching of certificates that should be excluded from notification", +) def expirations(exclude): """ Runs Lemur's notification engine, that looks for expired certificates and sends @@ -33,12 +40,13 @@ def expirations(exclude): success, failed = send_expiration_notifications(exclude) print( "Finished notifying subscribers about expiring certificates! Sent: {success} Failed: {failed}".format( - success=success, - failed=failed + success=success, failed=failed ) ) status = SUCCESS_METRIC_STATUS except Exception as e: sentry.captureException() - metrics.send('expiration_notification_job', 'counter', 1, metric_tags={'status': status}) + metrics.send( + "expiration_notification_job", "counter", 1, metric_tags={"status": status} + ) diff --git a/lemur/notifications/messaging.py b/lemur/notifications/messaging.py index cd88ebc8..82db7b6e 100644 --- a/lemur/notifications/messaging.py +++ b/lemur/notifications/messaging.py @@ -36,21 +36,23 @@ def get_certificates(exclude=None): now = arrow.utcnow() max = now + timedelta(days=90) - q = database.db.session.query(Certificate) \ - .filter(Certificate.not_after <= max) \ - .filter(Certificate.notify == True) \ - .filter(Certificate.expired == False) # noqa + q = ( + database.db.session.query(Certificate) + .filter(Certificate.not_after <= max) + .filter(Certificate.notify == True) + .filter(Certificate.expired == False) + ) # noqa exclude_conditions = [] if exclude: for e in exclude: - exclude_conditions.append(~Certificate.name.ilike('%{}%'.format(e))) + exclude_conditions.append(~Certificate.name.ilike("%{}%".format(e))) q = q.filter(and_(*exclude_conditions)) certs = [] - for c in windowed_query(q, Certificate.id, 100): + for c in windowed_query(q, Certificate.id, 10000): if needs_notification(c): certs.append(c) @@ -101,7 +103,12 @@ def send_notification(event_type, data, targets, notification): except Exception as e: sentry.captureException() - metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': event_type}) + metrics.send( + "notification", + "counter", + 1, + metric_tags={"status": status, "event_type": event_type}, + ) if status == SUCCESS_METRIC_STATUS: return True @@ -115,7 +122,7 @@ def send_expiration_notifications(exclude): success = failure = 0 # security team gets all - security_email = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') + security_email = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL") security_data = [] for owner, notification_group in get_eligible_certificates(exclude=exclude).items(): @@ -127,26 +134,43 @@ def send_expiration_notifications(exclude): for data in certificates: n, certificate = data - cert_data = certificate_notification_output_schema.dump(certificate).data + cert_data = certificate_notification_output_schema.dump( + certificate + ).data notification_data.append(cert_data) security_data.append(cert_data) - notification_recipient = get_plugin_option('recipients', notification.options) - if notification_recipient: - notification_recipient = notification_recipient.split(",") - - if send_notification('expiration', notification_data, [owner], notification): + if send_notification( + "expiration", notification_data, [owner], notification + ): success += 1 else: failure += 1 - if notification_recipient and owner != notification_recipient and security_email != notification_recipient: - if send_notification('expiration', notification_data, notification_recipient, notification): + notification_recipient = get_plugin_option( + "recipients", notification.options + ) + if notification_recipient: + notification_recipient = notification_recipient.split(",") + # removing owner and security_email from notification_recipient + notification_recipient = [i for i in notification_recipient if i not in security_email and i != owner] + + if ( + notification_recipient + ): + if send_notification( + "expiration", + notification_data, + notification_recipient, + notification, + ): success += 1 else: failure += 1 - if send_notification('expiration', security_data, security_email, notification): + if send_notification( + "expiration", security_data, security_email, notification + ): success += 1 else: failure += 1 @@ -165,24 +189,35 @@ def send_rotation_notification(certificate, notification_plugin=None): """ status = FAILURE_METRIC_STATUS if not notification_plugin: - notification_plugin = plugins.get(current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN')) + notification_plugin = plugins.get( + current_app.config.get("LEMUR_DEFAULT_NOTIFICATION_PLUGIN") + ) data = certificate_notification_output_schema.dump(certificate).data try: - notification_plugin.send('rotation', data, [data['owner']]) + notification_plugin.send("rotation", data, [data["owner"]]) status = SUCCESS_METRIC_STATUS except Exception as e: - current_app.logger.error('Unable to send notification to {}.'.format(data['owner']), exc_info=True) + current_app.logger.error( + "Unable to send notification to {}.".format(data["owner"]), exc_info=True + ) sentry.captureException() - metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'}) + metrics.send( + "notification", + "counter", + 1, + metric_tags={"status": status, "event_type": "rotation"}, + ) if status == SUCCESS_METRIC_STATUS: return True -def send_pending_failure_notification(pending_cert, notify_owner=True, notify_security=True, notification_plugin=None): +def send_pending_failure_notification( + pending_cert, notify_owner=True, notify_security=True, notification_plugin=None +): """ Sends a report to certificate owners when their pending certificate failed to be created. @@ -194,32 +229,47 @@ def send_pending_failure_notification(pending_cert, notify_owner=True, notify_se if not notification_plugin: notification_plugin = plugins.get( - current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN', 'email-notification') + current_app.config.get( + "LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification" + ) ) data = pending_certificate_output_schema.dump(pending_cert).data - data["security_email"] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') + data["security_email"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL") if notify_owner: try: - notification_plugin.send('failed', data, [data['owner']], pending_cert) + notification_plugin.send("failed", data, [data["owner"]], pending_cert) status = SUCCESS_METRIC_STATUS except Exception as e: - current_app.logger.error('Unable to send pending failure notification to {}.'.format(data['owner']), - exc_info=True) + current_app.logger.error( + "Unable to send pending failure notification to {}.".format( + data["owner"] + ), + exc_info=True, + ) sentry.captureException() if notify_security: try: - notification_plugin.send('failed', data, data["security_email"], pending_cert) + notification_plugin.send( + "failed", data, data["security_email"], pending_cert + ) status = SUCCESS_METRIC_STATUS except Exception as e: - current_app.logger.error('Unable to send pending failure notification to ' - '{}.'.format(data['security_email']), - exc_info=True) + current_app.logger.error( + "Unable to send pending failure notification to " + "{}.".format(data["security_email"]), + exc_info=True, + ) sentry.captureException() - metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'}) + metrics.send( + "notification", + "counter", + 1, + metric_tags={"status": status, "event_type": "rotation"}, + ) if status == SUCCESS_METRIC_STATUS: return True @@ -242,20 +292,22 @@ def needs_notification(certificate): if not notification.active or not notification.options: return - interval = get_plugin_option('interval', notification.options) - unit = get_plugin_option('unit', notification.options) + interval = get_plugin_option("interval", notification.options) + unit = get_plugin_option("unit", notification.options) - if unit == 'weeks': + if unit == "weeks": interval *= 7 - elif unit == 'months': + elif unit == "months": interval *= 30 - elif unit == 'days': # it's nice to be explicit about the base unit + elif unit == "days": # it's nice to be explicit about the base unit pass else: - raise Exception("Invalid base unit for expiration interval: {0}".format(unit)) + raise Exception( + "Invalid base unit for expiration interval: {0}".format(unit) + ) if days == interval: notifications.append(notification) diff --git a/lemur/notifications/models.py b/lemur/notifications/models.py index 87646b4c..7053b8d7 100644 --- a/lemur/notifications/models.py +++ b/lemur/notifications/models.py @@ -11,12 +11,14 @@ from sqlalchemy_utils import JSONType from lemur.database import db from lemur.plugins.base import plugins -from lemur.models import certificate_notification_associations, \ - pending_cert_notification_associations +from lemur.models import ( + certificate_notification_associations, + pending_cert_notification_associations, +) class Notification(db.Model): - __tablename__ = 'notifications' + __tablename__ = "notifications" id = Column(Integer, primary_key=True) label = Column(String(128), unique=True) description = Column(Text()) @@ -28,14 +30,14 @@ class Notification(db.Model): secondary=certificate_notification_associations, passive_deletes=True, backref="notification", - cascade='all,delete' + cascade="all,delete", ) pending_certificates = relationship( "PendingCertificate", secondary=pending_cert_notification_associations, passive_deletes=True, backref="notification", - cascade='all,delete' + cascade="all,delete", ) @property diff --git a/lemur/notifications/schemas.py b/lemur/notifications/schemas.py index b5d4e1e6..a3ff4c99 100644 --- a/lemur/notifications/schemas.py +++ b/lemur/notifications/schemas.py @@ -7,7 +7,11 @@ """ from marshmallow import fields, post_dump from lemur.common.schema import LemurInputSchema, LemurOutputSchema -from lemur.schemas import PluginInputSchema, PluginOutputSchema, AssociatedCertificateSchema +from lemur.schemas import ( + PluginInputSchema, + PluginOutputSchema, + AssociatedCertificateSchema, +) class NotificationInputSchema(LemurInputSchema): @@ -30,7 +34,7 @@ class NotificationOutputSchema(LemurOutputSchema): @post_dump def fill_object(self, data): if data: - data['plugin']['pluginOptions'] = data['options'] + data["plugin"]["pluginOptions"] = data["options"] return data diff --git a/lemur/notifications/service.py b/lemur/notifications/service.py index 957757bd..ac624d1c 100644 --- a/lemur/notifications/service.py +++ b/lemur/notifications/service.py @@ -31,26 +31,28 @@ def create_default_expiration_notifications(name, recipients, intervals=None): options = [ { - 'name': 'unit', - 'type': 'select', - 'required': True, - 'validation': '', - 'available': ['days', 'weeks', 'months'], - 'helpMessage': 'Interval unit', - 'value': 'days', + "name": "unit", + "type": "select", + "required": True, + "validation": "", + "available": ["days", "weeks", "months"], + "helpMessage": "Interval unit", + "value": "days", }, { - 'name': 'recipients', - 'type': 'str', - 'required': True, - 'validation': '^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$', - 'helpMessage': 'Comma delimited list of email addresses', - 'value': ','.join(recipients) + "name": "recipients", + "type": "str", + "required": True, + "validation": "^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$", + "helpMessage": "Comma delimited list of email addresses", + "value": ",".join(recipients), }, ] if intervals is None: - intervals = current_app.config.get("LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", [30, 15, 2]) + intervals = current_app.config.get( + "LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", [30, 15, 2] + ) notifications = [] for i in intervals: @@ -58,21 +60,25 @@ def create_default_expiration_notifications(name, recipients, intervals=None): if not n: inter = [ { - 'name': 'interval', - 'type': 'int', - 'required': True, - 'validation': '^\d+$', - 'helpMessage': 'Number of days to be alert before expiration.', - 'value': i, + "name": "interval", + "type": "int", + "required": True, + "validation": "^\d+$", + "helpMessage": "Number of days to be alert before expiration.", + "value": i, } ] inter.extend(options) n = create( label="{name}_{interval}_DAY".format(name=name, interval=i), - plugin_name=current_app.config.get("LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification"), + plugin_name=current_app.config.get( + "LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification" + ), options=list(inter), - description="Default {interval} day expiration notification".format(interval=i), - certificates=[] + description="Default {interval} day expiration notification".format( + interval=i + ), + certificates=[], ) notifications.append(n) @@ -91,7 +97,9 @@ def create(label, plugin_name, options, description, certificates): :rtype : Notification :return: """ - notification = Notification(label=label, options=options, plugin_name=plugin_name, description=description) + notification = Notification( + label=label, options=options, plugin_name=plugin_name, description=description + ) notification.certificates = certificates return database.create(notification) @@ -147,7 +155,7 @@ def get_by_label(label): :param label: :return: """ - return database.get(Notification, label, field='label') + return database.get(Notification, label, field="label") def get_all(): @@ -161,18 +169,20 @@ def get_all(): def render(args): - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if certificate_id: - query = database.session_query(Notification).join(Certificate, Notification.certificate) + query = database.session_query(Notification).join( + Certificate, Notification.certificate + ) query = query.filter(Certificate.id == certificate_id) else: query = database.session_query(Notification) if filt: - terms = filt.split(';') - if terms[0] == 'active': + terms = filt.split(";") + if terms[0] == "active": query = query.filter(Notification.active == truthiness(terms[1])) else: query = database.filter(query, Notification, terms) diff --git a/lemur/notifications/views.py b/lemur/notifications/views.py index 4a2d82a8..cdabb4d4 100644 --- a/lemur/notifications/views.py +++ b/lemur/notifications/views.py @@ -9,7 +9,11 @@ from flask import Blueprint from flask_restful import Api, reqparse, inputs from lemur.notifications import service -from lemur.notifications.schemas import notification_input_schema, notification_output_schema, notifications_output_schema +from lemur.notifications.schemas import ( + notification_input_schema, + notification_output_schema, + notifications_output_schema, +) from lemur.auth.service import AuthenticatedResource from lemur.common.utils import paginated_parser @@ -17,12 +21,13 @@ from lemur.common.utils import paginated_parser from lemur.common.schema import validate_schema -mod = Blueprint('notifications', __name__) +mod = Blueprint("notifications", __name__) api = Api(mod) class NotificationsList(AuthenticatedResource): """ Defines the 'notifications' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(NotificationsList, self).__init__() @@ -103,7 +108,7 @@ class NotificationsList(AuthenticatedResource): :statuscode 200: no error """ parser = paginated_parser.copy() - parser.add_argument('active', type=inputs.boolean, location='args') + parser.add_argument("active", type=inputs.boolean, location="args") args = parser.parse_args() return service.render(args) @@ -215,11 +220,11 @@ class NotificationsList(AuthenticatedResource): :statuscode 200: no error """ return service.create( - data['label'], - data['plugin']['slug'], - data['plugin']['plugin_options'], - data['description'], - data['certificates'] + data["label"], + data["plugin"]["slug"], + data["plugin"]["plugin_options"], + data["description"], + data["certificates"], ) @@ -334,20 +339,21 @@ class Notifications(AuthenticatedResource): """ return service.update( notification_id, - data['label'], - data['plugin']['plugin_options'], - data['description'], - data['active'], - data['certificates'] + data["label"], + data["plugin"]["plugin_options"], + data["description"], + data["active"], + data["certificates"], ) def delete(self, notification_id): service.delete(notification_id) - return {'result': True} + return {"result": True} class CertificateNotifications(AuthenticatedResource): """ Defines the 'certificate/', endpoint='notification') -api.add_resource(CertificateNotifications, '/certificates//notifications', - endpoint='certificateNotifications') +api.add_resource(NotificationsList, "/notifications", endpoint="notifications") +api.add_resource( + Notifications, "/notifications/", endpoint="notification" +) +api.add_resource( + CertificateNotifications, + "/certificates//notifications", + endpoint="certificateNotifications", +) diff --git a/lemur/pending_certificates/cli.py b/lemur/pending_certificates/cli.py index ccad8de5..2ff29f10 100644 --- a/lemur/pending_certificates/cli.py +++ b/lemur/pending_certificates/cli.py @@ -19,7 +19,9 @@ from lemur.plugins.base import plugins manager = Manager(usage="Handles pending certificate related tasks.") -@manager.option('-i', dest='ids', action='append', help='IDs of pending certificates to fetch') +@manager.option( + "-i", dest="ids", action="append", help="IDs of pending certificates to fetch" +) def fetch(ids): """ Attempt to get full certificate for each pending certificate listed. @@ -39,25 +41,18 @@ def fetch(ids): if real_cert: # If a real certificate was returned from issuer, then create it in Lemur and mark # the pending certificate as resolved - final_cert = pending_certificate_service.create_certificate(cert, real_cert, cert.user) - pending_certificate_service.update( - cert.id, - resolved=True - ) - pending_certificate_service.update( - cert.id, - resolved_cert_id=final_cert.id + final_cert = pending_certificate_service.create_certificate( + cert, real_cert, cert.user ) + pending_certificate_service.update(cert.id, resolved_cert_id=final_cert.id) + pending_certificate_service.update(cert.id, resolved=True) # add metrics to metrics extension new += 1 else: pending_certificate_service.increment_attempt(cert) failed += 1 print( - "[+] Certificates: New: {new} Failed: {failed}".format( - new=new, - failed=failed, - ) + "[+] Certificates: New: {new} Failed: {failed}".format(new=new, failed=failed) ) @@ -69,9 +64,7 @@ def fetch_all_acme(): certificates. """ - log_data = { - "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name) - } + log_data = {"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)} pending_certs = pending_certificate_service.get_unresolved_pending_certs() new = 0 failed = 0 @@ -81,7 +74,7 @@ def fetch_all_acme(): # We only care about certs using the acme-issuer plugin for cert in pending_certs: cert_authority = get_authority(cert.authority_id) - if cert_authority.plugin_name == 'acme-issuer': + if cert_authority.plugin_name == "acme-issuer": acme_certs.append(cert) else: wrong_issuer += 1 @@ -97,15 +90,13 @@ def fetch_all_acme(): if real_cert: # If a real certificate was returned from issuer, then create it in Lemur and mark # the pending certificate as resolved - final_cert = pending_certificate_service.create_certificate(pending_cert, real_cert, pending_cert.user) - pending_certificate_service.update( - pending_cert.id, - resolved=True + final_cert = pending_certificate_service.create_certificate( + pending_cert, real_cert, pending_cert.user ) pending_certificate_service.update( - pending_cert.id, - resolved_cert_id=final_cert.id + pending_cert.id, resolved_cert_id=final_cert.id ) + pending_certificate_service.update(pending_cert.id, resolved=True) # add metrics to metrics extension new += 1 else: @@ -118,17 +109,15 @@ def fetch_all_acme(): if pending_cert.number_attempts > 4: error_log["message"] = "Marking pending certificate as resolved" - send_pending_failure_notification(pending_cert, notify_owner=pending_cert.notify) - # Mark "resolved" as True - pending_certificate_service.update( - cert.id, - resolved=True + send_pending_failure_notification( + pending_cert, notify_owner=pending_cert.notify ) + # Mark "resolved" as True + pending_certificate_service.update(cert.id, resolved=True) else: pending_certificate_service.increment_attempt(pending_cert) pending_certificate_service.update( - cert.get("pending_cert").id, - status=str(cert.get("last_error")) + cert.get("pending_cert").id, status=str(cert.get("last_error")) ) current_app.logger.error(error_log) log_data["message"] = "Complete" @@ -138,8 +127,6 @@ def fetch_all_acme(): current_app.logger.debug(log_data) print( "[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format( - new=new, - failed=failed, - wrong_issuer=wrong_issuer + new=new, failed=failed, wrong_issuer=wrong_issuer ) ) diff --git a/lemur/pending_certificates/models.py b/lemur/pending_certificates/models.py index 7dc8e602..fa6be073 100644 --- a/lemur/pending_certificates/models.py +++ b/lemur/pending_certificates/models.py @@ -5,7 +5,16 @@ """ from datetime import datetime as dt -from sqlalchemy import Integer, ForeignKey, String, PassiveDefault, func, Column, Text, Boolean +from sqlalchemy import ( + Integer, + ForeignKey, + String, + PassiveDefault, + func, + Column, + Text, + Boolean, +) from sqlalchemy.orm import relationship from sqlalchemy_utils import JSONType from sqlalchemy_utils.types.arrow import ArrowType @@ -13,20 +22,28 @@ from sqlalchemy_utils.types.arrow import ArrowType from lemur.certificates.models import get_sequence from lemur.common import defaults, utils from lemur.database import db -from lemur.models import pending_cert_source_associations, \ - pending_cert_destination_associations, pending_cert_notification_associations, \ - pending_cert_replacement_associations, pending_cert_role_associations +from lemur.models import ( + pending_cert_source_associations, + pending_cert_destination_associations, + pending_cert_notification_associations, + pending_cert_replacement_associations, + pending_cert_role_associations, +) from lemur.utils import Vault def get_or_increase_name(name, serial): - certificates = PendingCertificate.query.filter(PendingCertificate.name.ilike('{0}%'.format(name))).all() + certificates = PendingCertificate.query.filter( + PendingCertificate.name.ilike("{0}%".format(name)) + ).all() if not certificates: return name - serial_name = '{0}-{1}'.format(name, hex(int(serial))[2:].upper()) - certificates = PendingCertificate.query.filter(PendingCertificate.name.ilike('{0}%'.format(serial_name))).all() + serial_name = "{0}-{1}".format(name, hex(int(serial))[2:].upper()) + certificates = PendingCertificate.query.filter( + PendingCertificate.name.ilike("{0}%".format(serial_name)) + ).all() if not certificates: return serial_name @@ -38,11 +55,11 @@ def get_or_increase_name(name, serial): if end: ends.append(end) - return '{0}-{1}'.format(root, max(ends) + 1) + return "{0}-{1}".format(root, max(ends) + 1) class PendingCertificate(db.Model): - __tablename__ = 'pending_certs' + __tablename__ = "pending_certs" id = Column(Integer, primary_key=True) external_id = Column(String(128)) owner = Column(String(128), nullable=False) @@ -60,69 +77,101 @@ class PendingCertificate(db.Model): private_key = Column(Vault, nullable=True) date_created = Column(ArrowType, PassiveDefault(func.now()), nullable=False) - dns_provider_id = Column(Integer, ForeignKey('dns_providers.id', ondelete="CASCADE")) + dns_provider_id = Column( + Integer, ForeignKey("dns_providers.id", ondelete="CASCADE") + ) status = Column(Text(), nullable=True) - last_updated = Column(ArrowType, PassiveDefault(func.now()), onupdate=func.now(), nullable=False) + last_updated = Column( + ArrowType, PassiveDefault(func.now()), onupdate=func.now(), nullable=False + ) rotation = Column(Boolean, default=False) - user_id = Column(Integer, ForeignKey('users.id')) - authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - root_authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - rotation_policy_id = Column(Integer, ForeignKey('rotation_policies.id')) + user_id = Column(Integer, ForeignKey("users.id")) + authority_id = Column(Integer, ForeignKey("authorities.id", ondelete="CASCADE")) + root_authority_id = Column( + Integer, ForeignKey("authorities.id", ondelete="CASCADE") + ) + rotation_policy_id = Column(Integer, ForeignKey("rotation_policies.id")) - notifications = relationship('Notification', secondary=pending_cert_notification_associations, - backref='pending_cert', passive_deletes=True) - destinations = relationship('Destination', secondary=pending_cert_destination_associations, backref='pending_cert', - passive_deletes=True) - sources = relationship('Source', secondary=pending_cert_source_associations, backref='pending_cert', - passive_deletes=True) - roles = relationship('Role', secondary=pending_cert_role_associations, backref='pending_cert', passive_deletes=True) - replaces = relationship('Certificate', - secondary=pending_cert_replacement_associations, - backref='pending_cert', - passive_deletes=True) + notifications = relationship( + "Notification", + secondary=pending_cert_notification_associations, + backref="pending_cert", + passive_deletes=True, + ) + destinations = relationship( + "Destination", + secondary=pending_cert_destination_associations, + backref="pending_cert", + passive_deletes=True, + ) + sources = relationship( + "Source", + secondary=pending_cert_source_associations, + backref="pending_cert", + passive_deletes=True, + ) + roles = relationship( + "Role", + secondary=pending_cert_role_associations, + backref="pending_cert", + passive_deletes=True, + ) + replaces = relationship( + "Certificate", + secondary=pending_cert_replacement_associations, + backref="pending_cert", + passive_deletes=True, + ) options = Column(JSONType) rotation_policy = relationship("RotationPolicy") - sensitive_fields = ('private_key',) + sensitive_fields = ("private_key",) def __init__(self, **kwargs): - self.csr = kwargs.get('csr') - self.private_key = kwargs.get('private_key', "") + self.csr = kwargs.get("csr") + self.private_key = kwargs.get("private_key", "") if self.private_key: # If the request does not send private key, the key exists but the value is None self.private_key = self.private_key.strip() - self.external_id = kwargs.get('external_id') + self.external_id = kwargs.get("external_id") # when destinations are appended they require a valid name. - if kwargs.get('name'): - self.name = get_or_increase_name(defaults.text_to_slug(kwargs['name']), 0) + if kwargs.get("name"): + self.name = get_or_increase_name(defaults.text_to_slug(kwargs["name"]), 0) self.rename = False else: # TODO: Fix auto-generated name, it should be renamed on creation self.name = get_or_increase_name( - defaults.certificate_name(kwargs['common_name'], kwargs['authority'].name, - dt.now(), dt.now(), False), self.external_id) + defaults.certificate_name( + kwargs["common_name"], + kwargs["authority"].name, + dt.now(), + dt.now(), + False, + ), + self.external_id, + ) self.rename = True self.cn = defaults.common_name(utils.parse_csr(self.csr)) - self.owner = kwargs['owner'] + self.owner = kwargs["owner"] self.number_attempts = 0 - if kwargs.get('chain'): - self.chain = kwargs['chain'].strip() + if kwargs.get("chain"): + self.chain = kwargs["chain"].strip() - self.notify = kwargs.get('notify', True) - self.destinations = kwargs.get('destinations', []) - self.notifications = kwargs.get('notifications', []) - self.description = kwargs.get('description') - self.roles = list(set(kwargs.get('roles', []))) - self.replaces = kwargs.get('replaces', []) - self.rotation = kwargs.get('rotation') - self.rotation_policy = kwargs.get('rotation_policy') + self.notify = kwargs.get("notify", True) + self.destinations = kwargs.get("destinations", []) + self.notifications = kwargs.get("notifications", []) + self.description = kwargs.get("description") + self.roles = list(set(kwargs.get("roles", []))) + self.replaces = kwargs.get("replaces", []) + self.rotation = kwargs.get("rotation") + self.rotation_policy = kwargs.get("rotation_policy") try: - self.dns_provider_id = kwargs.get('dns_provider').id + self.dns_provider_id = kwargs.get("dns_provider").id except (AttributeError, KeyError, TypeError, Exception): pass diff --git a/lemur/pending_certificates/schemas.py b/lemur/pending_certificates/schemas.py index fbc94f4e..68f22b4a 100644 --- a/lemur/pending_certificates/schemas.py +++ b/lemur/pending_certificates/schemas.py @@ -1,5 +1,7 @@ -from marshmallow import fields, post_load +from marshmallow import fields, validates_schema, post_load +from marshmallow.exceptions import ValidationError +from lemur.common import utils, validators from lemur.authorities.schemas import AuthorityNestedOutputSchema from lemur.certificates.schemas import CertificateNestedOutputSchema from lemur.common.schema import LemurInputSchema, LemurOutputSchema @@ -15,14 +17,14 @@ from lemur.schemas import ( AssociatedNotificationSchema, AssociatedRoleSchema, EndpointNestedOutputSchema, - ExtensionSchema + ExtensionSchema, ) from lemur.users.schemas import UserNestedOutputSchema class PendingCertificateSchema(LemurInputSchema): owner = fields.Email(required=True) - description = fields.String(missing='', allow_none=True) + description = fields.String(missing="", allow_none=True) class PendingCertificateOutputSchema(LemurOutputSchema): @@ -44,10 +46,10 @@ class PendingCertificateOutputSchema(LemurOutputSchema): # Note aliasing is the first step in deprecating these fields. notify = fields.Boolean() - active = fields.Boolean(attribute='notify') + active = fields.Boolean(attribute="notify") cn = fields.String() - common_name = fields.String(attribute='cn') + common_name = fields.String(attribute="cn") owner = fields.Email() @@ -64,7 +66,9 @@ class PendingCertificateOutputSchema(LemurOutputSchema): authority = fields.Nested(AuthorityNestedOutputSchema) roles = fields.Nested(RoleNestedOutputSchema, many=True) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) - replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') + replaced_by = fields.Nested( + CertificateNestedOutputSchema, many=True, attribute="replaced" + ) rotation_policy = fields.Nested(RotationPolicyNestedOutputSchema) @@ -87,10 +91,15 @@ class PendingCertificateEditInputSchema(PendingCertificateSchema): :param data: :return: """ - if data['owner']: - notification_name = "DEFAULT_{0}".format(data['owner'].split('@')[0].upper()) - data['notifications'] += notification_service.create_default_expiration_notifications(notification_name, - [data['owner']]) + if data["owner"]: + notification_name = "DEFAULT_{0}".format( + data["owner"].split("@")[0].upper() + ) + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + notification_name, [data["owner"]] + ) return data @@ -98,6 +107,35 @@ class PendingCertificateCancelSchema(LemurInputSchema): note = fields.String() +class PendingCertificateUploadInputSchema(LemurInputSchema): + external_id = fields.String(missing=None, allow_none=True) + body = fields.String(required=True) + chain = fields.String(missing=None, allow_none=True) + + @validates_schema + def validate_cert_chain(self, data): + cert = None + if data.get("body"): + try: + cert = utils.parse_certificate(data["body"]) + except ValueError: + raise ValidationError( + "Public certificate presented is not valid.", field_names=["body"] + ) + + if data.get("chain"): + try: + chain = utils.parse_cert_chain(data["chain"]) + except ValueError: + raise ValidationError( + "Invalid certificate in certificate chain.", field_names=["chain"] + ) + + # Throws ValidationError + validators.verify_cert_chain([cert] + chain) + + pending_certificate_output_schema = PendingCertificateOutputSchema() pending_certificate_edit_input_schema = PendingCertificateEditInputSchema() pending_certificate_cancel_schema = PendingCertificateCancelSchema() +pending_certificate_upload_input_schema = PendingCertificateUploadInputSchema() diff --git a/lemur/pending_certificates/service.py b/lemur/pending_certificates/service.py index 405b2c4b..8b4d033c 100644 --- a/lemur/pending_certificates/service.py +++ b/lemur/pending_certificates/service.py @@ -8,9 +8,11 @@ from sqlalchemy import or_, cast, Integer from lemur import database from lemur.authorities.models import Authority +from lemur.authorities import service as authorities_service from lemur.certificates import service as certificate_service from lemur.certificates.schemas import CertificateUploadInputSchema -from lemur.common.utils import truthiness +from lemur.common.utils import truthiness, parse_cert_chain, parse_certificate +from lemur.common import validators from lemur.destinations.models import Destination from lemur.domains.models import Domain from lemur.notifications.models import Notification @@ -38,17 +40,18 @@ def get_by_external_id(issuer, external_id): """ if isinstance(external_id, int): external_id = str(external_id) - return PendingCertificate.query \ - .filter(PendingCertificate.authority_id == issuer.id) \ - .filter(PendingCertificate.external_id == external_id) \ + return ( + PendingCertificate.query.filter(PendingCertificate.authority_id == issuer.id) + .filter(PendingCertificate.external_id == external_id) .one_or_none() + ) def get_by_name(pending_cert_name): """ Retrieve pending certificate by name """ - return database.get(PendingCertificate, pending_cert_name, field='name') + return database.get(PendingCertificate, pending_cert_name, field="name") def delete(pending_certificate): @@ -64,7 +67,9 @@ def get_unresolved_pending_certs(): Retrieve a list of unresolved pending certs given a list of ids Filters out non-existing pending certs """ - query = database.session_query(PendingCertificate).filter(PendingCertificate.resolved.is_(False)) + query = database.session_query(PendingCertificate).filter( + PendingCertificate.resolved.is_(False) + ) return database.find_all(query, PendingCertificate, {}).all() @@ -74,7 +79,7 @@ def get_pending_certs(pending_ids): Filters out non-existing pending certs """ pending_certs = [] - if 'all' in pending_ids: + if "all" in pending_ids: query = database.session_query(PendingCertificate) return database.find_all(query, PendingCertificate, {}).all() else: @@ -94,23 +99,25 @@ def create_certificate(pending_certificate, certificate, user): user: User that called this function, used as 'creator' of the certificate if it does not have an owner """ - certificate['owner'] = pending_certificate.owner + certificate["owner"] = pending_certificate.owner data, errors = CertificateUploadInputSchema().load(certificate) if errors: - raise Exception("Unable to create certificate: {reasons}".format(reasons=errors)) + raise Exception( + "Unable to create certificate: {reasons}".format(reasons=errors) + ) data.update(vars(pending_certificate)) # Copy relationships, vars doesn't copy this without explicit fields - data['notifications'] = list(pending_certificate.notifications) - data['destinations'] = list(pending_certificate.destinations) - data['sources'] = list(pending_certificate.sources) - data['roles'] = list(pending_certificate.roles) - data['replaces'] = list(pending_certificate.replaces) - data['rotation_policy'] = pending_certificate.rotation_policy + data["notifications"] = list(pending_certificate.notifications) + data["destinations"] = list(pending_certificate.destinations) + data["sources"] = list(pending_certificate.sources) + data["roles"] = list(pending_certificate.roles) + data["replaces"] = list(pending_certificate.replaces) + data["rotation_policy"] = pending_certificate.rotation_policy # Replace external id and chain with the one fetched from source - data['external_id'] = certificate['external_id'] - data['chain'] = certificate['chain'] + data["external_id"] = certificate["external_id"] + data["chain"] = certificate["chain"] creator = user_service.get_by_email(pending_certificate.owner) if not creator: # Owner of the pending certificate is not the creator, so use the current user who called @@ -119,8 +126,8 @@ def create_certificate(pending_certificate, certificate, user): if pending_certificate.rename: # If generating name from certificate, remove the one from pending certificate - del data['name'] - data['creator'] = creator + del data["name"] + data["creator"] = creator cert = certificate_service.import_certificate(**data) database.update(cert) @@ -157,76 +164,125 @@ def cancel(pending_certificate, **kwargs): """ plugin = plugins.get(pending_certificate.authority.plugin_name) plugin.cancel_ordered_certificate(pending_certificate, **kwargs) - pending_certificate.status = 'Cancelled' + pending_certificate.status = "Cancelled" database.update(pending_certificate) return pending_certificate def render(args): query = database.session_query(PendingCertificate) - time_range = args.pop('time_range') - destination_id = args.pop('destination_id') - notification_id = args.pop('notification_id', None) - show = args.pop('show') + time_range = args.pop("time_range") + destination_id = args.pop("destination_id") + notification_id = args.pop("notification_id", None) + show = args.pop("show") # owner = args.pop('owner') # creator = args.pop('creator') # TODO we should enabling filtering by owner - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') + terms = filt.split(";") - if 'issuer' in terms: + if "issuer" in terms: # we can't rely on issuer being correct in the cert directly so we combine queries - sub_query = database.session_query(Authority.id) \ - .filter(Authority.name.ilike('%{0}%'.format(terms[1]))) \ + sub_query = ( + database.session_query(Authority.id) + .filter(Authority.name.ilike("%{0}%".format(terms[1]))) .subquery() + ) query = query.filter( or_( - PendingCertificate.issuer.ilike('%{0}%'.format(terms[1])), - PendingCertificate.authority_id.in_(sub_query) + PendingCertificate.issuer.ilike("%{0}%".format(terms[1])), + PendingCertificate.authority_id.in_(sub_query), ) ) - elif 'destination' in terms: - query = query.filter(PendingCertificate.destinations.any(Destination.id == terms[1])) - elif 'notify' in filt: + elif "destination" in terms: + query = query.filter( + PendingCertificate.destinations.any(Destination.id == terms[1]) + ) + elif "notify" in filt: query = query.filter(PendingCertificate.notify == truthiness(terms[1])) - elif 'active' in filt: + elif "active" in filt: query = query.filter(PendingCertificate.active == truthiness(terms[1])) - elif 'cn' in terms: + elif "cn" in terms: query = query.filter( or_( - PendingCertificate.cn.ilike('%{0}%'.format(terms[1])), - PendingCertificate.domains.any(Domain.name.ilike('%{0}%'.format(terms[1]))) + PendingCertificate.cn.ilike("%{0}%".format(terms[1])), + PendingCertificate.domains.any( + Domain.name.ilike("%{0}%".format(terms[1])) + ), ) ) - elif 'id' in terms: + elif "id" in terms: query = query.filter(PendingCertificate.id == cast(terms[1], Integer)) else: query = database.filter(query, PendingCertificate, terms) if show: - sub_query = database.session_query(Role.name).filter(Role.user_id == args['user'].id).subquery() + sub_query = ( + database.session_query(Role.name) + .filter(Role.user_id == args["user"].id) + .subquery() + ) query = query.filter( or_( - PendingCertificate.user_id == args['user'].id, - PendingCertificate.owner.in_(sub_query) + PendingCertificate.user_id == args["user"].id, + PendingCertificate.owner.in_(sub_query), ) ) if destination_id: - query = query.filter(PendingCertificate.destinations.any(Destination.id == destination_id)) + query = query.filter( + PendingCertificate.destinations.any(Destination.id == destination_id) + ) if notification_id: - query = query.filter(PendingCertificate.notifications.any(Notification.id == notification_id)) + query = query.filter( + PendingCertificate.notifications.any(Notification.id == notification_id) + ) if time_range: - to = arrow.now().replace(weeks=+time_range).format('YYYY-MM-DD') - now = arrow.now().format('YYYY-MM-DD') - query = query.filter(PendingCertificate.not_after <= to).filter(PendingCertificate.not_after >= now) + to = arrow.now().shift(weeks=+time_range).format("YYYY-MM-DD") + now = arrow.now().format("YYYY-MM-DD") + query = query.filter(PendingCertificate.not_after <= to).filter( + PendingCertificate.not_after >= now + ) # Only show unresolved certificates in the UI query = query.filter(PendingCertificate.resolved.is_(False)) return database.sort_and_page(query, PendingCertificate, args) + + +def upload(pending_certificate_id, **kwargs): + """ + Uploads a (signed) pending certificate. The allowed fields are validated by + PendingCertificateUploadInputSchema. The certificate is also validated to be + signed by the correct authoritity. + """ + pending_cert = get(pending_certificate_id) + partial_cert = kwargs + uploaded_chain = partial_cert["chain"] + + authority = authorities_service.get(pending_cert.authority.id) + + # Construct the chain for cert validation + if uploaded_chain: + chain = uploaded_chain + "\n" + authority.authority_certificate.body + else: + chain = authority.authority_certificate.body + + parsed_chain = parse_cert_chain(chain) + + # Check that the certificate is actually signed by the CA to avoid incorrect cert pasting + validators.verify_cert_chain( + [parse_certificate(partial_cert["body"])] + parsed_chain + ) + + final_cert = create_certificate(pending_cert, partial_cert, pending_cert.user) + + pending_cert_final_result = update(pending_cert.id, resolved_cert_id=final_cert.id) + update(pending_cert.id, resolved=True) + + return pending_cert_final_result diff --git a/lemur/pending_certificates/views.py b/lemur/pending_certificates/views.py index 13598040..4651aed7 100644 --- a/lemur/pending_certificates/views.py +++ b/lemur/pending_certificates/views.py @@ -20,9 +20,10 @@ from lemur.pending_certificates.schemas import ( pending_certificate_output_schema, pending_certificate_edit_input_schema, pending_certificate_cancel_schema, + pending_certificate_upload_input_schema, ) -mod = Blueprint('pending_certificates', __name__) +mod = Blueprint("pending_certificates", __name__) api = Api(mod) @@ -109,15 +110,17 @@ class PendingCertificatesList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('timeRange', type=int, dest='time_range', location='args') - parser.add_argument('owner', type=inputs.boolean, location='args') - parser.add_argument('id', type=str, location='args') - parser.add_argument('active', type=inputs.boolean, location='args') - parser.add_argument('destinationId', type=int, dest="destination_id", location='args') - parser.add_argument('creator', type=str, location='args') - parser.add_argument('show', type=str, location='args') + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") args = parser.parse_args() - args['user'] = g.user + args["user"] = g.user return service.render(args) @@ -205,7 +208,9 @@ class PendingCertificates(AuthenticatedResource): """ return service.get(pending_certificate_id) - @validate_schema(pending_certificate_edit_input_schema, pending_certificate_output_schema) + @validate_schema( + pending_certificate_edit_input_schema, pending_certificate_output_schema + ) def put(self, pending_certificate_id, data=None): """ .. http:put:: /pending_certificates/1 @@ -296,19 +301,27 @@ class PendingCertificates(AuthenticatedResource): # allow creators if g.current_user != pending_cert.user: owner_role = role_service.get_by_name(pending_cert.owner) - permission = CertificatePermission(owner_role, [x.name for x in pending_cert.roles]) + permission = CertificatePermission( + owner_role, [x.name for x in pending_cert.roles] + ) if not permission.can(): - return dict(message='You are not authorized to update this certificate'), 403 + return ( + dict(message="You are not authorized to update this certificate"), + 403, + ) - for destination in data['destinations']: + for destination in data["destinations"]: if destination.plugin.requires_key: if not pending_cert.private_key: - return dict( - message='Unable to add destination: {0}. Certificate does not have required private key.'.format( - destination.label - ) - ), 400 + return ( + dict( + message="Unable to add destination: {0}. Certificate does not have required private key.".format( + destination.label + ) + ), + 400, + ) pending_cert = service.update(pending_certificate_id, **data) return pending_cert @@ -353,18 +366,28 @@ class PendingCertificates(AuthenticatedResource): # allow creators if g.current_user != pending_cert.user: owner_role = role_service.get_by_name(pending_cert.owner) - permission = CertificatePermission(owner_role, [x.name for x in pending_cert.roles]) + permission = CertificatePermission( + owner_role, [x.name for x in pending_cert.roles] + ) if not permission.can(): - return dict(message='You are not authorized to update this certificate'), 403 + return ( + dict(message="You are not authorized to update this certificate"), + 403, + ) if service.cancel(pending_cert, **data): service.delete(pending_cert) - return('', 204) + return ("", 204) else: # service.cancel raises exception if there was an issue, but this will ensure something # is relayed to user in case of something unexpected (unsuccessful update somehow). - return dict(message="Unexpected error occurred while trying to cancel this certificate"), 500 + return ( + dict( + message="Unexpected error occurred while trying to cancel this certificate" + ), + 500, + ) class PendingCertificatePrivateKey(AuthenticatedResource): @@ -411,14 +434,125 @@ class PendingCertificatePrivateKey(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to view this key'), 403 + return dict(message="You are not authorized to view this key"), 403 response = make_response(jsonify(key=cert.private_key), 200) - response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store' - response.headers['pragma'] = 'no-cache' + response.headers["cache-control"] = "private, max-age=0, no-cache, no-store" + response.headers["pragma"] = "no-cache" return response -api.add_resource(PendingCertificatesList, '/pending_certificates', endpoint='pending_certificates') -api.add_resource(PendingCertificates, '/pending_certificates/', endpoint='pending_certificate') -api.add_resource(PendingCertificatePrivateKey, '/pending_certificates//key', endpoint='privateKeyPendingCertificates') +class PendingCertificatesUpload(AuthenticatedResource): + """ Defines the 'pending_certificates' upload endpoint """ + + def __init__(self): + self.reqparse = reqparse.RequestParser() + super(PendingCertificatesUpload, self).__init__() + + @validate_schema( + pending_certificate_upload_input_schema, pending_certificate_output_schema + ) + def post(self, pending_certificate_id, data=None): + """ + .. http:post:: /pending_certificates/1/upload + + Upload the body for a (signed) pending_certificate + + **Example request**: + + .. sourcecode:: http + + POST /certificates/1/upload HTTP/1.1 + Host: example.com + Accept: application/json, text/javascript + + { + "body": "-----BEGIN CERTIFICATE-----...", + "chain": "-----BEGIN CERTIFICATE-----...", + } + + **Example response**: + + .. sourcecode:: http + + HTTP/1.1 200 OK + Vary: Accept + Content-Type: text/javascript + + { + "status": null, + "cn": "*.test.example.net", + "chain": "", + "authority": { + "active": true, + "owner": "secure@example.com", + "id": 1, + "description": "verisign test authority", + "name": "verisign" + }, + "owner": "joe@example.com", + "serial": "82311058732025924142789179368889309156", + "id": 2288, + "issuer": "SymantecCorporation", + "dateCreated": "2016-06-03T06:09:42.133769+00:00", + "notBefore": "2016-06-03T00:00:00+00:00", + "notAfter": "2018-01-12T23:59:59+00:00", + "destinations": [], + "bits": 2048, + "body": "-----BEGIN CERTIFICATE-----...", + "description": null, + "deleted": null, + "notifications": [{ + "id": 1 + }], + "signingAlgorithm": "sha256", + "user": { + "username": "jane", + "active": true, + "email": "jane@example.com", + "id": 2 + }, + "active": true, + "domains": [{ + "sensitive": false, + "id": 1090, + "name": "*.test.example.net" + }], + "replaces": [], + "rotation": true, + "rotationPolicy": {"name": "default"}, + "name": "WILDCARD.test.example.net-SymantecCorporation-20160603-20180112", + "roles": [{ + "id": 464, + "description": "This is a google group based role created by Lemur", + "name": "joe@example.com" + }], + "san": null + } + + :reqheader Authorization: OAuth token to authenticate + :statuscode 403: unauthenticated + :statuscode 200: no error + + """ + return service.upload(pending_certificate_id, **data) + + +api.add_resource( + PendingCertificatesList, "/pending_certificates", endpoint="pending_certificates" +) +api.add_resource( + PendingCertificates, + "/pending_certificates/", + endpoint="pending_certificate", +) +api.add_resource( + PendingCertificatesUpload, + "/pending_certificates//upload", + endpoint="pendingCertificateUpload", +) +api.add_resource( + PendingCertificatePrivateKey, + "/pending_certificates//key", + endpoint="privateKeyPendingCertificates", +) diff --git a/lemur/plugins/base/manager.py b/lemur/plugins/base/manager.py index a2306445..117700a6 100644 --- a/lemur/plugins/base/manager.py +++ b/lemur/plugins/base/manager.py @@ -18,7 +18,9 @@ class PluginManager(InstanceManager): return sum(1 for i in self.all()) def all(self, version=1, plugin_type=None): - for plugin in sorted(super(PluginManager, self).all(), key=lambda x: x.get_title()): + for plugin in sorted( + super(PluginManager, self).all(), key=lambda x: x.get_title() + ): if not plugin.type == plugin_type and plugin_type: continue if not plugin.is_enabled(): @@ -36,29 +38,34 @@ class PluginManager(InstanceManager): return plugin current_app.logger.error( "Unable to find slug: {} in self.all version 1: {} or version 2: {}".format( - slug, self.all(version=1), self.all(version=2)) + slug, self.all(version=1), self.all(version=2) + ) ) raise KeyError(slug) def first(self, func_name, *args, **kwargs): - version = kwargs.pop('version', 1) + version = kwargs.pop("version", 1) for plugin in self.all(version=version): try: result = getattr(plugin, func_name)(*args, **kwargs) except Exception as e: - current_app.logger.error('Error processing %s() on %r: %s', func_name, plugin.__class__, e, extra={ - 'func_arg': args, - 'func_kwargs': kwargs, - }, exc_info=True) + current_app.logger.error( + "Error processing %s() on %r: %s", + func_name, + plugin.__class__, + e, + extra={"func_arg": args, "func_kwargs": kwargs}, + exc_info=True, + ) continue if result is not None: return result def register(self, cls): - self.add('%s.%s' % (cls.__module__, cls.__name__)) + self.add("%s.%s" % (cls.__module__, cls.__name__)) return cls def unregister(self, cls): - self.remove('%s.%s' % (cls.__module__, cls.__name__)) + self.remove("%s.%s" % (cls.__module__, cls.__name__)) return cls diff --git a/lemur/plugins/base/v1.py b/lemur/plugins/base/v1.py index fb688c73..664385b3 100644 --- a/lemur/plugins/base/v1.py +++ b/lemur/plugins/base/v1.py @@ -18,7 +18,7 @@ class PluginMount(type): if new_cls.title is None: new_cls.title = new_cls.__name__ if not new_cls.slug: - new_cls.slug = new_cls.title.replace(' ', '-').lower() + new_cls.slug = new_cls.title.replace(" ", "-").lower() return new_cls @@ -36,6 +36,7 @@ class IPlugin(local): As a general rule all inherited methods should allow ``**kwargs`` to ensure ease of future compatibility. """ + # Generic plugin information title = None slug = None @@ -72,7 +73,7 @@ class IPlugin(local): Returns a string representing the configuration keyspace prefix for this plugin. """ if not self.conf_key: - self.conf_key = self.get_conf_title().lower().replace(' ', '_') + self.conf_key = self.get_conf_title().lower().replace(" ", "_") return self.conf_key def get_conf_title(self): @@ -111,8 +112,8 @@ class IPlugin(local): @staticmethod def get_option(name, options): for o in options: - if o.get('name') == name: - return o.get('value', o.get('default')) + if o.get("name") == name: + return o.get("value", o.get("default")) class Plugin(IPlugin): @@ -121,5 +122,6 @@ class Plugin(IPlugin): control when or how the plugin gets instantiated, nor is it guaranteed that it will happen, or happen more than once. """ + __version__ = 1 __metaclass__ = PluginMount diff --git a/lemur/plugins/bases/destination.py b/lemur/plugins/bases/destination.py index 1e7e4ed2..e00c5090 100644 --- a/lemur/plugins/bases/destination.py +++ b/lemur/plugins/bases/destination.py @@ -10,8 +10,10 @@ from lemur.plugins.base import Plugin, plugins class DestinationPlugin(Plugin): - type = 'destination' + type = "destination" requires_key = True + sync_as_source = False + sync_as_source_name = "" def upload(self, name, body, private_key, cert_chain, options, **kwargs): raise NotImplementedError @@ -20,10 +22,10 @@ class DestinationPlugin(Plugin): class ExportDestinationPlugin(DestinationPlugin): default_options = [ { - 'name': 'exportPlugin', - 'type': 'export-plugin', - 'required': True, - 'helpMessage': 'Export plugin to use before sending data to destination.' + "name": "exportPlugin", + "type": "export-plugin", + "required": True, + "helpMessage": "Export plugin to use before sending data to destination.", } ] @@ -32,15 +34,17 @@ class ExportDestinationPlugin(DestinationPlugin): return self.default_options + self.additional_options def export(self, body, private_key, cert_chain, options): - export_plugin = self.get_option('exportPlugin', options) + export_plugin = self.get_option("exportPlugin", options) if export_plugin: - plugin = plugins.get(export_plugin['slug']) - extension, passphrase, data = plugin.export(body, cert_chain, private_key, export_plugin['plugin_options']) + plugin = plugins.get(export_plugin["slug"]) + extension, passphrase, data = plugin.export( + body, cert_chain, private_key, export_plugin["plugin_options"] + ) return [(extension, passphrase, data)] - data = body + '\n' + cert_chain + '\n' + private_key - return [('.pem', '', data)] + data = body + "\n" + cert_chain + "\n" + private_key + return [(".pem", "", data)] def upload(self, name, body, private_key, cert_chain, options, **kwargs): raise NotImplementedError diff --git a/lemur/plugins/bases/export.py b/lemur/plugins/bases/export.py index 1466c1ab..6d078906 100644 --- a/lemur/plugins/bases/export.py +++ b/lemur/plugins/bases/export.py @@ -14,7 +14,8 @@ class ExportPlugin(Plugin): This is the base class from which all supported exporters will inherit from. """ - type = 'export' + + type = "export" requires_key = True def export(self, body, chain, key, options, **kwargs): diff --git a/lemur/plugins/bases/issuer.py b/lemur/plugins/bases/issuer.py index 5eb0964c..f1e6aa0e 100644 --- a/lemur/plugins/bases/issuer.py +++ b/lemur/plugins/bases/issuer.py @@ -14,7 +14,8 @@ class IssuerPlugin(Plugin): This is the base class from which all of the supported issuers will inherit from. """ - type = 'issuer' + + type = "issuer" def create_certificate(self, csr, issuer_options): raise NotImplementedError diff --git a/lemur/plugins/bases/metric.py b/lemur/plugins/bases/metric.py index 259af235..2e4ce69b 100644 --- a/lemur/plugins/bases/metric.py +++ b/lemur/plugins/bases/metric.py @@ -10,7 +10,9 @@ from lemur.plugins.base import Plugin class MetricPlugin(Plugin): - type = 'metric' + type = "metric" - def submit(self, metric_name, metric_type, metric_value, metric_tags=None, options=None): + def submit( + self, metric_name, metric_type, metric_value, metric_tags=None, options=None + ): raise NotImplementedError diff --git a/lemur/plugins/bases/notification.py b/lemur/plugins/bases/notification.py index a7ba4e0d..730f68be 100644 --- a/lemur/plugins/bases/notification.py +++ b/lemur/plugins/bases/notification.py @@ -14,7 +14,8 @@ class NotificationPlugin(Plugin): This is the base class from which all of the supported issuers will inherit from. """ - type = 'notification' + + type = "notification" def send(self, notification_type, message, targets, options, **kwargs): raise NotImplementedError @@ -26,22 +27,23 @@ class ExpirationNotificationPlugin(NotificationPlugin): It contains some default options that are needed for all expiration notification plugins. """ + default_options = [ { - 'name': 'interval', - 'type': 'int', - 'required': True, - 'validation': '^\d+$', - 'helpMessage': 'Number of days to be alert before expiration.', + "name": "interval", + "type": "int", + "required": True, + "validation": "^\d+$", + "helpMessage": "Number of days to be alert before expiration.", }, { - 'name': 'unit', - 'type': 'select', - 'required': True, - 'validation': '', - 'available': ['days', 'weeks', 'months'], - 'helpMessage': 'Interval unit', - } + "name": "unit", + "type": "select", + "required": True, + "validation": "", + "available": ["days", "weeks", "months"], + "helpMessage": "Interval unit", + }, ] @property diff --git a/lemur/plugins/bases/source.py b/lemur/plugins/bases/source.py index ff3492fe..6f521e40 100644 --- a/lemur/plugins/bases/source.py +++ b/lemur/plugins/bases/source.py @@ -10,15 +10,15 @@ from lemur.plugins.base import Plugin class SourcePlugin(Plugin): - type = 'source' + type = "source" default_options = [ { - 'name': 'pollRate', - 'type': 'int', - 'required': False, - 'helpMessage': 'Rate in seconds to poll source for new information.', - 'default': '60', + "name": "pollRate", + "type": "int", + "required": False, + "helpMessage": "Rate in seconds to poll source for new information.", + "default": "60", } ] diff --git a/lemur/plugins/lemur_acme/__init__.py b/lemur/plugins/lemur_acme/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_acme/__init__.py +++ b/lemur/plugins/lemur_acme/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_acme/cloudflare.py b/lemur/plugins/lemur_acme/cloudflare.py index 77052242..a19495f8 100644 --- a/lemur/plugins/lemur_acme/cloudflare.py +++ b/lemur/plugins/lemur_acme/cloudflare.py @@ -5,24 +5,24 @@ from flask import current_app def cf_api_call(): - cf_key = current_app.config.get('ACME_CLOUDFLARE_KEY', '') - cf_email = current_app.config.get('ACME_CLOUDFLARE_EMAIL', '') + cf_key = current_app.config.get("ACME_CLOUDFLARE_KEY", "") + cf_email = current_app.config.get("ACME_CLOUDFLARE_EMAIL", "") return CloudFlare.CloudFlare(email=cf_email, token=cf_key) def find_zone_id(host): - elements = host.split('.') + elements = host.split(".") cf = cf_api_call() n = 1 while n < 5: n = n + 1 - domain = '.'.join(elements[-n:]) + domain = ".".join(elements[-n:]) current_app.logger.debug("Trying to get ID for zone {0}".format(domain)) try: - zone = cf.zones.get(params={'name': domain, 'per_page': 1}) + zone = cf.zones.get(params={"name": domain, "per_page": 1}) except Exception as e: current_app.logger.error("Cloudflare API error: %s" % e) pass @@ -31,10 +31,10 @@ def find_zone_id(host): break if len(zone) == 0: - current_app.logger.error('No zone found') + current_app.logger.error("No zone found") return else: - return zone[0]['id'] + return zone[0]["id"] def wait_for_dns_change(change_id, account_number=None): @@ -42,8 +42,8 @@ def wait_for_dns_change(change_id, account_number=None): zone_id, record_id = change_id while True: r = cf.zones.get(zone_id, record_id) - current_app.logger.debug("Record status: %s" % r['status']) - if r['status'] == 'active': + current_app.logger.debug("Record status: %s" % r["status"]) + if r["status"] == "active": break time.sleep(1) return @@ -55,22 +55,27 @@ def create_txt_record(host, value, account_number): if not zone_id: return - txt_record = {'name': host, 'type': 'TXT', 'content': value} + txt_record = {"name": host, "type": "TXT", "content": value} - current_app.logger.debug("Creating TXT record {0} with value {1}".format(host, value)) + current_app.logger.debug( + "Creating TXT record {0} with value {1}".format(host, value) + ) try: r = cf.zones.dns_records.post(zone_id, data=txt_record) except Exception as e: - current_app.logger.error('/zones.dns_records.post %s: %s' % (txt_record['name'], e)) - return zone_id, r['id'] + current_app.logger.error( + "/zones.dns_records.post %s: %s" % (txt_record["name"], e) + ) + return zone_id, r["id"] -def delete_txt_record(change_id, account_number, host, value): +def delete_txt_record(change_ids, account_number, host, value): cf = cf_api_call() - zone_id, record_id = change_id - current_app.logger.debug("Removing record with id {0}".format(record_id)) - try: - cf.zones.dns_records.delete(zone_id, record_id) - except Exception as e: - current_app.logger.error('/zones.dns_records.post: %s' % e) + for change_id in change_ids: + zone_id, record_id = change_id + current_app.logger.debug("Removing record with id {0}".format(record_id)) + try: + cf.zones.dns_records.delete(zone_id, record_id) + except Exception as e: + current_app.logger.error("/zones.dns_records.post: %s" % e) diff --git a/lemur/plugins/lemur_acme/dyn.py b/lemur/plugins/lemur_acme/dyn.py index 9bab3a65..fff2e632 100644 --- a/lemur/plugins/lemur_acme/dyn.py +++ b/lemur/plugins/lemur_acme/dyn.py @@ -5,35 +5,50 @@ import dns.exception import dns.name import dns.query import dns.resolver -from dyn.tm.errors import DynectCreateError +from dyn.tm.errors import ( + DynectCreateError, + DynectDeleteError, + DynectGetError, + DynectUpdateError, +) from dyn.tm.session import DynectSession from dyn.tm.zones import Node, Zone, get_all_zones from flask import current_app +from lemur.extensions import metrics, sentry + def get_dynect_session(): - dynect_session = DynectSession( - current_app.config.get('ACME_DYN_CUSTOMER_NAME', ''), - current_app.config.get('ACME_DYN_USERNAME', ''), - current_app.config.get('ACME_DYN_PASSWORD', ''), - ) + try: + dynect_session = DynectSession( + current_app.config.get("ACME_DYN_CUSTOMER_NAME", ""), + current_app.config.get("ACME_DYN_USERNAME", ""), + current_app.config.get("ACME_DYN_PASSWORD", ""), + ) + except Exception as e: + sentry.captureException() + metrics.send("get_dynect_session_fail", "counter", 1) + current_app.logger.debug("Unable to establish connection to Dyn", exc_info=True) + raise return dynect_session -def _has_dns_propagated(name, token): +def _has_dns_propagated(fqdn, token): txt_records = [] try: dns_resolver = dns.resolver.Resolver() - dns_resolver.nameservers = [get_authoritative_nameserver(name)] - dns_response = dns_resolver.query(name, 'TXT') + dns_resolver.nameservers = [get_authoritative_nameserver(fqdn)] + dns_response = dns_resolver.query(fqdn, "TXT") for rdata in dns_response: for txt_record in rdata.strings: txt_records.append(txt_record.decode("utf-8")) except dns.exception.DNSException: + metrics.send("has_dns_propagated_fail", "counter", 1, metric_tags={"dns": fqdn}) return False for txt_record in txt_records: if txt_record == token: + metrics.send("has_dns_propagated_success", "counter", 1, metric_tags={"dns": fqdn}) return True return False @@ -41,16 +56,24 @@ def _has_dns_propagated(name, token): def wait_for_dns_change(change_id, account_number=None): fqdn, token = change_id - number_of_attempts = 10 + number_of_attempts = 20 for attempts in range(0, number_of_attempts): status = _has_dns_propagated(fqdn, token) current_app.logger.debug("Record status for fqdn: {}: {}".format(fqdn, status)) if status: + metrics.send("wait_for_dns_change_success", "counter", 1, metric_tags={"dns": fqdn}) break - time.sleep(20) + time.sleep(10) if not status: # TODO: Delete associated DNS text record here - raise Exception("Unable to query DNS token for fqdn {}.".format(fqdn)) + metrics.send("wait_for_dns_change_fail", "counter", 1, metric_tags={"dns": fqdn}) + sentry.captureException(extra={"fqdn": str(fqdn), "txt_record": str(token)}) + metrics.send( + "wait_for_dns_change_error", + "counter", + 1, + metric_tags={"fqdn": fqdn, "txt_record": token}, + ) return @@ -67,6 +90,7 @@ def get_zone_name(domain): if z.name.count(".") > zone_name.count("."): zone_name = z.name if not zone_name: + metrics.send("dyn_no_zone_name", "counter", 1) raise Exception("No Dyn zone found for domain: {}".format(domain)) return zone_name @@ -83,22 +107,29 @@ def get_zones(account_number): def create_txt_record(domain, token, account_number): get_dynect_session() zone_name = get_zone_name(domain) - zone_parts = len(zone_name.split('.')) - node_name = '.'.join(domain.split('.')[:-zone_parts]) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) fqdn = "{0}.{1}".format(node_name, zone_name) zone = Zone(zone_name) try: - zone.add_record(node_name, record_type='TXT', txtdata="\"{}\"".format(token), ttl=5) + zone.add_record( + node_name, record_type="TXT", txtdata='"{}"'.format(token), ttl=5 + ) zone.publish() - current_app.logger.debug("TXT record created: {0}, token: {1}".format(fqdn, token)) - except DynectCreateError as e: + current_app.logger.debug( + "TXT record created: {0}, token: {1}".format(fqdn, token) + ) + except (DynectCreateError, DynectUpdateError) as e: if "Cannot duplicate existing record data" in e.message: current_app.logger.debug( "Unable to add record. Domain: {}. Token: {}. " - "Record already exists: {}".format(domain, token, e), exc_info=True + "Record already exists: {}".format(domain, token, e), + exc_info=True, ) else: + metrics.send("create_txt_record_error", "counter", 1) + sentry.captureException() raise change_id = (fqdn, token) @@ -112,19 +143,57 @@ def delete_txt_record(change_id, account_number, domain, token): return zone_name = get_zone_name(domain) - zone_parts = len(zone_name.split('.')) - node_name = '.'.join(domain.split('.')[:-zone_parts]) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) fqdn = "{0}.{1}".format(node_name, zone_name) zone = Zone(zone_name) node = Node(zone_name, fqdn) - all_txt_records = node.get_all_records_by_type('TXT') + try: + all_txt_records = node.get_all_records_by_type("TXT") + except DynectGetError: + metrics.send("delete_txt_record_geterror", "counter", 1) + # No Text Records remain or host is not in the zone anymore because all records have been deleted. + return for txt_record in all_txt_records: if txt_record.txtdata == ("{}".format(token)): current_app.logger.debug("Deleting TXT record name: {0}".format(fqdn)) - txt_record.delete() - zone.publish() + try: + txt_record.delete() + except DynectDeleteError: + sentry.captureException( + extra={ + "fqdn": str(fqdn), + "zone_name": str(zone_name), + "node_name": str(node_name), + "txt_record": str(txt_record.txtdata), + } + ) + metrics.send( + "delete_txt_record_deleteerror", + "counter", + 1, + metric_tags={"fqdn": fqdn, "txt_record": txt_record.txtdata}, + ) + + try: + zone.publish() + except DynectUpdateError: + sentry.captureException( + extra={ + "fqdn": str(fqdn), + "zone_name": str(zone_name), + "node_name": str(node_name), + "txt_record": str(txt_record.txtdata), + } + ) + metrics.send( + "delete_txt_record_publish_error", + "counter", + 1, + metric_tags={"fqdn": str(fqdn), "txt_record": str(txt_record.txtdata)}, + ) def delete_acme_txt_records(domain): @@ -136,26 +205,45 @@ def delete_acme_txt_records(domain): if not domain.startswith(acme_challenge_string): current_app.logger.debug( "delete_acme_txt_records: Domain {} doesn't start with string {}. " - "Cowardly refusing to delete TXT records".format(domain, acme_challenge_string)) + "Cowardly refusing to delete TXT records".format( + domain, acme_challenge_string + ) + ) return zone_name = get_zone_name(domain) - zone_parts = len(zone_name.split('.')) - node_name = '.'.join(domain.split('.')[:-zone_parts]) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) fqdn = "{0}.{1}".format(node_name, zone_name) zone = Zone(zone_name) node = Node(zone_name, fqdn) - all_txt_records = node.get_all_records_by_type('TXT') + all_txt_records = node.get_all_records_by_type("TXT") for txt_record in all_txt_records: current_app.logger.debug("Deleting TXT record name: {0}".format(fqdn)) - txt_record.delete() + try: + txt_record.delete() + except DynectDeleteError: + sentry.captureException( + extra={ + "fqdn": str(fqdn), + "zone_name": str(zone_name), + "node_name": str(node_name), + "txt_record": str(txt_record.txtdata), + } + ) + metrics.send( + "delete_txt_record_deleteerror", + "counter", + 1, + metric_tags={"fqdn": fqdn, "txt_record": txt_record.txtdata}, + ) zone.publish() def get_authoritative_nameserver(domain): - if current_app.config.get('ACME_DYN_GET_AUTHORATATIVE_NAMESERVER'): + if current_app.config.get("ACME_DYN_GET_AUTHORATATIVE_NAMESERVER"): n = dns.name.from_text(domain) depth = 2 @@ -166,7 +254,7 @@ def get_authoritative_nameserver(domain): while not last: s = n.split(depth) - last = s[0].to_unicode() == u'@' + last = s[0].to_unicode() == u"@" sub = s[1] query = dns.message.make_query(sub, dns.rdatatype.NS) @@ -174,10 +262,11 @@ def get_authoritative_nameserver(domain): rcode = response.rcode() if rcode != dns.rcode.NOERROR: + metrics.send("get_authoritative_nameserver_error", "counter", 1) if rcode == dns.rcode.NXDOMAIN: - raise Exception('%s does not exist.' % sub) + raise Exception("%s does not exist." % sub) else: - raise Exception('Error %s' % dns.rcode.to_text(rcode)) + raise Exception("Error %s" % dns.rcode.to_text(rcode)) if len(response.authority) > 0: rrset = response.authority[0] diff --git a/lemur/plugins/lemur_acme/plugin.py b/lemur/plugins/lemur_acme/plugin.py index 66295ed2..e38870d8 100644 --- a/lemur/plugins/lemur_acme/plugin.py +++ b/lemur/plugins/lemur_acme/plugin.py @@ -17,9 +17,9 @@ import time import OpenSSL.crypto import josepy as jose -from acme import challenges, messages +from acme import challenges, errors, messages from acme.client import BackwardsCompatibleClientV2, ClientNetwork -from acme.errors import PollError, WildcardUnsupportedError +from acme.errors import PollError, TimeoutError, WildcardUnsupportedError from acme.messages import Error as AcmeError from botocore.exceptions import ClientError from flask import current_app @@ -28,9 +28,11 @@ from lemur.authorizations import service as authorization_service from lemur.common.utils import generate_private_key from lemur.dns_providers import service as dns_provider_service from lemur.exceptions import InvalidAuthority, InvalidConfiguration, UnknownProvider +from lemur.extensions import metrics, sentry from lemur.plugins import lemur_acme as acme from lemur.plugins.bases import IssuerPlugin -from lemur.plugins.lemur_acme import cloudflare, dyn, route53 +from lemur.plugins.lemur_acme import cloudflare, dyn, route53, ultradns +from retrying import retry class AuthorizationRecord(object): @@ -47,12 +49,16 @@ class AcmeHandler(object): try: self.all_dns_providers = dns_provider_service.get_all_dns_providers() except Exception as e: - current_app.logger.error("Unable to fetch DNS Providers: {}".format(e)) + metrics.send("AcmeHandler_init_error", "counter", 1) + sentry.captureException() + current_app.logger.error(f"Unable to fetch DNS Providers: {e}") self.all_dns_providers = [] - def find_dns_challenge(self, authorizations): + def find_dns_challenge(self, host, authorizations): dns_challenges = [] for authz in authorizations: + if not authz.body.identifier.value.lower() == host.lower(): + continue for combo in authz.body.challenges: if isinstance(combo.chall, challenges.DNS01): dns_challenges.append(combo) @@ -62,39 +68,60 @@ class AcmeHandler(object): return host.replace("*.", "") def maybe_add_extension(self, host, dns_provider_options): - if dns_provider_options and dns_provider_options.get("acme_challenge_extension"): + if dns_provider_options and dns_provider_options.get( + "acme_challenge_extension" + ): host = host + dns_provider_options.get("acme_challenge_extension") return host - def start_dns_challenge(self, acme_client, account_number, host, dns_provider, order, dns_provider_options): + def start_dns_challenge( + self, + acme_client, + account_number, + host, + dns_provider, + order, + dns_provider_options, + ): current_app.logger.debug("Starting DNS challenge for {0}".format(host)) - dns_challenges = self.find_dns_challenge(order.authorizations) change_ids = [] host_to_validate = self.maybe_remove_wildcard(host) - host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) + dns_challenges = self.find_dns_challenge(host_to_validate, order.authorizations) + host_to_validate = self.maybe_add_extension( + host_to_validate, dns_provider_options + ) - for dns_challenge in self.find_dns_challenge(order.authorizations): + if not dns_challenges: + sentry.captureException() + metrics.send("start_dns_challenge_error_no_dns_challenges", "counter", 1) + raise Exception("Unable to determine DNS challenges from authorizations") + + for dns_challenge in dns_challenges: change_id = dns_provider.create_txt_record( dns_challenge.validation_domain_name(host_to_validate), dns_challenge.validation(acme_client.client.net.key), - account_number + account_number, ) change_ids.append(change_id) return AuthorizationRecord( - host, - order.authorizations, - dns_challenges, - change_ids + host, order.authorizations, dns_challenges, change_ids ) def complete_dns_challenge(self, acme_client, authz_record): - current_app.logger.debug("Finalizing DNS challenge for {0}".format(authz_record.authz[0].body.identifier.value)) + current_app.logger.debug( + "Finalizing DNS challenge for {0}".format( + authz_record.authz[0].body.identifier.value + ) + ) dns_providers = self.dns_providers_for_domain.get(authz_record.host) if not dns_providers: - raise Exception("No DNS providers found for domain: {}".format(authz_record.host)) + metrics.send("complete_dns_challenge_error_no_dnsproviders", "counter", 1) + raise Exception( + "No DNS providers found for domain: {}".format(authz_record.host) + ) for dns_provider in dns_providers: # Grab account number (For Route53) @@ -102,7 +129,19 @@ class AcmeHandler(object): account_number = dns_provider_options.get("account_id") dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) for change_id in authz_record.change_id: - dns_provider_plugin.wait_for_dns_change(change_id, account_number=account_number) + try: + dns_provider_plugin.wait_for_dns_change( + change_id, account_number=account_number + ) + except Exception: + metrics.send("complete_dns_challenge_error", "counter", 1) + sentry.captureException() + current_app.logger.debug( + f"Unable to resolve DNS challenge for change_id: {change_id}, account_id: " + f"{account_number}", + exc_info=True, + ) + raise for dns_challenge in authz_record.dns_challenge: response = dns_challenge.response(acme_client.client.net.key) @@ -110,36 +149,56 @@ class AcmeHandler(object): verified = response.simple_verify( dns_challenge.chall, authz_record.host, - acme_client.client.net.key.public_key() + acme_client.client.net.key.public_key(), ) - if not verified: - raise ValueError("Failed verification") + if not verified: + metrics.send("complete_dns_challenge_verification_error", "counter", 1) + raise ValueError("Failed verification") - time.sleep(5) - acme_client.answer_challenge(dns_challenge, response) + time.sleep(5) + res = acme_client.answer_challenge(dns_challenge, response) + current_app.logger.debug(f"answer_challenge response: {res}") def request_certificate(self, acme_client, authorizations, order): for authorization in authorizations: for authz in authorization.authz: authorization_resource, _ = acme_client.poll(authz) - deadline = datetime.datetime.now() + datetime.timedelta(seconds=90) + deadline = datetime.datetime.now() + datetime.timedelta(seconds=360) try: - orderr = acme_client.finalize_order(order, deadline) - except AcmeError: - current_app.logger.error("Unable to resolve Acme order: {}".format(order), exc_info=True) + orderr = acme_client.poll_and_finalize(order, deadline) + + except (AcmeError, TimeoutError): + sentry.captureException(extra={"order_url": str(order.uri)}) + metrics.send("request_certificate_error", "counter", 1) + current_app.logger.error( + f"Unable to resolve Acme order: {order.uri}", exc_info=True + ) raise + except errors.ValidationError: + if order.fullchain_pem: + orderr = order + else: + raise - pem_certificate = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, - OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, - orderr.fullchain_pem)).decode() - pem_certificate_chain = orderr.fullchain_pem[len(pem_certificate):].lstrip() + pem_certificate = OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, orderr.fullchain_pem + ), + ).decode() + pem_certificate_chain = orderr.fullchain_pem[ + len(pem_certificate) : # noqa + ].lstrip() - current_app.logger.debug("{0} {1}".format(type(pem_certificate), type(pem_certificate_chain))) + current_app.logger.debug( + "{0} {1}".format(type(pem_certificate), type(pem_certificate_chain)) + ) return pem_certificate, pem_certificate_chain + @retry(stop_max_attempt_number=5, wait_fixed=5000) def setup_acme_client(self, authority): if not authority.options: raise InvalidAuthority("Invalid authority. Options not set") @@ -147,30 +206,40 @@ class AcmeHandler(object): for option in json.loads(authority.options): options[option["name"]] = option.get("value") - email = options.get('email', current_app.config.get('ACME_EMAIL')) - tel = options.get('telephone', current_app.config.get('ACME_TEL')) - directory_url = options.get('acme_url', current_app.config.get('ACME_DIRECTORY_URL')) + email = options.get("email", current_app.config.get("ACME_EMAIL")) + tel = options.get("telephone", current_app.config.get("ACME_TEL")) + directory_url = options.get( + "acme_url", current_app.config.get("ACME_DIRECTORY_URL") + ) - existing_key = options.get('acme_private_key', current_app.config.get('ACME_PRIVATE_KEY')) - existing_regr = options.get('acme_regr', current_app.config.get('ACME_REGR')) + existing_key = options.get( + "acme_private_key", current_app.config.get("ACME_PRIVATE_KEY") + ) + existing_regr = options.get("acme_regr", current_app.config.get("ACME_REGR")) if existing_key and existing_regr: # Reuse the same account for each certificate issuance key = jose.JWK.json_loads(existing_key) regr = messages.RegistrationResource.json_loads(existing_regr) - current_app.logger.debug("Connecting with directory at {0}".format(directory_url)) + current_app.logger.debug( + "Connecting with directory at {0}".format(directory_url) + ) net = ClientNetwork(key, account=regr) client = BackwardsCompatibleClientV2(net, key, directory_url) return client, {} else: # Create an account for each certificate issuance - key = jose.JWKRSA(key=generate_private_key('RSA2048')) + key = jose.JWKRSA(key=generate_private_key("RSA2048")) - current_app.logger.debug("Connecting with directory at {0}".format(directory_url)) + current_app.logger.debug( + "Connecting with directory at {0}".format(directory_url) + ) net = ClientNetwork(key, account=None, timeout=3600) client = BackwardsCompatibleClientV2(net, key, directory_url) - registration = client.new_account_and_tos(messages.NewRegistration.from_data(email=email)) + registration = client.new_account_and_tos( + messages.NewRegistration.from_data(email=email) + ) current_app.logger.debug("Connected: {0}".format(registration.uri)) return client, registration @@ -183,9 +252,9 @@ class AcmeHandler(object): """ current_app.logger.debug("Fetching domains") - domains = [options['common_name']] - if options.get('extensions'): - for name in options['extensions']['sub_alt_names']['names']: + domains = [options["common_name"]] + if options.get("extensions"): + for name in options["extensions"]["sub_alt_names"]["names"]: domains.append(name) current_app.logger.debug("Got these domains: {0}".format(domains)) @@ -196,15 +265,22 @@ class AcmeHandler(object): for domain in order_info.domains: if not self.dns_providers_for_domain.get(domain): + metrics.send( + "get_authorizations_no_dns_provider_for_domain", "counter", 1 + ) raise Exception("No DNS providers found for domain: {}".format(domain)) for dns_provider in self.dns_providers_for_domain[domain]: dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) dns_provider_options = json.loads(dns_provider.credentials) account_number = dns_provider_options.get("account_id") - authz_record = self.start_dns_challenge(acme_client, account_number, domain, - dns_provider_plugin, - order, - dns_provider.options) + authz_record = self.start_dns_challenge( + acme_client, + account_number, + domain, + dns_provider_plugin, + order, + dns_provider.options, + ) authorizations.append(authz_record) return authorizations @@ -220,7 +296,7 @@ class AcmeHandler(object): if not dns_provider.domains: continue for name in dns_provider.domains: - if domain.endswith("." + name): + if name == domain or domain.endswith("." + name): if len(name) > match_length: self.dns_providers_for_domain[domain] = [dns_provider] match_length = len(name) @@ -238,16 +314,20 @@ class AcmeHandler(object): dns_providers = self.dns_providers_for_domain.get(authz_record.host) for dns_provider in dns_providers: # Grab account number (For Route53) - dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) + dns_provider_plugin = self.get_dns_provider( + dns_provider.provider_type + ) dns_provider_options = json.loads(dns_provider.credentials) account_number = dns_provider_options.get("account_id") host_to_validate = self.maybe_remove_wildcard(authz_record.host) - host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) + host_to_validate = self.maybe_add_extension( + host_to_validate, dns_provider_options + ) dns_provider_plugin.delete_txt_record( authz_record.change_id, account_number, dns_challenge.validation_domain_name(host_to_validate), - dns_challenge.validation(acme_client.client.net.key) + dns_challenge.validation(acme_client.client.net.key), ) return authorizations @@ -272,25 +352,31 @@ class AcmeHandler(object): account_number = dns_provider_options.get("account_id") dns_challenges = authz_record.dns_challenge host_to_validate = self.maybe_remove_wildcard(authz_record.host) - host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) + host_to_validate = self.maybe_add_extension( + host_to_validate, dns_provider_options + ) + dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) for dns_challenge in dns_challenges: try: - dns_provider.delete_txt_record( + dns_provider_plugin.delete_txt_record( authz_record.change_id, account_number, dns_challenge.validation_domain_name(host_to_validate), - dns_challenge.validation(acme_client.client.net.key) + dns_challenge.validation(acme_client.client.net.key), ) except Exception as e: # If this fails, it's most likely because the record doesn't exist (It was already cleaned up) # or we're not authorized to modify it. + metrics.send("cleanup_dns_challenges_error", "counter", 1) + sentry.captureException() pass def get_dns_provider(self, type): provider_types = { - 'cloudflare': cloudflare, - 'dyn': dyn, - 'route53': route53, + "cloudflare": cloudflare, + "dyn": dyn, + "route53": route53, + "ultradns": ultradns, } provider = provider_types.get(type) if not provider: @@ -299,41 +385,43 @@ class AcmeHandler(object): class ACMEIssuerPlugin(IssuerPlugin): - title = 'Acme' - slug = 'acme-issuer' - description = 'Enables the creation of certificates via ACME CAs (including Let\'s Encrypt)' + title = "Acme" + slug = "acme-issuer" + description = ( + "Enables the creation of certificates via ACME CAs (including Let's Encrypt)" + ) version = acme.VERSION - author = 'Netflix' - author_url = 'https://github.com/netflix/lemur.git' + author = "Netflix" + author_url = "https://github.com/netflix/lemur.git" options = [ { - 'name': 'acme_url', - 'type': 'str', - 'required': True, - 'validation': '/^http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+$/', - 'helpMessage': 'Must be a valid web url starting with http[s]://', + "name": "acme_url", + "type": "str", + "required": True, + "validation": "/^http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+$/", + "helpMessage": "Must be a valid web url starting with http[s]://", }, { - 'name': 'telephone', - 'type': 'str', - 'default': '', - 'helpMessage': 'Telephone to use' + "name": "telephone", + "type": "str", + "default": "", + "helpMessage": "Telephone to use", }, { - 'name': 'email', - 'type': 'str', - 'default': '', - 'validation': '/^?([-a-zA-Z0-9.`?{}]+@\w+\.\w+)$/', - 'helpMessage': 'Email to use' + "name": "email", + "type": "str", + "default": "", + "validation": "/^?([-a-zA-Z0-9.`?{}]+@\w+\.\w+)$/", + "helpMessage": "Email to use", }, { - 'name': 'certificate', - 'type': 'textarea', - 'default': '', - 'validation': '/^-----BEGIN CERTIFICATE-----/', - 'helpMessage': 'Certificate to use' + "name": "certificate", + "type": "textarea", + "default": "", + "validation": "/^-----BEGIN CERTIFICATE-----/", + "helpMessage": "Certificate to use", }, ] @@ -344,9 +432,10 @@ class ACMEIssuerPlugin(IssuerPlugin): self.acme = AcmeHandler() provider_types = { - 'cloudflare': cloudflare, - 'dyn': dyn, - 'route53': route53, + "cloudflare": cloudflare, + "dyn": dyn, + "route53": route53, + "ultradns": ultradns, } provider = provider_types.get(type) if not provider: @@ -378,21 +467,31 @@ class ACMEIssuerPlugin(IssuerPlugin): try: order = acme_client.new_order(pending_cert.csr) except WildcardUnsupportedError: - raise Exception("The currently selected ACME CA endpoint does" - " not support issuing wildcard certificates.") + metrics.send("get_ordered_certificate_wildcard_unsupported", "counter", 1) + raise Exception( + "The currently selected ACME CA endpoint does" + " not support issuing wildcard certificates." + ) try: - authorizations = self.acme.get_authorizations(acme_client, order, order_info) + authorizations = self.acme.get_authorizations( + acme_client, order, order_info + ) except ClientError: - current_app.logger.error("Unable to resolve pending cert: {}".format(pending_cert.name), exc_info=True) + sentry.captureException() + metrics.send("get_ordered_certificate_error", "counter", 1) + current_app.logger.error( + f"Unable to resolve pending cert: {pending_cert.name}", exc_info=True + ) return False authorizations = self.acme.finalize_authorizations(acme_client, authorizations) pem_certificate, pem_certificate_chain = self.acme.request_certificate( - acme_client, authorizations, order) + acme_client, authorizations, order + ) cert = { - 'body': "\n".join(str(pem_certificate).splitlines()), - 'chain': "\n".join(str(pem_certificate_chain).splitlines()), - 'external_id': str(pending_cert.external_id) + "body": "\n".join(str(pem_certificate).splitlines()), + "chain": "\n".join(str(pem_certificate_chain).splitlines()), + "external_id": str(pending_cert.external_id), } return cert @@ -402,10 +501,14 @@ class ACMEIssuerPlugin(IssuerPlugin): certs = [] for pending_cert in pending_certs: try: - acme_client, registration = self.acme.setup_acme_client(pending_cert.authority) + acme_client, registration = self.acme.setup_acme_client( + pending_cert.authority + ) order_info = authorization_service.get(pending_cert.external_id) if pending_cert.dns_provider_id: - dns_provider = dns_provider_service.get(pending_cert.dns_provider_id) + dns_provider = dns_provider_service.get( + pending_cert.dns_provider_id + ) for domain in order_info.domains: # Currently, we only support specifying one DNS provider per certificate, even if that @@ -418,57 +521,80 @@ class ACMEIssuerPlugin(IssuerPlugin): try: order = acme_client.new_order(pending_cert.csr) except WildcardUnsupportedError: - raise Exception("The currently selected ACME CA endpoint does" - " not support issuing wildcard certificates.") + sentry.captureException() + metrics.send( + "get_ordered_certificates_wildcard_unsupported_error", + "counter", + 1, + ) + raise Exception( + "The currently selected ACME CA endpoint does" + " not support issuing wildcard certificates." + ) - authorizations = self.acme.get_authorizations(acme_client, order, order_info) + authorizations = self.acme.get_authorizations( + acme_client, order, order_info + ) - pending.append({ - "acme_client": acme_client, - "authorizations": authorizations, - "pending_cert": pending_cert, - "order": order, - }) + pending.append( + { + "acme_client": acme_client, + "authorizations": authorizations, + "pending_cert": pending_cert, + "order": order, + } + ) except (ClientError, ValueError, Exception) as e: - current_app.logger.error("Unable to resolve pending cert: {}".format(pending_cert), exc_info=True) - certs.append({ - "cert": False, - "pending_cert": pending_cert, - "last_error": e, - }) + sentry.captureException() + metrics.send( + "get_ordered_certificates_pending_creation_error", "counter", 1 + ) + current_app.logger.error( + f"Unable to resolve pending cert: {pending_cert}", exc_info=True + ) + + error = e + if globals().get("order") and order: + error += f" Order uri: {order.uri}" + certs.append( + {"cert": False, "pending_cert": pending_cert, "last_error": e} + ) for entry in pending: try: entry["authorizations"] = self.acme.finalize_authorizations( - entry["acme_client"], - entry["authorizations"], + entry["acme_client"], entry["authorizations"] ) pem_certificate, pem_certificate_chain = self.acme.request_certificate( - entry["acme_client"], - entry["authorizations"], - entry["order"] + entry["acme_client"], entry["authorizations"], entry["order"] ) cert = { - 'body': "\n".join(str(pem_certificate).splitlines()), - 'chain': "\n".join(str(pem_certificate_chain).splitlines()), - 'external_id': str(entry["pending_cert"].external_id) + "body": "\n".join(str(pem_certificate).splitlines()), + "chain": "\n".join(str(pem_certificate_chain).splitlines()), + "external_id": str(entry["pending_cert"].external_id), } - certs.append({ - "cert": cert, - "pending_cert": entry["pending_cert"], - }) + certs.append({"cert": cert, "pending_cert": entry["pending_cert"]}) except (PollError, AcmeError, Exception) as e: - current_app.logger.error("Unable to resolve pending cert: {}".format(pending_cert), exc_info=True) - certs.append({ - "cert": False, - "pending_cert": entry["pending_cert"], - "last_error": e, - }) + sentry.captureException() + metrics.send("get_ordered_certificates_resolution_error", "counter", 1) + order_url = order.uri + error = f"{e}. Order URI: {order_url}" + current_app.logger.error( + f"Unable to resolve pending cert: {pending_cert}. " + f"Check out {order_url} for more information.", + exc_info=True, + ) + certs.append( + { + "cert": False, + "pending_cert": entry["pending_cert"], + "last_error": error, + } + ) # Ensure DNS records get deleted self.acme.cleanup_dns_challenges( - entry["acme_client"], - entry["authorizations"], + entry["acme_client"], entry["authorizations"] ) return certs @@ -481,20 +607,26 @@ class ACMEIssuerPlugin(IssuerPlugin): :return: :raise Exception: """ self.acme = AcmeHandler() - authority = issuer_options.get('authority') - create_immediately = issuer_options.get('create_immediately', False) + authority = issuer_options.get("authority") + create_immediately = issuer_options.get("create_immediately", False) acme_client, registration = self.acme.setup_acme_client(authority) - dns_provider = issuer_options.get('dns_provider', {}) + dns_provider = issuer_options.get("dns_provider", {}) if dns_provider: dns_provider_options = dns_provider.options credentials = json.loads(dns_provider.credentials) - current_app.logger.debug("Using DNS provider: {0}".format(dns_provider.provider_type)) - dns_provider_plugin = __import__(dns_provider.provider_type, globals(), locals(), [], 1) + current_app.logger.debug( + "Using DNS provider: {0}".format(dns_provider.provider_type) + ) + dns_provider_plugin = __import__( + dns_provider.provider_type, globals(), locals(), [], 1 + ) account_number = credentials.get("account_id") provider_type = dns_provider.provider_type if provider_type == "route53" and not account_number: - error = "Route53 DNS Provider {} does not have an account number configured.".format(dns_provider.name) + error = "Route53 DNS Provider {} does not have an account number configured.".format( + dns_provider.name + ) current_app.logger.error(error) raise InvalidConfiguration(error) else: @@ -513,16 +645,29 @@ class ACMEIssuerPlugin(IssuerPlugin): else: authz_domains.append(d.value) - dns_authorization = authorization_service.create(account_number, authz_domains, - provider_type) + dns_authorization = authorization_service.create( + account_number, authz_domains, provider_type + ) # Return id of the DNS Authorization return None, None, dns_authorization.id - authorizations = self.acme.get_authorizations(acme_client, account_number, domains, dns_provider_plugin, - dns_provider_options) - self.acme.finalize_authorizations(acme_client, account_number, dns_provider_plugin, authorizations, - dns_provider_options) - pem_certificate, pem_certificate_chain = self.acme.request_certificate(acme_client, authorizations, csr) + authorizations = self.acme.get_authorizations( + acme_client, + account_number, + domains, + dns_provider_plugin, + dns_provider_options, + ) + self.acme.finalize_authorizations( + acme_client, + account_number, + dns_provider_plugin, + authorizations, + dns_provider_options, + ) + pem_certificate, pem_certificate_chain = self.acme.request_certificate( + acme_client, authorizations, csr + ) # TODO add external ID (if possible) return pem_certificate, pem_certificate_chain, None @@ -535,18 +680,18 @@ class ACMEIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'acme'} - plugin_options = options.get('plugin', {}).get('plugin_options') + role = {"username": "", "password": "", "name": "acme"} + plugin_options = options.get("plugin", {}).get("plugin_options") if not plugin_options: error = "Invalid options for lemur_acme plugin: {}".format(options) current_app.logger.error(error) raise InvalidConfiguration(error) # Define static acme_root based off configuration variable by default. However, if user has passed a # certificate, use this certificate as the root. - acme_root = current_app.config.get('ACME_ROOT') + acme_root = current_app.config.get("ACME_ROOT") for option in plugin_options: - if option.get('name') == 'certificate': - acme_root = option.get('value') + if option.get("name") == "certificate": + acme_root = option.get("value") return acme_root, "", [role] def cancel_ordered_certificate(self, pending_cert, **kwargs): diff --git a/lemur/plugins/lemur_acme/route53.py b/lemur/plugins/lemur_acme/route53.py index 3b6c5b32..55da5161 100644 --- a/lemur/plugins/lemur_acme/route53.py +++ b/lemur/plugins/lemur_acme/route53.py @@ -3,7 +3,7 @@ import time from lemur.plugins.lemur_aws.sts import sts_client -@sts_client('route53') +@sts_client("route53") def wait_for_dns_change(change_id, client=None): _, change_id = change_id @@ -14,7 +14,7 @@ def wait_for_dns_change(change_id, client=None): time.sleep(5) -@sts_client('route53') +@sts_client("route53") def find_zone_id(domain, client=None): paginator = client.get_paginator("list_hosted_zones") zones = [] @@ -25,34 +25,35 @@ def find_zone_id(domain, client=None): zones.append((zone["Name"], zone["Id"])) if not zones: - raise ValueError( - "Unable to find a Route53 hosted zone for {}".format(domain) - ) + raise ValueError("Unable to find a Route53 hosted zone for {}".format(domain)) return zones[0][1] -@sts_client('route53') +@sts_client("route53") def get_zones(client=None): paginator = client.get_paginator("list_hosted_zones") zones = [] for page in paginator.paginate(): for zone in page["HostedZones"]: - zones.append(zone["Name"][:-1]) # We need [:-1] to strip out the trailing dot. + zones.append( + zone["Name"][:-1] + ) # We need [:-1] to strip out the trailing dot. return zones -@sts_client('route53') +@sts_client("route53") def change_txt_record(action, zone_id, domain, value, client=None): current_txt_records = [] try: current_records = client.list_resource_record_sets( HostedZoneId=zone_id, StartRecordName=domain, - StartRecordType='TXT', - MaxItems="1")["ResourceRecordSets"] + StartRecordType="TXT", + MaxItems="1", + )["ResourceRecordSets"] for record in current_records: - if record.get('Type') == 'TXT': + if record.get("Type") == "TXT": current_txt_records.extend(record.get("ResourceRecords", [])) except Exception as e: # Current Resource Record does not exist @@ -72,7 +73,9 @@ def change_txt_record(action, zone_id, domain, value, client=None): # If we want to delete one record out of many, we'll update the record to not include the deleted value instead. # This allows us to support concurrent issuance. current_txt_records = [ - record for record in current_txt_records if not (record.get('Value') == '"{}"'.format(value)) + record + for record in current_txt_records + if not (record.get("Value") == '"{}"'.format(value)) ] action = "UPSERT" @@ -87,10 +90,10 @@ def change_txt_record(action, zone_id, domain, value, client=None): "Type": "TXT", "TTL": 300, "ResourceRecords": current_txt_records, - } + }, } ] - } + }, ) return response["ChangeInfo"]["Id"] @@ -98,11 +101,7 @@ def change_txt_record(action, zone_id, domain, value, client=None): def create_txt_record(host, value, account_number): zone_id = find_zone_id(host, account_number=account_number) change_id = change_txt_record( - "UPSERT", - zone_id, - host, - value, - account_number=account_number + "UPSERT", zone_id, host, value, account_number=account_number ) return zone_id, change_id @@ -113,11 +112,7 @@ def delete_txt_record(change_ids, account_number, host, value): zone_id, _ = change_id try: change_txt_record( - "DELETE", - zone_id, - host, - value, - account_number=account_number + "DELETE", zone_id, host, value, account_number=account_number ) except Exception as e: if "but it was not found" in e.response.get("Error", {}).get("Message"): diff --git a/lemur/plugins/lemur_acme/tests/test_acme.py b/lemur/plugins/lemur_acme/tests/test_acme.py index 0c406627..2f9dd719 100644 --- a/lemur/plugins/lemur_acme/tests/test_acme.py +++ b/lemur/plugins/lemur_acme/tests/test_acme.py @@ -1,13 +1,13 @@ import unittest +from requests.models import Response from mock import MagicMock, Mock, patch -from lemur.plugins.lemur_acme import plugin +from lemur.plugins.lemur_acme import plugin, ultradns class TestAcme(unittest.TestCase): - - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") def setUp(self, mock_dns_provider_service): self.ACMEIssuerPlugin = plugin.ACMEIssuerPlugin() self.acme = plugin.AcmeHandler() @@ -15,14 +15,17 @@ class TestAcme(unittest.TestCase): mock_dns_provider.name = "cloudflare" mock_dns_provider.credentials = "{}" mock_dns_provider.provider_type = "cloudflare" - self.acme.dns_providers_for_domain = {"www.test.com": [mock_dns_provider], - "test.fakedomain.net": [mock_dns_provider]} + self.acme.dns_providers_for_domain = { + "www.test.com": [mock_dns_provider], + "test.fakedomain.net": [mock_dns_provider], + } - @patch('lemur.plugins.lemur_acme.plugin.len', return_value=1) + @patch("lemur.plugins.lemur_acme.plugin.len", return_value=1) def test_find_dns_challenge(self, mock_len): assert mock_len from acme import challenges + c = challenges.DNS01() mock_authz = Mock() @@ -37,11 +40,13 @@ class TestAcme(unittest.TestCase): a = plugin.AuthorizationRecord("host", "authz", "challenge", "id") self.assertEqual(type(a), plugin.AuthorizationRecord) - @patch('acme.client.Client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.len', return_value=1) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge') - def test_start_dns_challenge(self, mock_find_dns_challenge, mock_len, mock_app, mock_acme): + @patch("acme.client.Client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.len", return_value=1) + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge") + def test_start_dns_challenge( + self, mock_find_dns_challenge, mock_len, mock_app, mock_acme + ): assert mock_len mock_order = Mock() mock_app.logger.debug = Mock() @@ -49,6 +54,7 @@ class TestAcme(unittest.TestCase): mock_authz.body.resolved_combinations = [] mock_entry = MagicMock() from acme import challenges + c = challenges.DNS01() mock_entry.chall = TestAcme.test_complete_dns_challenge_fail mock_authz.body.resolved_combinations.append(mock_entry) @@ -60,13 +66,17 @@ class TestAcme(unittest.TestCase): iterable = mock_find_dns_challenge.return_value iterator = iter(values) iterable.__iter__.return_value = iterator - result = self.acme.start_dns_challenge(mock_acme, "accountid", "host", mock_dns_provider, mock_order, {}) + result = self.acme.start_dns_challenge( + mock_acme, "accountid", "host", mock_dns_provider, mock_order, {} + ) self.assertEqual(type(result), plugin.AuthorizationRecord) - @patch('acme.client.Client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change') - def test_complete_dns_challenge_success(self, mock_wait_for_dns_change, mock_current_app, mock_acme): + @patch("acme.client.Client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change") + def test_complete_dns_challenge_success( + self, mock_wait_for_dns_change, mock_current_app, mock_acme + ): mock_dns_provider = Mock() mock_dns_provider.wait_for_dns_change = Mock(return_value=True) mock_authz = Mock() @@ -84,10 +94,12 @@ class TestAcme(unittest.TestCase): mock_authz.dns_challenge.append(dns_challenge) self.acme.complete_dns_challenge(mock_acme, mock_authz) - @patch('acme.client.Client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change') - def test_complete_dns_challenge_fail(self, mock_wait_for_dns_change, mock_current_app, mock_acme): + @patch("acme.client.Client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change") + def test_complete_dns_challenge_fail( + self, mock_wait_for_dns_change, mock_current_app, mock_acme + ): mock_dns_provider = Mock() mock_dns_provider.wait_for_dns_change = Mock(return_value=True) @@ -105,16 +117,22 @@ class TestAcme(unittest.TestCase): dns_challenge = Mock() mock_authz.dns_challenge.append(dns_challenge) self.assertRaises( - ValueError, - self.acme.complete_dns_challenge(mock_acme, mock_authz) + ValueError, self.acme.complete_dns_challenge(mock_acme, mock_authz) ) - @patch('acme.client.Client') - @patch('OpenSSL.crypto', return_value="mock_cert") - @patch('josepy.util.ComparableX509') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - def test_request_certificate(self, mock_current_app, mock_find_dns_challenge, mock_jose, mock_crypto, mock_acme): + @patch("acme.client.Client") + @patch("OpenSSL.crypto", return_value="mock_cert") + @patch("josepy.util.ComparableX509") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + def test_request_certificate( + self, + mock_current_app, + mock_find_dns_challenge, + mock_jose, + mock_crypto, + mock_acme, + ): mock_cert_response = Mock() mock_cert_response.body = "123" mock_cert_response_full = [mock_cert_response, True] @@ -124,7 +142,7 @@ class TestAcme(unittest.TestCase): mock_authz_record.authz = Mock() mock_authz.append(mock_authz_record) mock_acme.fetch_chain = Mock(return_value="mock_chain") - mock_crypto.dump_certificate = Mock(return_value=b'chain') + mock_crypto.dump_certificate = Mock(return_value=b"chain") mock_order = Mock() self.acme.request_certificate(mock_acme, [], mock_order) @@ -134,8 +152,8 @@ class TestAcme(unittest.TestCase): with self.assertRaises(Exception): self.acme.setup_acme_client(mock_authority) - @patch('lemur.plugins.lemur_acme.plugin.BackwardsCompatibleClientV2') - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.BackwardsCompatibleClientV2") + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_setup_acme_client_success(self, mock_current_app, mock_acme): mock_authority = Mock() mock_authority.options = '[{"name": "mock_name", "value": "mock_value"}]' @@ -150,31 +168,29 @@ class TestAcme(unittest.TestCase): assert result_client assert result_registration - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_get_domains_single(self, mock_current_app): - options = { - "common_name": "test.netflix.net" - } + options = {"common_name": "test.netflix.net"} result = self.acme.get_domains(options) self.assertEqual(result, [options["common_name"]]) - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_get_domains_multiple(self, mock_current_app): options = { "common_name": "test.netflix.net", "extensions": { - "sub_alt_names": { - "names": [ - "test2.netflix.net", - "test3.netflix.net" - ] - } - } + "sub_alt_names": {"names": ["test2.netflix.net", "test3.netflix.net"]} + }, } result = self.acme.get_domains(options) - self.assertEqual(result, [options["common_name"], "test2.netflix.net", "test3.netflix.net"]) + self.assertEqual( + result, [options["common_name"], "test2.netflix.net", "test3.netflix.net"] + ) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.start_dns_challenge', return_value="test") + @patch( + "lemur.plugins.lemur_acme.plugin.AcmeHandler.start_dns_challenge", + return_value="test", + ) def test_get_authorizations(self, mock_start_dns_challenge): mock_order = Mock() mock_order.body.identifiers = [] @@ -183,10 +199,15 @@ class TestAcme(unittest.TestCase): mock_order_info = Mock() mock_order_info.account_number = 1 mock_order_info.domains = ["test.fakedomain.net"] - result = self.acme.get_authorizations("acme_client", mock_order, mock_order_info) + result = self.acme.get_authorizations( + "acme_client", mock_order, mock_order_info + ) self.assertEqual(result, ["test"]) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.complete_dns_challenge', return_value="test") + @patch( + "lemur.plugins.lemur_acme.plugin.AcmeHandler.complete_dns_challenge", + return_value="test", + ) def test_finalize_authorizations(self, mock_complete_dns_challenge): mock_authz = [] mock_authz_record = MagicMock() @@ -202,28 +223,28 @@ class TestAcme(unittest.TestCase): result = self.acme.finalize_authorizations(mock_acme_client, mock_authz) self.assertEqual(result, mock_authz) - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_create_authority(self, mock_current_app): mock_current_app.config = Mock() options = { - "plugin": { - "plugin_options": [{ - "name": "certificate", - "value": "123" - }] - } + "plugin": {"plugin_options": [{"name": "certificate", "value": "123"}]} } acme_root, b, role = self.ACMEIssuerPlugin.create_authority(options) self.assertEqual(acme_root, "123") self.assertEqual(b, "") - self.assertEqual(role, [{'username': '', 'password': '', 'name': 'acme'}]) + self.assertEqual(role, [{"username": "", "password": "", "name": "acme"}]) - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.dyn.current_app') - @patch('lemur.plugins.lemur_acme.cloudflare.current_app') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - def test_get_dns_provider(self, mock_dns_provider_service, mock_current_app_cloudflare, mock_current_app_dyn, - mock_current_app): + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.dyn.current_app") + @patch("lemur.plugins.lemur_acme.cloudflare.current_app") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + def test_get_dns_provider( + self, + mock_dns_provider_service, + mock_current_app_cloudflare, + mock_current_app_dyn, + mock_current_app, + ): provider = plugin.ACMEIssuerPlugin() route53 = provider.get_dns_provider("route53") assert route53 @@ -232,16 +253,23 @@ class TestAcme(unittest.TestCase): dyn = provider.get_dns_provider("dyn") assert dyn - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.authorization_service') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate') + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.authorization_service") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") def test_get_ordered_certificate( - self, mock_request_certificate, mock_finalize_authorizations, mock_get_authorizations, - mock_dns_provider_service, mock_authorization_service, mock_current_app, mock_acme): + self, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_dns_provider_service, + mock_authorization_service, + mock_current_app, + mock_acme, + ): mock_client = Mock() mock_acme.return_value = (mock_client, "") mock_request_certificate.return_value = ("pem_certificate", "chain") @@ -253,24 +281,26 @@ class TestAcme(unittest.TestCase): provider.get_dns_provider = Mock() result = provider.get_ordered_certificate(mock_cert) self.assertEqual( - result, - { - 'body': "pem_certificate", - 'chain': "chain", - 'external_id': "1" - } + result, {"body": "pem_certificate", "chain": "chain", "external_id": "1"} ) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.authorization_service') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate') + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.authorization_service") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") def test_get_ordered_certificates( - self, mock_request_certificate, mock_finalize_authorizations, mock_get_authorizations, - mock_dns_provider_service, mock_authorization_service, mock_current_app, mock_acme): + self, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_dns_provider_service, + mock_authorization_service, + mock_current_app, + mock_acme, + ): mock_client = Mock() mock_acme.return_value = (mock_client, "") mock_request_certificate.return_value = ("pem_certificate", "chain") @@ -285,19 +315,32 @@ class TestAcme(unittest.TestCase): provider.get_dns_provider = Mock() result = provider.get_ordered_certificates([mock_cert, mock_cert2]) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['cert'], {'body': 'pem_certificate', 'chain': 'chain', 'external_id': '1'}) - self.assertEqual(result[1]['cert'], {'body': 'pem_certificate', 'chain': 'chain', 'external_id': '2'}) + self.assertEqual( + result[0]["cert"], + {"body": "pem_certificate", "chain": "chain", "external_id": "1"}, + ) + self.assertEqual( + result[1]["cert"], + {"body": "pem_certificate", "chain": "chain", "external_id": "2"}, + ) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate') - @patch('lemur.plugins.lemur_acme.plugin.authorization_service') - def test_create_certificate(self, mock_authorization_service, mock_request_certificate, - mock_finalize_authorizations, mock_get_authorizations, - mock_current_app, mock_dns_provider_service, mock_acme): + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") + @patch("lemur.plugins.lemur_acme.plugin.authorization_service") + def test_create_certificate( + self, + mock_authorization_service, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_current_app, + mock_dns_provider_service, + mock_acme, + ): provider = plugin.ACMEIssuerPlugin() mock_authority = Mock() @@ -310,11 +353,129 @@ class TestAcme(unittest.TestCase): mock_dns_provider_service.get.return_value = mock_dns_provider issuer_options = { - 'authority': mock_authority, - 'dns_provider': mock_dns_provider, - "common_name": "test.netflix.net" + "authority": mock_authority, + "dns_provider": mock_dns_provider, + "common_name": "test.netflix.net", } csr = "123" mock_request_certificate.return_value = ("pem_certificate", "chain") result = provider.create_certificate(csr, issuer_options) assert result + + @patch("lemur.plugins.lemur_acme.ultradns.requests") + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + def test_get_ultradns_token(self, mock_current_app, mock_requests): + # ret_val = json.dumps({"access_token": "access"}) + the_response = Response() + the_response._content = b'{"access_token": "access"}' + mock_requests.post = Mock(return_value=the_response) + mock_current_app.config.get = Mock(return_value="Test") + result = ultradns.get_ultradns_token() + self.assertTrue(len(result) > 0) + + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + def test_create_txt_record(self, mock_current_app): + domain = "_acme_challenge.test.example.com" + zone = "test.example.com" + token = "ABCDEFGHIJ" + account_number = "1234567890" + change_id = (domain, token) + ultradns.get_zone_name = Mock(return_value=zone) + mock_current_app.logger.debug = Mock() + ultradns._post = Mock() + log_data = { + "function": "create_txt_record", + "fqdn": domain, + "token": token, + "message": "TXT record created" + } + result = ultradns.create_txt_record(domain, token, account_number) + mock_current_app.logger.debug.assert_called_with(log_data) + self.assertEqual(result, change_id) + + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + @patch("lemur.extensions.metrics") + def test_delete_txt_record(self, mock_metrics, mock_current_app): + domain = "_acme_challenge.test.example.com" + zone = "test.example.com" + token = "ABCDEFGHIJ" + account_number = "1234567890" + change_id = (domain, token) + mock_current_app.logger.debug = Mock() + ultradns.get_zone_name = Mock(return_value=zone) + ultradns._post = Mock() + ultradns._get = Mock() + ultradns._get.return_value = {'zoneName': 'test.example.com.com', + 'rrSets': [{'ownerName': '_acme-challenge.test.example.com.', + 'rrtype': 'TXT (16)', 'ttl': 5, 'rdata': ['ABCDEFGHIJ']}], + 'queryInfo': {'sort': 'OWNER', 'reverse': False, 'limit': 100}, + 'resultInfo': {'totalCount': 1, 'offset': 0, 'returnedCount': 1}} + ultradns._delete = Mock() + mock_metrics.send = Mock() + ultradns.delete_txt_record(change_id, account_number, domain, token) + mock_current_app.logger.debug.assert_not_called() + mock_metrics.send.assert_not_called() + + @patch("lemur.plugins.lemur_acme.ultradns.current_app") + @patch("lemur.extensions.metrics") + def test_wait_for_dns_change(self, mock_metrics, mock_current_app): + ultradns._has_dns_propagated = Mock(return_value=True) + nameserver = "1.1.1.1" + ultradns.get_authoritative_nameserver = Mock(return_value=nameserver) + mock_metrics.send = Mock() + domain = "_acme-challenge.test.example.com" + token = "ABCDEFGHIJ" + change_id = (domain, token) + mock_current_app.logger.debug = Mock() + ultradns.wait_for_dns_change(change_id) + # mock_metrics.send.assert_not_called() + log_data = { + "function": "wait_for_dns_change", + "fqdn": domain, + "status": True, + "message": "Record status on Public DNS" + } + mock_current_app.logger.debug.assert_called_with(log_data) + + def test_get_zone_name(self): + zones = ['example.com', 'test.example.com'] + zone = "test.example.com" + domain = "_acme-challenge.test.example.com" + account_number = "1234567890" + ultradns.get_zones = Mock(return_value=zones) + result = ultradns.get_zone_name(domain, account_number) + self.assertEqual(result, zone) + + def test_get_zones(self): + account_number = "1234567890" + path = "a/b/c" + zones = ['example.com', 'test.example.com'] + paginate_response = [{ + 'properties': { + 'name': 'example.com.', 'accountName': 'example', 'type': 'PRIMARY', + 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, + 'lastModifiedDateTime': '2017-06-14T06:45Z'}, + 'registrarInfo': { + 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', + 'example.ultradns.biz.', 'example.ultradns.org.']}}, + 'inherit': 'ALL'}, { + 'properties': { + 'name': 'test.example.com.', 'accountName': 'example', 'type': 'PRIMARY', + 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, + 'lastModifiedDateTime': '2017-06-14T06:45Z'}, + 'registrarInfo': { + 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', + 'example.ultradns.biz.', 'example.ultradns.org.']}}, + 'inherit': 'ALL'}, { + 'properties': { + 'name': 'example2.com.', 'accountName': 'example', 'type': 'SECONDARY', + 'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9, + 'lastModifiedDateTime': '2017-06-14T06:45Z'}, + 'registrarInfo': { + 'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.', + 'example.ultradns.biz.', 'example.ultradns.org.']}}, + 'inherit': 'ALL'}] + ultradns._paginate = Mock(path, "zones") + ultradns._paginate.side_effect = [[paginate_response]] + result = ultradns.get_zones(account_number) + self.assertEqual(result, zones) diff --git a/lemur/plugins/lemur_acme/ultradns.py b/lemur/plugins/lemur_acme/ultradns.py new file mode 100644 index 00000000..dcf3e3c6 --- /dev/null +++ b/lemur/plugins/lemur_acme/ultradns.py @@ -0,0 +1,445 @@ +import time +import requests +import json +import sys + +import dns +import dns.exception +import dns.name +import dns.query +import dns.resolver + +from flask import current_app +from lemur.extensions import metrics, sentry + + +class Record: + """ + This class implements an Ultra DNS record. + + Accepts the response from the API call as the argument. + """ + + def __init__(self, _data): + # Since we are dealing with only TXT records for Lemur, we expect only 1 RRSet in the response. + # Thus we default to picking up the first entry (_data["rrsets"][0]) from the response. + self._data = _data["rrSets"][0] + + @property + def name(self): + return self._data["ownerName"] + + @property + def rrtype(self): + return self._data["rrtype"] + + @property + def rdata(self): + return self._data["rdata"] + + @property + def ttl(self): + return self._data["ttl"] + + +class Zone: + """ + This class implements an Ultra DNS zone. + """ + + def __init__(self, _data, _client="Client"): + self._data = _data + self._client = _client + + @property + def name(self): + """ + Zone name, has a trailing "." at the end, which we manually remove. + """ + return self._data["properties"]["name"][:-1] + + @property + def authoritative_type(self): + """ + Indicates whether the zone is setup as a PRIMARY or SECONDARY + """ + return self._data["properties"]["type"] + + @property + def record_count(self): + return self._data["properties"]["resourceRecordCount"] + + @property + def status(self): + """ + Returns the status of the zone - ACTIVE, SUSPENDED, etc + """ + return self._data["properties"]["status"] + + +def get_ultradns_token(): + """ + Function to call the UltraDNS Authorization API. + + Returns the Authorization access_token which is valid for 1 hour. + Each request calls this function and we generate a new token every time. + """ + path = "/v2/authorization/token" + data = { + "grant_type": "password", + "username": current_app.config.get("ACME_ULTRADNS_USERNAME", ""), + "password": current_app.config.get("ACME_ULTRADNS_PASSWORD", ""), + } + base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "") + resp = requests.post(f"{base_uri}{path}", data=data, verify=True) + return resp.json()["access_token"] + + +def _generate_header(): + """ + Function to generate the header for a request. + + Contains the Authorization access_key obtained from the get_ultradns_token() function. + """ + access_token = get_ultradns_token() + return {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"} + + +def _paginate(path, key): + limit = 100 + params = {"offset": 0, "limit": 1} + resp = _get(path, params) + for index in range(0, resp["resultInfo"]["totalCount"], limit): + params["offset"] = index + params["limit"] = limit + resp = _get(path, params) + yield resp[key] + + +def _get(path, params=None): + """Function to execute a GET request on the given URL (base_uri + path) with given params""" + base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "") + resp = requests.get( + f"{base_uri}{path}", + headers=_generate_header(), + params=params, + verify=True, + ) + resp.raise_for_status() + return resp.json() + + +def _delete(path): + """Function to execute a DELETE request on the given URL""" + base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "") + resp = requests.delete( + f"{base_uri}{path}", + headers=_generate_header(), + verify=True, + ) + resp.raise_for_status() + + +def _post(path, params): + """Executes a POST request on given URL. Body is sent in JSON format""" + base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "") + resp = requests.post( + f"{base_uri}{path}", + headers=_generate_header(), + data=json.dumps(params), + verify=True, + ) + resp.raise_for_status() + + +def _has_dns_propagated(name, token, domain): + """ + Check whether the DNS change made by Lemur have propagated to the public DNS or not. + + Invoked by wait_for_dns_change() function + """ + txt_records = [] + try: + dns_resolver = dns.resolver.Resolver() + dns_resolver.nameservers = [domain] + dns_response = dns_resolver.query(name, "TXT") + for rdata in dns_response: + for txt_record in rdata.strings: + txt_records.append(txt_record.decode("utf-8")) + except dns.exception.DNSException: + function = sys._getframe().f_code.co_name + metrics.send(f"{function}.fail", "counter", 1) + return False + + for txt_record in txt_records: + if txt_record == token: + function = sys._getframe().f_code.co_name + metrics.send(f"{function}.success", "counter", 1) + return True + + return False + + +def wait_for_dns_change(change_id, account_number=None): + """ + Waits and checks if the DNS changes have propagated or not. + + First check the domains authoritative server. Once this succeeds, + we ask a public DNS server (Google <8.8.8.8> in our case). + """ + fqdn, token = change_id + number_of_attempts = 20 + nameserver = get_authoritative_nameserver(fqdn) + for attempts in range(0, number_of_attempts): + status = _has_dns_propagated(fqdn, token, nameserver) + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "fqdn": fqdn, + "status": status, + "message": "Record status on ultraDNS authoritative server" + } + current_app.logger.debug(log_data) + if status: + time.sleep(10) + break + time.sleep(10) + if status: + nameserver = get_public_authoritative_nameserver() + for attempts in range(0, number_of_attempts): + status = _has_dns_propagated(fqdn, token, nameserver) + log_data = { + "function": function, + "fqdn": fqdn, + "status": status, + "message": "Record status on Public DNS" + } + current_app.logger.debug(log_data) + if status: + metrics.send(f"{function}.success", "counter", 1) + break + time.sleep(10) + if not status: + metrics.send(f"{function}.fail", "counter", 1, metric_tags={"fqdn": fqdn, "txt_record": token}) + sentry.captureException(extra={"fqdn": str(fqdn), "txt_record": str(token)}) + return + + +def get_zones(account_number): + """Get zones from the UltraDNS""" + path = "/v2/zones" + zones = [] + for page in _paginate(path, "zones"): + for elem in page: + # UltraDNS zone names end with a "." - Example - lemur.example.com. + # We pick out the names minus the "." at the end while returning the list + zone = Zone(elem) + if zone.authoritative_type == "PRIMARY" and zone.status == "ACTIVE": + zones.append(zone.name) + + return zones + + +def get_zone_name(domain, account_number): + """Get the matching zone for the given domain""" + zones = get_zones(account_number) + zone_name = "" + for z in zones: + if domain.endswith(z): + # Find the most specific zone possible for the domain + # Ex: If fqdn is a.b.c.com, there is a zone for c.com, + # and a zone for b.c.com, we want to use b.c.com. + if z.count(".") > zone_name.count("."): + zone_name = z + if not zone_name: + function = sys._getframe().f_code.co_name + metrics.send(f"{function}.fail", "counter", 1) + raise Exception(f"No UltraDNS zone found for domain: {domain}") + return zone_name + + +def create_txt_record(domain, token, account_number): + """ + Create a TXT record for the given domain. + + The part of the domain that matches with the zone becomes the zone name. + The remainder becomes the owner name (referred to as node name here) + Example: Let's say we have a zone named "exmaple.com" in UltraDNS and we + get a request to create a cert for lemur.example.com + Domain - _acme-challenge.lemur.example.com + Matching zone - example.com + Owner name - _acme-challenge.lemur + """ + + zone_name = get_zone_name(domain, account_number) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) + fqdn = f"{node_name}.{zone_name}" + path = f"/v2/zones/{zone_name}/rrsets/TXT/{node_name}" + params = { + "ttl": 5, + "rdata": [ + f"{token}" + ], + } + + try: + _post(path, params) + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "fqdn": fqdn, + "token": token, + "message": "TXT record created" + } + current_app.logger.debug(log_data) + except Exception as e: + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "domain": domain, + "token": token, + "Exception": e, + "message": "Unable to add record. Record already exists." + } + current_app.logger.debug(log_data) + + change_id = (fqdn, token) + return change_id + + +def delete_txt_record(change_id, account_number, domain, token): + """ + Delete the TXT record that was created in the create_txt_record() function. + + UltraDNS handles records differently compared to Dyn. It creates an RRSet + which is a set of records of the same type and owner. This means + that while deleting the record, we cannot delete any individual record from + the RRSet. Instead, we have to delete the entire RRSet. If multiple certs are + being created for the same domain at the same time, the challenge TXT records + that are created will be added under the same RRSet. If the RRSet had more + than 1 record, then we create a new RRSet on UltraDNS minus the record that + has to be deleted. + """ + + if not domain: + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "message": "No domain passed" + } + current_app.logger.debug(log_data) + return + + zone_name = get_zone_name(domain, account_number) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) + path = f"/v2/zones/{zone_name}/rrsets/16/{node_name}" + + try: + rrsets = _get(path) + record = Record(rrsets) + except Exception as e: + function = sys._getframe().f_code.co_name + metrics.send(f"{function}.geterror", "counter", 1) + # No Text Records remain or host is not in the zone anymore because all records have been deleted. + return + try: + # Remove the record from the RRSet locally + record.rdata.remove(f"{token}") + except ValueError: + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "token": token, + "message": "Token not found" + } + current_app.logger.debug(log_data) + return + + # Delete the RRSet from UltraDNS + _delete(path) + + # Check if the RRSet has more records. If yes, add the modified RRSet back to UltraDNS + if len(record.rdata) > 0: + params = { + "ttl": 5, + "rdata": record.rdata, + } + _post(path, params) + + +def delete_acme_txt_records(domain): + + if not domain: + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "message": "No domain passed" + } + current_app.logger.debug(log_data) + return + acme_challenge_string = "_acme-challenge" + if not domain.startswith(acme_challenge_string): + function = sys._getframe().f_code.co_name + log_data = { + "function": function, + "domain": domain, + "acme_challenge_string": acme_challenge_string, + "message": "Domain does not start with the acme challenge string" + } + current_app.logger.debug(log_data) + return + + zone_name = get_zone_name(domain) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) + path = f"/v2/zones/{zone_name}/rrsets/16/{node_name}" + + _delete(path) + + +def get_authoritative_nameserver(domain): + """Get the authoritative nameserver for the given domain""" + n = dns.name.from_text(domain) + + depth = 2 + default = dns.resolver.get_default_resolver() + nameserver = default.nameservers[0] + + last = False + while not last: + s = n.split(depth) + + last = s[0].to_unicode() == u"@" + sub = s[1] + + query = dns.message.make_query(sub, dns.rdatatype.NS) + response = dns.query.udp(query, nameserver) + + rcode = response.rcode() + if rcode != dns.rcode.NOERROR: + function = sys._getframe().f_code.co_name + metrics.send(f"{function}.error", "counter", 1) + if rcode == dns.rcode.NXDOMAIN: + raise Exception("%s does not exist." % sub) + else: + raise Exception("Error %s" % dns.rcode.to_text(rcode)) + + if len(response.authority) > 0: + rrset = response.authority[0] + else: + rrset = response.answer[0] + + rr = rrset[0] + if rr.rdtype != dns.rdatatype.SOA: + authority = rr.target + nameserver = default.query(authority).rrset[0].to_text() + + depth += 1 + + return nameserver + + +def get_public_authoritative_nameserver(): + return "8.8.8.8" diff --git a/lemur/plugins/lemur_adcs/__init__.py b/lemur/plugins/lemur_adcs/__init__.py new file mode 100644 index 00000000..b902ed7a --- /dev/null +++ b/lemur/plugins/lemur_adcs/__init__.py @@ -0,0 +1,5 @@ +"""Set the version information.""" +try: + VERSION = __import__("pkg_resources").get_distribution(__name__).version +except Exception as e: + VERSION = "unknown" diff --git a/lemur/plugins/lemur_adcs/plugin.py b/lemur/plugins/lemur_adcs/plugin.py new file mode 100644 index 00000000..bc07ede3 --- /dev/null +++ b/lemur/plugins/lemur_adcs/plugin.py @@ -0,0 +1,130 @@ +from lemur.plugins.bases import IssuerPlugin, SourcePlugin +import requests +from lemur.plugins import lemur_adcs as ADCS +from certsrv import Certsrv +from OpenSSL import crypto +from flask import current_app + + +class ADCSIssuerPlugin(IssuerPlugin): + title = "ADCS" + slug = "adcs-issuer" + description = "Enables the creation of certificates by ADCS (Active Directory Certificate Services)" + version = ADCS.VERSION + + author = "sirferl" + author_url = "https://github.com/sirferl/lemur" + + def __init__(self, *args, **kwargs): + """Initialize the issuer with the appropriate details.""" + self.session = requests.Session() + super(ADCSIssuerPlugin, self).__init__(*args, **kwargs) + + @staticmethod + def create_authority(options): + """Create an authority. + Creates an authority, this authority is then used by Lemur to + allow a user to specify which Certificate Authority they want + to sign their certificate. + + :param options: + :return: + """ + adcs_root = current_app.config.get("ADCS_ROOT") + adcs_issuing = current_app.config.get("ADCS_ISSUING") + role = {"username": "", "password": "", "name": "adcs"} + return adcs_root, adcs_issuing, [role] + + def create_certificate(self, csr, issuer_options): + adcs_server = current_app.config.get("ADCS_SERVER") + adcs_user = current_app.config.get("ADCS_USER") + adcs_pwd = current_app.config.get("ADCS_PWD") + adcs_auth_method = current_app.config.get("ADCS_AUTH_METHOD") + adcs_template = current_app.config.get("ADCS_TEMPLATE") + ca_server = Certsrv( + adcs_server, adcs_user, adcs_pwd, auth_method=adcs_auth_method + ) + current_app.logger.info("Requesting CSR: {0}".format(csr)) + current_app.logger.info("Issuer options: {0}".format(issuer_options)) + cert, req_id = ( + ca_server.get_cert(csr, adcs_template, encoding="b64") + .decode("utf-8") + .replace("\r\n", "\n") + ) + chain = ( + ca_server.get_ca_cert(encoding="b64").decode("utf-8").replace("\r\n", "\n") + ) + return cert, chain, req_id + + def revoke_certificate(self, certificate, comments): + raise NotImplementedError("Not implemented\n", self, certificate, comments) + + def get_ordered_certificate(self, order_id): + raise NotImplementedError("Not implemented\n", self, order_id) + + def canceled_ordered_certificate(self, pending_cert, **kwargs): + raise NotImplementedError("Not implemented\n", self, pending_cert, **kwargs) + + +class ADCSSourcePlugin(SourcePlugin): + title = "ADCS" + slug = "adcs-source" + description = "Enables the collecion of certificates" + version = ADCS.VERSION + + author = "sirferl" + author_url = "https://github.com/sirferl/lemur" + options = [ + { + "name": "dummy", + "type": "str", + "required": False, + "validation": "/^[0-9]{12,12}$/", + "helpMessage": "Just to prevent error", + } + ] + + def get_certificates(self, options, **kwargs): + adcs_server = current_app.config.get("ADCS_SERVER") + adcs_user = current_app.config.get("ADCS_USER") + adcs_pwd = current_app.config.get("ADCS_PWD") + adcs_auth_method = current_app.config.get("ADCS_AUTH_METHOD") + adcs_start = current_app.config.get("ADCS_START") + adcs_stop = current_app.config.get("ADCS_STOP") + ca_server = Certsrv( + adcs_server, adcs_user, adcs_pwd, auth_method=adcs_auth_method + ) + out_certlist = [] + for id in range(adcs_start, adcs_stop): + try: + cert = ( + ca_server.get_existing_cert(id, encoding="b64") + .decode("utf-8") + .replace("\r\n", "\n") + ) + except Exception as err: + if "{0}".format(err).find("CERTSRV_E_PROPERTY_EMPTY"): + # this error indicates end of certificate list(?), so we stop + break + else: + # We do nothing in case there is no certificate returned for other reasons + current_app.logger.info("Error with id {0}: {1}".format(id, err)) + else: + # we have a certificate + pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + # loop through extensions to see if we find "TLS Web Server Authentication" + for e_id in range(0, pubkey.get_extension_count() - 1): + try: + extension = "{0}".format(pubkey.get_extension(e_id)) + except Exception: + extensionn = "" + if extension.find("TLS Web Server Authentication") != -1: + out_certlist.append( + {"name": format(pubkey.get_subject().CN), "body": cert} + ) + break + return out_certlist + + def get_endpoints(self, options, **kwargs): + # There are no endpoints in the ADCS + raise NotImplementedError("Not implemented\n", self, options, **kwargs) diff --git a/lemur/plugins/lemur_atlas/__init__.py b/lemur/plugins/lemur_atlas/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_atlas/__init__.py +++ b/lemur/plugins/lemur_atlas/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_atlas/plugin.py b/lemur/plugins/lemur_atlas/plugin.py index 09d4c9f9..7cf78ed2 100644 --- a/lemur/plugins/lemur_atlas/plugin.py +++ b/lemur/plugins/lemur_atlas/plugin.py @@ -26,44 +26,41 @@ def millis_since_epoch(): class AtlasMetricPlugin(MetricPlugin): - title = 'Atlas' - slug = 'atlas-metric' - description = 'Adds support for sending key metrics to Atlas' + title = "Atlas" + slug = "atlas-metric" + description = "Adds support for sending key metrics to Atlas" version = atlas.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" options = [ { - 'name': 'sidecar_host', - 'type': 'str', - 'required': False, - 'help_message': 'If no host is provided localhost is assumed', - 'default': 'localhost' + "name": "sidecar_host", + "type": "str", + "required": False, + "help_message": "If no host is provided localhost is assumed", + "default": "localhost", }, - { - 'name': 'sidecar_port', - 'type': 'int', - 'required': False, - 'default': 8078 - } + {"name": "sidecar_port", "type": "int", "required": False, "default": 8078}, ] metric_data = {} sidecar_host = None sidecar_port = None - def submit(self, metric_name, metric_type, metric_value, metric_tags=None, options=None): + def submit( + self, metric_name, metric_type, metric_value, metric_tags=None, options=None + ): if not options: options = self.options # TODO marshmallow schema? - valid_types = ['COUNTER', 'GAUGE', 'TIMER'] + valid_types = ["COUNTER", "GAUGE", "TIMER"] if metric_type.upper() not in valid_types: raise Exception( "Invalid Metric Type for Atlas: '{metric}' choose from: {options}".format( - metric=metric_type, options=','.join(valid_types) + metric=metric_type, options=",".join(valid_types) ) ) @@ -73,31 +70,35 @@ class AtlasMetricPlugin(MetricPlugin): "Invalid Metric Tags for Atlas: Tags must be in dict format" ) - if metric_value == "NaN" or isinstance(metric_value, int) or isinstance(metric_value, float): - self.metric_data['value'] = metric_value + if ( + metric_value == "NaN" + or isinstance(metric_value, int) + or isinstance(metric_value, float) + ): + self.metric_data["value"] = metric_value else: - raise Exception( - "Invalid Metric Value for Atlas: Metric must be a number" - ) + raise Exception("Invalid Metric Value for Atlas: Metric must be a number") - self.metric_data['type'] = metric_type.upper() - self.metric_data['name'] = str(metric_name) - self.metric_data['tags'] = metric_tags - self.metric_data['timestamp'] = millis_since_epoch() + self.metric_data["type"] = metric_type.upper() + self.metric_data["name"] = str(metric_name) + self.metric_data["tags"] = metric_tags + self.metric_data["timestamp"] = millis_since_epoch() - self.sidecar_host = self.get_option('sidecar_host', options) - self.sidecar_port = self.get_option('sidecar_port', options) + self.sidecar_host = self.get_option("sidecar_host", options) + self.sidecar_port = self.get_option("sidecar_port", options) try: res = requests.post( - 'http://{host}:{port}/metrics'.format( - host=self.sidecar_host, - port=self.sidecar_port), - data=json.dumps([self.metric_data]) + "http://{host}:{port}/metrics".format( + host=self.sidecar_host, port=self.sidecar_port + ), + data=json.dumps([self.metric_data]), ) if res.status_code != 200: - current_app.logger.warning("Failed to publish altas metric. {0}".format(res.content)) + current_app.logger.warning( + "Failed to publish altas metric. {0}".format(res.content) + ) except ConnectionError: current_app.logger.warning( diff --git a/lemur/plugins/lemur_atlas_redis/__init__.py b/lemur/plugins/lemur_atlas_redis/__init__.py new file mode 100644 index 00000000..f8afd7e3 --- /dev/null +++ b/lemur/plugins/lemur_atlas_redis/__init__.py @@ -0,0 +1,4 @@ +try: + VERSION = __import__("pkg_resources").get_distribution(__name__).version +except Exception as e: + VERSION = "unknown" diff --git a/lemur/plugins/lemur_atlas_redis/plugin.py b/lemur/plugins/lemur_atlas_redis/plugin.py new file mode 100644 index 00000000..e69ae672 --- /dev/null +++ b/lemur/plugins/lemur_atlas_redis/plugin.py @@ -0,0 +1,97 @@ +""" +.. module: lemur.plugins.lemur_atlas_redis.plugin + :platform: Unix + :copyright: (c) 2018 by Netflix Inc., see AUTHORS for more + :license: Apache, see LICENSE for more details. + +.. moduleauthor:: Jay Zarfoss +""" + +from redis import Redis +import json +from datetime import datetime + +from flask import current_app +from lemur.plugins import lemur_atlas as atlas +from lemur.plugins.bases.metric import MetricPlugin + + +def millis_since_epoch(): + """ + current time since epoch in milliseconds + """ + epoch = datetime.utcfromtimestamp(0) + delta = datetime.now() - epoch + return int(delta.total_seconds() * 1000.0) + + +class AtlasMetricRedisPlugin(MetricPlugin): + title = "AtlasRedis" + slug = "atlas-metric-redis" + description = "Adds support for sending key metrics to Atlas via local Redis" + version = atlas.VERSION + + author = "Jay Zarfoss" + author_url = "https://github.com/netflix/lemur" + + options = [ + { + "name": "redis_host", + "type": "str", + "required": False, + "help_message": "If no host is provided localhost is assumed", + "default": "localhost", + }, + {"name": "redis_port", "type": "int", "required": False, "default": 28527}, + ] + + metric_data = {} + redis_host = None + redis_port = None + + def submit( + self, metric_name, metric_type, metric_value, metric_tags=None, options=None + ): + if not options: + options = self.options + + valid_types = ["COUNTER", "GAUGE", "TIMER"] + if metric_type.upper() not in valid_types: + raise Exception( + "Invalid Metric Type for Atlas: '{metric}' choose from: {options}".format( + metric=metric_type, options=",".join(valid_types) + ) + ) + + if metric_tags: + if not isinstance(metric_tags, dict): + raise Exception( + "Invalid Metric Tags for Atlas: Tags must be in dict format" + ) + + self.metric_data["timestamp"] = millis_since_epoch() + self.metric_data["type"] = metric_type.upper() + self.metric_data["name"] = str(metric_name) + self.metric_data["tags"] = metric_tags + + if ( + metric_value == "NaN" + or isinstance(metric_value, int) + or isinstance(metric_value, float) + ): + self.metric_data["value"] = metric_value + else: + raise Exception("Invalid Metric Value for Atlas: Metric must be a number") + + self.redis_host = self.get_option("redis_host", options) + self.redis_port = self.get_option("redis_port", options) + + try: + r = Redis(host=self.redis_host, port=self.redis_port, socket_timeout=0.1) + r.rpush('atlas-agent', json.dumps(self.metric_data)) + except Exception as e: + current_app.logger.warning( + "AtlasMetricsRedis: exception [{exception}] could not post atlas metrics to AtlasRedis [{host}:{port}], metric [{metricdata}]".format( + exception=e, host=self.redis_host, port=self.redis_port, metricdata=json.dumps(self.metric_data) + ) + ) diff --git a/lemur/plugins/lemur_aws/__init__.py b/lemur/plugins/lemur_aws/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_aws/__init__.py +++ b/lemur/plugins/lemur_aws/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_aws/ec2.py b/lemur/plugins/lemur_aws/ec2.py index 3bd20e60..04b42140 100644 --- a/lemur/plugins/lemur_aws/ec2.py +++ b/lemur/plugins/lemur_aws/ec2.py @@ -8,16 +8,16 @@ from lemur.plugins.lemur_aws.sts import sts_client -@sts_client('ec2') +@sts_client("ec2") def get_regions(**kwargs): - regions = kwargs['client'].describe_regions() - return [x['RegionName'] for x in regions['Regions']] + regions = kwargs["client"].describe_regions() + return [x["RegionName"] for x in regions["Regions"]] -@sts_client('ec2') +@sts_client("ec2") def get_all_instances(**kwargs): """ Fetches all instance objects for a given account and region. """ - paginator = kwargs['client'].get_paginator('describe_instances') + paginator = kwargs["client"].get_paginator("describe_instances") return paginator.paginate() diff --git a/lemur/plugins/lemur_aws/elb.py b/lemur/plugins/lemur_aws/elb.py index 4c4ce97f..595a3826 100644 --- a/lemur/plugins/lemur_aws/elb.py +++ b/lemur/plugins/lemur_aws/elb.py @@ -10,7 +10,7 @@ from flask import current_app from retrying import retry -from lemur.extensions import metrics +from lemur.extensions import metrics, sentry from lemur.exceptions import InvalidListener from lemur.plugins.lemur_aws.sts import sts_client @@ -21,14 +21,21 @@ def retry_throttled(exception): :param exception: :return: """ + + # Log details about the exception + try: + raise exception + except Exception as e: + current_app.logger.error("ELB retry_throttled triggered", exc_info=True) + metrics.send("elb_retry", "counter", 1, metric_tags={"exception": str(e)}) + sentry.captureException() + if isinstance(exception, botocore.exceptions.ClientError): - if exception.response['Error']['Code'] == 'LoadBalancerNotFound': + if exception.response["Error"]["Code"] == "LoadBalancerNotFound": return False - if exception.response['Error']['Code'] == 'CertificateNotFound': + if exception.response["Error"]["Code"] == "CertificateNotFound": return False - - metrics.send('elb_retry', 'counter', 1) return True @@ -48,7 +55,7 @@ def is_valid(listener_tuple): :param listener_tuple: """ lb_port, i_port, lb_protocol, arn = listener_tuple - if lb_protocol.lower() in ['ssl', 'https']: + if lb_protocol.lower() in ["ssl", "https"]: if not arn: raise InvalidListener @@ -63,16 +70,20 @@ def get_all_elbs(**kwargs): :return: """ elbs = [] + try: + while True: + response = get_elbs(**kwargs) - while True: - response = get_elbs(**kwargs) + elbs += response["LoadBalancerDescriptions"] - elbs += response['LoadBalancerDescriptions'] - - if not response.get('NextMarker'): - return elbs - else: - kwargs.update(dict(Marker=response['NextMarker'])) + if not response.get("NextMarker"): + return elbs + else: + kwargs.update(dict(Marker=response["NextMarker"])) + except Exception as e: # noqa + metrics.send("get_all_elbs_error", "counter", 1) + sentry.captureException() + raise def get_all_elbs_v2(**kwargs): @@ -84,18 +95,23 @@ def get_all_elbs_v2(**kwargs): """ elbs = [] - while True: - response = get_elbs_v2(**kwargs) - elbs += response['LoadBalancers'] + try: + while True: + response = get_elbs_v2(**kwargs) + elbs += response["LoadBalancers"] - if not response.get('NextMarker'): - return elbs - else: - kwargs.update(dict(Marker=response['NextMarker'])) + if not response.get("NextMarker"): + return elbs + else: + kwargs.update(dict(Marker=response["NextMarker"])) + except Exception as e: # noqa + metrics.send("get_all_elbs_v2_error", "counter", 1) + sentry.captureException() + raise -@sts_client('elbv2') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elbv2") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def get_listener_arn_from_endpoint(endpoint_name, endpoint_port, **kwargs): """ Get a listener ARN from an endpoint. @@ -103,27 +119,53 @@ def get_listener_arn_from_endpoint(endpoint_name, endpoint_port, **kwargs): :param endpoint_port: :return: """ - client = kwargs.pop('client') - elbs = client.describe_load_balancers(Names=[endpoint_name]) - for elb in elbs['LoadBalancers']: - listeners = client.describe_listeners(LoadBalancerArn=elb['LoadBalancerArn']) - for listener in listeners['Listeners']: - if listener['Port'] == endpoint_port: - return listener['ListenerArn'] + try: + client = kwargs.pop("client") + elbs = client.describe_load_balancers(Names=[endpoint_name]) + for elb in elbs["LoadBalancers"]: + listeners = client.describe_listeners( + LoadBalancerArn=elb["LoadBalancerArn"] + ) + for listener in listeners["Listeners"]: + if listener["Port"] == endpoint_port: + return listener["ListenerArn"] + except Exception as e: # noqa + metrics.send( + "get_listener_arn_from_endpoint_error", + "counter", + 1, + metric_tags={ + "error": str(e), + "endpoint_name": endpoint_name, + "endpoint_port": endpoint_port, + }, + ) + sentry.captureException( + extra={ + "endpoint_name": str(endpoint_name), + "endpoint_port": str(endpoint_port), + } + ) + raise -@sts_client('elb') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elb") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def get_elbs(**kwargs): """ Fetches one page elb objects for a given account and region. """ - client = kwargs.pop('client') - return client.describe_load_balancers(**kwargs) + try: + client = kwargs.pop("client") + return client.describe_load_balancers(**kwargs) + except Exception as e: # noqa + metrics.send("get_elbs_error", "counter", 1, metric_tags={"error": str(e)}) + sentry.captureException() + raise -@sts_client('elbv2') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elbv2") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def get_elbs_v2(**kwargs): """ Fetches one page of elb objects for a given account and region. @@ -131,12 +173,17 @@ def get_elbs_v2(**kwargs): :param kwargs: :return: """ - client = kwargs.pop('client') - return client.describe_load_balancers(**kwargs) + try: + client = kwargs.pop("client") + return client.describe_load_balancers(**kwargs) + except Exception as e: # noqa + metrics.send("get_elbs_v2_error", "counter", 1, metric_tags={"error": str(e)}) + sentry.captureException() + raise -@sts_client('elbv2') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elbv2") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_listeners_v2(**kwargs): """ Fetches one page of listener objects for a given elb arn. @@ -144,12 +191,19 @@ def describe_listeners_v2(**kwargs): :param kwargs: :return: """ - client = kwargs.pop('client') - return client.describe_listeners(**kwargs) + try: + client = kwargs.pop("client") + return client.describe_listeners(**kwargs) + except Exception as e: # noqa + metrics.send( + "describe_listeners_v2_error", "counter", 1, metric_tags={"error": str(e)} + ) + sentry.captureException() + raise -@sts_client('elb') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elb") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_load_balancer_policies(load_balancer_name, policy_names, **kwargs): """ Fetching all policies currently associated with an ELB. @@ -157,11 +211,33 @@ def describe_load_balancer_policies(load_balancer_name, policy_names, **kwargs): :param load_balancer_name: :return: """ - return kwargs['client'].describe_load_balancer_policies(LoadBalancerName=load_balancer_name, PolicyNames=policy_names) + + try: + return kwargs["client"].describe_load_balancer_policies( + LoadBalancerName=load_balancer_name, PolicyNames=policy_names + ) + except Exception as e: # noqa + metrics.send( + "describe_load_balancer_policies_error", + "counter", + 1, + metric_tags={ + "load_balancer_name": load_balancer_name, + "policy_names": policy_names, + "error": str(e), + }, + ) + sentry.captureException( + extra={ + "load_balancer_name": str(load_balancer_name), + "policy_names": str(policy_names), + } + ) + raise -@sts_client('elbv2') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elbv2") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_ssl_policies_v2(policy_names, **kwargs): """ Fetching all policies currently associated with an ELB. @@ -169,11 +245,21 @@ def describe_ssl_policies_v2(policy_names, **kwargs): :param policy_names: :return: """ - return kwargs['client'].describe_ssl_policies(Names=policy_names) + try: + return kwargs["client"].describe_ssl_policies(Names=policy_names) + except Exception as e: # noqa + metrics.send( + "describe_ssl_policies_v2_error", + "counter", + 1, + metric_tags={"policy_names": policy_names, "error": str(e)}, + ) + sentry.captureException(extra={"policy_names": str(policy_names)}) + raise -@sts_client('elb') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elb") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_load_balancer_types(policies, **kwargs): """ Describe the policies with policy details. @@ -181,11 +267,13 @@ def describe_load_balancer_types(policies, **kwargs): :param policies: :return: """ - return kwargs['client'].describe_load_balancer_policy_types(PolicyTypeNames=policies) + return kwargs["client"].describe_load_balancer_policy_types( + PolicyTypeNames=policies + ) -@sts_client('elb') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elb") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def attach_certificate(name, port, certificate_id, **kwargs): """ Attaches a certificate to a listener, throws exception @@ -196,16 +284,20 @@ def attach_certificate(name, port, certificate_id, **kwargs): :param certificate_id: """ try: - return kwargs['client'].set_load_balancer_listener_ssl_certificate(LoadBalancerName=name, LoadBalancerPort=port, SSLCertificateId=certificate_id) + return kwargs["client"].set_load_balancer_listener_ssl_certificate( + LoadBalancerName=name, + LoadBalancerPort=port, + SSLCertificateId=certificate_id, + ) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] == 'LoadBalancerNotFound': + if e.response["Error"]["Code"] == "LoadBalancerNotFound": current_app.logger.warning("Loadbalancer does not exist.") else: raise e -@sts_client('elbv2') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=1000) +@sts_client("elbv2") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def attach_certificate_v2(listener_arn, port, certificates, **kwargs): """ Attaches a certificate to a listener, throws exception @@ -216,9 +308,11 @@ def attach_certificate_v2(listener_arn, port, certificates, **kwargs): :param certificates: """ try: - return kwargs['client'].modify_listener(ListenerArn=listener_arn, Port=port, Certificates=certificates) + return kwargs["client"].modify_listener( + ListenerArn=listener_arn, Port=port, Certificates=certificates + ) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] == 'LoadBalancerNotFound': + if e.response["Error"]["Code"] == "LoadBalancerNotFound": current_app.logger.warning("Loadbalancer does not exist.") else: raise e diff --git a/lemur/plugins/lemur_aws/iam.py b/lemur/plugins/lemur_aws/iam.py index b2a07798..13590ddd 100644 --- a/lemur/plugins/lemur_aws/iam.py +++ b/lemur/plugins/lemur_aws/iam.py @@ -10,7 +10,7 @@ import botocore from retrying import retry -from lemur.extensions import metrics +from lemur.extensions import metrics, sentry from lemur.plugins.lemur_aws.sts import sts_client @@ -21,10 +21,10 @@ def retry_throttled(exception): :return: """ if isinstance(exception, botocore.exceptions.ClientError): - if exception.response['Error']['Code'] == 'NoSuchEntity': + if exception.response["Error"]["Code"] == "NoSuchEntity": return False - metrics.send('iam_retry', 'counter', 1) + metrics.send("iam_retry", "counter", 1, metric_tags={"exception": str(exception)}) return True @@ -47,12 +47,12 @@ def create_arn_from_cert(account_number, region, certificate_name): :return: """ return "arn:aws:iam::{account_number}:server-certificate/{certificate_name}".format( - account_number=account_number, - certificate_name=certificate_name) + account_number=account_number, certificate_name=certificate_name + ) -@sts_client('iam') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=100) +@sts_client("iam") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=25) def upload_cert(name, body, private_key, path, cert_chain=None, **kwargs): """ Upload a certificate to AWS @@ -64,38 +64,38 @@ def upload_cert(name, body, private_key, path, cert_chain=None, **kwargs): :param path: :return: """ - client = kwargs.pop('client') + assert isinstance(private_key, str) + client = kwargs.pop("client") - if not path or path == '/': - path = '/' + if not path or path == "/": + path = "/" else: - name = name + '-' + path.strip('/') + name = name + "-" + path.strip("/") + metrics.send("upload_cert", "counter", 1, metric_tags={"name": name, "path": path}) try: - if isinstance(private_key, bytes): - private_key = private_key.decode("utf-8") if cert_chain: return client.upload_server_certificate( Path=path, ServerCertificateName=name, CertificateBody=str(body), PrivateKey=str(private_key), - CertificateChain=str(cert_chain) + CertificateChain=str(cert_chain), ) else: return client.upload_server_certificate( Path=path, ServerCertificateName=name, CertificateBody=str(body), - PrivateKey=str(private_key) + PrivateKey=str(private_key), ) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] != 'EntityAlreadyExists': + if e.response["Error"]["Code"] != "EntityAlreadyExists": raise e -@sts_client('iam') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=100) +@sts_client("iam") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=25) def delete_cert(cert_name, **kwargs): """ Delete a certificate from AWS @@ -103,37 +103,42 @@ def delete_cert(cert_name, **kwargs): :param cert_name: :return: """ - client = kwargs.pop('client') + client = kwargs.pop("client") + metrics.send("delete_cert", "counter", 1, metric_tags={"cert_name": cert_name}) try: client.delete_server_certificate(ServerCertificateName=cert_name) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] != 'NoSuchEntity': + if e.response["Error"]["Code"] != "NoSuchEntity": raise e -@sts_client('iam') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=100) +@sts_client("iam") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=25) def get_certificate(name, **kwargs): """ Retrieves an SSL certificate. :return: """ - client = kwargs.pop('client') - return client.get_server_certificate( - ServerCertificateName=name - )['ServerCertificate'] + client = kwargs.pop("client") + metrics.send("get_certificate", "counter", 1, metric_tags={"name": name}) + try: + return client.get_server_certificate(ServerCertificateName=name)["ServerCertificate"] + except client.exceptions.NoSuchEntityException: + sentry.captureException() + return None -@sts_client('iam') -@retry(retry_on_exception=retry_throttled, stop_max_attempt_number=7, wait_exponential_multiplier=100) +@sts_client("iam") +@retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=25) def get_certificates(**kwargs): """ Fetches one page of certificate objects for a given account. :param kwargs: :return: """ - client = kwargs.pop('client') + client = kwargs.pop("client") + metrics.send("get_certificates", "counter", 1) return client.list_server_certificates(**kwargs) @@ -142,16 +147,26 @@ def get_all_certificates(**kwargs): Use STS to fetch all of the SSL certificates from a given account """ certificates = [] - account_number = kwargs.get('account_number') + account_number = kwargs.get("account_number") + metrics.send( + "get_all_certificates", + "counter", + 1, + metric_tags={"account_number": account_number}, + ) while True: response = get_certificates(**kwargs) - metadata = response['ServerCertificateMetadataList'] + metadata = response["ServerCertificateMetadataList"] for m in metadata: - certificates.append(get_certificate(m['ServerCertificateName'], account_number=account_number)) + certificates.append( + get_certificate( + m["ServerCertificateName"], account_number=account_number + ) + ) - if not response.get('Marker'): + if not response.get("Marker"): return certificates else: - kwargs.update(dict(Marker=response['Marker'])) + kwargs.update(dict(Marker=response["Marker"])) diff --git a/lemur/plugins/lemur_aws/plugin.py b/lemur/plugins/lemur_aws/plugin.py index 1c2607a5..6669f641 100644 --- a/lemur/plugins/lemur_aws/plugin.py +++ b/lemur/plugins/lemur_aws/plugin.py @@ -32,7 +32,9 @@ .. moduleauthor:: Mikhail Khodorovskiy .. moduleauthor:: Harm Weites """ +from acme.errors import ClientError from flask import current_app +from lemur.extensions import sentry, metrics from lemur.plugins import lemur_aws as aws from lemur.plugins.bases import DestinationPlugin, ExportDestinationPlugin, SourcePlugin @@ -40,7 +42,12 @@ from lemur.plugins.lemur_aws import iam, s3, elb, ec2 def get_region_from_dns(dns): - return dns.split('.')[-4] + # XXX.REGION.elb.amazonaws.com + if dns.endswith(".elb.amazonaws.com"): + return dns.split(".")[-4] + else: + # NLBs have a different pattern on the dns XXXX.elb.REGION.amazonaws.com + return dns.split(".")[-3] def format_elb_cipher_policy_v2(policy): @@ -52,10 +59,10 @@ def format_elb_cipher_policy_v2(policy): ciphers = [] name = None - for descr in policy['SslPolicies']: - name = descr['Name'] - for cipher in descr['Ciphers']: - ciphers.append(cipher['Name']) + for descr in policy["SslPolicies"]: + name = descr["Name"] + for cipher in descr["Ciphers"]: + ciphers.append(cipher["Name"]) return dict(name=name, ciphers=ciphers) @@ -68,14 +75,14 @@ def format_elb_cipher_policy(policy): """ ciphers = [] name = None - for descr in policy['PolicyDescriptions']: - for attr in descr['PolicyAttributeDescriptions']: - if attr['AttributeName'] == 'Reference-Security-Policy': - name = attr['AttributeValue'] + for descr in policy["PolicyDescriptions"]: + for attr in descr["PolicyAttributeDescriptions"]: + if attr["AttributeName"] == "Reference-Security-Policy": + name = attr["AttributeValue"] continue - if attr['AttributeValue'] == 'true': - ciphers.append(attr['AttributeName']) + if attr["AttributeValue"] == "true": + ciphers.append(attr["AttributeName"]) return dict(name=name, ciphers=ciphers) @@ -89,25 +96,31 @@ def get_elb_endpoints(account_number, region, elb_dict): :return: """ endpoints = [] - for listener in elb_dict['ListenerDescriptions']: - if not listener['Listener'].get('SSLCertificateId'): + for listener in elb_dict["ListenerDescriptions"]: + if not listener["Listener"].get("SSLCertificateId"): continue - if listener['Listener']['SSLCertificateId'] == 'Invalid-Certificate': + if listener["Listener"]["SSLCertificateId"] == "Invalid-Certificate": continue endpoint = dict( - name=elb_dict['LoadBalancerName'], - dnsname=elb_dict['DNSName'], - type='elb', - port=listener['Listener']['LoadBalancerPort'], - certificate_name=iam.get_name_from_arn(listener['Listener']['SSLCertificateId']) + name=elb_dict["LoadBalancerName"], + dnsname=elb_dict["DNSName"], + type="elb", + port=listener["Listener"]["LoadBalancerPort"], + certificate_name=iam.get_name_from_arn( + listener["Listener"]["SSLCertificateId"] + ), ) - if listener['PolicyNames']: - policy = elb.describe_load_balancer_policies(elb_dict['LoadBalancerName'], listener['PolicyNames'], - account_number=account_number, region=region) - endpoint['policy'] = format_elb_cipher_policy(policy) + if listener["PolicyNames"]: + policy = elb.describe_load_balancer_policies( + elb_dict["LoadBalancerName"], + listener["PolicyNames"], + account_number=account_number, + region=region, + ) + endpoint["policy"] = format_elb_cipher_policy(policy) current_app.logger.debug("Found new endpoint. Endpoint: {}".format(endpoint)) @@ -125,120 +138,100 @@ def get_elb_endpoints_v2(account_number, region, elb_dict): :return: """ endpoints = [] - listeners = elb.describe_listeners_v2(account_number=account_number, region=region, - LoadBalancerArn=elb_dict['LoadBalancerArn']) - for listener in listeners['Listeners']: - if not listener.get('Certificates'): + listeners = elb.describe_listeners_v2( + account_number=account_number, + region=region, + LoadBalancerArn=elb_dict["LoadBalancerArn"], + ) + for listener in listeners["Listeners"]: + if not listener.get("Certificates"): continue - for certificate in listener['Certificates']: + for certificate in listener["Certificates"]: endpoint = dict( - name=elb_dict['LoadBalancerName'], - dnsname=elb_dict['DNSName'], - type='elbv2', - port=listener['Port'], - certificate_name=iam.get_name_from_arn(certificate['CertificateArn']) + name=elb_dict["LoadBalancerName"], + dnsname=elb_dict["DNSName"], + type="elbv2", + port=listener["Port"], + certificate_name=iam.get_name_from_arn(certificate["CertificateArn"]), ) - if listener['SslPolicy']: - policy = elb.describe_ssl_policies_v2([listener['SslPolicy']], account_number=account_number, region=region) - endpoint['policy'] = format_elb_cipher_policy_v2(policy) + if listener["SslPolicy"]: + policy = elb.describe_ssl_policies_v2( + [listener["SslPolicy"]], account_number=account_number, region=region + ) + endpoint["policy"] = format_elb_cipher_policy_v2(policy) endpoints.append(endpoint) return endpoints -class AWSDestinationPlugin(DestinationPlugin): - title = 'AWS' - slug = 'aws-destination' - description = 'Allow the uploading of certificates to AWS IAM' - version = aws.VERSION - - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' - - options = [ - { - 'name': 'accountNumber', - 'type': 'str', - 'required': True, - 'validation': '[0-9]{12}', - 'helpMessage': 'Must be a valid AWS account number!', - }, - { - 'name': 'path', - 'type': 'str', - 'default': '/', - 'helpMessage': 'Path to upload certificate.' - } - ] - - # 'elb': { - # 'name': {'type': 'name'}, - # 'region': {'type': 'str'}, - # 'port': {'type': 'int'} - # } - - def upload(self, name, body, private_key, cert_chain, options, **kwargs): - iam.upload_cert(name, body, private_key, - self.get_option('path', options), - cert_chain=cert_chain, - account_number=self.get_option('accountNumber', options)) - - def deploy(self, elb_name, account, region, certificate): - pass - - class AWSSourcePlugin(SourcePlugin): - title = 'AWS' - slug = 'aws-source' - description = 'Discovers all SSL certificates and ELB endpoints in an AWS account' + title = "AWS" + slug = "aws-source" + description = "Discovers all SSL certificates and ELB endpoints in an AWS account" version = aws.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" options = [ { - 'name': 'accountNumber', - 'type': 'str', - 'required': True, - 'validation': '/^[0-9]{12,12}$/', - 'helpMessage': 'Must be a valid AWS account number!', + "name": "accountNumber", + "type": "str", + "required": True, + "validation": "/^[0-9]{12,12}$/", + "helpMessage": "Must be a valid AWS account number!", }, { - 'name': 'regions', - 'type': 'str', - 'helpMessage': 'Comma separated list of regions to search in, if no region is specified we look in all regions.' + "name": "regions", + "type": "str", + "helpMessage": "Comma separated list of regions to search in, if no region is specified we look in all regions.", }, ] def get_certificates(self, options, **kwargs): - cert_data = iam.get_all_certificates(account_number=self.get_option('accountNumber', options)) - return [dict(body=c['CertificateBody'], chain=c.get('CertificateChain'), - name=c['ServerCertificateMetadata']['ServerCertificateName']) for c in cert_data] + cert_data = iam.get_all_certificates( + account_number=self.get_option("accountNumber", options) + ) + return [ + dict( + body=c["CertificateBody"], + chain=c.get("CertificateChain"), + name=c["ServerCertificateMetadata"]["ServerCertificateName"], + ) + for c in cert_data + ] def get_endpoints(self, options, **kwargs): endpoints = [] - account_number = self.get_option('accountNumber', options) - regions = self.get_option('regions', options) + account_number = self.get_option("accountNumber", options) + regions = self.get_option("regions", options) if not regions: regions = ec2.get_regions(account_number=account_number) else: - regions = regions.split(',') + regions = "".join(regions.split()).split(",") for region in regions: elbs = elb.get_all_elbs(account_number=account_number, region=region) - current_app.logger.info("Describing classic load balancers in {0}-{1}".format(account_number, region)) + current_app.logger.info( + "Describing classic load balancers in {0}-{1}".format( + account_number, region + ) + ) for e in elbs: endpoints.extend(get_elb_endpoints(account_number, region, e)) # fetch advanced ELBs elbs_v2 = elb.get_all_elbs_v2(account_number=account_number, region=region) - current_app.logger.info("Describing advanced load balancers in {0}-{1}".format(account_number, region)) + current_app.logger.info( + "Describing advanced load balancers in {0}-{1}".format( + account_number, region + ) + ) for e in elbs_v2: endpoints.extend(get_elb_endpoints_v2(account_number, region, e)) @@ -247,69 +240,148 @@ class AWSSourcePlugin(SourcePlugin): def update_endpoint(self, endpoint, certificate): options = endpoint.source.options - account_number = self.get_option('accountNumber', options) + account_number = self.get_option("accountNumber", options) # relies on the fact that region is included in DNS name region = get_region_from_dns(endpoint.dnsname) arn = iam.create_arn_from_cert(account_number, region, certificate.name) - if endpoint.type == 'elbv2': - listener_arn = elb.get_listener_arn_from_endpoint(endpoint.name, endpoint.port, - account_number=account_number, region=region) - elb.attach_certificate_v2(listener_arn, endpoint.port, [{'CertificateArn': arn}], - account_number=account_number, region=region) + if endpoint.type == "elbv2": + listener_arn = elb.get_listener_arn_from_endpoint( + endpoint.name, + endpoint.port, + account_number=account_number, + region=region, + ) + elb.attach_certificate_v2( + listener_arn, + endpoint.port, + [{"CertificateArn": arn}], + account_number=account_number, + region=region, + ) else: - elb.attach_certificate(endpoint.name, endpoint.port, arn, account_number=account_number, region=region) + elb.attach_certificate( + endpoint.name, + endpoint.port, + arn, + account_number=account_number, + region=region, + ) def clean(self, certificate, options, **kwargs): - account_number = self.get_option('accountNumber', options) + account_number = self.get_option("accountNumber", options) iam.delete_cert(certificate.name, account_number=account_number) + def get_certificate_by_name(self, certificate_name, options): + account_number = self.get_option("accountNumber", options) + # certificate name may contain path, in which case we remove it + if "/" in certificate_name: + certificate_name = certificate_name.split('/')[-1] + try: + cert = iam.get_certificate(certificate_name, account_number=account_number) + if cert: + return dict( + body=cert["CertificateBody"], + chain=cert.get("CertificateChain"), + name=cert["ServerCertificateMetadata"]["ServerCertificateName"], + ) + except ClientError: + current_app.logger.warning( + "get_elb_certificate_failed: Unable to get certificate for {0}".format(certificate_name)) + sentry.captureException() + metrics.send( + "get_elb_certificate_failed", "counter", 1, + metric_tags={"certificate_name": certificate_name, "account_number": account_number} + ) + return None + + +class AWSDestinationPlugin(DestinationPlugin): + title = "AWS" + slug = "aws-destination" + description = "Allow the uploading of certificates to AWS IAM" + version = aws.VERSION + sync_as_source = True + sync_as_source_name = AWSSourcePlugin.slug + + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" + + options = [ + { + "name": "accountNumber", + "type": "str", + "required": True, + "validation": "[0-9]{12}", + "helpMessage": "Must be a valid AWS account number!", + }, + { + "name": "path", + "type": "str", + "default": "/", + "helpMessage": "Path to upload certificate.", + }, + ] + + def upload(self, name, body, private_key, cert_chain, options, **kwargs): + iam.upload_cert( + name, + body, + private_key, + self.get_option("path", options), + cert_chain=cert_chain, + account_number=self.get_option("accountNumber", options), + ) + + def deploy(self, elb_name, account, region, certificate): + pass + class S3DestinationPlugin(ExportDestinationPlugin): - title = 'AWS-S3' - slug = 'aws-s3' - description = 'Allow the uploading of certificates to Amazon S3' + title = "AWS-S3" + slug = "aws-s3" + description = "Allow the uploading of certificates to Amazon S3" - author = 'Mikhail Khodorovskiy, Harm Weites ' - author_url = 'https://github.com/Netflix/lemur' + author = "Mikhail Khodorovskiy, Harm Weites " + author_url = "https://github.com/Netflix/lemur" additional_options = [ { - 'name': 'bucket', - 'type': 'str', - 'required': True, - 'validation': '[0-9a-z.-]{3,63}', - 'helpMessage': 'Must be a valid S3 bucket name!', + "name": "bucket", + "type": "str", + "required": True, + "validation": "[0-9a-z.-]{3,63}", + "helpMessage": "Must be a valid S3 bucket name!", }, { - 'name': 'accountNumber', - 'type': 'str', - 'required': True, - 'validation': '[0-9]{12}', - 'helpMessage': 'A valid AWS account number with permission to access S3', + "name": "accountNumber", + "type": "str", + "required": True, + "validation": "[0-9]{12}", + "helpMessage": "A valid AWS account number with permission to access S3", }, { - 'name': 'region', - 'type': 'str', - 'default': 'us-east-1', - 'required': False, - 'helpMessage': 'Region bucket exists', - 'available': ['us-east-1', 'us-west-2', 'eu-west-1'] + "name": "region", + "type": "str", + "default": "us-east-1", + "required": False, + "helpMessage": "Region bucket exists", + "available": ["us-east-1", "us-west-2", "eu-west-1"], }, { - 'name': 'encrypt', - 'type': 'bool', - 'required': False, - 'helpMessage': 'Enable server side encryption', - 'default': True + "name": "encrypt", + "type": "bool", + "required": False, + "helpMessage": "Enable server side encryption", + "default": True, }, { - 'name': 'prefix', - 'type': 'str', - 'required': False, - 'helpMessage': 'Must be a valid S3 object prefix!', - } + "name": "prefix", + "type": "str", + "required": False, + "helpMessage": "Must be a valid S3 object prefix!", + }, ] def __init__(self, *args, **kwargs): @@ -320,13 +392,12 @@ class S3DestinationPlugin(ExportDestinationPlugin): for ext, passphrase, data in files: s3.put( - self.get_option('bucket', options), - self.get_option('region', options), - '{prefix}/{name}.{extension}'.format( - prefix=self.get_option('prefix', options), - name=name, - extension=ext), + self.get_option("bucket", options), + self.get_option("region", options), + "{prefix}/{name}.{extension}".format( + prefix=self.get_option("prefix", options), name=name, extension=ext + ), data, - self.get_option('encrypt', options), - account_number=self.get_option('accountNumber', options) + self.get_option("encrypt", options), + account_number=self.get_option("accountNumber", options), ) diff --git a/lemur/plugins/lemur_aws/s3.py b/lemur/plugins/lemur_aws/s3.py index 2f8983e5..43faa28f 100644 --- a/lemur/plugins/lemur_aws/s3.py +++ b/lemur/plugins/lemur_aws/s3.py @@ -10,28 +10,26 @@ from flask import current_app from .sts import sts_client -@sts_client('s3', service_type='resource') +@sts_client("s3", service_type="resource") def put(bucket_name, region, prefix, data, encrypt, **kwargs): """ Use STS to write to an S3 bucket """ - bucket = kwargs['resource'].Bucket(bucket_name) - current_app.logger.debug('Persisting data to S3. Bucket: {0} Prefix: {1}'.format(bucket_name, prefix)) + bucket = kwargs["resource"].Bucket(bucket_name) + current_app.logger.debug( + "Persisting data to S3. Bucket: {0} Prefix: {1}".format(bucket_name, prefix) + ) # get data ready for writing if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") if encrypt: bucket.put_object( Key=prefix, Body=data, - ACL='bucket-owner-full-control', - ServerSideEncryption='AES256' + ACL="bucket-owner-full-control", + ServerSideEncryption="AES256", ) else: - bucket.put_object( - Key=prefix, - Body=data, - ACL='bucket-owner-full-control' - ) + bucket.put_object(Key=prefix, Body=data, ACL="bucket-owner-full-control") diff --git a/lemur/plugins/lemur_aws/sts.py b/lemur/plugins/lemur_aws/sts.py index 001ea2c8..c1bd562c 100644 --- a/lemur/plugins/lemur_aws/sts.py +++ b/lemur/plugins/lemur_aws/sts.py @@ -9,40 +9,46 @@ from functools import wraps import boto3 +from botocore.config import Config from flask import current_app -def sts_client(service, service_type='client'): +config = Config(retries=dict(max_attempts=20)) + + +def sts_client(service, service_type="client"): def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): - sts = boto3.client('sts') - arn = 'arn:aws:iam::{0}:role/{1}'.format( - kwargs.pop('account_number'), - current_app.config.get('LEMUR_INSTANCE_PROFILE', 'Lemur') + sts = boto3.client("sts", config=config) + arn = "arn:aws:iam::{0}:role/{1}".format( + kwargs.pop("account_number"), + current_app.config.get("LEMUR_INSTANCE_PROFILE", "Lemur"), ) # TODO add user specific information to RoleSessionName - role = sts.assume_role(RoleArn=arn, RoleSessionName='lemur') + role = sts.assume_role(RoleArn=arn, RoleSessionName="lemur") - if service_type == 'client': + if service_type == "client": client = boto3.client( service, - region_name=kwargs.pop('region', 'us-east-1'), - aws_access_key_id=role['Credentials']['AccessKeyId'], - aws_secret_access_key=role['Credentials']['SecretAccessKey'], - aws_session_token=role['Credentials']['SessionToken'] + region_name=kwargs.pop("region", "us-east-1"), + aws_access_key_id=role["Credentials"]["AccessKeyId"], + aws_secret_access_key=role["Credentials"]["SecretAccessKey"], + aws_session_token=role["Credentials"]["SessionToken"], + config=config, ) - kwargs['client'] = client - elif service_type == 'resource': + kwargs["client"] = client + elif service_type == "resource": resource = boto3.resource( service, - region_name=kwargs.pop('region', 'us-east-1'), - aws_access_key_id=role['Credentials']['AccessKeyId'], - aws_secret_access_key=role['Credentials']['SecretAccessKey'], - aws_session_token=role['Credentials']['SessionToken'] + region_name=kwargs.pop("region", "us-east-1"), + aws_access_key_id=role["Credentials"]["AccessKeyId"], + aws_secret_access_key=role["Credentials"]["SecretAccessKey"], + aws_session_token=role["Credentials"]["SessionToken"], + config=config, ) - kwargs['resource'] = resource + kwargs["resource"] = resource return f(*args, **kwargs) return decorated_function diff --git a/lemur/plugins/lemur_aws/tests/test_elb.py b/lemur/plugins/lemur_aws/tests/test_elb.py index 7facc4dd..4571b87a 100644 --- a/lemur/plugins/lemur_aws/tests/test_elb.py +++ b/lemur/plugins/lemur_aws/tests/test_elb.py @@ -6,23 +6,24 @@ from moto import mock_sts, mock_elb @mock_elb() def test_get_all_elbs(app, aws_credentials): from lemur.plugins.lemur_aws.elb import get_all_elbs - client = boto3.client('elb', region_name='us-east-1') - elbs = get_all_elbs(account_number='123456789012', region='us-east-1') + client = boto3.client("elb", region_name="us-east-1") + + elbs = get_all_elbs(account_number="123456789012", region="us-east-1") assert not elbs client.create_load_balancer( - LoadBalancerName='example-lb', + LoadBalancerName="example-lb", Listeners=[ { - 'Protocol': 'string', - 'LoadBalancerPort': 443, - 'InstanceProtocol': 'tcp', - 'InstancePort': 5443, - 'SSLCertificateId': 'tcp' + "Protocol": "string", + "LoadBalancerPort": 443, + "InstanceProtocol": "tcp", + "InstancePort": 5443, + "SSLCertificateId": "tcp", } - ] + ], ) - elbs = get_all_elbs(account_number='123456789012', region='us-east-1') + elbs = get_all_elbs(account_number="123456789012", region="us-east-1") assert elbs diff --git a/lemur/plugins/lemur_aws/tests/test_iam.py b/lemur/plugins/lemur_aws/tests/test_iam.py index deec221e..5932d52d 100644 --- a/lemur/plugins/lemur_aws/tests/test_iam.py +++ b/lemur/plugins/lemur_aws/tests/test_iam.py @@ -6,15 +6,21 @@ from lemur.tests.vectors import EXTERNAL_VALID_STR, SAN_CERT_KEY def test_get_name_from_arn(): from lemur.plugins.lemur_aws.iam import get_name_from_arn - arn = 'arn:aws:iam::123456789012:server-certificate/tttt2.netflixtest.net-NetflixInc-20150624-20150625' - assert get_name_from_arn(arn) == 'tttt2.netflixtest.net-NetflixInc-20150624-20150625' + + arn = "arn:aws:iam::123456789012:server-certificate/tttt2.netflixtest.net-NetflixInc-20150624-20150625" + assert ( + get_name_from_arn(arn) == "tttt2.netflixtest.net-NetflixInc-20150624-20150625" + ) -@pytest.mark.skipif(True, reason="this fails because moto is not currently returning what boto does") +@pytest.mark.skipif( + True, reason="this fails because moto is not currently returning what boto does" +) @mock_sts() @mock_iam() def test_get_all_server_certs(app): from lemur.plugins.lemur_aws.iam import upload_cert, get_all_certificates - upload_cert('123456789012', 'testCert', EXTERNAL_VALID_STR, SAN_CERT_KEY) - certs = get_all_certificates('123456789012') + + upload_cert("123456789012", "testCert", EXTERNAL_VALID_STR, SAN_CERT_KEY) + certs = get_all_certificates("123456789012") assert len(certs) == 1 diff --git a/lemur/plugins/lemur_aws/tests/test_plugin.py b/lemur/plugins/lemur_aws/tests/test_plugin.py index 95e4c9a4..dbad7b02 100644 --- a/lemur/plugins/lemur_aws/tests/test_plugin.py +++ b/lemur/plugins/lemur_aws/tests/test_plugin.py @@ -1,6 +1,5 @@ - def test_get_certificates(app): from lemur.plugins.base import plugins - p = plugins.get('aws-s3') + p = plugins.get("aws-s3") assert p diff --git a/lemur/plugins/lemur_cfssl/__init__.py b/lemur/plugins/lemur_cfssl/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_cfssl/__init__.py +++ b/lemur/plugins/lemur_cfssl/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_cfssl/plugin.py b/lemur/plugins/lemur_cfssl/plugin.py index 030f290a..02f3159d 100644 --- a/lemur/plugins/lemur_cfssl/plugin.py +++ b/lemur/plugins/lemur_cfssl/plugin.py @@ -10,6 +10,9 @@ import json import requests +import base64 +import hmac +import hashlib from flask import current_app @@ -21,13 +24,13 @@ from lemur.extensions import metrics class CfsslIssuerPlugin(IssuerPlugin): - title = 'CFSSL' - slug = 'cfssl-issuer' - description = 'Enables the creation of certificates by CFSSL private CA' + title = "CFSSL" + slug = "cfssl-issuer" + description = "Enables the creation of certificates by CFSSL private CA" version = cfssl.VERSION - author = 'Charles Hendrie' - author_url = 'https://github.com/netflix/lemur.git' + author = "Charles Hendrie" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): self.session = requests.Session() @@ -41,23 +44,51 @@ class CfsslIssuerPlugin(IssuerPlugin): :param issuer_options: :return: """ - current_app.logger.info("Requesting a new cfssl certificate with csr: {0}".format(csr)) + current_app.logger.info( + "Requesting a new cfssl certificate with csr: {0}".format(csr) + ) - url = "{0}{1}".format(current_app.config.get('CFSSL_URL'), '/api/v1/cfssl/sign') + url = "{0}{1}".format(current_app.config.get("CFSSL_URL"), "/api/v1/cfssl/sign") - data = {'certificate_request': csr} + data = {"certificate_request": csr} data = json.dumps(data) - response = self.session.post(url, data=data.encode(encoding='utf_8', errors='strict')) + try: + hex_key = current_app.config.get("CFSSL_KEY") + key = bytes.fromhex(hex_key) + except (ValueError, NameError, TypeError): + # unable to find CFSSL_KEY in config, continue using normal sign method + pass + else: + data = data.encode() + + token = base64.b64encode( + hmac.new(key, data, digestmod=hashlib.sha256).digest() + ) + data = base64.b64encode(data) + + data = json.dumps( + {"token": token.decode("utf-8"), "request": data.decode("utf-8")} + ) + + url = "{0}{1}".format( + current_app.config.get("CFSSL_URL"), "/api/v1/cfssl/authsign" + ) + response = self.session.post( + url, data=data.encode(encoding="utf_8", errors="strict") + ) if response.status_code > 399: - metrics.send('cfssl_create_certificate_failure', 'counter', 1) - raise Exception( - "Error creating cert. Please check your CFSSL API server") - response_json = json.loads(response.content.decode('utf_8')) - cert = response_json['result']['certificate'] + metrics.send("cfssl_create_certificate_failure", "counter", 1) + raise Exception("Error creating cert. Please check your CFSSL API server") + response_json = json.loads(response.content.decode("utf_8")) + cert = response_json["result"]["certificate"] parsed_cert = parse_certificate(cert) - metrics.send('cfssl_create_certificate_success', 'counter', 1) - return cert, current_app.config.get('CFSSL_INTERMEDIATE'), parsed_cert.serial_number + metrics.send("cfssl_create_certificate_success", "counter", 1) + return ( + cert, + current_app.config.get("CFSSL_INTERMEDIATE"), + parsed_cert.serial_number, + ) @staticmethod def create_authority(options): @@ -68,22 +99,26 @@ class CfsslIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'cfssl'} - return current_app.config.get('CFSSL_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "cfssl"} + return current_app.config.get("CFSSL_ROOT"), "", [role] def revoke_certificate(self, certificate, comments): """Revoke a CFSSL certificate.""" - base_url = current_app.config.get('CFSSL_URL') - create_url = '{0}/api/v1/cfssl/revoke'.format(base_url) - data = '{"serial": "' + certificate.external_id + '","authority_key_id": "' + \ - get_authority_key(certificate.body) + \ - '", "reason": "superseded"}' + base_url = current_app.config.get("CFSSL_URL") + create_url = "{0}/api/v1/cfssl/revoke".format(base_url) + data = ( + '{"serial": "' + + certificate.external_id + + '","authority_key_id": "' + + get_authority_key(certificate.body) + + '", "reason": "superseded"}' + ) current_app.logger.debug("Revoking cert: {0}".format(data)) response = self.session.post( - create_url, data=data.encode(encoding='utf_8', errors='strict')) + create_url, data=data.encode(encoding="utf_8", errors="strict") + ) if response.status_code > 399: - metrics.send('cfssl_revoke_certificate_failure', 'counter', 1) - raise Exception( - "Error revoking cert. Please check your CFSSL API server") - metrics.send('cfssl_revoke_certificate_success', 'counter', 1) + metrics.send("cfssl_revoke_certificate_failure", "counter", 1) + raise Exception("Error revoking cert. Please check your CFSSL API server") + metrics.send("cfssl_revoke_certificate_success", "counter", 1) return response.json() diff --git a/lemur/plugins/lemur_cfssl/tests/test_cfssl.py b/lemur/plugins/lemur_cfssl/tests/test_cfssl.py index ea8f0856..10fb9963 100644 --- a/lemur/plugins/lemur_cfssl/tests/test_cfssl.py +++ b/lemur/plugins/lemur_cfssl/tests/test_cfssl.py @@ -1,6 +1,5 @@ - def test_get_certificates(app): from lemur.plugins.base import plugins - p = plugins.get('cfssl-issuer') + p = plugins.get("cfssl-issuer") assert p diff --git a/lemur/plugins/lemur_cryptography/__init__.py b/lemur/plugins/lemur_cryptography/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_cryptography/__init__.py +++ b/lemur/plugins/lemur_cryptography/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_cryptography/plugin.py b/lemur/plugins/lemur_cryptography/plugin.py index fe9d7bb3..005f36f9 100644 --- a/lemur/plugins/lemur_cryptography/plugin.py +++ b/lemur/plugins/lemur_cryptography/plugin.py @@ -14,6 +14,7 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization +from lemur.common.utils import parse_private_key from lemur.plugins.bases import IssuerPlugin from lemur.plugins import lemur_cryptography as cryptography_issuer @@ -21,7 +22,7 @@ from lemur.certificates.service import create_csr def build_certificate_authority(options): - options['certificate_authority'] = True + options["certificate_authority"] = True csr, private_key = create_csr(**options) cert_pem, chain_cert_pem = issue_certificate(csr, options, private_key) @@ -29,35 +30,43 @@ def build_certificate_authority(options): def issue_certificate(csr, options, private_key=None): - csr = x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend()) + csr = x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend()) if options.get("parent"): # creating intermediate authorities will have options['parent'] to specify the issuer # creating certificates will have options['authority'] to specify the issuer # This works around that by making sure options['authority'] can be referenced for either - options['authority'] = options['parent'] + options["authority"] = options["parent"] if options.get("authority"): # Issue certificate signed by an existing lemur_certificates authority - issuer_subject = options['authority'].authority_certificate.subject - issuer_private_key = options['authority'].authority_certificate.private_key - chain_cert_pem = options['authority'].authority_certificate.body - authority_key_identifier_public = options['authority'].authority_certificate.public_key - authority_key_identifier_subject = x509.SubjectKeyIdentifier.from_public_key(authority_key_identifier_public) + issuer_subject = options["authority"].authority_certificate.subject + assert ( + private_key is None + ), "Private would be ignored, authority key used instead" + private_key = options["authority"].authority_certificate.private_key + chain_cert_pem = options["authority"].authority_certificate.body + authority_key_identifier_public = options[ + "authority" + ].authority_certificate.public_key + authority_key_identifier_subject = x509.SubjectKeyIdentifier.from_public_key( + authority_key_identifier_public + ) authority_key_identifier_issuer = issuer_subject - authority_key_identifier_serial = int(options['authority'].authority_certificate.serial) + authority_key_identifier_serial = int( + options["authority"].authority_certificate.serial + ) # TODO figure out a better way to increment serial # New authorities have a value at options['serial_number'] that is being ignored here. serial = int(uuid.uuid4()) else: # Issue certificate that is self-signed (new lemur_certificates root authority) issuer_subject = csr.subject - issuer_private_key = private_key chain_cert_pem = "" authority_key_identifier_public = csr.public_key() authority_key_identifier_subject = None authority_key_identifier_issuer = csr.subject - authority_key_identifier_serial = options['serial_number'] + authority_key_identifier_serial = options["serial_number"] # TODO figure out a better way to increment serial serial = int(uuid.uuid4()) @@ -67,19 +76,20 @@ def issue_certificate(csr, options, private_key=None): issuer_name=issuer_subject, subject_name=csr.subject, public_key=csr.public_key(), - not_valid_before=options['validity_start'], - not_valid_after=options['validity_end'], + not_valid_before=options["validity_start"], + not_valid_after=options["validity_end"], serial_number=serial, - extensions=extensions) + extensions=extensions, + ) - for k, v in options.get('extensions', {}).items(): - if k == 'authority_key_identifier': + for k, v in options.get("extensions", {}).items(): + if k == "authority_key_identifier": # One or both of these options may be present inside the aki extension (authority_key_identifier, authority_identifier) = (False, False) for k2, v2 in v.items(): - if k2 == 'use_key_identifier' and v2: + if k2 == "use_key_identifier" and v2: authority_key_identifier = True - if k2 == 'use_authority_cert' and v2: + if k2 == "use_authority_cert" and v2: authority_identifier = True if authority_key_identifier: if authority_key_identifier_subject: @@ -88,13 +98,21 @@ def issue_certificate(csr, options, private_key=None): # but the digest of the ski is at just ski.digest. Until that library is fixed, # this function won't work. The second line has the same result. # aki = x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(authority_key_identifier_subject) - aki = x509.AuthorityKeyIdentifier(authority_key_identifier_subject.digest, None, None) + aki = x509.AuthorityKeyIdentifier( + authority_key_identifier_subject.digest, None, None + ) else: - aki = x509.AuthorityKeyIdentifier.from_issuer_public_key(authority_key_identifier_public) + aki = x509.AuthorityKeyIdentifier.from_issuer_public_key( + authority_key_identifier_public + ) elif authority_identifier: - aki = x509.AuthorityKeyIdentifier(None, [x509.DirectoryName(authority_key_identifier_issuer)], authority_key_identifier_serial) + aki = x509.AuthorityKeyIdentifier( + None, + [x509.DirectoryName(authority_key_identifier_issuer)], + authority_key_identifier_serial, + ) builder = builder.add_extension(aki, critical=False) - if k == 'certificate_info_access': + if k == "certificate_info_access": # FIXME: Implement the AuthorityInformationAccess extension # descriptions = [ # x509.AccessDescription(x509.oid.AuthorityInformationAccessOID.OCSP, x509.UniformResourceIdentifier(u"http://FIXME")), @@ -107,32 +125,32 @@ def issue_certificate(csr, options, private_key=None): # critical=False # ) pass - if k == 'crl_distribution_points': + if k == "crl_distribution_points": # FIXME: Implement the CRLDistributionPoints extension # FIXME: Not implemented in lemur/schemas.py yet https://github.com/Netflix/lemur/issues/662 pass - private_key = serialization.load_pem_private_key( - bytes(str(issuer_private_key).encode('utf-8')), - password=None, - backend=default_backend() - ) + private_key = parse_private_key(private_key) cert = builder.sign(private_key, hashes.SHA256(), default_backend()) - cert_pem = cert.public_bytes( - encoding=serialization.Encoding.PEM - ).decode('utf-8') + cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") return cert_pem, chain_cert_pem def normalize_extensions(csr): try: - san_extension = csr.extensions.get_extension_for_oid(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + san_extension = csr.extensions.get_extension_for_oid( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ) san_dnsnames = san_extension.value.get_values_for_type(x509.DNSName) except x509.extensions.ExtensionNotFound: san_dnsnames = [] - san_extension = x509.Extension(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, True, x509.SubjectAlternativeName(san_dnsnames)) + san_extension = x509.Extension( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + True, + x509.SubjectAlternativeName(san_dnsnames), + ) common_name = csr.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME) common_name = common_name[0].value @@ -152,7 +170,11 @@ def normalize_extensions(csr): for san in san_extension.value: general_names.append(san) - san_extension = x509.Extension(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, True, x509.SubjectAlternativeName(general_names)) + san_extension = x509.Extension( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + True, + x509.SubjectAlternativeName(general_names), + ) # Remove original san extension from CSR and add new SAN extension extensions = list(filter(filter_san_extensions, csr.extensions._extensions)) @@ -169,13 +191,13 @@ def filter_san_extensions(ext): class CryptographyIssuerPlugin(IssuerPlugin): - title = 'Cryptography' - slug = 'cryptography-issuer' - description = 'Enables the creation and signing of self-signed certificates' + title = "Cryptography" + slug = "cryptography-issuer" + description = "Enables the creation and signing of self-signed certificates" version = cryptography_issuer.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def create_certificate(self, csr, options): """ @@ -185,7 +207,9 @@ class CryptographyIssuerPlugin(IssuerPlugin): :param options: :return: :raise Exception: """ - current_app.logger.debug("Issuing new cryptography certificate with options: {0}".format(options)) + current_app.logger.debug( + "Issuing new cryptography certificate with options: {0}".format(options) + ) cert_pem, chain_cert_pem = issue_certificate(csr, options) return cert_pem, chain_cert_pem, None @@ -198,10 +222,12 @@ class CryptographyIssuerPlugin(IssuerPlugin): :param options: :return: """ - current_app.logger.debug("Issuing new cryptography authority with options: {0}".format(options)) + current_app.logger.debug( + "Issuing new cryptography authority with options: {0}".format(options) + ) cert_pem, private_key, chain_cert_pem = build_certificate_authority(options) roles = [ - {'username': '', 'password': '', 'name': options['name'] + '_admin'}, - {'username': '', 'password': '', 'name': options['name'] + '_operator'} + {"username": "", "password": "", "name": options["name"] + "_admin"}, + {"username": "", "password": "", "name": options["name"] + "_operator"}, ] return cert_pem, private_key, chain_cert_pem, roles diff --git a/lemur/plugins/lemur_cryptography/tests/test_cryptography.py b/lemur/plugins/lemur_cryptography/tests/test_cryptography.py index 8a81bf6c..7f1777fc 100644 --- a/lemur/plugins/lemur_cryptography/tests/test_cryptography.py +++ b/lemur/plugins/lemur_cryptography/tests/test_cryptography.py @@ -5,24 +5,24 @@ def test_build_certificate_authority(): from lemur.plugins.lemur_cryptography.plugin import build_certificate_authority options = { - 'key_type': 'RSA2048', - 'country': 'US', - 'state': 'CA', - 'location': 'Example place', - 'organization': 'Example, Inc.', - 'organizational_unit': 'Example Unit', - 'common_name': 'Example ROOT', - 'validity_start': arrow.get('2016-12-01').datetime, - 'validity_end': arrow.get('2016-12-02').datetime, - 'first_serial': 1, - 'serial_number': 1, - 'owner': 'owner@example.com' + "key_type": "RSA2048", + "country": "US", + "state": "CA", + "location": "Example place", + "organization": "Example, Inc.", + "organizational_unit": "Example Unit", + "common_name": "Example ROOT", + "validity_start": arrow.get("2016-12-01").datetime, + "validity_end": arrow.get("2016-12-02").datetime, + "first_serial": 1, + "serial_number": 1, + "owner": "owner@example.com", } cert_pem, private_key_pem, chain_cert_pem = build_certificate_authority(options) assert cert_pem assert private_key_pem - assert chain_cert_pem == '' + assert chain_cert_pem == "" def test_issue_certificate(authority): @@ -30,10 +30,10 @@ def test_issue_certificate(authority): from lemur.plugins.lemur_cryptography.plugin import issue_certificate options = { - 'common_name': 'Example.com', - 'authority': authority, - 'validity_start': arrow.get('2016-12-01').datetime, - 'validity_end': arrow.get('2016-12-02').datetime + "common_name": "Example.com", + "authority": authority, + "validity_start": arrow.get("2016-12-01").datetime, + "validity_end": arrow.get("2016-12-02").datetime, } cert_pem, chain_cert_pem = issue_certificate(CSR_STR, options) assert cert_pem diff --git a/lemur/plugins/lemur_csr/__init__.py b/lemur/plugins/lemur_csr/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_csr/__init__.py +++ b/lemur/plugins/lemur_csr/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_csr/plugin.py b/lemur/plugins/lemur_csr/plugin.py index e06035d1..776dfce5 100644 --- a/lemur/plugins/lemur_csr/plugin.py +++ b/lemur/plugins/lemur_csr/plugin.py @@ -38,48 +38,35 @@ def create_csr(cert, chain, csr_tmp, key): :param csr_tmp: :param key: """ - if isinstance(cert, bytes): - cert = cert.decode('utf-8') - - if isinstance(chain, bytes): - chain = chain.decode('utf-8') - - if isinstance(key, bytes): - key = key.decode('utf-8') + assert isinstance(cert, str) + assert isinstance(chain, str) + assert isinstance(key, str) with mktempfile() as key_tmp: - with open(key_tmp, 'w') as f: + with open(key_tmp, "w") as f: f.write(key) with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: + with open(cert_tmp, "w") as f: if chain: f.writelines([cert.strip() + "\n", chain.strip() + "\n"]) else: f.writelines([cert.strip() + "\n"]) - output = subprocess.check_output([ - "openssl", - "x509", - "-x509toreq", - "-in", cert_tmp, - "-signkey", key_tmp, - ]) - subprocess.run([ - "openssl", - "req", - "-out", csr_tmp - ], input=output) + output = subprocess.check_output( + ["openssl", "x509", "-x509toreq", "-in", cert_tmp, "-signkey", key_tmp] + ) + subprocess.run(["openssl", "req", "-out", csr_tmp], input=output) class CSRExportPlugin(ExportPlugin): - title = 'CSR' - slug = 'openssl-csr' - description = 'Exports a CSR' + title = "CSR" + slug = "openssl-csr" + description = "Exports a CSR" version = csr.VERSION - author = 'jchuong' - author_url = 'https://github.com/jchuong' + author = "jchuong" + author_url = "https://github.com/jchuong" def export(self, body, chain, key, options, **kwargs): """ @@ -98,7 +85,7 @@ class CSRExportPlugin(ExportPlugin): create_csr(body, chain, output_tmp, key) extension = "csr" - with open(output_tmp, 'rb') as f: + with open(output_tmp, "rb") as f: raw = f.read() # passphrase is None return extension, None, raw diff --git a/lemur/plugins/lemur_csr/tests/test_csr_export.py b/lemur/plugins/lemur_csr/tests/test_csr_export.py index 9b233a4e..0b55aefe 100644 --- a/lemur/plugins/lemur_csr/tests/test_csr_export.py +++ b/lemur/plugins/lemur_csr/tests/test_csr_export.py @@ -4,7 +4,8 @@ from lemur.tests.vectors import INTERNAL_PRIVATE_KEY_A_STR, INTERNAL_CERTIFICATE def test_export_certificate_to_csr(app): from lemur.plugins.base import plugins - p = plugins.get('openssl-csr') + + p = plugins.get("openssl-csr") options = [] with pytest.raises(Exception): p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) diff --git a/lemur/plugins/lemur_digicert/__init__.py b/lemur/plugins/lemur_digicert/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_digicert/__init__.py +++ b/lemur/plugins/lemur_digicert/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_digicert/plugin.py b/lemur/plugins/lemur_digicert/plugin.py index 619b24e7..88ea5b6b 100644 --- a/lemur/plugins/lemur_digicert/plugin.py +++ b/lemur/plugins/lemur_digicert/plugin.py @@ -40,7 +40,7 @@ def log_status_code(r, *args, **kwargs): :param kwargs: :return: """ - metrics.send('digicert_status_code_{}'.format(r.status_code), 'counter', 1) + metrics.send("digicert_status_code_{}".format(r.status_code), "counter", 1) def signature_hash(signing_algorithm): @@ -50,18 +50,18 @@ def signature_hash(signing_algorithm): :return: str digicert specific algorithm string """ if not signing_algorithm: - return current_app.config.get('DIGICERT_DEFAULT_SIGNING_ALGORITHM', 'sha256') + return current_app.config.get("DIGICERT_DEFAULT_SIGNING_ALGORITHM", "sha256") - if signing_algorithm == 'sha256WithRSA': - return 'sha256' + if signing_algorithm == "sha256WithRSA": + return "sha256" - elif signing_algorithm == 'sha384WithRSA': - return 'sha384' + elif signing_algorithm == "sha384WithRSA": + return "sha384" - elif signing_algorithm == 'sha512WithRSA': - return 'sha512' + elif signing_algorithm == "sha512WithRSA": + return "sha512" - raise Exception('Unsupported signing algorithm.') + raise Exception("Unsupported signing algorithm.") def determine_validity_years(end_date): @@ -72,15 +72,16 @@ def determine_validity_years(end_date): """ now = arrow.utcnow() - if end_date < now.replace(years=+1): + if end_date < now.shift(years=+1): return 1 - elif end_date < now.replace(years=+2): + elif end_date < now.shift(years=+2): return 2 - elif end_date < now.replace(years=+3): + elif end_date < now.shift(years=+3): return 3 - raise Exception("DigiCert issued certificates cannot exceed three" - " years in validity") + raise Exception( + "DigiCert issued certificates cannot exceed three" " years in validity" + ) def get_additional_names(options): @@ -92,8 +93,8 @@ def get_additional_names(options): """ names = [] # add SANs if present - if options.get('extensions'): - for san in options['extensions']['sub_alt_names']['names']: + if options.get("extensions"): + for san in options["extensions"]["sub_alt_names"]["names"]: if isinstance(san, x509.DNSName): names.append(san.value) return names @@ -106,31 +107,33 @@ def map_fields(options, csr): :param csr: :return: dict or valid DigiCert options """ - if not options.get('validity_years'): - if not options.get('validity_end'): - options['validity_years'] = current_app.config.get('DIGICERT_DEFAULT_VALIDITY', 1) + if not options.get("validity_years"): + if not options.get("validity_end"): + options["validity_years"] = current_app.config.get( + "DIGICERT_DEFAULT_VALIDITY", 1 + ) - data = dict(certificate={ - "common_name": options['common_name'], - "csr": csr, - "signature_hash": - signature_hash(options.get('signing_algorithm')), - }, organization={ - "id": current_app.config.get("DIGICERT_ORG_ID") - }) + data = dict( + certificate={ + "common_name": options["common_name"], + "csr": csr, + "signature_hash": signature_hash(options.get("signing_algorithm")), + }, + organization={"id": current_app.config.get("DIGICERT_ORG_ID")}, + ) - data['certificate']['dns_names'] = get_additional_names(options) + data["certificate"]["dns_names"] = get_additional_names(options) - if options.get('validity_years'): - data['validity_years'] = options['validity_years'] + if options.get("validity_years"): + data["validity_years"] = options["validity_years"] else: - data['custom_expiration_date'] = options['validity_end'].format('YYYY-MM-DD') + data["custom_expiration_date"] = options["validity_end"].format("YYYY-MM-DD") - if current_app.config.get('DIGICERT_PRIVATE', False): - if 'product' in data: - data['product']['type_hint'] = 'private' + if current_app.config.get("DIGICERT_PRIVATE", False): + if "product" in data: + data["product"]["type_hint"] = "private" else: - data['product'] = dict(type_hint='private') + data["product"] = dict(type_hint="private") return data @@ -143,27 +146,34 @@ def map_cis_fields(options, csr): :param csr: :return: """ - if not options.get('validity_years'): - if not options.get('validity_end'): - options['validity_end'] = arrow.utcnow().replace(years=current_app.config.get('DIGICERT_DEFAULT_VALIDITY', 1)) - options['validity_years'] = determine_validity_years(options['validity_end']) + if not options.get("validity_years"): + if not options.get("validity_end"): + options["validity_end"] = arrow.utcnow().shift( + years=current_app.config.get("DIGICERT_DEFAULT_VALIDITY", 1) + ) + options["validity_years"] = determine_validity_years(options["validity_end"]) else: - options['validity_end'] = arrow.utcnow().replace(years=options['validity_years']) + options["validity_end"] = arrow.utcnow().shift( + years=options["validity_years"] + ) data = { - "profile_name": current_app.config.get('DIGICERT_CIS_PROFILE_NAME'), - "common_name": options['common_name'], + "profile_name": current_app.config.get("DIGICERT_CIS_PROFILE_NAMES", {}).get(options['authority'].name), + "common_name": options["common_name"], "additional_dns_names": get_additional_names(options), "csr": csr, - "signature_hash": signature_hash(options.get('signing_algorithm')), + "signature_hash": signature_hash(options.get("signing_algorithm")), "validity": { - "valid_to": options['validity_end'].format('YYYY-MM-DDTHH:MM') + 'Z' + "valid_to": options["validity_end"].format("YYYY-MM-DDTHH:MM") + "Z" }, "organization": { - "name": options['organization'], - "units": [options['organizational_unit']] - } + "name": options["organization"], + "units": [options["organizational_unit"]], + }, } + # possibility to default to a SIGNING_ALGORITHM for a given profile + if current_app.config.get("DIGICERT_CIS_SIGNING_ALGORITHMS", {}).get(options['authority'].name): + data["signature_hash"] = current_app.config.get("DIGICERT_CIS_SIGNING_ALGORITHMS", {}).get(options['authority'].name) return data @@ -175,7 +185,7 @@ def handle_response(response): :return: """ if response.status_code > 399: - raise Exception(response.json()['errors'][0]['message']) + raise Exception(response.json()["errors"][0]["message"]) return response.json() @@ -187,7 +197,7 @@ def handle_cis_response(response): :return: """ if response.status_code > 399: - raise Exception(response.json()['errors'][0]['message']) + raise Exception(response.text) return response.json() @@ -197,19 +207,17 @@ def get_certificate_id(session, base_url, order_id): """Retrieve certificate order id from Digicert API.""" order_url = "{0}/services/v2/order/certificate/{1}".format(base_url, order_id) response_data = handle_response(session.get(order_url)) - if response_data['status'] != 'issued': + if response_data["status"] != "issued": raise Exception("Order not in issued state.") - return response_data['certificate']['id'] + return response_data["certificate"]["id"] @retry(stop_max_attempt_number=10, wait_fixed=10000) def get_cis_certificate(session, base_url, order_id): """Retrieve certificate order id from Digicert API.""" - certificate_url = '{0}/platform/cis/certificate/{1}'.format(base_url, order_id) - session.headers.update( - {'Accept': 'application/x-pem-file'} - ) + certificate_url = "{0}/platform/cis/certificate/{1}".format(base_url, order_id) + session.headers.update({"Accept": "application/x-pem-file"}) response = session.get(certificate_url) if response.status_code == 404: @@ -220,29 +228,30 @@ def get_cis_certificate(session, base_url, order_id): class DigiCertSourcePlugin(SourcePlugin): """Wrap the Digicert Certifcate API.""" - title = 'DigiCert' - slug = 'digicert-source' + + title = "DigiCert" + slug = "digicert-source" description = "Enables the use of Digicert as a source of existing certificates." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): """Initialize source with appropriate details.""" required_vars = [ - 'DIGICERT_API_KEY', - 'DIGICERT_URL', - 'DIGICERT_ORG_ID', - 'DIGICERT_ROOT', + "DIGICERT_API_KEY", + "DIGICERT_URL", + "DIGICERT_ORG_ID", + "DIGICERT_ROOT", ] validate_conf(current_app, required_vars) self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_API_KEY"], + "Content-Type": "application/json", } ) @@ -256,22 +265,23 @@ class DigiCertSourcePlugin(SourcePlugin): class DigiCertIssuerPlugin(IssuerPlugin): """Wrap the Digicert Issuer API.""" - title = 'DigiCert' - slug = 'digicert-issuer' + + title = "DigiCert" + slug = "digicert-issuer" description = "Enables the creation of certificates by the DigiCert REST API." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): """Initialize the issuer with the appropriate details.""" required_vars = [ - 'DIGICERT_API_KEY', - 'DIGICERT_URL', - 'DIGICERT_ORG_ID', - 'DIGICERT_ORDER_TYPE', - 'DIGICERT_ROOT', + "DIGICERT_API_KEY", + "DIGICERT_URL", + "DIGICERT_ORG_ID", + "DIGICERT_ORDER_TYPE", + "DIGICERT_ROOT", ] validate_conf(current_app, required_vars) @@ -279,8 +289,8 @@ class DigiCertIssuerPlugin(IssuerPlugin): self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_API_KEY"], + "Content-Type": "application/json", } ) @@ -295,69 +305,93 @@ class DigiCertIssuerPlugin(IssuerPlugin): :param issuer_options: :return: :raise Exception: """ - base_url = current_app.config.get('DIGICERT_URL') - cert_type = current_app.config.get('DIGICERT_ORDER_TYPE') + base_url = current_app.config.get("DIGICERT_URL") + cert_type = current_app.config.get("DIGICERT_ORDER_TYPE") # make certificate request - determinator_url = "{0}/services/v2/order/certificate/{1}".format(base_url, cert_type) + determinator_url = "{0}/services/v2/order/certificate/{1}".format( + base_url, cert_type + ) data = map_fields(issuer_options, csr) response = self.session.post(determinator_url, data=json.dumps(data)) if response.status_code > 399: - raise Exception(response.json()['errors'][0]['message']) + raise Exception(response.json()["errors"][0]["message"]) - order_id = response.json()['id'] + order_id = response.json()["id"] certificate_id = get_certificate_id(self.session, base_url, order_id) # retrieve certificate - certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format(base_url, certificate_id) - end_entity, intermediate, root = pem.parse(self.session.get(certificate_url).content) - return "\n".join(str(end_entity).splitlines()), "\n".join(str(intermediate).splitlines()), certificate_id + certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format( + base_url, certificate_id + ) + end_entity, intermediate, root = pem.parse( + self.session.get(certificate_url).content + ) + return ( + "\n".join(str(end_entity).splitlines()), + "\n".join(str(intermediate).splitlines()), + certificate_id, + ) def revoke_certificate(self, certificate, comments): """Revoke a Digicert certificate.""" - base_url = current_app.config.get('DIGICERT_URL') + base_url = current_app.config.get("DIGICERT_URL") # make certificate revoke request - create_url = '{0}/services/v2/certificate/{1}/revoke'.format(base_url, certificate.external_id) - metrics.send('digicert_revoke_certificate', 'counter', 1) - response = self.session.put(create_url, data=json.dumps({'comments': comments})) + create_url = "{0}/services/v2/certificate/{1}/revoke".format( + base_url, certificate.external_id + ) + metrics.send("digicert_revoke_certificate", "counter", 1) + response = self.session.put(create_url, data=json.dumps({"comments": comments})) return handle_response(response) def get_ordered_certificate(self, pending_cert): """ Retrieve a certificate via order id """ order_id = pending_cert.external_id - base_url = current_app.config.get('DIGICERT_URL') + base_url = current_app.config.get("DIGICERT_URL") try: certificate_id = get_certificate_id(self.session, base_url, order_id) except Exception as ex: return None - certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format(base_url, certificate_id) - end_entity, intermediate, root = pem.parse(self.session.get(certificate_url).content) - cert = {'body': "\n".join(str(end_entity).splitlines()), - 'chain': "\n".join(str(intermediate).splitlines()), - 'external_id': str(certificate_id)} + certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format( + base_url, certificate_id + ) + end_entity, intermediate, root = pem.parse( + self.session.get(certificate_url).content + ) + cert = { + "body": "\n".join(str(end_entity).splitlines()), + "chain": "\n".join(str(intermediate).splitlines()), + "external_id": str(certificate_id), + } return cert def cancel_ordered_certificate(self, pending_cert, **kwargs): """ Set the certificate order to canceled """ - base_url = current_app.config.get('DIGICERT_URL') - api_url = "{0}/services/v2/order/certificate/{1}/status".format(base_url, pending_cert.external_id) - payload = { - 'status': 'CANCELED', - 'note': kwargs.get('note') - } + base_url = current_app.config.get("DIGICERT_URL") + api_url = "{0}/services/v2/order/certificate/{1}/status".format( + base_url, pending_cert.external_id + ) + payload = {"status": "CANCELED", "note": kwargs.get("note")} response = self.session.put(api_url, data=json.dumps(payload)) if response.status_code == 404: # not well documented by Digicert, but either the certificate does not exist or we # don't own that order (someone else's order id!). Either way, we can just ignore it # and have it removed from Lemur current_app.logger.warning( - "Digicert Plugin tried to cancel pending certificate {0} but it does not exist!".format(pending_cert.name)) + "Digicert Plugin tried to cancel pending certificate {0} but it does not exist!".format( + pending_cert.name + ) + ) elif response.status_code != 204: - current_app.logger.debug("{0} code {1}".format(response.status_code, response.content)) - raise Exception("Failed to cancel pending certificate {0}".format(pending_cert.name)) + current_app.logger.debug( + "{0} code {1}".format(response.status_code, response.content) + ) + raise Exception( + "Failed to cancel pending certificate {0}".format(pending_cert.name) + ) @staticmethod def create_authority(options): @@ -370,72 +404,81 @@ class DigiCertIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'digicert'} - return current_app.config.get('DIGICERT_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "digicert"} + return current_app.config.get("DIGICERT_ROOT"), "", [role] class DigiCertCISSourcePlugin(SourcePlugin): """Wrap the Digicert CIS Certifcate API.""" - title = 'DigiCert' - slug = 'digicert-cis-source' + + title = "DigiCert" + slug = "digicert-cis-source" description = "Enables the use of Digicert as a source of existing certificates." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" additional_options = [] def __init__(self, *args, **kwargs): """Initialize source with appropriate details.""" required_vars = [ - 'DIGICERT_CIS_API_KEY', - 'DIGICERT_CIS_URL', - 'DIGICERT_CIS_ROOT', - 'DIGICERT_CIS_INTERMEDIATE', - 'DIGICERT_CIS_PROFILE_NAME' + "DIGICERT_CIS_API_KEY", + "DIGICERT_CIS_URL", + "DIGICERT_CIS_ROOTS", + "DIGICERT_CIS_INTERMEDIATES", + "DIGICERT_CIS_PROFILE_NAMES", ] validate_conf(current_app, required_vars) self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_CIS_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_CIS_API_KEY"], + "Content-Type": "application/json", } ) self.session.hooks = dict(response=log_status_code) a = requests.adapters.HTTPAdapter(max_retries=3) - self.session.mount('https://', a) + self.session.mount("https://", a) super(DigiCertCISSourcePlugin, self).__init__(*args, **kwargs) def get_certificates(self, options, **kwargs): """Fetch all Digicert certificates.""" - base_url = current_app.config.get('DIGICERT_CIS_URL') + base_url = current_app.config.get("DIGICERT_CIS_URL") # make request - search_url = '{0}/platform/cis/certificate/search'.format(base_url) + search_url = "{0}/platform/cis/certificate/search".format(base_url) certs = [] page = 1 while True: - response = self.session.get(search_url, params={'status': ['issued'], 'page': page}) + response = self.session.get( + search_url, params={"status": ["issued"], "page": page} + ) data = handle_cis_response(response) - for c in data['certificates']: - download_url = '{0}/platform/cis/certificate/{1}'.format(base_url, c['id']) + for c in data["certificates"]: + download_url = "{0}/platform/cis/certificate/{1}".format( + base_url, c["id"] + ) certificate = self.session.get(download_url) # normalize serial - serial = str(int(c['serial_number'], 16)) - cert = {'body': certificate.content, 'serial': serial, 'external_id': c['id']} + serial = str(int(c["serial_number"], 16)) + cert = { + "body": certificate.content, + "serial": serial, + "external_id": c["id"], + } certs.append(cert) - if page == data['total_pages']: + if page == data["total_pages"]: break page += 1 @@ -444,22 +487,23 @@ class DigiCertCISSourcePlugin(SourcePlugin): class DigiCertCISIssuerPlugin(IssuerPlugin): """Wrap the Digicert Certificate Issuing API.""" - title = 'DigiCert CIS' - slug = 'digicert-cis-issuer' + + title = "DigiCert CIS" + slug = "digicert-cis-issuer" description = "Enables the creation of certificates by the DigiCert CIS REST API." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): """Initialize the issuer with the appropriate details.""" required_vars = [ - 'DIGICERT_CIS_API_KEY', - 'DIGICERT_CIS_URL', - 'DIGICERT_CIS_ROOT', - 'DIGICERT_CIS_INTERMEDIATE', - 'DIGICERT_CIS_PROFILE_NAME' + "DIGICERT_CIS_API_KEY", + "DIGICERT_CIS_URL", + "DIGICERT_CIS_ROOTS", + "DIGICERT_CIS_INTERMEDIATES", + "DIGICERT_CIS_PROFILE_NAMES", ] validate_conf(current_app, required_vars) @@ -467,8 +511,8 @@ class DigiCertCISIssuerPlugin(IssuerPlugin): self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_CIS_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_CIS_API_KEY"], + "Content-Type": "application/json", } ) @@ -478,41 +522,51 @@ class DigiCertCISIssuerPlugin(IssuerPlugin): def create_certificate(self, csr, issuer_options): """Create a DigiCert certificate.""" - base_url = current_app.config.get('DIGICERT_CIS_URL') + base_url = current_app.config.get("DIGICERT_CIS_URL") # make certificate request - create_url = '{0}/platform/cis/certificate'.format(base_url) + create_url = "{0}/platform/cis/certificate".format(base_url) data = map_cis_fields(issuer_options, csr) response = self.session.post(create_url, data=json.dumps(data)) data = handle_cis_response(response) # retrieve certificate - certificate_pem = get_cis_certificate(self.session, base_url, data['id']) + certificate_pem = get_cis_certificate(self.session, base_url, data["id"]) - self.session.headers.pop('Accept') + self.session.headers.pop("Accept") end_entity = pem.parse(certificate_pem)[0] - if 'ECC' in issuer_options['key_type']: - return "\n".join(str(end_entity).splitlines()), current_app.config.get('DIGICERT_ECC_CIS_INTERMEDIATE'), data['id'] + if "ECC" in issuer_options["key_type"]: + return ( + "\n".join(str(end_entity).splitlines()), + current_app.config.get("DIGICERT_ECC_CIS_INTERMEDIATES", {}).get(issuer_options['authority'].name), + data["id"], + ) # By default return RSA - return "\n".join(str(end_entity).splitlines()), current_app.config.get('DIGICERT_CIS_INTERMEDIATE'), data['id'] + return ( + "\n".join(str(end_entity).splitlines()), + current_app.config.get("DIGICERT_CIS_INTERMEDIATES", {}).get(issuer_options['authority'].name), + data["id"], + ) def revoke_certificate(self, certificate, comments): """Revoke a Digicert certificate.""" - base_url = current_app.config.get('DIGICERT_CIS_URL') + base_url = current_app.config.get("DIGICERT_CIS_URL") # make certificate revoke request - revoke_url = '{0}/platform/cis/certificate/{1}/revoke'.format(base_url, certificate.external_id) - metrics.send('digicert_revoke_certificate_success', 'counter', 1) - response = self.session.put(revoke_url, data=json.dumps({'comments': comments})) + revoke_url = "{0}/platform/cis/certificate/{1}/revoke".format( + base_url, certificate.external_id + ) + metrics.send("digicert_revoke_certificate_success", "counter", 1) + response = self.session.put(revoke_url, data=json.dumps({"comments": comments})) if response.status_code != 204: - metrics.send('digicert_revoke_certificate_failure', 'counter', 1) - raise Exception('Failed to revoke certificate.') + metrics.send("digicert_revoke_certificate_failure", "counter", 1) + raise Exception("Failed to revoke certificate.") - metrics.send('digicert_revoke_certificate_success', 'counter', 1) + metrics.send("digicert_revoke_certificate_success", "counter", 1) @staticmethod def create_authority(options): @@ -525,5 +579,5 @@ class DigiCertCISIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'digicert'} - return current_app.config.get('DIGICERT_CIS_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "digicert"} + return current_app.config.get("DIGICERT_CIS_ROOTS", {}).get(options['authority'].name), "", [role] diff --git a/lemur/plugins/lemur_digicert/tests/test_digicert.py b/lemur/plugins/lemur_digicert/tests/test_digicert.py index d8d1519d..77b0a1fa 100644 --- a/lemur/plugins/lemur_digicert/tests/test_digicert.py +++ b/lemur/plugins/lemur_digicert/tests/test_digicert.py @@ -13,144 +13,131 @@ from cryptography import x509 def test_map_fields_with_validity_end_and_start(app): from lemur.plugins.lemur_digicert.plugin import map_fields - names = [u'one.example.com', u'two.example.com', u'three.example.com'] + names = [u"one.example.com", u"two.example.com", u"three.example.com"] options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'validity_end': arrow.get(2017, 5, 7), - 'validity_start': arrow.get(2016, 10, 30) + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "validity_end": arrow.get(2017, 5, 7), + "validity_start": arrow.get(2016, 10, 30), } data = map_fields(options, CSR_STR) assert data == { - 'certificate': { - 'csr': CSR_STR, - 'common_name': 'example.com', - 'dns_names': names, - 'signature_hash': 'sha256' + "certificate": { + "csr": CSR_STR, + "common_name": "example.com", + "dns_names": names, + "signature_hash": "sha256", }, - 'organization': {'id': 111111}, - 'custom_expiration_date': arrow.get(2017, 5, 7).format('YYYY-MM-DD') + "organization": {"id": 111111}, + "custom_expiration_date": arrow.get(2017, 5, 7).format("YYYY-MM-DD"), } def test_map_fields_with_validity_years(app): from lemur.plugins.lemur_digicert.plugin import map_fields - names = [u'one.example.com', u'two.example.com', u'three.example.com'] + names = [u"one.example.com", u"two.example.com", u"three.example.com"] options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'validity_years': 2, - 'validity_end': arrow.get(2017, 10, 30) + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "validity_years": 2, + "validity_end": arrow.get(2017, 10, 30), } data = map_fields(options, CSR_STR) assert data == { - 'certificate': { - 'csr': CSR_STR, - 'common_name': 'example.com', - 'dns_names': names, - 'signature_hash': 'sha256' + "certificate": { + "csr": CSR_STR, + "common_name": "example.com", + "dns_names": names, + "signature_hash": "sha256", }, - 'organization': {'id': 111111}, - 'validity_years': 2 + "organization": {"id": 111111}, + "validity_years": 2, } -def test_map_cis_fields(app): +def test_map_cis_fields(app, authority): from lemur.plugins.lemur_digicert.plugin import map_cis_fields - names = [u'one.example.com', u'two.example.com', u'three.example.com'] + names = [u"one.example.com", u"two.example.com", u"three.example.com"] options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'organization': 'Example, Inc.', - 'organizational_unit': 'Example Org', - 'validity_end': arrow.get(2017, 5, 7), - 'validity_start': arrow.get(2016, 10, 30) + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "organization": "Example, Inc.", + "organizational_unit": "Example Org", + "validity_end": arrow.get(2017, 5, 7), + "validity_start": arrow.get(2016, 10, 30), + "authority": authority, } data = map_cis_fields(options, CSR_STR) assert data == { - 'common_name': 'example.com', - 'csr': CSR_STR, - 'additional_dns_names': names, - 'signature_hash': 'sha256', - 'organization': {'name': 'Example, Inc.', 'units': ['Example Org']}, - 'validity': { - 'valid_to': arrow.get(2017, 5, 7).format('YYYY-MM-DDTHH:MM') + 'Z' + "common_name": "example.com", + "csr": CSR_STR, + "additional_dns_names": names, + "signature_hash": "sha256", + "organization": {"name": "Example, Inc.", "units": ["Example Org"]}, + "validity": { + "valid_to": arrow.get(2017, 5, 7).format("YYYY-MM-DDTHH:MM") + "Z" }, - 'profile_name': None + "profile_name": None, } options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'organization': 'Example, Inc.', - 'organizational_unit': 'Example Org', - 'validity_years': 2 + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "organization": "Example, Inc.", + "organizational_unit": "Example Org", + "validity_years": 2, + "authority": authority, } with freeze_time(time_to_freeze=arrow.get(2016, 11, 3).datetime): data = map_cis_fields(options, CSR_STR) assert data == { - 'common_name': 'example.com', - 'csr': CSR_STR, - 'additional_dns_names': names, - 'signature_hash': 'sha256', - 'organization': {'name': 'Example, Inc.', 'units': ['Example Org']}, - 'validity': { - 'valid_to': arrow.get(2018, 11, 3).format('YYYY-MM-DDTHH:MM') + 'Z' + "common_name": "example.com", + "csr": CSR_STR, + "additional_dns_names": names, + "signature_hash": "sha256", + "organization": {"name": "Example, Inc.", "units": ["Example Org"]}, + "validity": { + "valid_to": arrow.get(2018, 11, 3).format("YYYY-MM-DDTHH:MM") + "Z" }, - 'profile_name': None + "profile_name": None, } def test_signature_hash(app): from lemur.plugins.lemur_digicert.plugin import signature_hash - assert signature_hash(None) == 'sha256' - assert signature_hash('sha256WithRSA') == 'sha256' - assert signature_hash('sha384WithRSA') == 'sha384' - assert signature_hash('sha512WithRSA') == 'sha512' + assert signature_hash(None) == "sha256" + assert signature_hash("sha256WithRSA") == "sha256" + assert signature_hash("sha384WithRSA") == "sha384" + assert signature_hash("sha512WithRSA") == "sha512" with pytest.raises(Exception): - signature_hash('sdfdsf') + signature_hash("sdfdsf") -def test_issuer_plugin_create_certificate(certificate_="""\ +def test_issuer_plugin_create_certificate( + certificate_="""\ -----BEGIN CERTIFICATE----- abc -----END CERTIFICATE----- @@ -160,7 +147,8 @@ def -----BEGIN CERTIFICATE----- ghi -----END CERTIFICATE----- -"""): +""" +): import requests_mock from lemur.plugins.lemur_digicert.plugin import DigiCertIssuerPlugin @@ -168,12 +156,26 @@ ghi subject = DigiCertIssuerPlugin() adapter = requests_mock.Adapter() - adapter.register_uri('POST', 'mock://www.digicert.com/services/v2/order/certificate/ssl_plus', text=json.dumps({'id': 'id123'})) - adapter.register_uri('GET', 'mock://www.digicert.com/services/v2/order/certificate/id123', text=json.dumps({'status': 'issued', 'certificate': {'id': 'cert123'}})) - adapter.register_uri('GET', 'mock://www.digicert.com/services/v2/certificate/cert123/download/format/pem_all', text=pem_fixture) - subject.session.mount('mock', adapter) + adapter.register_uri( + "POST", + "mock://www.digicert.com/services/v2/order/certificate/ssl_plus", + text=json.dumps({"id": "id123"}), + ) + adapter.register_uri( + "GET", + "mock://www.digicert.com/services/v2/order/certificate/id123", + text=json.dumps({"status": "issued", "certificate": {"id": "cert123"}}), + ) + adapter.register_uri( + "GET", + "mock://www.digicert.com/services/v2/certificate/cert123/download/format/pem_all", + text=pem_fixture, + ) + subject.session.mount("mock", adapter) - cert, intermediate, external_id = subject.create_certificate("", {'common_name': 'test.com'}) + cert, intermediate, external_id = subject.create_certificate( + "", {"common_name": "test.com"} + ) assert cert == "-----BEGIN CERTIFICATE-----\nabc\n-----END CERTIFICATE-----" assert intermediate == "-----BEGIN CERTIFICATE-----\ndef\n-----END CERTIFICATE-----" @@ -187,10 +189,18 @@ def test_cancel_ordered_certificate(mock_pending_cert): mock_pending_cert.external_id = 1234 subject = DigiCertIssuerPlugin() adapter = requests_mock.Adapter() - adapter.register_uri('PUT', 'mock://www.digicert.com/services/v2/order/certificate/1234/status', status_code=204) - adapter.register_uri('PUT', 'mock://www.digicert.com/services/v2/order/certificate/111/status', status_code=404) - subject.session.mount('mock', adapter) - data = {'note': 'Test'} + adapter.register_uri( + "PUT", + "mock://www.digicert.com/services/v2/order/certificate/1234/status", + status_code=204, + ) + adapter.register_uri( + "PUT", + "mock://www.digicert.com/services/v2/order/certificate/111/status", + status_code=404, + ) + subject.session.mount("mock", adapter) + data = {"note": "Test"} subject.cancel_ordered_certificate(mock_pending_cert, **data) # A non-existing order id, does not raise exception because if it doesn't exist, then it doesn't matter diff --git a/lemur/plugins/lemur_email/__init__.py b/lemur/plugins/lemur_email/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_email/__init__.py +++ b/lemur/plugins/lemur_email/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_email/plugin.py b/lemur/plugins/lemur_email/plugin.py index 18007b99..241aa1b0 100644 --- a/lemur/plugins/lemur_email/plugin.py +++ b/lemur/plugins/lemur_email/plugin.py @@ -27,8 +27,10 @@ def render_html(template_name, message): :param message: :return: """ - template = env.get_template('{}.html'.format(template_name)) - return template.render(dict(message=message, hostname=current_app.config.get('LEMUR_HOSTNAME'))) + template = env.get_template("{}.html".format(template_name)) + return template.render( + dict(message=message, hostname=current_app.config.get("LEMUR_HOSTNAME")) + ) def send_via_smtp(subject, body, targets): @@ -40,7 +42,9 @@ def send_via_smtp(subject, body, targets): :param targets: :return: """ - msg = Message(subject, recipients=targets, sender=current_app.config.get("LEMUR_EMAIL")) + msg = Message( + subject, recipients=targets, sender=current_app.config.get("LEMUR_EMAIL") + ) msg.body = "" # kinda a weird api for sending html emails msg.html = body smtp_mail.send(msg) @@ -54,65 +58,55 @@ def send_via_ses(subject, body, targets): :param targets: :return: """ - client = boto3.client('ses', region_name='us-east-1') + client = boto3.client("ses", region_name="us-east-1") client.send_email( - Source=current_app.config.get('LEMUR_EMAIL'), - Destination={ - 'ToAddresses': targets - }, + Source=current_app.config.get("LEMUR_EMAIL"), + Destination={"ToAddresses": targets}, Message={ - 'Subject': { - 'Data': subject, - 'Charset': 'UTF-8' - }, - 'Body': { - 'Html': { - 'Data': body, - 'Charset': 'UTF-8' - } - } - } + "Subject": {"Data": subject, "Charset": "UTF-8"}, + "Body": {"Html": {"Data": body, "Charset": "UTF-8"}}, + }, ) class EmailNotificationPlugin(ExpirationNotificationPlugin): - title = 'Email' - slug = 'email-notification' - description = 'Sends expiration email notifications' + title = "Email" + slug = "email-notification" + description = "Sends expiration email notifications" version = email.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" additional_options = [ { - 'name': 'recipients', - 'type': 'str', - 'required': True, - 'validation': '^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$', - 'helpMessage': 'Comma delimited list of email addresses', - }, + "name": "recipients", + "type": "str", + "required": True, + "validation": "^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$", + "helpMessage": "Comma delimited list of email addresses", + } ] def __init__(self, *args, **kwargs): """Initialize the plugin with the appropriate details.""" - sender = current_app.config.get('LEMUR_EMAIL_SENDER', 'ses').lower() + sender = current_app.config.get("LEMUR_EMAIL_SENDER", "ses").lower() - if sender not in ['ses', 'smtp']: - raise InvalidConfiguration('Email sender type {0} is not recognized.') + if sender not in ["ses", "smtp"]: + raise InvalidConfiguration("Email sender type {0} is not recognized.") @staticmethod def send(notification_type, message, targets, options, **kwargs): - subject = 'Lemur: {0} Notification'.format(notification_type.capitalize()) + subject = "Lemur: {0} Notification".format(notification_type.capitalize()) - data = {'options': options, 'certificates': message} + data = {"options": options, "certificates": message} body = render_html(notification_type, data) - s_type = current_app.config.get("LEMUR_EMAIL_SENDER", 'ses').lower() + s_type = current_app.config.get("LEMUR_EMAIL_SENDER", "ses").lower() - if s_type == 'ses': + if s_type == "ses": send_via_ses(subject, body, targets) - elif s_type == 'smtp': + elif s_type == "smtp": send_via_smtp(subject, body, targets) diff --git a/lemur/plugins/lemur_email/templates/config.py b/lemur/plugins/lemur_email/templates/config.py index 2ec8a6c2..3d877fe0 100644 --- a/lemur/plugins/lemur_email/templates/config.py +++ b/lemur/plugins/lemur_email/templates/config.py @@ -5,22 +5,24 @@ from jinja2 import Environment, FileSystemLoader, select_autoescape from lemur.plugins.utils import get_plugin_option loader = FileSystemLoader(searchpath=os.path.dirname(os.path.realpath(__file__))) -env = Environment(loader=loader, # nosec: potentially dangerous types esc. - autoescape=select_autoescape(['html', 'xml'])) +env = Environment( + loader=loader, # nosec: potentially dangerous types esc. + autoescape=select_autoescape(["html", "xml"]), +) def human_time(time): - return arrow.get(time).format('dddd, MMMM D, YYYY') + return arrow.get(time).format("dddd, MMMM D, YYYY") def interval(options): - return get_plugin_option('interval', options) + return get_plugin_option("interval", options) def unit(options): - return get_plugin_option('unit', options) + return get_plugin_option("unit", options) -env.filters['time'] = human_time -env.filters['interval'] = interval -env.filters['unit'] = unit +env.filters["time"] = human_time +env.filters["interval"] = interval +env.filters["unit"] = unit diff --git a/lemur/plugins/lemur_email/templates/expiration.html b/lemur/plugins/lemur_email/templates/expiration.html index 3c500c38..f5185acd 100644 --- a/lemur/plugins/lemur_email/templates/expiration.html +++ b/lemur/plugins/lemur_email/templates/expiration.html @@ -106,7 +106,13 @@ - If the above certificates are still in use. You should re-issue and deploy new certificates as soon as possible. + Your action is required if the above certificates are still needed for your service. +

+ If your endpoints are still in use, you can access your certificate in Lemur, and enable Auto Rotate under the Action->Edit menu. + Lemur will take care of re-issuance and rotation of the certificate on the listed endpoints within one day. +

+ If your certificate is deployed with your service, you should re-issue and manually deploy a new certificate as soon as possible. + diff --git a/lemur/plugins/lemur_email/tests/test_email.py b/lemur/plugins/lemur_email/tests/test_email.py index 9d58402f..43168cab 100644 --- a/lemur/plugins/lemur_email/tests/test_email.py +++ b/lemur/plugins/lemur_email/tests/test_email.py @@ -13,21 +13,24 @@ def test_render(certificate, endpoint): new_cert.replaces.append(certificate) data = { - 'certificates': [certificate_notification_output_schema.dump(certificate).data], - 'options': [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + "certificates": [certificate_notification_output_schema.dump(certificate).data], + "options": [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ], } - template = env.get_template('{}.html'.format('expiration')) + template = env.get_template("{}.html".format("expiration")) - body = template.render(dict(message=data, hostname='lemur.test.example.com')) + body = template.render(dict(message=data, hostname="lemur.test.example.com")) - template = env.get_template('{}.html'.format('rotation')) + template = env.get_template("{}.html".format("rotation")) certificate.endpoints.append(endpoint) body = template.render( dict( certificate=certificate_notification_output_schema.dump(certificate).data, - hostname='lemur.test.example.com' + hostname="lemur.test.example.com", ) ) diff --git a/lemur/plugins/lemur_java/__init__.py b/lemur/plugins/lemur_java/__init__.py deleted file mode 100644 index 8ce5a7f3..00000000 --- a/lemur/plugins/lemur_java/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version -except Exception as e: - VERSION = 'unknown' diff --git a/lemur/plugins/lemur_java/plugin.py b/lemur/plugins/lemur_java/plugin.py deleted file mode 100644 index 151794da..00000000 --- a/lemur/plugins/lemur_java/plugin.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -.. module: lemur.plugins.lemur_java.plugin - :platform: Unix - :copyright: (c) 2018 by Netflix Inc., see AUTHORS for more - :license: Apache, see LICENSE for more details. - -.. moduleauthor:: Kevin Glisson -""" -import subprocess - -from flask import current_app - -from cryptography.fernet import Fernet - -from lemur.utils import mktempfile, mktemppath -from lemur.plugins.bases import ExportPlugin -from lemur.plugins import lemur_java as java - - -def run_process(command): - """ - Runs a given command with pOpen and wraps some - error handling around it. - :param command: - :return: - """ - p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = p.communicate() - - if p.returncode != 0: - current_app.logger.debug(" ".join(command)) - current_app.logger.error(stderr) - current_app.logger.error(stdout) - raise Exception(stderr) - - -def split_chain(chain): - """ - Split the chain into individual certificates for import into keystore - - :param chain: - :return: - """ - certs = [] - - if not chain: - return certs - - lines = chain.split('\n') - - cert = [] - for line in lines: - cert.append(line + '\n') - if line == '-----END CERTIFICATE-----': - certs.append("".join(cert)) - cert = [] - - return certs - - -def create_truststore(cert, chain, jks_tmp, alias, passphrase): - if isinstance(cert, bytes): - cert = cert.decode('utf-8') - - if isinstance(chain, bytes): - chain = chain.decode('utf-8') - - with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: - f.write(cert) - - run_process([ - "keytool", - "-importcert", - "-file", cert_tmp, - "-keystore", jks_tmp, - "-alias", "{0}_cert".format(alias), - "-storepass", passphrase, - "-noprompt" - ]) - - # Import the entire chain - for idx, cert in enumerate(split_chain(chain)): - with mktempfile() as c_tmp: - with open(c_tmp, 'w') as f: - f.write(cert) - - # Import signed cert in to JKS keystore - run_process([ - "keytool", - "-importcert", - "-file", c_tmp, - "-keystore", jks_tmp, - "-alias", "{0}_cert_{1}".format(alias, idx), - "-storepass", passphrase, - "-noprompt" - ]) - - -def create_keystore(cert, chain, jks_tmp, key, alias, passphrase): - if isinstance(cert, bytes): - cert = cert.decode('utf-8') - - if isinstance(chain, bytes): - chain = chain.decode('utf-8') - - if isinstance(key, bytes): - key = key.decode('utf-8') - - # Create PKCS12 keystore from private key and public certificate - with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: - if chain: - f.writelines([key.strip() + "\n", cert.strip() + "\n", chain.strip() + "\n"]) - else: - f.writelines([key.strip() + "\n", cert.strip() + "\n"]) - - with mktempfile() as p12_tmp: - run_process([ - "openssl", - "pkcs12", - "-export", - "-nodes", - "-name", alias, - "-in", cert_tmp, - "-out", p12_tmp, - "-password", "pass:{}".format(passphrase) - ]) - - # Convert PKCS12 keystore into a JKS keystore - run_process([ - "keytool", - "-importkeystore", - "-destkeystore", jks_tmp, - "-srckeystore", p12_tmp, - "-srcstoretype", "pkcs12", - "-deststoretype", "JKS", - "-alias", alias, - "-srcstorepass", passphrase, - "-deststorepass", passphrase - ]) - - -class JavaTruststoreExportPlugin(ExportPlugin): - title = 'Java Truststore (JKS)' - slug = 'java-truststore-jks' - description = 'Attempts to generate a JKS truststore' - requires_key = False - version = java.VERSION - - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' - - options = [ - { - 'name': 'alias', - 'type': 'str', - 'required': False, - 'helpMessage': 'Enter the alias you wish to use for the truststore.', - }, - { - 'name': 'passphrase', - 'type': 'str', - 'required': False, - 'helpMessage': 'If no passphrase is given one will be generated for you, we highly recommend this. Minimum length is 8.', - 'validation': '' - }, - ] - - def export(self, body, chain, key, options, **kwargs): - """ - Generates a Java Truststore - - :param key: - :param chain: - :param body: - :param options: - :param kwargs: - """ - - if self.get_option('alias', options): - alias = self.get_option('alias', options) - else: - alias = "blah" - - if self.get_option('passphrase', options): - passphrase = self.get_option('passphrase', options) - else: - passphrase = Fernet.generate_key().decode('utf-8') - - with mktemppath() as jks_tmp: - create_truststore(body, chain, jks_tmp, alias, passphrase) - - with open(jks_tmp, 'rb') as f: - raw = f.read() - - return "jks", passphrase, raw - - -class JavaKeystoreExportPlugin(ExportPlugin): - title = 'Java Keystore (JKS)' - slug = 'java-keystore-jks' - description = 'Attempts to generate a JKS keystore' - version = java.VERSION - - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' - - options = [ - { - 'name': 'passphrase', - 'type': 'str', - 'required': False, - 'helpMessage': 'If no passphrase is given one will be generated for you, we highly recommend this. Minimum length is 8.', - 'validation': '' - }, - { - 'name': 'alias', - 'type': 'str', - 'required': False, - 'helpMessage': 'Enter the alias you wish to use for the keystore.', - } - ] - - def export(self, body, chain, key, options, **kwargs): - """ - Generates a Java Keystore - - :param key: - :param chain: - :param body: - :param options: - :param kwargs: - """ - - if self.get_option('passphrase', options): - passphrase = self.get_option('passphrase', options) - else: - passphrase = Fernet.generate_key().decode('utf-8') - - if self.get_option('alias', options): - alias = self.get_option('alias', options) - else: - alias = "blah" - - with mktemppath() as jks_tmp: - create_keystore(body, chain, jks_tmp, key, alias, passphrase) - - with open(jks_tmp, 'rb') as f: - raw = f.read() - - return "jks", passphrase, raw diff --git a/lemur/plugins/lemur_java/tests/test_java.py b/lemur/plugins/lemur_java/tests/test_java.py deleted file mode 100644 index 2b8598b8..00000000 --- a/lemur/plugins/lemur_java/tests/test_java.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest - -from lemur.tests.vectors import INTERNAL_CERTIFICATE_A_STR, INTERNAL_PRIVATE_KEY_A_STR - - -@pytest.mark.skip(reason="no way of currently testing this") -def test_export_truststore(app): - from lemur.plugins.base import plugins - - p = plugins.get('java-truststore-jks') - options = [{'name': 'passphrase', 'value': 'test1234'}] - actual = p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) - - assert actual[0] == 'jks' - assert actual[1] == 'test1234' - assert isinstance(actual[2], bytes) - - -@pytest.mark.skip(reason="no way of currently testing this") -def test_export_truststore_default_password(app): - from lemur.plugins.base import plugins - - p = plugins.get('java-truststore-jks') - options = [] - actual = p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) - - assert actual[0] == 'jks' - assert isinstance(actual[1], str) - assert isinstance(actual[2], bytes) - - -@pytest.mark.skip(reason="no way of currently testing this") -def test_export_keystore(app): - from lemur.plugins.base import plugins - - p = plugins.get('java-keystore-jks') - options = [{'name': 'passphrase', 'value': 'test1234'}] - - with pytest.raises(Exception): - p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) - - actual = p.export(INTERNAL_CERTIFICATE_A_STR, "", INTERNAL_PRIVATE_KEY_A_STR, options) - - assert actual[0] == 'jks' - assert actual[1] == 'test1234' - assert isinstance(actual[2], bytes) - - -@pytest.mark.skip(reason="no way of currently testing this") -def test_export_keystore_default_password(app): - from lemur.plugins.base import plugins - - p = plugins.get('java-keystore-jks') - options = [] - - with pytest.raises(Exception): - p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) - - actual = p.export(INTERNAL_CERTIFICATE_A_STR, "", INTERNAL_PRIVATE_KEY_A_STR, options) - - assert actual[0] == 'jks' - assert isinstance(actual[1], str) - assert isinstance(actual[2], bytes) diff --git a/lemur/plugins/lemur_jks/__init__.py b/lemur/plugins/lemur_jks/__init__.py new file mode 100644 index 00000000..f8afd7e3 --- /dev/null +++ b/lemur/plugins/lemur_jks/__init__.py @@ -0,0 +1,4 @@ +try: + VERSION = __import__("pkg_resources").get_distribution(__name__).version +except Exception as e: + VERSION = "unknown" diff --git a/lemur/plugins/lemur_jks/plugin.py b/lemur/plugins/lemur_jks/plugin.py new file mode 100644 index 00000000..7134faeb --- /dev/null +++ b/lemur/plugins/lemur_jks/plugin.py @@ -0,0 +1,140 @@ +""" +.. module: lemur.plugins.lemur_jks.plugin + :platform: Unix + :copyright: (c) 2018 by Netflix Inc., see AUTHORS for more + :license: Apache, see LICENSE for more details. + +.. moduleauthor:: Marti Raudsepp +""" + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import serialization +from jks import PrivateKeyEntry, KeyStore, TrustedCertEntry + +from lemur.common.defaults import common_name +from lemur.common.utils import parse_certificate, parse_cert_chain, parse_private_key +from lemur.plugins import lemur_jks as jks +from lemur.plugins.bases import ExportPlugin + + +def cert_chain_as_der(cert, chain): + """Return a certificate and its chain in a list format, as expected by pyjks.""" + + certs = [parse_certificate(cert)] + certs.extend(parse_cert_chain(chain)) + # certs (list) – A list of certificates, as byte strings. The first one should be the one belonging to the private + # key, the others the chain (in correct order). + return [cert.public_bytes(encoding=serialization.Encoding.DER) for cert in certs] + + +def create_truststore(cert, chain, alias, passphrase): + entries = [] + for idx, cert_bytes in enumerate(cert_chain_as_der(cert, chain)): + # The original cert gets name _cert, first chain element is _cert_1, etc. + cert_alias = alias + "_cert" + ("_{}".format(idx) if idx else "") + entries.append(TrustedCertEntry.new(cert_alias, cert_bytes)) + + return KeyStore.new("jks", entries).saves(passphrase) + + +def create_keystore(cert, chain, key, alias, passphrase): + certs_bytes = cert_chain_as_der(cert, chain) + key_bytes = parse_private_key(key).private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + entry = PrivateKeyEntry.new(alias, certs_bytes, key_bytes) + + return KeyStore.new("jks", [entry]).saves(passphrase) + + +class JavaTruststoreExportPlugin(ExportPlugin): + title = "Java Truststore (JKS)" + slug = "java-truststore-jks" + description = "Generates a JKS truststore" + requires_key = False + version = jks.VERSION + + author = "Marti Raudsepp" + author_url = "https://github.com/intgr" + + options = [ + { + "name": "alias", + "type": "str", + "required": False, + "helpMessage": "Enter the alias you wish to use for the truststore.", + }, + { + "name": "passphrase", + "type": "str", + "required": False, + "helpMessage": "If no passphrase is given one will be generated for you, we highly recommend this.", + "validation": "", + }, + ] + + def export(self, body, chain, key, options, **kwargs): + """ + Generates a Java Truststore + """ + + if self.get_option("alias", options): + alias = self.get_option("alias", options) + else: + alias = common_name(parse_certificate(body)) + + if self.get_option("passphrase", options): + passphrase = self.get_option("passphrase", options) + else: + passphrase = Fernet.generate_key().decode("utf-8") + + raw = create_truststore(body, chain, alias, passphrase) + + return "jks", passphrase, raw + + +class JavaKeystoreExportPlugin(ExportPlugin): + title = "Java Keystore (JKS)" + slug = "java-keystore-jks" + description = "Generates a JKS keystore" + version = jks.VERSION + + author = "Marti Raudsepp" + author_url = "https://github.com/intgr" + + options = [ + { + "name": "passphrase", + "type": "str", + "required": False, + "helpMessage": "If no passphrase is given one will be generated for you, we highly recommend this.", + "validation": "", + }, + { + "name": "alias", + "type": "str", + "required": False, + "helpMessage": "Enter the alias you wish to use for the keystore.", + }, + ] + + def export(self, body, chain, key, options, **kwargs): + """ + Generates a Java Keystore + """ + + if self.get_option("passphrase", options): + passphrase = self.get_option("passphrase", options) + else: + passphrase = Fernet.generate_key().decode("utf-8") + + if self.get_option("alias", options): + alias = self.get_option("alias", options) + else: + alias = common_name(parse_certificate(body)) + + raw = create_keystore(body, chain, key, alias, passphrase) + + return "jks", passphrase, raw diff --git a/lemur/plugins/lemur_java/tests/conftest.py b/lemur/plugins/lemur_jks/tests/conftest.py similarity index 100% rename from lemur/plugins/lemur_java/tests/conftest.py rename to lemur/plugins/lemur_jks/tests/conftest.py diff --git a/lemur/plugins/lemur_jks/tests/test_jks.py b/lemur/plugins/lemur_jks/tests/test_jks.py new file mode 100644 index 00000000..b9fe9b33 --- /dev/null +++ b/lemur/plugins/lemur_jks/tests/test_jks.py @@ -0,0 +1,105 @@ +import pytest +from jks import KeyStore, TrustedCertEntry, PrivateKeyEntry + +from lemur.tests.vectors import ( + INTERNAL_CERTIFICATE_A_STR, + SAN_CERT_STR, + INTERMEDIATE_CERT_STR, + ROOTCA_CERT_STR, + SAN_CERT_KEY, +) + + +def test_export_truststore(app): + from lemur.plugins.base import plugins + + p = plugins.get("java-truststore-jks") + options = [ + {"name": "passphrase", "value": "hunter2"}, + {"name": "alias", "value": "AzureDiamond"}, + ] + chain = INTERMEDIATE_CERT_STR + "\n" + ROOTCA_CERT_STR + ext, password, raw = p.export(SAN_CERT_STR, chain, SAN_CERT_KEY, options) + + assert ext == "jks" + assert password == "hunter2" + assert isinstance(raw, bytes) + + ks = KeyStore.loads(raw, "hunter2") + assert ks.store_type == "jks" + # JKS lower-cases alias strings + assert ks.entries.keys() == { + "azurediamond_cert", + "azurediamond_cert_1", + "azurediamond_cert_2", + } + assert isinstance(ks.entries["azurediamond_cert"], TrustedCertEntry) + + +def test_export_truststore_defaults(app): + from lemur.plugins.base import plugins + + p = plugins.get("java-truststore-jks") + options = [] + ext, password, raw = p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) + + assert ext == "jks" + assert isinstance(password, str) + assert isinstance(raw, bytes) + + ks = KeyStore.loads(raw, password) + assert ks.store_type == "jks" + # JKS lower-cases alias strings + assert ks.entries.keys() == {"acommonname_cert"} + assert isinstance(ks.entries["acommonname_cert"], TrustedCertEntry) + + +def test_export_keystore(app): + from lemur.plugins.base import plugins + + p = plugins.get("java-keystore-jks") + options = [ + {"name": "passphrase", "value": "hunter2"}, + {"name": "alias", "value": "AzureDiamond"}, + ] + + chain = INTERMEDIATE_CERT_STR + "\n" + ROOTCA_CERT_STR + with pytest.raises(Exception): + p.export(INTERNAL_CERTIFICATE_A_STR, chain, "", options) + + ext, password, raw = p.export(SAN_CERT_STR, chain, SAN_CERT_KEY, options) + + assert ext == "jks" + assert password == "hunter2" + assert isinstance(raw, bytes) + + ks = KeyStore.loads(raw, password) + assert ks.store_type == "jks" + # JKS lower-cases alias strings + assert ks.entries.keys() == {"azurediamond"} + entry = ks.entries["azurediamond"] + assert isinstance(entry, PrivateKeyEntry) + assert len(entry.cert_chain) == 3 # Cert and chain were provided + + +def test_export_keystore_defaults(app): + from lemur.plugins.base import plugins + + p = plugins.get("java-keystore-jks") + options = [] + + with pytest.raises(Exception): + p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) + + ext, password, raw = p.export(SAN_CERT_STR, "", SAN_CERT_KEY, options) + + assert ext == "jks" + assert isinstance(password, str) + assert isinstance(raw, bytes) + + ks = KeyStore.loads(raw, password) + assert ks.store_type == "jks" + assert ks.entries.keys() == {"san.example.org"} + entry = ks.entries["san.example.org"] + assert isinstance(entry, PrivateKeyEntry) + assert len(entry.cert_chain) == 1 # Only cert itself, no chain was provided diff --git a/lemur/plugins/lemur_kubernetes/__init__.py b/lemur/plugins/lemur_kubernetes/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_kubernetes/__init__.py +++ b/lemur/plugins/lemur_kubernetes/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_kubernetes/plugin.py b/lemur/plugins/lemur_kubernetes/plugin.py index ee466596..62ffffda 100644 --- a/lemur/plugins/lemur_kubernetes/plugin.py +++ b/lemur/plugins/lemur_kubernetes/plugin.py @@ -11,31 +11,37 @@ .. moduleauthor:: Mikhail Khodorovskiy """ import base64 -import os -import urllib -import requests import itertools +import os -from lemur.certificates.models import Certificate +import requests +from flask import current_app + +from lemur.common.defaults import common_name +from lemur.common.utils import parse_certificate from lemur.plugins.bases import DestinationPlugin -DEFAULT_API_VERSION = 'v1' +DEFAULT_API_VERSION = "v1" def ensure_resource(k8s_api, k8s_base_uri, namespace, kind, name, data): - # _resolve_uri(k8s_base_uri, namespace, kind, name, api_ver=DEFAULT_API_VERSION) url = _resolve_uri(k8s_base_uri, namespace, kind) + current_app.logger.debug("K8S POST request URL: %s", url) create_resp = k8s_api.post(url, json=data) + current_app.logger.debug("K8S POST response: %s", create_resp) if 200 <= create_resp.status_code <= 299: return None - - elif create_resp.json()['reason'] != 'AlreadyExists': + elif create_resp.json().get("reason", "") != "AlreadyExists": return create_resp.content - update_resp = k8s_api.put(_resolve_uri(k8s_base_uri, namespace, kind, name), json=data) + url = _resolve_uri(k8s_base_uri, namespace, kind, name) + current_app.logger.debug("K8S PUT request URL: %s", url) + + update_resp = k8s_api.put(url, json=data) + current_app.logger.debug("K8S PUT response: %s", update_resp) if not 200 <= update_resp.status_code <= 299: return update_resp.content @@ -43,62 +49,145 @@ def ensure_resource(k8s_api, k8s_base_uri, namespace, kind, name, data): return -def _resolve_ns(k8s_base_uri, namespace, api_ver=DEFAULT_API_VERSION,): - api_group = 'api' - if '/' in api_ver: - api_group = 'apis' - return '{base}/{api_group}/{api_ver}/namespaces'.format(base=k8s_base_uri, api_group=api_group, api_ver=api_ver) + ('/' + namespace if namespace else '') +def _resolve_ns(k8s_base_uri, namespace, api_ver=DEFAULT_API_VERSION): + api_group = "api" + if "/" in api_ver: + api_group = "apis" + return "{base}/{api_group}/{api_ver}/namespaces".format( + base=k8s_base_uri, api_group=api_group, api_ver=api_ver + ) + ("/" + namespace if namespace else "") def _resolve_uri(k8s_base_uri, namespace, kind, name=None, api_ver=DEFAULT_API_VERSION): if not namespace: - namespace = 'default' + namespace = "default" - return "/".join(itertools.chain.from_iterable([ - (_resolve_ns(k8s_base_uri, namespace, api_ver=api_ver),), - ((kind + 's').lower(),), - (name,) if name else (), - ])) + return "/".join( + itertools.chain.from_iterable( + [ + (_resolve_ns(k8s_base_uri, namespace, api_ver=api_ver),), + ((kind + "s").lower(),), + (name,) if name else (), + ] + ) + ) + + +# Performs Base64 encoding of string to string using the base64.b64encode() function +# which encodes bytes to bytes. +def base64encode(string): + return base64.b64encode(string.encode()).decode() + + +def build_secret(secret_format, secret_name, body, private_key, cert_chain): + secret = { + "apiVersion": "v1", + "kind": "Secret", + "type": "Opaque", + "metadata": {"name": secret_name}, + } + if secret_format == "Full": + secret["data"] = { + "combined.pem": base64encode("%s\n%s" % (body, private_key)), + "ca.crt": base64encode(cert_chain), + "service.key": base64encode(private_key), + "service.crt": base64encode(body), + } + if secret_format == "TLS": + secret["type"] = "kubernetes.io/tls" + secret["data"] = { + "tls.crt": base64encode(cert_chain), + "tls.key": base64encode(private_key), + } + if secret_format == "Certificate": + secret["data"] = {"tls.crt": base64encode(cert_chain)} + return secret class KubernetesDestinationPlugin(DestinationPlugin): - title = 'Kubernetes' - slug = 'kubernetes-destination' - description = 'Allow the uploading of certificates to Kubernetes as secret' + title = "Kubernetes" + slug = "kubernetes-destination" + description = "Allow the uploading of certificates to Kubernetes as secret" - author = 'Mikhail Khodorovskiy' - author_url = 'https://github.com/mik373/lemur' + author = "Mikhail Khodorovskiy" + author_url = "https://github.com/mik373/lemur" options = [ { - 'name': 'kubernetesURL', - 'type': 'str', - 'required': True, - 'validation': '@(https?|http)://(-\.)?([^\s/?\.#-]+\.?)+(/[^\s]*)?$@iS', - 'helpMessage': 'Must be a valid Kubernetes server URL!', + "name": "secretNameFormat", + "type": "str", + "required": False, + # Validation is difficult. This regex is used by kubectl to validate secret names: + # [a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)* + # Allowing the insertion of "{common_name}" (or any other such placeholder} + # at any point in the string proved very challenging and had a tendency to + # cause my browser to hang. The specified expression will allow any valid string + # but will also accept many invalid strings. + "validation": "(?:[a-z0-9.-]|\\{common_name\\})+", + "helpMessage": 'Must be a valid secret name, possibly including "{common_name}"', + "default": "{common_name}", }, { - 'name': 'kubernetesAuthToken', - 'type': 'str', - 'required': True, - 'validation': '/^$|\s+/', - 'helpMessage': 'Must be a valid Kubernetes server Token!', + "name": "kubernetesURL", + "type": "str", + "required": False, + "validation": "https?://[a-zA-Z0-9.-]+(?::[0-9]+)?", + "helpMessage": "Must be a valid Kubernetes server URL!", + "default": "https://kubernetes.default", }, { - 'name': 'kubernetesServerCertificate', - 'type': 'str', - 'required': True, - 'validation': '/^$|\s+/', - 'helpMessage': 'Must be a valid Kubernetes server Certificate!', + "name": "kubernetesAuthToken", + "type": "str", + "required": False, + "validation": "[0-9a-zA-Z-_.]+", + "helpMessage": "Must be a valid Kubernetes server Token!", }, { - 'name': 'kubernetesNamespace', - 'type': 'str', - 'required': True, - 'validation': '/^$|\s+/', - 'helpMessage': 'Must be a valid Kubernetes Namespace!', + "name": "kubernetesAuthTokenFile", + "type": "str", + "required": False, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + "default": "/var/run/secrets/kubernetes.io/serviceaccount/token", + }, + { + "name": "kubernetesServerCertificate", + "type": "textarea", + "required": False, + "validation": "-----BEGIN CERTIFICATE-----[a-zA-Z0-9/+\\s\\r\\n]+-----END CERTIFICATE-----", + "helpMessage": "Must be a valid Kubernetes server Certificate!", + }, + { + "name": "kubernetesServerCertificateFile", + "type": "str", + "required": False, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + "default": "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt", + }, + { + "name": "kubernetesNamespace", + "type": "str", + "required": False, + "validation": "[a-z0-9]([-a-z0-9]*[a-z0-9])?", + "helpMessage": "Must be a valid Kubernetes Namespace!", + }, + { + "name": "kubernetesNamespaceFile", + "type": "str", + "required": False, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + "default": "/var/run/secrets/kubernetes.io/serviceaccount/namespace", + }, + { + "name": "secretFormat", + "type": "select", + "required": True, + "available": ["Full", "TLS", "Certificate"], + "helpMessage": "The type of Secret to create.", + "default": "Full", }, - ] def __init__(self, *args, **kwargs): @@ -106,56 +195,129 @@ class KubernetesDestinationPlugin(DestinationPlugin): def upload(self, name, body, private_key, cert_chain, options, **kwargs): - k8_bearer = self.get_option('kubernetesAuthToken', options) - k8_cert = self.get_option('kubernetesServerCertificate', options) - k8_namespace = self.get_option('kubernetesNamespace', options) - k8_base_uri = self.get_option('kubernetesURL', options) + try: + k8_base_uri = self.get_option("kubernetesURL", options) + secret_format = self.get_option("secretFormat", options) + k8s_api = K8sSession(self.k8s_bearer(options), self.k8s_cert(options)) + cn = common_name(parse_certificate(body)) + secret_name_format = self.get_option("secretNameFormat", options) + secret_name = secret_name_format.format(common_name=cn) + secret = build_secret( + secret_format, secret_name, body, private_key, cert_chain + ) + err = ensure_resource( + k8s_api, + k8s_base_uri=k8_base_uri, + namespace=self.k8s_namespace(options), + kind="secret", + name=secret_name, + data=secret, + ) - k8s_api = K8sSession(k8_bearer, k8_cert) - - cert = Certificate(body=body) - - # in the future once runtime properties can be passed-in - use passed-in secret name - secret_name = 'certs-' + urllib.quote_plus(cert.name) - - err = ensure_resource(k8s_api, k8s_base_uri=k8_base_uri, namespace=k8_namespace, kind="secret", name=secret_name, data={ - 'apiVersion': 'v1', - 'kind': 'Secret', - 'metadata': { - 'name': secret_name, - }, - 'data': { - 'combined.pem': base64.b64encode(body + private_key), - 'ca.crt': base64.b64encode(cert_chain), - 'service.key': base64.b64encode(private_key), - 'service.crt': base64.b64encode(body), - } - }) + except Exception as e: + current_app.logger.exception( + "Exception in upload: {}".format(e), exc_info=True + ) + raise if err is not None: + current_app.logger.error("Error deploying resource: %s", err) raise Exception("Error uploading secret: " + err) + def k8s_bearer(self, options): + bearer = self.get_option("kubernetesAuthToken", options) + if not bearer: + bearer_file = self.get_option("kubernetesAuthTokenFile", options) + with open(bearer_file, "r") as file: + bearer = file.readline() + if bearer: + current_app.logger.debug("Using token read from %s", bearer_file) + else: + raise Exception( + "Unable to locate token in options or from %s", bearer_file + ) + else: + current_app.logger.debug("Using token from options") + return bearer + + def k8s_cert(self, options): + cert_file = self.get_option("kubernetesServerCertificateFile", options) + cert = self.get_option("kubernetesServerCertificate", options) + if cert: + cert_file = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "k8.cert" + ) + with open(cert_file, "w") as text_file: + text_file.write(cert) + current_app.logger.debug("Using certificate from options") + else: + current_app.logger.debug("Using certificate from %s", cert_file) + return cert_file + + def k8s_namespace(self, options): + namespace = self.get_option("kubernetesNamespace", options) + if not namespace: + namespace_file = self.get_option("kubernetesNamespaceFile", options) + with open(namespace_file, "r") as file: + namespace = file.readline() + if namespace: + current_app.logger.debug( + "Using namespace %s from %s", namespace, namespace_file + ) + else: + raise Exception( + "Unable to locate namespace in options or from %s", namespace_file + ) + else: + current_app.logger.debug("Using namespace %s from options", namespace) + return namespace + class K8sSession(requests.Session): - - def __init__(self, bearer, cert): + def __init__(self, bearer, cert_file): super(K8sSession, self).__init__() - self.headers.update({ - 'Authorization': 'Bearer %s' % bearer - }) + self.headers.update({"Authorization": "Bearer %s" % bearer}) - k8_ca = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'k8.cert') + self.verify = cert_file - with open(k8_ca, "w") as text_file: - text_file.write(cert) - - self.verify = k8_ca - - def request(self, method, url, params=None, data=None, headers=None, cookies=None, files=None, auth=None, timeout=30, allow_redirects=True, proxies=None, - hooks=None, stream=None, verify=None, cert=None, json=None): + def request( + self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=30, + allow_redirects=True, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None, + ): """ This method overrides the default timeout to be 10s. """ - return super(K8sSession, self).request(method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, - verify, cert, json) + return super(K8sSession, self).request( + method, + url, + params, + data, + headers, + cookies, + files, + auth, + timeout, + allow_redirects, + proxies, + hooks, + stream, + verify, + cert, + json, + ) diff --git a/lemur/plugins/lemur_openssl/__init__.py b/lemur/plugins/lemur_openssl/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_openssl/__init__.py +++ b/lemur/plugins/lemur_openssl/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_openssl/plugin.py b/lemur/plugins/lemur_openssl/plugin.py index d50b4e43..02da311b 100644 --- a/lemur/plugins/lemur_openssl/plugin.py +++ b/lemur/plugins/lemur_openssl/plugin.py @@ -14,7 +14,8 @@ from flask import current_app from lemur.utils import mktempfile, mktemppath from lemur.plugins.bases import ExportPlugin from lemur.plugins import lemur_openssl as openssl -from lemur.common.utils import get_psuedo_random_string +from lemur.common.utils import get_psuedo_random_string, parse_certificate +from lemur.common.defaults import common_name def run_process(command): @@ -44,69 +45,71 @@ def create_pkcs12(cert, chain, p12_tmp, key, alias, passphrase): :param alias: :param passphrase: """ - if isinstance(cert, bytes): - cert = cert.decode('utf-8') - - if isinstance(chain, bytes): - chain = chain.decode('utf-8') - - if isinstance(key, bytes): - key = key.decode('utf-8') + assert isinstance(cert, str) + assert isinstance(chain, str) + assert isinstance(key, str) with mktempfile() as key_tmp: - with open(key_tmp, 'w') as f: + with open(key_tmp, "w") as f: f.write(key) # Create PKCS12 keystore from private key and public certificate with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: + with open(cert_tmp, "w") as f: if chain: f.writelines([cert.strip() + "\n", chain.strip() + "\n"]) else: f.writelines([cert.strip() + "\n"]) - run_process([ - "openssl", - "pkcs12", - "-export", - "-name", alias, - "-in", cert_tmp, - "-inkey", key_tmp, - "-out", p12_tmp, - "-password", "pass:{}".format(passphrase) - ]) + run_process( + [ + "openssl", + "pkcs12", + "-export", + "-name", + alias, + "-in", + cert_tmp, + "-inkey", + key_tmp, + "-out", + p12_tmp, + "-password", + "pass:{}".format(passphrase), + ] + ) class OpenSSLExportPlugin(ExportPlugin): - title = 'OpenSSL' - slug = 'openssl-export' - description = 'Is a loose interface to openssl and support various formats' + title = "OpenSSL" + slug = "openssl-export" + description = "Is a loose interface to openssl and support various formats" version = openssl.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" options = [ { - 'name': 'type', - 'type': 'select', - 'required': True, - 'available': ['PKCS12 (.p12)'], - 'helpMessage': 'Choose the format you wish to export', + "name": "type", + "type": "select", + "required": True, + "available": ["PKCS12 (.p12)"], + "helpMessage": "Choose the format you wish to export", }, { - 'name': 'passphrase', - 'type': 'str', - 'required': False, - 'helpMessage': 'If no passphrase is given one will be generated for you, we highly recommend this.', - 'validation': '' + "name": "passphrase", + "type": "str", + "required": False, + "helpMessage": "If no passphrase is given one will be generated for you, we highly recommend this.", + "validation": "", }, { - 'name': 'alias', - 'type': 'str', - 'required': False, - 'helpMessage': 'Enter the alias you wish to use for the keystore.', - } + "name": "alias", + "type": "str", + "required": False, + "helpMessage": "Enter the alias you wish to use for the keystore.", + }, ] def export(self, body, chain, key, options, **kwargs): @@ -119,20 +122,20 @@ class OpenSSLExportPlugin(ExportPlugin): :param options: :param kwargs: """ - if self.get_option('passphrase', options): - passphrase = self.get_option('passphrase', options) + if self.get_option("passphrase", options): + passphrase = self.get_option("passphrase", options) else: passphrase = get_psuedo_random_string() - if self.get_option('alias', options): - alias = self.get_option('alias', options) + if self.get_option("alias", options): + alias = self.get_option("alias", options) else: - alias = "blah" + alias = common_name(parse_certificate(body)) - type = self.get_option('type', options) + type = self.get_option("type", options) with mktemppath() as output_tmp: - if type == 'PKCS12 (.p12)': + if type == "PKCS12 (.p12)": if not key: raise Exception("Private Key required by {0}".format(type)) @@ -141,7 +144,7 @@ class OpenSSLExportPlugin(ExportPlugin): else: raise Exception("Unable to export, unsupported type: {0}".format(type)) - with open(output_tmp, 'rb') as f: + with open(output_tmp, "rb") as f: raw = f.read() return extension, passphrase, raw diff --git a/lemur/plugins/lemur_openssl/tests/test_openssl.py b/lemur/plugins/lemur_openssl/tests/test_openssl.py index e24033e8..c332f941 100644 --- a/lemur/plugins/lemur_openssl/tests/test_openssl.py +++ b/lemur/plugins/lemur_openssl/tests/test_openssl.py @@ -4,8 +4,12 @@ from lemur.tests.vectors import INTERNAL_PRIVATE_KEY_A_STR, INTERNAL_CERTIFICATE def test_export_certificate_to_pkcs12(app): from lemur.plugins.base import plugins - p = plugins.get('openssl-export') - options = [{'name': 'passphrase', 'value': 'test1234'}, {'name': 'type', 'value': 'PKCS12 (.p12)'}] + + p = plugins.get("openssl-export") + options = [ + {"name": "passphrase", "value": "test1234"}, + {"name": "type", "value": "PKCS12 (.p12)"}, + ] with pytest.raises(Exception): p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) diff --git a/lemur/plugins/lemur_sftp/__init__.py b/lemur/plugins/lemur_sftp/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_sftp/__init__.py +++ b/lemur/plugins/lemur_sftp/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_sftp/plugin.py b/lemur/plugins/lemur_sftp/plugin.py index d74effc5..66784048 100644 --- a/lemur/plugins/lemur_sftp/plugin.py +++ b/lemur/plugins/lemur_sftp/plugin.py @@ -27,107 +27,105 @@ from lemur.plugins.bases import DestinationPlugin class SFTPDestinationPlugin(DestinationPlugin): - title = 'SFTP' - slug = 'sftp-destination' - description = 'Allow the uploading of certificates to SFTP' + title = "SFTP" + slug = "sftp-destination" + description = "Allow the uploading of certificates to SFTP" version = lemur_sftp.VERSION - author = 'Dmitry Zykov' - author_url = 'https://github.com/DmitryZykov' + author = "Dmitry Zykov" + author_url = "https://github.com/DmitryZykov" options = [ { - 'name': 'host', - 'type': 'str', - 'required': True, - 'helpMessage': 'The SFTP host.' + "name": "host", + "type": "str", + "required": True, + "helpMessage": "The SFTP host.", }, { - 'name': 'port', - 'type': 'int', - 'required': True, - 'helpMessage': 'The SFTP port, default is 22.', - 'validation': '^(6553[0-5]|655[0-2][0-9]\d|65[0-4](\d){2}|6[0-4](\d){3}|[1-5](\d){4}|[1-9](\d){0,3})', - 'default': '22' + "name": "port", + "type": "int", + "required": True, + "helpMessage": "The SFTP port, default is 22.", + "validation": "^(6553[0-5]|655[0-2][0-9]\d|65[0-4](\d){2}|6[0-4](\d){3}|[1-5](\d){4}|[1-9](\d){0,3})", + "default": "22", }, { - 'name': 'user', - 'type': 'str', - 'required': True, - 'helpMessage': 'The SFTP user. Default is root.', - 'default': 'root' + "name": "user", + "type": "str", + "required": True, + "helpMessage": "The SFTP user. Default is root.", + "default": "root", }, { - 'name': 'password', - 'type': 'str', - 'required': False, - 'helpMessage': 'The SFTP password (optional when the private key is used).', - 'default': None + "name": "password", + "type": "str", + "required": False, + "helpMessage": "The SFTP password (optional when the private key is used).", + "default": None, }, { - 'name': 'privateKeyPath', - 'type': 'str', - 'required': False, - 'helpMessage': 'The path to the RSA private key on the Lemur server (optional).', - 'default': None + "name": "privateKeyPath", + "type": "str", + "required": False, + "helpMessage": "The path to the RSA private key on the Lemur server (optional).", + "default": None, }, { - 'name': 'privateKeyPass', - 'type': 'str', - 'required': False, - 'helpMessage': 'The password for the encrypted RSA private key (optional).', - 'default': None + "name": "privateKeyPass", + "type": "str", + "required": False, + "helpMessage": "The password for the encrypted RSA private key (optional).", + "default": None, }, { - 'name': 'destinationPath', - 'type': 'str', - 'required': True, - 'helpMessage': 'The SFTP path where certificates will be uploaded.', - 'default': '/etc/nginx/certs' + "name": "destinationPath", + "type": "str", + "required": True, + "helpMessage": "The SFTP path where certificates will be uploaded.", + "default": "/etc/nginx/certs", }, { - 'name': 'exportFormat', - 'required': True, - 'value': 'NGINX', - 'helpMessage': 'The export format for certificates.', - 'type': 'select', - 'available': [ - 'NGINX', - 'Apache' - ] - } + "name": "exportFormat", + "required": True, + "value": "NGINX", + "helpMessage": "The export format for certificates.", + "type": "select", + "available": ["NGINX", "Apache"], + }, ] def upload(self, name, body, private_key, cert_chain, options, **kwargs): - current_app.logger.debug('SFTP destination plugin is started') + current_app.logger.debug("SFTP destination plugin is started") cn = common_name(parse_certificate(body)) - host = self.get_option('host', options) - port = self.get_option('port', options) - user = self.get_option('user', options) - password = self.get_option('password', options) - ssh_priv_key = self.get_option('privateKeyPath', options) - ssh_priv_key_pass = self.get_option('privateKeyPass', options) - dst_path = self.get_option('destinationPath', options) - export_format = self.get_option('exportFormat', options) + host = self.get_option("host", options) + port = self.get_option("port", options) + user = self.get_option("user", options) + password = self.get_option("password", options) + ssh_priv_key = self.get_option("privateKeyPath", options) + ssh_priv_key_pass = self.get_option("privateKeyPass", options) + dst_path = self.get_option("destinationPath", options) + export_format = self.get_option("exportFormat", options) # prepare files for upload - files = {cn + '.key': private_key, - cn + '.pem': body} + files = {cn + ".key": private_key, cn + ".pem": body} if cert_chain: - if export_format == 'NGINX': + if export_format == "NGINX": # assemble body + chain in the single file - files[cn + '.pem'] += '\n' + cert_chain + files[cn + ".pem"] += "\n" + cert_chain - elif export_format == 'Apache': + elif export_format == "Apache": # store chain in the separate file - files[cn + '.ca.bundle.pem'] = cert_chain + files[cn + ".ca.bundle.pem"] = cert_chain # upload files try: - current_app.logger.debug('Connecting to {0}@{1}:{2}'.format(user, host, port)) + current_app.logger.debug( + "Connecting to {0}@{1}:{2}".format(user, host, port) + ) ssh = paramiko.SSHClient() # allow connection to the new unknown host @@ -135,14 +133,18 @@ class SFTPDestinationPlugin(DestinationPlugin): # open the ssh connection if password: - current_app.logger.debug('Using password') + current_app.logger.debug("Using password") ssh.connect(host, username=user, port=port, password=password) elif ssh_priv_key: - current_app.logger.debug('Using RSA private key') - pkey = paramiko.RSAKey.from_private_key_file(ssh_priv_key, ssh_priv_key_pass) + current_app.logger.debug("Using RSA private key") + pkey = paramiko.RSAKey.from_private_key_file( + ssh_priv_key, ssh_priv_key_pass + ) ssh.connect(host, username=user, port=port, pkey=pkey) else: - current_app.logger.error("No password or private key provided. Can't proceed") + current_app.logger.error( + "No password or private key provided. Can't proceed" + ) raise paramiko.ssh_exception.AuthenticationException # open the sftp session inside the ssh connection @@ -150,29 +152,42 @@ class SFTPDestinationPlugin(DestinationPlugin): # make sure that the destination path exist try: - current_app.logger.debug('Creating {0}'.format(dst_path)) + current_app.logger.debug("Creating {0}".format(dst_path)) sftp.mkdir(dst_path) except IOError: - current_app.logger.debug('{0} already exist, resuming'.format(dst_path)) + current_app.logger.debug("{0} already exist, resuming".format(dst_path)) try: - dst_path_cn = dst_path + '/' + cn - current_app.logger.debug('Creating {0}'.format(dst_path_cn)) + dst_path_cn = dst_path + "/" + cn + current_app.logger.debug("Creating {0}".format(dst_path_cn)) sftp.mkdir(dst_path_cn) except IOError: - current_app.logger.debug('{0} already exist, resuming'.format(dst_path_cn)) + current_app.logger.debug( + "{0} already exist, resuming".format(dst_path_cn) + ) # upload certificate files to the sftp destination for filename, data in files.items(): - current_app.logger.debug('Uploading {0} to {1}'.format(filename, dst_path_cn)) - with sftp.open(dst_path_cn + '/' + filename, 'w') as f: - f.write(data) + current_app.logger.debug( + "Uploading {0} to {1}".format(filename, dst_path_cn) + ) + try: + with sftp.open(dst_path_cn + "/" + filename, "w") as f: + f.write(data) + except (PermissionError) as permerror: + if permerror.errno == 13: + current_app.logger.debug( + "Uploading {0} to {1} returned Permission Denied Error, making file writable and retrying".format(filename, dst_path_cn) + ) + sftp.chmod(dst_path_cn + "/" + filename, 0o600) + with sftp.open(dst_path_cn + "/" + filename, "w") as f: + f.write(data) # read only for owner, -r-------- - sftp.chmod(dst_path_cn + '/' + filename, 0o400) + sftp.chmod(dst_path_cn + "/" + filename, 0o400) ssh.close() except Exception as e: - current_app.logger.error('ERROR in {0}: {1}'.format(e.__class__, e)) + current_app.logger.error("ERROR in {0}: {1}".format(e.__class__, e)) try: ssh.close() except BaseException: diff --git a/lemur/plugins/lemur_slack/__init__.py b/lemur/plugins/lemur_slack/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_slack/__init__.py +++ b/lemur/plugins/lemur_slack/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_slack/plugin.py b/lemur/plugins/lemur_slack/plugin.py index a986aa9a..7569d295 100644 --- a/lemur/plugins/lemur_slack/plugin.py +++ b/lemur/plugins/lemur_slack/plugin.py @@ -17,102 +17,101 @@ import requests def create_certificate_url(name): - return 'https://{hostname}/#/certificates/{name}'.format( - hostname=current_app.config.get('LEMUR_HOSTNAME'), - name=name + return "https://{hostname}/#/certificates/{name}".format( + hostname=current_app.config.get("LEMUR_HOSTNAME"), name=name ) def create_expiration_attachments(certificates): attachments = [] for certificate in certificates: - attachments.append({ - 'title': certificate['name'], - 'title_link': create_certificate_url(certificate['name']), - 'color': 'danger', - 'fallback': '', - 'fields': [ - { - 'title': 'Owner', - 'value': certificate['owner'], - 'short': True - }, - { - 'title': 'Expires', - 'value': arrow.get(certificate['validityEnd']).format('dddd, MMMM D, YYYY'), - 'short': True - }, - { - 'title': 'Endpoints Detected', - 'value': len(certificate['endpoints']), - 'short': True - } - ], - 'text': '', - 'mrkdwn_in': ['text'] - }) + attachments.append( + { + "title": certificate["name"], + "title_link": create_certificate_url(certificate["name"]), + "color": "danger", + "fallback": "", + "fields": [ + {"title": "Owner", "value": certificate["owner"], "short": True}, + { + "title": "Expires", + "value": arrow.get(certificate["validityEnd"]).format( + "dddd, MMMM D, YYYY" + ), + "short": True, + }, + { + "title": "Endpoints Detected", + "value": len(certificate["endpoints"]), + "short": True, + }, + ], + "text": "", + "mrkdwn_in": ["text"], + } + ) return attachments def create_rotation_attachments(certificate): return { - 'title': certificate['name'], - 'title_link': create_certificate_url(certificate['name']), - 'fields': [ + "title": certificate["name"], + "title_link": create_certificate_url(certificate["name"]), + "fields": [ { + {"title": "Owner", "value": certificate["owner"], "short": True}, { - 'title': 'Owner', - 'value': certificate['owner'], - 'short': True + "title": "Expires", + "value": arrow.get(certificate["validityEnd"]).format( + "dddd, MMMM D, YYYY" + ), + "short": True, }, { - 'title': 'Expires', - 'value': arrow.get(certificate['validityEnd']).format('dddd, MMMM D, YYYY'), - 'short': True + "title": "Replaced By", + "value": len(certificate["replaced"][0]["name"]), + "short": True, }, { - 'title': 'Replaced By', - 'value': len(certificate['replaced'][0]['name']), - 'short': True + "title": "Endpoints Rotated", + "value": len(certificate["endpoints"]), + "short": True, }, - { - 'title': 'Endpoints Rotated', - 'value': len(certificate['endpoints']), - 'short': True - } } - ] + ], } class SlackNotificationPlugin(ExpirationNotificationPlugin): - title = 'Slack' - slug = 'slack-notification' - description = 'Sends notifications to Slack' + title = "Slack" + slug = "slack-notification" + description = "Sends notifications to Slack" version = slack.VERSION - author = 'Harm Weites' - author_url = 'https://github.com/netflix/lemur' + author = "Harm Weites" + author_url = "https://github.com/netflix/lemur" additional_options = [ { - 'name': 'webhook', - 'type': 'str', - 'required': True, - 'validation': '^https:\/\/hooks\.slack\.com\/services\/.+$', - 'helpMessage': 'The url Slack told you to use for this integration', - }, { - 'name': 'username', - 'type': 'str', - 'validation': '^.+$', - 'helpMessage': 'The great storyteller', - 'default': 'Lemur' - }, { - 'name': 'recipients', - 'type': 'str', - 'required': True, - 'validation': '^(@|#).+$', - 'helpMessage': 'Where to send to, either @username or #channel', + "name": "webhook", + "type": "str", + "required": True, + "validation": "^https:\/\/hooks\.slack\.com\/services\/.+$", + "helpMessage": "The url Slack told you to use for this integration", + }, + { + "name": "username", + "type": "str", + "validation": "^.+$", + "helpMessage": "The great storyteller", + "default": "Lemur", + }, + { + "name": "recipients", + "type": "str", + "required": True, + "validation": "^(@|#).+$", + "helpMessage": "Where to send to, either @username or #channel", }, ] @@ -122,25 +121,27 @@ class SlackNotificationPlugin(ExpirationNotificationPlugin): `lemur notify` """ attachments = None - if notification_type == 'expiration': + if notification_type == "expiration": attachments = create_expiration_attachments(message) - elif notification_type == 'rotation': + elif notification_type == "rotation": attachments = create_rotation_attachments(message) if not attachments: - raise Exception('Unable to create message attachments') + raise Exception("Unable to create message attachments") body = { - 'text': 'Lemur {0} Notification'.format(notification_type.capitalize()), - 'attachments': attachments, - 'channel': self.get_option('recipients', options), - 'username': self.get_option('username', options) + "text": "Lemur {0} Notification".format(notification_type.capitalize()), + "attachments": attachments, + "channel": self.get_option("recipients", options), + "username": self.get_option("username", options), } - r = requests.post(self.get_option('webhook', options), json.dumps(body)) + r = requests.post(self.get_option("webhook", options), json.dumps(body)) if r.status_code not in [200]: - raise Exception('Failed to send message') + raise Exception("Failed to send message") - current_app.logger.error("Slack response: {0} Message Body: {1}".format(r.status_code, body)) + current_app.logger.error( + "Slack response: {0} Message Body: {1}".format(r.status_code, body) + ) diff --git a/lemur/plugins/lemur_slack/tests/test_slack.py b/lemur/plugins/lemur_slack/tests/test_slack.py index 701f69d9..86add25f 100644 --- a/lemur/plugins/lemur_slack/tests/test_slack.py +++ b/lemur/plugins/lemur_slack/tests/test_slack.py @@ -1,33 +1,23 @@ - - def test_formatting(certificate): from lemur.plugins.lemur_slack.plugin import create_expiration_attachments from lemur.certificates.schemas import certificate_notification_output_schema + data = [certificate_notification_output_schema.dump(certificate).data] attachment = { - 'title': certificate.name, - 'color': 'danger', - 'fields': [ - { - 'short': True, - 'value': 'joe@example.com', - 'title': 'Owner' - }, - { - 'short': True, - 'value': u'Tuesday, December 31, 2047', - 'title': 'Expires' - }, { - 'short': True, - 'value': 0, - 'title': 'Endpoints Detected' - } + "title": certificate.name, + "color": "danger", + "fields": [ + {"short": True, "value": "joe@example.com", "title": "Owner"}, + {"short": True, "value": u"Tuesday, December 31, 2047", "title": "Expires"}, + {"short": True, "value": 0, "title": "Endpoints Detected"}, ], - 'title_link': 'https://lemur.example.com/#/certificates/{name}'.format(name=certificate.name), - 'mrkdwn_in': ['text'], - 'text': '', - 'fallback': '' + "title_link": "https://lemur.example.com/#/certificates/{name}".format( + name=certificate.name + ), + "mrkdwn_in": ["text"], + "text": "", + "fallback": "", } assert attachment == create_expiration_attachments(data)[0] diff --git a/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py b/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py index 3a751848..b4d708ce 100644 --- a/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py +++ b/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py @@ -1,4 +1,4 @@ try: - VERSION = __import__('pkg_resources').get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'Unknown' + VERSION = "Unknown" diff --git a/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py b/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py index a6a87c66..293b4634 100644 --- a/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py +++ b/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py @@ -6,40 +6,44 @@ from datadog import DogStatsd class StatsdMetricPlugin(MetricPlugin): - title = 'Statsd' - slug = 'statsd-metrics' - description = 'Adds support for sending metrics to Statsd' + title = "Statsd" + slug = "statsd-metrics" + description = "Adds support for sending metrics to Statsd" version = plug.VERSION def __init__(self): - host = current_app.config.get('STATSD_HOST') - port = current_app.config.get('STATSD_PORT') - prefix = current_app.config.get('STATSD_PREFIX') + host = current_app.config.get("STATSD_HOST") + port = current_app.config.get("STATSD_PORT") + prefix = current_app.config.get("STATSD_PREFIX") self.statsd = DogStatsd(host=host, port=port, namespace=prefix) - def submit(self, metric_name, metric_type, metric_value, metric_tags=None, options=None): - valid_types = ['COUNTER', 'GAUGE', 'TIMER'] + def submit( + self, metric_name, metric_type, metric_value, metric_tags=None, options=None + ): + valid_types = ["COUNTER", "GAUGE", "TIMER"] tags = [] if metric_type.upper() not in valid_types: raise Exception( "Invalid Metric Type for Statsd, '{metric}' choose from: {options}".format( - metric=metric_type, options=','.join(valid_types) + metric=metric_type, options=",".join(valid_types) ) ) if metric_tags: if not isinstance(metric_tags, dict): - raise Exception("Invalid Metric Tags for Statsd: Tags must be in dict format") + raise Exception( + "Invalid Metric Tags for Statsd: Tags must be in dict format" + ) else: tags = map(lambda e: "{0}:{1}".format(*e), metric_tags.items()) - if metric_type.upper() == 'COUNTER': + if metric_type.upper() == "COUNTER": self.statsd.increment(metric_name, metric_value, tags) - elif metric_type.upper() == 'GAUGE': + elif metric_type.upper() == "GAUGE": self.statsd.gauge(metric_name, metric_value, tags) - elif metric_type.upper() == 'TIMER': + elif metric_type.upper() == "TIMER": self.statsd.timing(metric_name, metric_value, tags) return diff --git a/lemur/plugins/lemur_statsd/setup.py b/lemur/plugins/lemur_statsd/setup.py index 6c4c2dd6..9b3c5f52 100644 --- a/lemur/plugins/lemur_statsd/setup.py +++ b/lemur/plugins/lemur_statsd/setup.py @@ -2,23 +2,16 @@ from __future__ import absolute_import from setuptools import setup, find_packages -install_requires = [ - 'lemur', - 'datadog' -] +install_requires = ["lemur", "datadog"] setup( - name='lemur_statsd', - version='1.0.0', - author='Cloudflare Security Engineering', - author_email='', + name="lemur_statsd", + version="1.0.0", + author="Cloudflare Security Engineering", + author_email="", include_package_data=True, packages=find_packages(), zip_safe=False, install_requires=install_requires, - entry_points={ - 'lemur.plugins': [ - 'statsd = lemur_statsd.plugin:StatsdMetricPlugin', - ] - } + entry_points={"lemur.plugins": ["statsd = lemur_statsd.plugin:StatsdMetricPlugin"]}, ) diff --git a/lemur/plugins/lemur_vault_dest/__init__.py b/lemur/plugins/lemur_vault_dest/__init__.py new file mode 100644 index 00000000..f8afd7e3 --- /dev/null +++ b/lemur/plugins/lemur_vault_dest/__init__.py @@ -0,0 +1,4 @@ +try: + VERSION = __import__("pkg_resources").get_distribution(__name__).version +except Exception as e: + VERSION = "unknown" diff --git a/lemur/plugins/lemur_vault_dest/plugin.py b/lemur/plugins/lemur_vault_dest/plugin.py new file mode 100755 index 00000000..e1715592 --- /dev/null +++ b/lemur/plugins/lemur_vault_dest/plugin.py @@ -0,0 +1,323 @@ +""" +.. module: lemur.plugins.lemur_vault_dest.plugin + :platform: Unix + :copyright: (c) 2019 + :license: Apache, see LICENCE for more details. + + Plugin for uploading certificates and private key as secret to hashi vault + that can be pulled down by end point nodes. + +.. moduleauthor:: Christopher Jolley +""" +import os +import re +import hvac +from flask import current_app + +from lemur.common.defaults import common_name +from lemur.common.utils import parse_certificate +from lemur.plugins.bases import DestinationPlugin +from lemur.plugins.bases import SourcePlugin + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend + + +class VaultSourcePlugin(SourcePlugin): + """ Class for importing certificates from Hashicorp Vault""" + + title = "Vault" + slug = "vault-source" + description = "Discovers all certificates in a given path" + + author = "Christopher Jolley" + author_url = "https://github.com/alwaysjolley/lemur" + + options = [ + { + "name": "vaultUrl", + "type": "str", + "required": True, + "validation": "^https?://[a-zA-Z0-9.:-]+$", + "helpMessage": "Valid URL to Hashi Vault instance", + }, + { + "name": "vaultKvApiVersion", + "type": "select", + "value": "2", + "available": ["1", "2"], + "required": True, + "helpMessage": "Version of the Vault KV API to use", + }, + { + "name": "vaultAuthTokenFile", + "type": "str", + "required": True, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + }, + { + "name": "vaultMount", + "type": "str", + "required": True, + "validation": r"^\S+$", + "helpMessage": "Must be a valid Vault secrets mount name!", + }, + { + "name": "vaultPath", + "type": "str", + "required": True, + "validation": "^([a-zA-Z0-9._-]+/?)+$", + "helpMessage": "Must be a valid Vault secrets path", + }, + { + "name": "objectName", + "type": "str", + "required": True, + "validation": "[0-9a-zA-Z.:_-]+", + "helpMessage": "Object Name to search", + }, + ] + + def get_certificates(self, options, **kwargs): + """Pull certificates from objects in Hashicorp Vault""" + data = [] + cert = [] + body = "" + url = self.get_option("vaultUrl", options) + token_file = self.get_option("vaultAuthTokenFile", options) + mount = self.get_option("vaultMount", options) + path = self.get_option("vaultPath", options) + obj_name = self.get_option("objectName", options) + api_version = self.get_option("vaultKvApiVersion", options) + cert_filter = "-----BEGIN CERTIFICATE-----" + cert_delimiter = "-----END CERTIFICATE-----" + + with open(token_file, "r") as tfile: + token = tfile.readline().rstrip("\n") + + client = hvac.Client(url=url, token=token) + client.secrets.kv.default_kv_version = api_version + + path = "{0}/{1}".format(path, obj_name) + + secret = get_secret(client, mount, path) + for cname in secret["data"]: + if "crt" in secret["data"][cname]: + cert = secret["data"][cname]["crt"].split(cert_delimiter + "\n") + elif "pem" in secret["data"][cname]: + cert = secret["data"][cname]["pem"].split(cert_delimiter + "\n") + else: + for key in secret["data"][cname]: + if secret["data"][cname][key].startswith(cert_filter): + cert = secret["data"][cname][key].split(cert_delimiter + "\n") + break + body = cert[0] + cert_delimiter + if "chain" in secret["data"][cname]: + chain = secret["data"][cname]["chain"] + elif len(cert) > 1: + if cert[1].startswith(cert_filter): + chain = cert[1] + cert_delimiter + else: + chain = None + else: + chain = None + data.append({"body": body, "chain": chain, "name": cname}) + return [ + dict(body=c["body"], chain=c.get("chain"), name=c["name"]) for c in data + ] + + def get_endpoints(self, options, **kwargs): + """ Not implemented yet """ + endpoints = [] + return endpoints + + +class VaultDestinationPlugin(DestinationPlugin): + """Hashicorp Vault Destination plugin for Lemur""" + + title = "Vault" + slug = "hashi-vault-destination" + description = "Allow the uploading of certificates to Hashi Vault as secret" + + author = "Christopher Jolley" + author_url = "https://github.com/alwaysjolley/lemur" + + options = [ + { + "name": "vaultUrl", + "type": "str", + "required": True, + "validation": "^https?://[a-zA-Z0-9.:-]+$", + "helpMessage": "Valid URL to Hashi Vault instance", + }, + { + "name": "vaultKvApiVersion", + "type": "select", + "value": "2", + "available": ["1", "2"], + "required": True, + "helpMessage": "Version of the Vault KV API to use", + }, + { + "name": "vaultAuthTokenFile", + "type": "str", + "required": True, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + }, + { + "name": "vaultMount", + "type": "str", + "required": True, + "validation": r"^\S+$", + "helpMessage": "Must be a valid Vault secrets mount name!", + }, + { + "name": "vaultPath", + "type": "str", + "required": True, + "validation": "^([a-zA-Z0-9._-]+/?)+$", + "helpMessage": "Must be a valid Vault secrets path", + }, + { + "name": "objectName", + "type": "str", + "required": False, + "validation": "[0-9a-zA-Z.:_-]+", + "helpMessage": "Name to bundle certs under, if blank use cn", + }, + { + "name": "bundleChain", + "type": "select", + "value": "cert only", + "available": ["Nginx", "Apache", "PEM", "no chain"], + "required": True, + "helpMessage": "Bundle the chain into the certificate", + }, + { + "name": "sanFilter", + "type": "str", + "value": ".*", + "required": False, + "validation": ".*", + "helpMessage": "Valid regex filter", + }, + ] + + def __init__(self, *args, **kwargs): + super(VaultDestinationPlugin, self).__init__(*args, **kwargs) + + def upload(self, name, body, private_key, cert_chain, options, **kwargs): + """ + Upload certificate and private key + + :param private_key: + :param cert_chain: + :return: + """ + cname = common_name(parse_certificate(body)) + + url = self.get_option("vaultUrl", options) + token_file = self.get_option("vaultAuthTokenFile", options) + mount = self.get_option("vaultMount", options) + path = self.get_option("vaultPath", options) + bundle = self.get_option("bundleChain", options) + obj_name = self.get_option("objectName", options) + api_version = self.get_option("vaultKvApiVersion", options) + san_filter = self.get_option("sanFilter", options) + + san_list = get_san_list(body) + if san_filter: + for san in san_list: + try: + if not re.match(san_filter, san, flags=re.IGNORECASE): + current_app.logger.exception( + "Exception uploading secret to vault: invalid SAN: {}".format( + san + ), + exc_info=True, + ) + os._exit(1) + except re.error: + current_app.logger.exception( + "Exception compiling regex filter: invalid filter", + exc_info=True, + ) + + with open(token_file, "r") as tfile: + token = tfile.readline().rstrip("\n") + + client = hvac.Client(url=url, token=token) + client.secrets.kv.default_kv_version = api_version + + if obj_name: + path = "{0}/{1}".format(path, obj_name) + else: + path = "{0}/{1}".format(path, cname) + + secret = get_secret(client, mount, path) + secret["data"][cname] = {} + + if not cert_chain: + chain = '' + else: + chain = cert_chain + + if bundle == "Nginx": + secret["data"][cname]["crt"] = "{0}\n{1}".format(body, chain) + secret["data"][cname]["key"] = private_key + elif bundle == "Apache": + secret["data"][cname]["crt"] = body + secret["data"][cname]["chain"] = chain + secret["data"][cname]["key"] = private_key + elif bundle == "PEM": + secret["data"][cname]["pem"] = "{0}\n{1}\n{2}".format( + body, chain, private_key + ) + else: + secret["data"][cname]["crt"] = body + secret["data"][cname]["key"] = private_key + if isinstance(san_list, list): + secret["data"][cname]["san"] = san_list + try: + client.secrets.kv.create_or_update_secret( + path=path, mount_point=mount, secret=secret["data"] + ) + except ConnectionError as err: + current_app.logger.exception( + "Exception uploading secret to vault: {0}".format(err), exc_info=True + ) + + +def get_san_list(body): + """ parse certificate for SAN names and return list, return empty list on error """ + san_list = [] + try: + byte_body = body.encode("utf-8") + cert = x509.load_pem_x509_certificate(byte_body, default_backend()) + ext = cert.extensions.get_extension_for_oid( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ) + san_list = ext.value.get_values_for_type(x509.DNSName) + except x509.extensions.ExtensionNotFound: + pass + finally: + return san_list + + +def get_secret(client, mount, path): + """ retreive existing data from mount path and return dictionary """ + result = {"data": {}} + try: + if client.secrets.kv.default_kv_version == "1": + result = client.secrets.kv.v1.read_secret(path=path, mount_point=mount) + else: + result = client.secrets.kv.v2.read_secret_version( + path=path, mount_point=mount + ) + result = result['data'] + except ConnectionError: + pass + finally: + return result diff --git a/lemur/plugins/lemur_vault_dest/tests/conftest.py b/lemur/plugins/lemur_vault_dest/tests/conftest.py new file mode 100644 index 00000000..0e1cd89f --- /dev/null +++ b/lemur/plugins/lemur_vault_dest/tests/conftest.py @@ -0,0 +1 @@ +from lemur.tests.conftest import * # noqa diff --git a/lemur/plugins/lemur_verisign/__init__.py b/lemur/plugins/lemur_verisign/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_verisign/__init__.py +++ b/lemur/plugins/lemur_verisign/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_verisign/plugin.py b/lemur/plugins/lemur_verisign/plugin.py index 3e672a43..7bf517b7 100644 --- a/lemur/plugins/lemur_verisign/plugin.py +++ b/lemur/plugins/lemur_verisign/plugin.py @@ -14,7 +14,7 @@ from cryptography import x509 from flask import current_app from lemur.common.utils import get_psuedo_random_string -from lemur.extensions import metrics +from lemur.extensions import metrics, sentry from lemur.plugins import lemur_verisign as verisign from lemur.plugins.bases import IssuerPlugin, SourcePlugin @@ -58,7 +58,7 @@ VERISIGN_ERRORS = { "0x300a": "Domain/SubjectAltName Mismatched -- make sure that the SANs have the proper domain suffix", "0x950e": "Invalid Common Name -- make sure the CN has a proper domain suffix", "0xa00e": "Pending. (Insufficient number of tokens.)", - "0x8134": "Pending. (Domain failed CAA validation.)" + "0x8134": "Pending. (Domain failed CAA validation.)", } @@ -71,7 +71,7 @@ def log_status_code(r, *args, **kwargs): :param kwargs: :return: """ - metrics.send('symantec_status_code_{}'.format(r.status_code), 'counter', 1) + metrics.send("symantec_status_code_{}".format(r.status_code), "counter", 1) def get_additional_names(options): @@ -83,8 +83,8 @@ def get_additional_names(options): """ names = [] # add SANs if present - if options.get('extensions'): - for san in options['extensions']['sub_alt_names']: + if options.get("extensions"): + for san in options["extensions"]["sub_alt_names"]: if isinstance(san, x509.DNSName): names.append(san.value) return names @@ -99,28 +99,41 @@ def process_options(options): :return: dict or valid verisign options """ data = { - 'challenge': get_psuedo_random_string(), - 'serverType': 'Apache', - 'certProductType': 'Server', - 'firstName': current_app.config.get("VERISIGN_FIRST_NAME"), - 'lastName': current_app.config.get("VERISIGN_LAST_NAME"), - 'signatureAlgorithm': 'sha256WithRSAEncryption', - 'email': current_app.config.get("VERISIGN_EMAIL"), - 'ctLogOption': current_app.config.get("VERISIGN_CS_LOG_OPTION", "public"), + "challenge": get_psuedo_random_string(), + "serverType": "Apache", + "certProductType": "Server", + "firstName": current_app.config.get("VERISIGN_FIRST_NAME"), + "lastName": current_app.config.get("VERISIGN_LAST_NAME"), + "signatureAlgorithm": "sha256WithRSAEncryption", + "email": current_app.config.get("VERISIGN_EMAIL"), + "ctLogOption": current_app.config.get("VERISIGN_CS_LOG_OPTION", "public"), } - data['subject_alt_names'] = ",".join(get_additional_names(options)) + data["subject_alt_names"] = ",".join(get_additional_names(options)) - if options.get('validity_end'): - period = get_default_issuance(options) - data['specificEndDate'] = options['validity_end'].format("MM/DD/YYYY") - data['validityPeriod'] = period + if options.get("validity_end") > arrow.utcnow().shift(years=2): + raise Exception( + "Verisign issued certificates cannot exceed two years in validity" + ) - elif options.get('validity_years'): - if options['validity_years'] in [1, 2]: - data['validityPeriod'] = str(options['validity_years']) + 'Y' + if options.get("validity_end"): + # VeriSign (Symantec) only accepts strictly smaller than 2 year end date + if options.get("validity_end") < arrow.utcnow().shift(years=2, days=-1): + period = get_default_issuance(options) + data["specificEndDate"] = options["validity_end"].format("MM/DD/YYYY") + data["validityPeriod"] = period else: - raise Exception("Verisign issued certificates cannot exceed two years in validity") + # allowing Symantec website setting the end date, given the validity period + data["validityPeriod"] = str(get_default_issuance(options)) + options.pop("validity_end", None) + + elif options.get("validity_years"): + if options["validity_years"] in [1, 2]: + data["validityPeriod"] = str(options["validity_years"]) + "Y" + else: + raise Exception( + "Verisign issued certificates cannot exceed two years in validity" + ) return data @@ -134,12 +147,14 @@ def get_default_issuance(options): """ now = arrow.utcnow() - if options['validity_end'] < now.replace(years=+1): - validity_period = '1Y' - elif options['validity_end'] < now.replace(years=+2): - validity_period = '2Y' + if options["validity_end"] < now.shift(years=+1): + validity_period = "1Y" + elif options["validity_end"] < now.shift(years=+2): + validity_period = "2Y" else: - raise Exception("Verisign issued certificates cannot exceed two years in validity") + raise Exception( + "Verisign issued certificates cannot exceed two years in validity" + ) return validity_period @@ -152,27 +167,27 @@ def handle_response(content): """ d = xmltodict.parse(content) global VERISIGN_ERRORS - if d.get('Error'): - status_code = d['Error']['StatusCode'] - elif d.get('Response'): - status_code = d['Response']['StatusCode'] + if d.get("Error"): + status_code = d["Error"]["StatusCode"] + elif d.get("Response"): + status_code = d["Response"]["StatusCode"] if status_code in VERISIGN_ERRORS.keys(): raise Exception(VERISIGN_ERRORS[status_code]) return d class VerisignIssuerPlugin(IssuerPlugin): - title = 'Verisign' - slug = 'verisign-issuer' - description = 'Enables the creation of certificates by the VICE2.0 verisign API.' + title = "Verisign" + slug = "verisign-issuer" + description = "Enables the creation of certificates by the VICE2.0 verisign API." version = verisign.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): self.session = requests.Session() - self.session.cert = current_app.config.get('VERISIGN_PEM_PATH') + self.session.cert = current_app.config.get("VERISIGN_PEM_PATH") self.session.hooks = dict(response=log_status_code) super(VerisignIssuerPlugin, self).__init__(*args, **kwargs) @@ -184,17 +199,31 @@ class VerisignIssuerPlugin(IssuerPlugin): :param issuer_options: :return: :raise Exception: """ - url = current_app.config.get("VERISIGN_URL") + '/rest/services/enroll' + url = current_app.config.get("VERISIGN_URL") + "/rest/services/enroll" data = process_options(issuer_options) - data['csr'] = csr + data["csr"] = csr - current_app.logger.info("Requesting a new verisign certificate: {0}".format(data)) + current_app.logger.info( + "Requesting a new verisign certificate: {0}".format(data) + ) response = self.session.post(url, data=data) - cert = handle_response(response.content)['Response']['Certificate'] + try: + cert = handle_response(response.content)["Response"]["Certificate"] + except KeyError: + metrics.send( + "verisign_create_certificate_error", + "counter", + 1, + metric_tags={"common_name": issuer_options.get("common_name", "")}, + ) + sentry.captureException( + extra={"common_name": issuer_options.get("common_name", "")} + ) + raise Exception(f"Error with Verisign: {response.content}") # TODO add external id - return cert, current_app.config.get('VERISIGN_INTERMEDIATE'), None + return cert, current_app.config.get("VERISIGN_INTERMEDIATE"), None @staticmethod def create_authority(options): @@ -205,8 +234,8 @@ class VerisignIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'verisign'} - return current_app.config.get('VERISIGN_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "verisign"} + return current_app.config.get("VERISIGN_ROOT"), "", [role] def get_available_units(self): """ @@ -215,9 +244,11 @@ class VerisignIssuerPlugin(IssuerPlugin): :return: """ - url = current_app.config.get("VERISIGN_URL") + '/rest/services/getTokens' - response = self.session.post(url, headers={'content-type': 'application/x-www-form-urlencoded'}) - return handle_response(response.content)['Response']['Order'] + url = current_app.config.get("VERISIGN_URL") + "/rest/services/getTokens" + response = self.session.post( + url, headers={"content-type": "application/x-www-form-urlencoded"} + ) + return handle_response(response.content)["Response"]["Order"] def clear_pending_certificates(self): """ @@ -225,52 +256,54 @@ class VerisignIssuerPlugin(IssuerPlugin): :return: """ - url = current_app.config.get('VERISIGN_URL') + '/reportingws' + url = current_app.config.get("VERISIGN_URL") + "/reportingws" end = arrow.now() - start = end.replace(days=-7) + start = end.shift(days=-7) data = { - 'reportType': 'detail', - 'certProductType': 'Server', - 'certStatus': 'Pending', - 'startDate': start.format("MM/DD/YYYY"), - 'endDate': end.format("MM/DD/YYYY") + "reportType": "detail", + "certProductType": "Server", + "certStatus": "Pending", + "startDate": start.format("MM/DD/YYYY"), + "endDate": end.format("MM/DD/YYYY"), } response = self.session.post(url, data=data) - url = current_app.config.get('VERISIGN_URL') + '/rest/services/reject' - for order_id in response.json()['orderNumber']: - response = self.session.get(url, params={'transaction_id': order_id}) + url = current_app.config.get("VERISIGN_URL") + "/rest/services/reject" + for order_id in response.json()["orderNumber"]: + response = self.session.get(url, params={"transaction_id": order_id}) if response.status_code == 200: print("Rejecting certificate. TransactionId: {}".format(order_id)) class VerisignSourcePlugin(SourcePlugin): - title = 'Verisign' - slug = 'verisign-source' - description = 'Allows for the polling of issued certificates from the VICE2.0 verisign API.' + title = "Verisign" + slug = "verisign-source" + description = ( + "Allows for the polling of issued certificates from the VICE2.0 verisign API." + ) version = verisign.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): self.session = requests.Session() - self.session.cert = current_app.config.get('VERISIGN_PEM_PATH') + self.session.cert = current_app.config.get("VERISIGN_PEM_PATH") super(VerisignSourcePlugin, self).__init__(*args, **kwargs) def get_certificates(self): - url = current_app.config.get('VERISIGN_URL') + '/reportingws' + url = current_app.config.get("VERISIGN_URL") + "/reportingws" end = arrow.now() - start = end.replace(years=-5) + start = end.shift(years=-5) data = { - 'reportType': 'detail', - 'startDate': start.format("MM/DD/YYYY"), - 'endDate': end.format("MM/DD/YYYY"), - 'structuredRecord': 'Y', - 'certStatus': 'Valid', + "reportType": "detail", + "startDate": start.format("MM/DD/YYYY"), + "endDate": end.format("MM/DD/YYYY"), + "structuredRecord": "Y", + "certStatus": "Valid", } current_app.logger.debug(data) response = self.session.post(url, data=data) diff --git a/lemur/plugins/lemur_verisign/tests/test_verisign.py b/lemur/plugins/lemur_verisign/tests/test_verisign.py index 8c4f1d81..42c528e8 100644 --- a/lemur/plugins/lemur_verisign/tests/test_verisign.py +++ b/lemur/plugins/lemur_verisign/tests/test_verisign.py @@ -1,4 +1,4 @@ - def test_get_certificates(app): from lemur.plugins.base import plugins - p = plugins.get('verisign-issuer') + + p = plugins.get("verisign-issuer") diff --git a/lemur/plugins/utils.py b/lemur/plugins/utils.py index a1914dd7..19655519 100644 --- a/lemur/plugins/utils.py +++ b/lemur/plugins/utils.py @@ -17,5 +17,15 @@ def get_plugin_option(name, options): :return: """ for o in options: - if o.get('name') == name: - return o['value'] + if o.get("name") == name: + return o.get("value", o.get("default")) + + +def set_plugin_option(name, value, options): + """ + Set value for option name for options dict. + :param options: + """ + for o in options: + if o.get("name") == name: + o.update({"value": value}) diff --git a/lemur/plugins/views.py b/lemur/plugins/views.py index dbdfccab..605b234a 100644 --- a/lemur/plugins/views.py +++ b/lemur/plugins/views.py @@ -15,12 +15,13 @@ from lemur.schemas import plugins_output_schema, plugin_output_schema from lemur.common.schema import validate_schema from lemur.plugins.base import plugins -mod = Blueprint('plugins', __name__) +mod = Blueprint("plugins", __name__) api = Api(mod) class PluginsList(AuthenticatedResource): """ Defines the 'plugins' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(PluginsList, self).__init__() @@ -69,17 +70,18 @@ class PluginsList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - self.reqparse.add_argument('type', type=str, location='args') + self.reqparse.add_argument("type", type=str, location="args") args = self.reqparse.parse_args() - if args['type']: - return list(plugins.all(plugin_type=args['type'])) + if args["type"]: + return list(plugins.all(plugin_type=args["type"])) return list(plugins.all()) class Plugins(AuthenticatedResource): """ Defines the 'plugins' endpoint """ + def __init__(self): super(Plugins, self).__init__() @@ -118,5 +120,5 @@ class Plugins(AuthenticatedResource): return plugins.get(name) -api.add_resource(PluginsList, '/plugins', endpoint='plugins') -api.add_resource(Plugins, '/plugins/', endpoint='pluginName') +api.add_resource(PluginsList, "/plugins", endpoint="plugins") +api.add_resource(Plugins, "/plugins/", endpoint="pluginName") diff --git a/lemur/policies/cli.py b/lemur/policies/cli.py index 725c1583..317f3414 100644 --- a/lemur/policies/cli.py +++ b/lemur/policies/cli.py @@ -12,8 +12,8 @@ from lemur.policies import service as policy_service manager = Manager(usage="Handles all policy related tasks.") -@manager.option('-d', '--days', dest='days', help='Number of days before expiration.') -@manager.option('-n', '--name', dest='name', help='Policy name.') +@manager.option("-d", "--days", dest="days", help="Number of days before expiration.") +@manager.option("-n", "--name", dest="name", help="Policy name.") def create(days, name): """ Create a new certificate rotation policy diff --git a/lemur/policies/models.py b/lemur/policies/models.py index 2329a347..a17d3ca1 100644 --- a/lemur/policies/models.py +++ b/lemur/policies/models.py @@ -12,10 +12,12 @@ from lemur.database import db class RotationPolicy(db.Model): - __tablename__ = 'rotation_policies' + __tablename__ = "rotation_policies" id = Column(Integer, primary_key=True) name = Column(String) days = Column(Integer) def __repr__(self): - return "RotationPolicy(days={days}, name={name})".format(days=self.days, name=self.name) + return "RotationPolicy(days={days}, name={name})".format( + days=self.days, name=self.name + ) diff --git a/lemur/policies/service.py b/lemur/policies/service.py index 10e9053b..cb43d52e 100644 --- a/lemur/policies/service.py +++ b/lemur/policies/service.py @@ -24,7 +24,7 @@ def get_by_name(policy_name): :param policy_name: :return: """ - return database.get_all(RotationPolicy, policy_name, field='name').all() + return database.get_all(RotationPolicy, policy_name, field="name").all() def delete(policy_id): diff --git a/lemur/reporting/cli.py b/lemur/reporting/cli.py index 8f797c33..c92b79cd 100644 --- a/lemur/reporting/cli.py +++ b/lemur/reporting/cli.py @@ -13,49 +13,73 @@ from lemur.reporting.service import fqdns, expiring_certificates manager = Manager(usage="Reporting related tasks.") -@manager.option('-v', '--validity', dest='validity', choices=['all', 'expired', 'valid'], default='all', help='Filter certificates by validity.') -@manager.option('-d', '--deployment', dest='deployment', choices=['all', 'deployed', 'ready'], default='all', help='Filter by deployment status.') +@manager.option( + "-v", + "--validity", + dest="validity", + choices=["all", "expired", "valid"], + default="all", + help="Filter certificates by validity.", +) +@manager.option( + "-d", + "--deployment", + dest="deployment", + choices=["all", "deployed", "ready"], + default="all", + help="Filter by deployment status.", +) def fqdn(deployment, validity): """ Generates a report in order to determine the number of FQDNs covered by Lemur issued certificates. """ - headers = ['FQDN', 'Root Domain', 'Issuer', 'Owner', 'Validity End', 'Total Length (days), Time Until Expiration (days)'] + headers = [ + "FQDN", + "Root Domain", + "Issuer", + "Owner", + "Validity End", + "Total Length (days), Time Until Expiration (days)", + ] rows = [] for cert in fqdns(validity=validity, deployment=deployment).all(): for domain in cert.domains: - rows.append([ - domain.name, - '.'.join(domain.name.split('.')[1:]), - cert.issuer, - cert.owner, - cert.not_after, - cert.validity_range.days, - cert.validity_remaining.days - ]) + rows.append( + [ + domain.name, + ".".join(domain.name.split(".")[1:]), + cert.issuer, + cert.owner, + cert.not_after, + cert.validity_range.days, + cert.validity_remaining.days, + ] + ) print(tabulate(rows, headers=headers)) -@manager.option('-ttl', '--ttl', dest='ttl', default=30, help='Days til expiration.') -@manager.option('-d', '--deployment', dest='deployment', choices=['all', 'deployed', 'ready'], default='all', help='Filter by deployment status.') +@manager.option("-ttl", "--ttl", dest="ttl", default=30, help="Days til expiration.") +@manager.option( + "-d", + "--deployment", + dest="deployment", + choices=["all", "deployed", "ready"], + default="all", + help="Filter by deployment status.", +) def expiring(ttl, deployment): """ Returns certificates expiring in the next n days. """ - headers = ['Common Name', 'Owner', 'Issuer', 'Validity End', 'Endpoint'] + headers = ["Common Name", "Owner", "Issuer", "Validity End", "Endpoint"] rows = [] for cert in expiring_certificates(ttl=ttl, deployment=deployment).all(): for endpoint in cert.endpoints: rows.append( - [ - cert.cn, - cert.owner, - cert.issuer, - cert.not_after, - endpoint.dnsname - ] + [cert.cn, cert.owner, cert.issuer, cert.not_after, endpoint.dnsname] ) print(tabulate(rows, headers=headers)) diff --git a/lemur/reporting/service.py b/lemur/reporting/service.py index 348cf2f4..77eb7b3e 100644 --- a/lemur/reporting/service.py +++ b/lemur/reporting/service.py @@ -9,10 +9,10 @@ from lemur.certificates.models import Certificate def filter_by_validity(query, validity=None): - if validity == 'expired': + if validity == "expired": query = query.filter(Certificate.expired == True) # noqa - elif validity == 'valid': + elif validity == "valid": query = query.filter(Certificate.expired == False) # noqa return query @@ -33,10 +33,10 @@ def filter_by_issuer(query, issuer=None): def filter_by_deployment(query, deployment=None): - if deployment == 'deployed': + if deployment == "deployed": query = query.filter(Certificate.endpoints.any()) - elif deployment == 'ready': + elif deployment == "ready": query = query.filter(not_(Certificate.endpoints.any())) return query @@ -55,8 +55,8 @@ def fqdns(**kwargs): :return: """ query = database.session_query(Certificate) - query = filter_by_deployment(query, deployment=kwargs.get('deployed')) - query = filter_by_validity(query, validity=kwargs.get('validity')) + query = filter_by_deployment(query, deployment=kwargs.get("deployed")) + query = filter_by_validity(query, validity=kwargs.get("validity")) return query @@ -65,13 +65,13 @@ def expiring_certificates(**kwargs): Returns an Expiring report. :return: """ - ttl = kwargs.get('ttl', 30) + ttl = kwargs.get("ttl", 30) now = arrow.utcnow() validity_end = now + timedelta(days=ttl) query = database.session_query(Certificate) - query = filter_by_deployment(query, deployment=kwargs.get('deployed')) - query = filter_by_validity(query, validity='valid') + query = filter_by_deployment(query, deployment=kwargs.get("deployed")) + query = filter_by_validity(query, validity="valid") query = filter_by_validity_end(query, validity_end=validity_end) return query diff --git a/lemur/roles/models.py b/lemur/roles/models.py index 85bf1bf1..91b5d58c 100644 --- a/lemur/roles/models.py +++ b/lemur/roles/models.py @@ -14,26 +14,42 @@ from sqlalchemy import Boolean, Column, Integer, String, Text, ForeignKey from lemur.database import db from lemur.utils import Vault -from lemur.models import roles_users, roles_authorities, roles_certificates, \ - pending_cert_role_associations +from lemur.models import ( + roles_users, + roles_authorities, + roles_certificates, + pending_cert_role_associations, +) class Role(db.Model): - __tablename__ = 'roles' + __tablename__ = "roles" id = Column(Integer, primary_key=True) name = Column(String(128), unique=True) username = Column(String(128)) password = Column(Vault) description = Column(Text) - authority_id = Column(Integer, ForeignKey('authorities.id')) - authorities = relationship("Authority", secondary=roles_authorities, passive_deletes=True, backref="role", cascade='all,delete') - user_id = Column(Integer, ForeignKey('users.id')) + authority_id = Column(Integer, ForeignKey("authorities.id")) + authorities = relationship( + "Authority", + secondary=roles_authorities, + passive_deletes=True, + backref="role", + cascade="all,delete", + ) + user_id = Column(Integer, ForeignKey("users.id")) third_party = Column(Boolean) - users = relationship("User", secondary=roles_users, passive_deletes=True, backref="role") - certificates = relationship("Certificate", secondary=roles_certificates, backref="role") - pending_certificates = relationship("PendingCertificate", secondary=pending_cert_role_associations, backref="role") + users = relationship( + "User", secondary=roles_users, passive_deletes=True, backref="role" + ) + certificates = relationship( + "Certificate", secondary=roles_certificates, backref="role" + ) + pending_certificates = relationship( + "PendingCertificate", secondary=pending_cert_role_associations, backref="role" + ) - sensitive_fields = ('password',) + sensitive_fields = ("password",) def __repr__(self): return "Role(name={name})".format(name=self.name) diff --git a/lemur/roles/service.py b/lemur/roles/service.py index bbeef1ce..51597d6e 100644 --- a/lemur/roles/service.py +++ b/lemur/roles/service.py @@ -47,7 +47,9 @@ def set_third_party(role_id, third_party_status=False): return role -def create(name, password=None, description=None, username=None, users=None, third_party=False): +def create( + name, password=None, description=None, username=None, users=None, third_party=False +): """ Create a new role @@ -58,7 +60,13 @@ def create(name, password=None, description=None, username=None, users=None, thi :param password: :return: """ - role = Role(name=name, description=description, username=username, password=password, third_party=third_party) + role = Role( + name=name, + description=description, + username=username, + password=password, + third_party=third_party, + ) if users: role.users = users @@ -83,7 +91,7 @@ def get_by_name(role_name): :param role_name: :return: """ - return database.get(Role, role_name, field='name') + return database.get(Role, role_name, field="name") def delete(role_id): @@ -105,9 +113,9 @@ def render(args): :return: """ query = database.session_query(Role) - filt = args.pop('filter') - user_id = args.pop('user_id', None) - authority_id = args.pop('authority_id', None) + filt = args.pop("filter") + user_id = args.pop("user_id", None) + authority_id = args.pop("authority_id", None) if user_id: query = query.filter(Role.users.any(User.id == user_id)) @@ -116,7 +124,7 @@ def render(args): query = query.filter(Role.authority_id == authority_id) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Role, terms) return database.sort_and_page(query, Role, args) diff --git a/lemur/roles/views.py b/lemur/roles/views.py index a635fdba..1e12f24b 100644 --- a/lemur/roles/views.py +++ b/lemur/roles/views.py @@ -17,15 +17,20 @@ from lemur.auth.permissions import RoleMemberPermission, admin_permission from lemur.common.utils import paginated_parser from lemur.common.schema import validate_schema -from lemur.roles.schemas import role_input_schema, role_output_schema, roles_output_schema +from lemur.roles.schemas import ( + role_input_schema, + role_output_schema, + roles_output_schema, +) -mod = Blueprint('roles', __name__) +mod = Blueprint("roles", __name__) api = Api(mod) class RolesList(AuthenticatedResource): """ Defines the 'roles' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(RolesList, self).__init__() @@ -79,11 +84,11 @@ class RolesList(AuthenticatedResource): :statuscode 403: unauthenticated """ parser = paginated_parser.copy() - parser.add_argument('owner', type=str, location='args') - parser.add_argument('id', type=str, location='args') + parser.add_argument("owner", type=str, location="args") + parser.add_argument("id", type=str, location="args") args = parser.parse_args() - args['user'] = g.current_user + args["user"] = g.current_user return service.render(args) @admin_permission.require(http_exception=403) @@ -135,8 +140,13 @@ class RolesList(AuthenticatedResource): :statuscode 200: no error :statuscode 403: unauthenticated """ - return service.create(data['name'], data.get('password'), data.get('description'), data.get('username'), - data.get('users')) + return service.create( + data["name"], + data.get("password"), + data.get("description"), + data.get("username"), + data.get("users"), + ) class RoleViewCredentials(AuthenticatedResource): @@ -177,11 +187,18 @@ class RoleViewCredentials(AuthenticatedResource): permission = RoleMemberPermission(role_id) if permission.can(): role = service.get(role_id) - response = make_response(jsonify(username=role.username, password=role.password), 200) - response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store' - response.headers['pragma'] = 'no-cache' + response = make_response( + jsonify(username=role.username, password=role.password), 200 + ) + response.headers["cache-control"] = "private, max-age=0, no-cache, no-store" + response.headers["pragma"] = "no-cache" return response - return dict(message='You are not authorized to view the credentials for this role.'), 403 + return ( + dict( + message="You are not authorized to view the credentials for this role." + ), + 403, + ) class Roles(AuthenticatedResource): @@ -227,7 +244,12 @@ class Roles(AuthenticatedResource): if permission.can(): return service.get(role_id) - return dict(message="You are not allowed to view a role which you are not a member of."), 403 + return ( + dict( + message="You are not allowed to view a role which you are not a member of." + ), + 403, + ) @validate_schema(role_input_schema, role_output_schema) def put(self, role_id, data=None): @@ -269,8 +291,10 @@ class Roles(AuthenticatedResource): """ permission = RoleMemberPermission(role_id) if permission.can(): - return service.update(role_id, data['name'], data.get('description'), data.get('users')) - return dict(message='You are not authorized to modify this role.'), 403 + return service.update( + role_id, data["name"], data.get("description"), data.get("users") + ) + return dict(message="You are not authorized to modify this role."), 403 @admin_permission.require(http_exception=403) def delete(self, role_id): @@ -304,11 +328,12 @@ class Roles(AuthenticatedResource): :statuscode 403: unauthenticated """ service.delete(role_id) - return {'message': 'ok'} + return {"message": "ok"} class UserRolesList(AuthenticatedResource): """ Defines the 'roles' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(UserRolesList, self).__init__() @@ -362,12 +387,13 @@ class UserRolesList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['user_id'] = user_id + args["user_id"] = user_id return service.render(args) class AuthorityRolesList(AuthenticatedResource): """ Defines the 'roles' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(AuthorityRolesList, self).__init__() @@ -421,12 +447,18 @@ class AuthorityRolesList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['authority_id'] = authority_id + args["authority_id"] = authority_id return service.render(args) -api.add_resource(RolesList, '/roles', endpoint='roles') -api.add_resource(Roles, '/roles/', endpoint='role') -api.add_resource(RoleViewCredentials, '/roles//credentials', endpoint='roleCredentials`') -api.add_resource(AuthorityRolesList, '/authorities//roles', endpoint='authorityRoles') -api.add_resource(UserRolesList, '/users//roles', endpoint='userRoles') +api.add_resource(RolesList, "/roles", endpoint="roles") +api.add_resource(Roles, "/roles/", endpoint="role") +api.add_resource( + RoleViewCredentials, "/roles//credentials", endpoint="roleCredentials`" +) +api.add_resource( + AuthorityRolesList, + "/authorities//roles", + endpoint="authorityRoles", +) +api.add_resource(UserRolesList, "/users//roles", endpoint="userRoles") diff --git a/lemur/schemas.py b/lemur/schemas.py index ffdfe66f..e7b0fd64 100644 --- a/lemur/schemas.py +++ b/lemur/schemas.py @@ -14,7 +14,12 @@ from marshmallow.exceptions import ValidationError from lemur.common import validators from lemur.common.schema import LemurSchema, LemurInputSchema, LemurOutputSchema -from lemur.common.fields import KeyUsageExtension, ExtendedKeyUsageExtension, BasicConstraintsExtension, SubjectAlternativeNameExtension +from lemur.common.fields import ( + KeyUsageExtension, + ExtendedKeyUsageExtension, + BasicConstraintsExtension, + SubjectAlternativeNameExtension, +) from lemur.plugins import plugins from lemur.plugins.utils import get_plugin_option @@ -34,40 +39,42 @@ def validate_options(options): :param options: :return: """ - interval = get_plugin_option('interval', options) - unit = get_plugin_option('unit', options) + interval = get_plugin_option("interval", options) + unit = get_plugin_option("unit", options) if not interval and not unit: return - if unit == 'month': + if unit == "month": interval *= 30 - elif unit == 'week': + elif unit == "week": interval *= 7 if interval > 90: - raise ValidationError('Notification cannot be more than 90 days into the future.') + raise ValidationError( + "Notification cannot be more than 90 days into the future." + ) def get_object_attribute(data, many=False): if many: - ids = [d.get('id') for d in data] - names = [d.get('name') for d in data] + ids = [d.get("id") for d in data] + names = [d.get("name") for d in data] if None in ids: if None in names: - raise ValidationError('Associated object require a name or id.') + raise ValidationError("Associated object require a name or id.") else: - return 'name' - return 'id' + return "name" + return "id" else: - if data.get('id'): - return 'id' - elif data.get('name'): - return 'name' + if data.get("id"): + return "id" + elif data.get("name"): + return "name" else: - raise ValidationError('Associated object require a name or id.') + raise ValidationError("Associated object require a name or id.") def fetch_objects(model, data, many=False): @@ -80,10 +87,11 @@ def fetch_objects(model, data, many=False): diff = set(values).symmetric_difference(set(found)) if diff: - raise ValidationError('Unable to locate {model} with {attr} {diff}'.format( - model=model, - attr=attr, - diff=",".join(list(diff)))) + raise ValidationError( + "Unable to locate {model} with {attr} {diff}".format( + model=model, attr=attr, diff=",".join(list(diff)) + ) + ) return items @@ -91,10 +99,11 @@ def fetch_objects(model, data, many=False): try: return model.query.filter(getattr(model, attr) == data[attr]).one() except NoResultFound: - raise ValidationError('Unable to find {model} with {attr}: {data}'.format( - model=model, - attr=attr, - data=data[attr])) + raise ValidationError( + "Unable to find {model} with {attr}: {data}".format( + model=model, attr=attr, data=data[attr] + ) + ) class AssociatedAuthoritySchema(LemurInputSchema): @@ -178,17 +187,19 @@ class PluginInputSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): try: - data['plugin_object'] = plugins.get(data['slug']) + data["plugin_object"] = plugins.get(data["slug"]) # parse any sub-plugins - for option in data.get('plugin_options', []): - if 'plugin' in option.get('type', []): - sub_data, errors = PluginInputSchema().load(option['value']) - option['value'] = sub_data + for option in data.get("plugin_options", []): + if "plugin" in option.get("type", []): + sub_data, errors = PluginInputSchema().load(option["value"]) + option["value"] = sub_data return data except Exception as e: - raise ValidationError('Unable to find plugin. Slug: {0} Reason: {1}'.format(data['slug'], e)) + raise ValidationError( + "Unable to find plugin. Slug: {0} Reason: {1}".format(data["slug"], e) + ) class PluginOutputSchema(LemurOutputSchema): @@ -196,7 +207,7 @@ class PluginOutputSchema(LemurOutputSchema): label = fields.String() description = fields.String() active = fields.Boolean() - options = fields.List(fields.Dict(), dump_to='pluginOptions') + options = fields.List(fields.Dict(), dump_to="pluginOptions") slug = fields.String() title = fields.String() @@ -227,7 +238,7 @@ class CertificateInfoAccessSchema(BaseExtensionSchema): @post_dump def handle_keys(self, data): - return {'includeAIA': data['include_aia']} + return {"includeAIA": data["include_aia"]} class CRLDistributionPointsSchema(BaseExtensionSchema): @@ -235,7 +246,7 @@ class CRLDistributionPointsSchema(BaseExtensionSchema): @post_dump def handle_keys(self, data): - return {'includeCRLDP': data['include_crl_dp']} + return {"includeCRLDP": data["include_crl_dp"]} class SubjectKeyIdentifierSchema(BaseExtensionSchema): @@ -243,7 +254,7 @@ class SubjectKeyIdentifierSchema(BaseExtensionSchema): @post_dump def handle_keys(self, data): - return {'includeSKI': data['include_ski']} + return {"includeSKI": data["include_ski"]} class CustomOIDSchema(BaseExtensionSchema): @@ -258,14 +269,18 @@ class NamesSchema(BaseExtensionSchema): class ExtensionSchema(BaseExtensionSchema): - basic_constraints = BasicConstraintsExtension() # some devices balk on default basic constraints + basic_constraints = ( + BasicConstraintsExtension() + ) # some devices balk on default basic constraints key_usage = KeyUsageExtension() extended_key_usage = ExtendedKeyUsageExtension() subject_key_identifier = fields.Nested(SubjectKeyIdentifierSchema) sub_alt_names = fields.Nested(NamesSchema) authority_key_identifier = fields.Nested(AuthorityKeyIdentifierSchema) certificate_info_access = fields.Nested(CertificateInfoAccessSchema) - crl_distribution_points = fields.Nested(CRLDistributionPointsSchema, dump_to='cRL_distribution_points') + crl_distribution_points = fields.Nested( + CRLDistributionPointsSchema, dump_to="cRL_distribution_points" + ) # FIXME: Convert custom OIDs to a custom field in fields.py like other Extensions # FIXME: Remove support in UI for Critical custom extensions https://github.com/Netflix/lemur/issues/665 custom = fields.List(fields.Nested(CustomOIDSchema)) diff --git a/lemur/sources/cli.py b/lemur/sources/cli.py index 0ab8c9f8..c41a1cf7 100644 --- a/lemur/sources/cli.py +++ b/lemur/sources/cli.py @@ -35,24 +35,32 @@ def validate_sources(source_strings): table.append([source.label, source.active, source.description]) print("No source specified choose from below:") - print(tabulate(table, headers=['Label', 'Active', 'Description'])) + print(tabulate(table, headers=["Label", "Active", "Description"])) sys.exit(1) - if 'all' in source_strings: + if "all" in source_strings: sources = source_service.get_all() else: for source_str in source_strings: source = source_service.get_by_label(source_str) if not source: - print("Unable to find specified source with label: {0}".format(source_str)) + print( + "Unable to find specified source with label: {0}".format(source_str) + ) sys.exit(1) sources.append(source) return sources -@manager.option('-s', '--sources', dest='source_strings', action='append', help='Sources to operate on.') +@manager.option( + "-s", + "--sources", + dest="source_strings", + action="append", + help="Sources to operate on.", +) def sync(source_strings): sources = validate_sources(source_strings) for source in sources: @@ -61,26 +69,23 @@ def sync(source_strings): start_time = time.time() print("[+] Staring to sync source: {label}!\n".format(label=source.label)) - user = user_service.get_by_username('lemur') + user = user_service.get_by_username("lemur") try: data = source_service.sync(source, user) print( "[+] Certificates: New: {new} Updated: {updated}".format( - new=data['certificates'][0], - updated=data['certificates'][1] + new=data["certificates"][0], updated=data["certificates"][1] ) ) print( "[+] Endpoints: New: {new} Updated: {updated}".format( - new=data['endpoints'][0], - updated=data['endpoints'][1] + new=data["endpoints"][0], updated=data["endpoints"][1] ) ) print( "[+] Finished syncing source: {label}. Run Time: {time}".format( - label=source.label, - time=(time.time() - start_time) + label=source.label, time=(time.time() - start_time) ) ) status = SUCCESS_METRIC_STATUS @@ -88,27 +93,50 @@ def sync(source_strings): except Exception as e: current_app.logger.exception(e) - print( - "[X] Failed syncing source {label}!\n".format(label=source.label) - ) + print("[X] Failed syncing source {label}!\n".format(label=source.label)) sentry.captureException() - metrics.send('source_sync_fail', 'counter', 1, metric_tags={'source': source.label, 'status': status}) + metrics.send( + "source_sync_fail", + "counter", + 1, + metric_tags={"source": source.label, "status": status}, + ) - metrics.send('source_sync', 'counter', 1, metric_tags={'source': source.label, 'status': status}) + metrics.send( + "source_sync", + "counter", + 1, + metric_tags={"source": source.label, "status": status}, + ) -@manager.option('-s', '--sources', dest='source_strings', action='append', help='Sources to operate on.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-s", + "--sources", + dest="source_strings", + action="append", + help="Sources to operate on.", +) +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def clean(source_strings, commit): sources = validate_sources(source_strings) for source in sources: s = plugins.get(source.plugin_name) - if not hasattr(s, 'clean'): - print("Cannot clean source: {0}, source plugin does not implement 'clean()'".format( - source.label - )) + if not hasattr(s, "clean"): + print( + "Cannot clean source: {0}, source plugin does not implement 'clean()'".format( + source.label + ) + ) continue start_time = time.time() @@ -128,19 +156,23 @@ def clean(source_strings, commit): current_app.logger.exception(e) sentry.captureException() - metrics.send('clean', 'counter', 1, metric_tags={'source': source.label, 'status': status}) + metrics.send( + "clean", + "counter", + 1, + metric_tags={"source": source.label, "status": status}, + ) - current_app.logger.warning("Removed {0} from source {1} during cleaning".format( - certificate.name, - source.label - )) + current_app.logger.warning( + "Removed {0} from source {1} during cleaning".format( + certificate.name, source.label + ) + ) cleaned += 1 print( "[+] Finished cleaning source: {label}. Removed {cleaned} certificates from source. Run Time: {time}\n".format( - label=source.label, - time=(time.time() - start_time), - cleaned=cleaned + label=source.label, time=(time.time() - start_time), cleaned=cleaned ) ) diff --git a/lemur/sources/models.py b/lemur/sources/models.py index 071688d1..78dbb213 100644 --- a/lemur/sources/models.py +++ b/lemur/sources/models.py @@ -15,7 +15,7 @@ from sqlalchemy_utils import ArrowType class Source(db.Model): - __tablename__ = 'sources' + __tablename__ = "sources" id = Column(Integer, primary_key=True) label = Column(String(32), unique=True) options = Column(JSONType) diff --git a/lemur/sources/schemas.py b/lemur/sources/schemas.py index 028fdb32..5531293f 100644 --- a/lemur/sources/schemas.py +++ b/lemur/sources/schemas.py @@ -30,7 +30,7 @@ class SourceOutputSchema(LemurOutputSchema): @post_dump def fill_object(self, data): if data: - data['plugin']['pluginOptions'] = data['options'] + data["plugin"]["pluginOptions"] = data["options"] return data diff --git a/lemur/sources/service.py b/lemur/sources/service.py index 227f1bce..f4783313 100644 --- a/lemur/sources/service.py +++ b/lemur/sources/service.py @@ -6,6 +6,7 @@ .. moduleauthor:: Kevin Glisson """ import arrow +import copy from flask import current_app @@ -14,22 +15,26 @@ from lemur.sources.models import Source from lemur.certificates.models import Certificate from lemur.certificates import service as certificate_service from lemur.endpoints import service as endpoint_service +from lemur.extensions import metrics, sentry from lemur.destinations import service as destination_service from lemur.certificates.schemas import CertificateUploadInputSchema -from lemur.common.utils import parse_certificate +from lemur.common.utils import find_matching_certificates_by_hash, parse_certificate from lemur.common.defaults import serial from lemur.plugins.base import plugins +from lemur.plugins.utils import get_plugin_option, set_plugin_option def certificate_create(certificate, source): data, errors = CertificateUploadInputSchema().load(certificate) if errors: - raise Exception("Unable to import certificate: {reasons}".format(reasons=errors)) + raise Exception( + "Unable to import certificate: {reasons}".format(reasons=errors) + ) - data['creator'] = certificate['creator'] + data["creator"] = certificate["creator"] cert = certificate_service.import_certificate(**data) cert.description = "This certificate was automatically discovered by Lemur" @@ -61,40 +66,88 @@ def sync_update_destination(certificate, source): def sync_endpoints(source): - new, updated = 0, 0 + new, updated, updated_by_hash = 0, 0, 0 current_app.logger.debug("Retrieving endpoints from {0}".format(source.label)) s = plugins.get(source.plugin_name) try: endpoints = s.get_endpoints(source.options) except NotImplementedError: - current_app.logger.warning("Unable to sync endpoints for source {0} plugin has not implemented 'get_endpoints'".format(source.label)) - return new, updated + current_app.logger.warning( + "Unable to sync endpoints for source {0} plugin has not implemented 'get_endpoints'".format( + source.label + ) + ) + return new, updated, updated_by_hash for endpoint in endpoints: - exists = endpoint_service.get_by_dnsname_and_port(endpoint['dnsname'], endpoint['port']) + exists = endpoint_service.get_by_dnsname_and_port( + endpoint["dnsname"], endpoint["port"] + ) - certificate_name = endpoint.pop('certificate_name') + certificate_name = endpoint.pop("certificate_name") - endpoint['certificate'] = certificate_service.get_by_name(certificate_name) + endpoint["certificate"] = certificate_service.get_by_name(certificate_name) - if not endpoint['certificate']: + # if get cert by name failed, we attempt a search via serial number and hash comparison + # and link the endpoint certificate to Lemur certificate + if not endpoint["certificate"]: + certificate_attached_to_endpoint = None + try: + certificate_attached_to_endpoint = s.get_certificate_by_name(certificate_name, source.options) + except NotImplementedError: + current_app.logger.warning( + "Unable to describe server certificate for endpoints in source {0}:" + " plugin has not implemented 'get_certificate_by_name'".format( + source.label + ) + ) + sentry.captureException() + + if certificate_attached_to_endpoint: + lemur_matching_cert, updated_by_hash_tmp = find_cert(certificate_attached_to_endpoint) + updated_by_hash += updated_by_hash_tmp + + if lemur_matching_cert: + endpoint["certificate"] = lemur_matching_cert[0] + + if len(lemur_matching_cert) > 1: + current_app.logger.error( + "Too Many Certificates Found{0}. Name: {1} Endpoint: {2}".format( + len(lemur_matching_cert), certificate_name, endpoint["name"] + ) + ) + metrics.send("endpoint.certificate.conflict", + "gauge", len(lemur_matching_cert), + metric_tags={"cert": certificate_name, "endpoint": endpoint["name"], + "acct": s.get_option("accountNumber", source.options)}) + + if not endpoint["certificate"]: current_app.logger.error( - "Certificate Not Found. Name: {0} Endpoint: {1}".format(certificate_name, endpoint['name'])) + "Certificate Not Found. Name: {0} Endpoint: {1}".format( + certificate_name, endpoint["name"] + ) + ) + metrics.send("endpoint.certificate.not.found", + "counter", 1, + metric_tags={"cert": certificate_name, "endpoint": endpoint["name"], + "acct": s.get_option("accountNumber", source.options)}) continue - policy = endpoint.pop('policy') + policy = endpoint.pop("policy") policy_ciphers = [] - for nc in policy['ciphers']: + for nc in policy["ciphers"]: policy_ciphers.append(endpoint_service.get_or_create_cipher(name=nc)) - policy['ciphers'] = policy_ciphers - endpoint['policy'] = endpoint_service.get_or_create_policy(**policy) - endpoint['source'] = source + policy["ciphers"] = policy_ciphers + endpoint["policy"] = endpoint_service.get_or_create_policy(**policy) + endpoint["source"] = source if not exists: - current_app.logger.debug("Endpoint Created: Name: {name}".format(name=endpoint['name'])) + current_app.logger.debug( + "Endpoint Created: Name: {name}".format(name=endpoint["name"]) + ) endpoint_service.create(**endpoint) new += 1 @@ -103,36 +156,50 @@ def sync_endpoints(source): endpoint_service.update(exists.id, **endpoint) updated += 1 - return new, updated + return new, updated, updated_by_hash + + +def find_cert(certificate): + updated_by_hash = 0 + exists = False + + if certificate.get("search", None): + conditions = certificate.pop("search") + exists = certificate_service.get_by_attributes(conditions) + + if not exists and certificate.get("name"): + result = certificate_service.get_by_name(certificate["name"]) + if result: + exists = [result] + + if not exists and certificate.get("serial"): + exists = certificate_service.get_by_serial(certificate["serial"]) + + if not exists: + cert = parse_certificate(certificate["body"]) + matching_serials = certificate_service.get_by_serial(serial(cert)) + exists = find_matching_certificates_by_hash(cert, matching_serials) + updated_by_hash += 1 + + exists = [x for x in exists if x] + return exists, updated_by_hash # TODO this is very slow as we don't batch update certificates def sync_certificates(source, user): - new, updated = 0, 0 + new, updated, updated_by_hash = 0, 0, 0 current_app.logger.debug("Retrieving certificates from {0}".format(source.label)) s = plugins.get(source.plugin_name) certificates = s.get_certificates(source.options) for certificate in certificates: - exists = False - if certificate.get('name'): - result = certificate_service.get_by_name(certificate['name']) - if result: - exists = [result] + exists, updated_by_hash = find_cert(certificate) - if not exists and certificate.get('serial'): - exists = certificate_service.get_by_serial(certificate['serial']) + if not certificate.get("owner"): + certificate["owner"] = user.email - if not exists: - cert = parse_certificate(certificate['body']) - exists = certificate_service.get_by_serial(serial(cert)) - - if not certificate.get('owner'): - certificate['owner'] = user.email - - certificate['creator'] = user - exists = [x for x in exists if x] + certificate["creator"] = user if not exists: certificate_create(certificate, source) @@ -140,24 +207,35 @@ def sync_certificates(source, user): else: for e in exists: - if certificate.get('external_id'): - e.external_id = certificate['external_id'] - if certificate.get('authority_id'): - e.authority_id = certificate['authority_id'] + if certificate.get("external_id"): + e.external_id = certificate["external_id"] + if certificate.get("authority_id"): + e.authority_id = certificate["authority_id"] certificate_update(e, source) updated += 1 - return new, updated + return new, updated, updated_by_hash def sync(source, user): - new_certs, updated_certs = sync_certificates(source, user) - new_endpoints, updated_endpoints = sync_endpoints(source) + new_certs, updated_certs, updated_certs_by_hash = sync_certificates(source, user) + new_endpoints, updated_endpoints, updated_endpoints_by_hash = sync_endpoints(source) + + metrics.send("sync.updated_certs_by_hash", + "gauge", updated_certs_by_hash, + metric_tags={"source": source.label}) + + metrics.send("sync.updated_endpoints_by_hash", + "gauge", updated_endpoints_by_hash, + metric_tags={"source": source.label}) source.last_run = arrow.utcnow() database.update(source) - return {'endpoints': (new_endpoints, updated_endpoints), 'certificates': (new_certs, updated_certs)} + return { + "endpoints": (new_endpoints, updated_endpoints), + "certificates": (new_certs, updated_certs), + } def create(label, plugin_name, options, description=None): @@ -171,7 +249,9 @@ def create(label, plugin_name, options, description=None): :rtype : Source :return: New source """ - source = Source(label=label, options=options, plugin_name=plugin_name, description=description) + source = Source( + label=label, options=options, plugin_name=plugin_name, description=description + ) return database.create(source) @@ -222,7 +302,7 @@ def get_by_label(label): :param label: :return: """ - return database.get(Source, label, field='label') + return database.get(Source, label, field="label") def get_all(): @@ -236,8 +316,8 @@ def get_all(): def render(args): - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if certificate_id: query = database.session_query(Source).join(Certificate, Source.certificate) @@ -246,7 +326,45 @@ def render(args): query = database.session_query(Source) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Source, terms) return database.sort_and_page(query, Source, args) + + +def add_aws_destination_to_sources(dst): + """ + Given a destination check, if it can be added as sources, and included it if not already a source + We identify qualified destinations based on the sync_as_source attributed of the plugin. + The destination sync_as_source_name reveals the name of the suitable source-plugin. + We rely on account numbers to avoid duplicates. + :return: true for success and false for not adding the destination as source + """ + # a set of all accounts numbers available as sources + src_accounts = set() + sources = get_all() + for src in sources: + src_accounts.add(get_plugin_option("accountNumber", src.options)) + + # check + destination_plugin = plugins.get(dst.plugin_name) + account_number = get_plugin_option("accountNumber", dst.options) + if ( + account_number is not None + and destination_plugin.sync_as_source is not None + and destination_plugin.sync_as_source + and (account_number not in src_accounts) + ): + src_options = copy.deepcopy( + plugins.get(destination_plugin.sync_as_source_name).options + ) + set_plugin_option("accountNumber", account_number, src_options) + create( + label=dst.label, + plugin_name=destination_plugin.sync_as_source_name, + options=src_options, + description=dst.description, + ) + return True + + return False diff --git a/lemur/sources/views.py b/lemur/sources/views.py index abf68109..b74c4d80 100644 --- a/lemur/sources/views.py +++ b/lemur/sources/views.py @@ -11,19 +11,24 @@ from flask_restful import Api, reqparse from lemur.sources import service from lemur.common.schema import validate_schema -from lemur.sources.schemas import source_input_schema, source_output_schema, sources_output_schema +from lemur.sources.schemas import ( + source_input_schema, + source_output_schema, + sources_output_schema, +) from lemur.auth.service import AuthenticatedResource from lemur.auth.permissions import admin_permission from lemur.common.utils import paginated_parser -mod = Blueprint('sources', __name__) +mod = Blueprint("sources", __name__) api = Api(mod) class SourcesList(AuthenticatedResource): """ Defines the 'sources' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(SourcesList, self).__init__() @@ -151,7 +156,12 @@ class SourcesList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.create(data['label'], data['plugin']['slug'], data['plugin']['plugin_options'], data['description']) + return service.create( + data["label"], + data["plugin"]["slug"], + data["plugin"]["plugin_options"], + data["description"], + ) class Sources(AuthenticatedResource): @@ -271,16 +281,22 @@ class Sources(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.update(source_id, data['label'], data['plugin']['plugin_options'], data['description']) + return service.update( + source_id, + data["label"], + data["plugin"]["plugin_options"], + data["description"], + ) @admin_permission.require(http_exception=403) def delete(self, source_id): service.delete(source_id) - return {'result': True} + return {"result": True} class CertificateSources(AuthenticatedResource): """ Defines the 'certificate/', endpoint='account') -api.add_resource(CertificateSources, '/certificates//sources', - endpoint='certificateSources') +api.add_resource(SourcesList, "/sources", endpoint="sources") +api.add_resource(Sources, "/sources/", endpoint="account") +api.add_resource( + CertificateSources, + "/certificates//sources", + endpoint="certificateSources", +) diff --git a/lemur/static/app/angular/certificates/certificate/certificate.js b/lemur/static/app/angular/certificates/certificate/certificate.js index 273fc9d5..21f61f22 100644 --- a/lemur/static/app/angular/certificates/certificate/certificate.js +++ b/lemur/static/app/angular/certificates/certificate/certificate.js @@ -371,4 +371,12 @@ angular.module('lemur') }); }); }; -}); +}) +.controller('CertificateInfoController', function ($scope, CertificateApi) { + $scope.fetchFullCertificate = function (certId) { + CertificateApi.get(certId).then(function (certificate) { + $scope.certificate = certificate; + }); + }; +}) +; diff --git a/lemur/static/app/angular/certificates/certificate/tracking.tpl.html b/lemur/static/app/angular/certificates/certificate/tracking.tpl.html index b64f6e3d..7ac2107f 100644 --- a/lemur/static/app/angular/certificates/certificate/tracking.tpl.html +++ b/lemur/static/app/angular/certificates/certificate/tracking.tpl.html @@ -30,9 +30,11 @@

@@ -131,7 +133,7 @@

@@ -139,8 +141,6 @@ - -
diff --git a/lemur/static/app/angular/certificates/certificate/upload.tpl.html b/lemur/static/app/angular/certificates/certificate/upload.tpl.html index c3339051..bf897a60 100644 --- a/lemur/static/app/angular/certificates/certificate/upload.tpl.html +++ b/lemur/static/app/angular/certificates/certificate/upload.tpl.html @@ -62,6 +62,19 @@ a valid certificate.

+
+ +
+ +

Enter a valid certificate signing request.

+
+
+
+ +
@@ -47,7 +52,7 @@
- Permalink + Permalink @@ -66,7 +71,7 @@
- + @@ -83,6 +88,8 @@
+
Distinguished Name
+
{{ certificate.distinguishedName }}
Certificate Authority
{{ certificate.authority ? certificate.authority.name : "Imported" }} ({{ certificate.issuer }})
Serial
@@ -196,10 +203,10 @@
{{ certificate.body }}
- + Private Key - +
{{ certificate.privateKey }}
diff --git a/lemur/static/app/angular/pager.html b/lemur/static/app/angular/pager.html index 3dc8a7d0..d9ee5204 100644 --- a/lemur/static/app/angular/pager.html +++ b/lemur/static/app/angular/pager.html @@ -4,6 +4,9 @@ +
+  Found +
diff --git a/lemur/static/app/angular/pending_certificates/pending_certificate/upload.js b/lemur/static/app/angular/pending_certificates/pending_certificate/upload.js new file mode 100644 index 00000000..10e92e0f --- /dev/null +++ b/lemur/static/app/angular/pending_certificates/pending_certificate/upload.js @@ -0,0 +1,34 @@ +'use strict'; + +angular.module('lemur') + .controller('PendingCertificateUploadController', function ($scope, $uibModalInstance, PendingCertificateApi, PendingCertificateService, toaster, uploadId) { + PendingCertificateApi.get(uploadId).then(function (pendingCertificate) { + $scope.pendingCertificate = pendingCertificate; + }); + + $scope.upload = PendingCertificateService.upload; + $scope.save = function (pendingCertificate) { + PendingCertificateService.upload(pendingCertificate).then( + function () { + toaster.pop({ + type: 'success', + title: pendingCertificate.name, + body: 'Successfully uploaded!' + }); + $uibModalInstance.close(); + }, + function (response) { + toaster.pop({ + type: 'error', + title: pendingCertificate.name, + body: 'Failed to upload ' + response.data.message, + timeout: 100000 + }); + }); + }; + + $scope.cancel = function () { + $uibModalInstance.dismiss('cancel'); + }; + + }); diff --git a/lemur/static/app/angular/pending_certificates/pending_certificate/upload.tpl.html b/lemur/static/app/angular/pending_certificates/pending_certificate/upload.tpl.html new file mode 100644 index 00000000..ba3c6a4c --- /dev/null +++ b/lemur/static/app/angular/pending_certificates/pending_certificate/upload.tpl.html @@ -0,0 +1,41 @@ + + + + +
diff --git a/lemur/static/app/angular/pending_certificates/services.js b/lemur/static/app/angular/pending_certificates/services.js index 32b335ac..4e1b23e4 100644 --- a/lemur/static/app/angular/pending_certificates/services.js +++ b/lemur/static/app/angular/pending_certificates/services.js @@ -245,5 +245,9 @@ angular.module('lemur') return pending_certificate.customOperation('remove', null, {}, {'Content-Type': 'application/json'}, options); }; + PendingCertificateService.upload = function (pending_certificate) { + return pending_certificate.customPOST({'body': pending_certificate.body, 'chain': pending_certificate.chain}, 'upload'); + }; + return PendingCertificateService; }); diff --git a/lemur/static/app/angular/pending_certificates/view/view.js b/lemur/static/app/angular/pending_certificates/view/view.js index 9ada8845..c46d6c74 100644 --- a/lemur/static/app/angular/pending_certificates/view/view.js +++ b/lemur/static/app/angular/pending_certificates/view/view.js @@ -99,4 +99,23 @@ angular.module('lemur') $scope.pendingCertificateTable.reload(); }); }; + + $scope.upload = function (pendingCertificateId) { + var uibModalInstance = $uibModal.open({ + animation: true, + controller: 'PendingCertificateUploadController', + templateUrl: '/angular/pending_certificates/pending_certificate/upload.tpl.html', + size: 'lg', + backdrop: 'static', + resolve: { + uploadId: function () { + return pendingCertificateId; + } + } + }); + uibModalInstance.result.then(function () { + $scope.pendingCertificateTable.reload(); + }); + }; + }); diff --git a/lemur/static/app/angular/pending_certificates/view/view.tpl.html b/lemur/static/app/angular/pending_certificates/view/view.tpl.html index 1f028793..d9c1b461 100644 --- a/lemur/static/app/angular/pending_certificates/view/view.tpl.html +++ b/lemur/static/app/angular/pending_certificates/view/view.tpl.html @@ -51,6 +51,7 @@ diff --git a/lemur/tests/conf.py b/lemur/tests/conf.py index bbe155cd..af0c09ce 100644 --- a/lemur/tests/conf.py +++ b/lemur/tests/conf.py @@ -15,49 +15,51 @@ debug = False TESTING = True # this is the secret key used by flask session management -SECRET_KEY = 'I/dVhOZNSMZMqrFJa5tWli6VQccOGudKerq3eWPMSzQNmHHVhMAQfQ==' +SECRET_KEY = "I/dVhOZNSMZMqrFJa5tWli6VQccOGudKerq3eWPMSzQNmHHVhMAQfQ==" # You should consider storing these separately from your config -LEMUR_TOKEN_SECRET = 'test' -LEMUR_ENCRYPTION_KEYS = 'o61sBLNBSGtAckngtNrfVNd8xy8Hp9LBGDstTbMbqCY=' +LEMUR_TOKEN_SECRET = "test" +LEMUR_ENCRYPTION_KEYS = "o61sBLNBSGtAckngtNrfVNd8xy8Hp9LBGDstTbMbqCY=" # List of domain regular expressions that non-admin users can issue LEMUR_WHITELISTED_DOMAINS = [ - '^[a-zA-Z0-9-]+\.example\.com$', - '^[a-zA-Z0-9-]+\.example\.org$', - '^example\d+\.long\.com$', + "^[a-zA-Z0-9-]+\.example\.com$", + "^[a-zA-Z0-9-]+\.example\.org$", + "^example\d+\.long\.com$", ] # Mail Server # Lemur currently only supports SES for sending email, this address # needs to be verified -LEMUR_EMAIL = '' -LEMUR_SECURITY_TEAM_EMAIL = ['security@example.com'] +LEMUR_EMAIL = "" +LEMUR_SECURITY_TEAM_EMAIL = ["security@example.com"] -LEMUR_HOSTNAME = 'lemur.example.com' +LEMUR_HOSTNAME = "lemur.example.com" # Logging LOG_LEVEL = "DEBUG" LOG_FILE = "lemur.log" -LEMUR_DEFAULT_COUNTRY = 'US' -LEMUR_DEFAULT_STATE = 'California' -LEMUR_DEFAULT_LOCATION = 'Los Gatos' -LEMUR_DEFAULT_ORGANIZATION = 'Example, Inc.' -LEMUR_DEFAULT_ORGANIZATIONAL_UNIT = 'Example' +LEMUR_DEFAULT_COUNTRY = "US" +LEMUR_DEFAULT_STATE = "California" +LEMUR_DEFAULT_LOCATION = "Los Gatos" +LEMUR_DEFAULT_ORGANIZATION = "Example, Inc." +LEMUR_DEFAULT_ORGANIZATIONAL_UNIT = "Example" LEMUR_ALLOW_WEEKEND_EXPIRATION = False # Database # modify this if you are not using a local database -SQLALCHEMY_DATABASE_URI = os.getenv('SQLALCHEMY_DATABASE_URI', 'postgresql://lemur:lemur@localhost:5432/lemur') +SQLALCHEMY_DATABASE_URI = os.getenv( + "SQLALCHEMY_DATABASE_URI", "postgresql://lemur:lemur@localhost:5432/lemur" +) SQLALCHEMY_TRACK_MODIFICATIONS = False # AWS -LEMUR_INSTANCE_PROFILE = 'Lemur' +LEMUR_INSTANCE_PROFILE = "Lemur" # Issuers @@ -72,21 +74,28 @@ LEMUR_INSTANCE_PROFILE = 'Lemur' # CLOUDCA_DEFAULT_VALIDITY = 2 -DIGICERT_URL = 'mock://www.digicert.com' -DIGICERT_ORDER_TYPE = 'ssl_plus' -DIGICERT_API_KEY = 'api-key' +DIGICERT_URL = "mock://www.digicert.com" +DIGICERT_ORDER_TYPE = "ssl_plus" +DIGICERT_API_KEY = "api-key" DIGICERT_ORG_ID = 111111 DIGICERT_ROOT = "ROOT" -VERISIGN_URL = 'http://example.com' -VERISIGN_PEM_PATH = '~/' -VERISIGN_FIRST_NAME = 'Jim' -VERISIGN_LAST_NAME = 'Bob' -VERSIGN_EMAIL = 'jim@example.com' +DIGICERT_CIS_URL = "mock://www.digicert.com" +DIGICERT_CIS_PROFILE_NAMES = {"sha2-rsa-ecc-root": "ssl_plus"} +DIGICERT_CIS_API_KEY = "api-key" +DIGICERT_CIS_ROOTS = {"root": "ROOT"} +DIGICERT_CIS_INTERMEDIATES = {"inter": "INTERMEDIATE_CA_CERT"} -ACME_AWS_ACCOUNT_NUMBER = '11111111111' -ACME_PRIVATE_KEY = ''' +VERISIGN_URL = "http://example.com" +VERISIGN_PEM_PATH = "~/" +VERISIGN_FIRST_NAME = "Jim" +VERISIGN_LAST_NAME = "Bob" +VERSIGN_EMAIL = "jim@example.com" + +ACME_AWS_ACCOUNT_NUMBER = "11111111111" + +ACME_PRIVATE_KEY = """ -----BEGIN RSA PRIVATE KEY----- MIIJJwIBAAKCAgEA0+jySNCc1i73LwDZEuIdSkZgRYQ4ZQVIioVf38RUhDElxy51 4gdWZwp8/TDpQ8cVXMj6QhdRpTVLluOz71hdvBAjxXTISRCRlItzizTgBD9CLXRh @@ -138,7 +147,7 @@ cRe4df5/EbRiUOyx/ZBepttB1meTnsH6cGPN0JnmTMQHQvanL3jjtjrC13408ONK omsEEjDt4qVqGvSyy+V/1EhqGPzm9ri3zapnorf69rscuXYYsMBZ8M6AtSio4ldB LjCRNS1lR6/mV8AqUNR9Kn2NLQyJ76yDoEVLulKZqGUsC9STN4oGJLUeFw== -----END RSA PRIVATE KEY----- -''' +""" ACME_ROOT = """ -----BEGIN CERTIFICATE----- @@ -174,15 +183,17 @@ PB0t6JzUA81mSqM3kxl5e+IZwhYAyO0OTg3/fs8HqGTNKd9BqoUwSRBzp06JMg5b rUCGwbCUDI0mxadJ3Bz4WxR6fyNpBK2yAinWEsikxqEt -----END CERTIFICATE----- """ -ACME_URL = 'https://acme-v01.api.letsencrypt.org' -ACME_EMAIL = 'jim@example.com' -ACME_TEL = '4088675309' -ACME_DIRECTORY_URL = 'https://acme-v01.api.letsencrypt.org' +ACME_URL = "https://acme-v01.api.letsencrypt.org" +ACME_EMAIL = "jim@example.com" +ACME_TEL = "4088675309" +ACME_DIRECTORY_URL = "https://acme-v01.api.letsencrypt.org" ACME_DISABLE_AUTORESOLVE = True LDAP_AUTH = True -LDAP_BIND_URI = 'ldap://localhost' -LDAP_BASE_DN = 'dc=example,dc=com' -LDAP_EMAIL_DOMAIN = 'example.com' -LDAP_REQUIRED_GROUP = 'Lemur Access' -LDAP_DEFAULT_ROLE = 'role1' +LDAP_BIND_URI = "ldap://localhost" +LDAP_BASE_DN = "dc=example,dc=com" +LDAP_EMAIL_DOMAIN = "example.com" +LDAP_REQUIRED_GROUP = "Lemur Access" +LDAP_DEFAULT_ROLE = "role1" + +ALLOW_CERT_DELETION = True diff --git a/lemur/tests/conftest.py b/lemur/tests/conftest.py index d0175c83..2efd65d9 100644 --- a/lemur/tests/conftest.py +++ b/lemur/tests/conftest.py @@ -4,22 +4,43 @@ import datetime import pytest from cryptography import x509 from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.serialization import load_pem_private_key +from cryptography.hazmat.primitives import hashes from flask import current_app from flask_principal import identity_changed, Identity +from sqlalchemy.sql import text from lemur import create_app +from lemur.common.utils import parse_private_key from lemur.database import db as _db from lemur.auth.service import create_token -from lemur.tests.vectors import SAN_CERT_KEY +from lemur.tests.vectors import ( + SAN_CERT_KEY, + INTERMEDIATE_KEY, + ROOTCA_CERT_STR, + ROOTCA_KEY, +) -from .factories import ApiKeyFactory, AuthorityFactory, NotificationFactory, DestinationFactory, \ - CertificateFactory, UserFactory, RoleFactory, SourceFactory, EndpointFactory, \ - RotationPolicyFactory, PendingCertificateFactory, AsyncAuthorityFactory +from .factories import ( + ApiKeyFactory, + AuthorityFactory, + NotificationFactory, + DestinationFactory, + CertificateFactory, + UserFactory, + RoleFactory, + SourceFactory, + EndpointFactory, + RotationPolicyFactory, + PendingCertificateFactory, + AsyncAuthorityFactory, + InvalidCertificateFactory, + CryptoAuthorityFactory, + CACertificateFactory, +) def pytest_runtest_setup(item): - if 'slow' in item.keywords and not item.config.getoption("--runslow"): + if "slow" in item.keywords and not item.config.getoption("--runslow"): pytest.skip("need --runslow option to run") if "incremental" in item.keywords: @@ -41,7 +62,9 @@ def app(request): Creates a new Flask application for a test duration. Uses application factory `create_app`. """ - _app = create_app(os.path.dirname(os.path.realpath(__file__)) + '/conf.py') + _app = create_app( + config_path=os.path.dirname(os.path.realpath(__file__)) + "/conf.py" + ) ctx = _app.app_context() ctx.push() @@ -53,14 +76,15 @@ def app(request): @pytest.yield_fixture(scope="session") def db(app, request): _db.drop_all() + _db.engine.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) _db.create_all() _db.app = app UserFactory() - r = RoleFactory(name='admin') + r = RoleFactory(name="admin") u = UserFactory(roles=[r]) - rp = RotationPolicyFactory(name='default') + rp = RotationPolicyFactory(name="default") ApiKeyFactory(user=u) _db.session.commit() @@ -91,6 +115,13 @@ def authority(session): return a +@pytest.fixture +def crypto_authority(session): + a = CryptoAuthorityFactory() + session.commit() + return a + + @pytest.fixture def async_authority(session): a = AsyncAuthorityFactory() @@ -148,8 +179,8 @@ def user(session): u = UserFactory() session.commit() user_token = create_token(u) - token = {'Authorization': 'Basic ' + user_token} - return {'user': u, 'token': token} + token = {"Authorization": "Basic " + user_token} + return {"user": u, "token": token} @pytest.fixture @@ -161,21 +192,50 @@ def pending_certificate(session): return p +@pytest.fixture +def pending_certificate_from_full_chain_ca(session): + u = UserFactory() + a = AuthorityFactory() + p = PendingCertificateFactory(user=u, authority=a) + session.commit() + return p + + +@pytest.fixture +def pending_certificate_from_partial_chain_ca(session): + u = UserFactory() + c = CACertificateFactory(body=ROOTCA_CERT_STR, private_key=ROOTCA_KEY, chain=None) + a = AuthorityFactory(authority_certificate=c) + p = PendingCertificateFactory(user=u, authority=a) + session.commit() + return p + + +@pytest.fixture +def invalid_certificate(session): + u = UserFactory() + a = AsyncAuthorityFactory() + i = InvalidCertificateFactory(user=u, authority=a) + session.commit() + return i + + @pytest.fixture def admin_user(session): u = UserFactory() - admin_role = RoleFactory(name='admin') + admin_role = RoleFactory(name="admin") u.roles.append(admin_role) session.commit() user_token = create_token(u) - token = {'Authorization': 'Basic ' + user_token} - return {'user': u, 'token': token} + token = {"Authorization": "Basic " + user_token} + return {"user": u, "token": token} @pytest.fixture def async_issuer_plugin(): from lemur.plugins.base import register from .plugins.issuer_plugin import TestAsyncIssuerPlugin + register(TestAsyncIssuerPlugin) return TestAsyncIssuerPlugin @@ -184,6 +244,7 @@ def async_issuer_plugin(): def issuer_plugin(): from lemur.plugins.base import register from .plugins.issuer_plugin import TestIssuerPlugin + register(TestIssuerPlugin) return TestIssuerPlugin @@ -192,6 +253,7 @@ def issuer_plugin(): def notification_plugin(): from lemur.plugins.base import register from .plugins.notification_plugin import TestNotificationPlugin + register(TestNotificationPlugin) return TestNotificationPlugin @@ -200,6 +262,7 @@ def notification_plugin(): def destination_plugin(): from lemur.plugins.base import register from .plugins.destination_plugin import TestDestinationPlugin + register(TestDestinationPlugin) return TestDestinationPlugin @@ -208,6 +271,7 @@ def destination_plugin(): def source_plugin(): from lemur.plugins.base import register from .plugins.source_plugin import TestSourcePlugin + register(TestSourcePlugin) return TestSourcePlugin @@ -228,23 +292,40 @@ def logged_in_admin(session, app): @pytest.fixture def private_key(): - return load_pem_private_key(SAN_CERT_KEY.encode(), password=None, backend=default_backend()) + return parse_private_key(SAN_CERT_KEY) + + +@pytest.fixture +def issuer_private_key(): + return parse_private_key(INTERMEDIATE_KEY) @pytest.fixture def cert_builder(private_key): - return (x509.CertificateBuilder() - .subject_name(x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, 'foo.com')])) - .issuer_name(x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, 'foo.com')])) - .serial_number(1) - .public_key(private_key.public_key()) - .not_valid_before(datetime.datetime(2017, 12, 22)) - .not_valid_after(datetime.datetime(2040, 1, 1))) + return ( + x509.CertificateBuilder() + .subject_name( + x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, "foo.com")]) + ) + .issuer_name( + x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, "foo.com")]) + ) + .serial_number(1) + .public_key(private_key.public_key()) + .not_valid_before(datetime.datetime(2017, 12, 22)) + .not_valid_after(datetime.datetime(2040, 1, 1)) + ) -@pytest.fixture(scope='function') +@pytest.fixture +def selfsigned_cert(cert_builder, private_key): + # cert_builder uses the same cert public key as 'private_key' + return cert_builder.sign(private_key, hashes.SHA256(), default_backend()) + + +@pytest.fixture(scope="function") def aws_credentials(): - os.environ['AWS_ACCESS_KEY_ID'] = 'testing' - os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' - os.environ['AWS_SECURITY_TOKEN'] = 'testing' - os.environ['AWS_SESSION_TOKEN'] = 'testing' + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" diff --git a/lemur/tests/factories.py b/lemur/tests/factories.py index cae2c354..fea4c59a 100644 --- a/lemur/tests/factories.py +++ b/lemur/tests/factories.py @@ -1,4 +1,3 @@ - from datetime import date from factory import Sequence, post_generation, SubFactory @@ -19,8 +18,16 @@ from lemur.endpoints.models import Policy, Endpoint from lemur.policies.models import RotationPolicy from lemur.api_keys.models import ApiKey -from .vectors import SAN_CERT_STR, SAN_CERT_KEY, CSR_STR, INTERMEDIATE_CERT_STR, ROOTCA_CERT_STR, INTERMEDIATE_KEY, \ - WILDCARD_CERT_KEY +from .vectors import ( + SAN_CERT_STR, + SAN_CERT_KEY, + CSR_STR, + INTERMEDIATE_CERT_STR, + ROOTCA_CERT_STR, + INTERMEDIATE_KEY, + WILDCARD_CERT_KEY, + INVALID_CERT_STR, +) class BaseFactory(SQLAlchemyModelFactory): @@ -28,28 +35,32 @@ class BaseFactory(SQLAlchemyModelFactory): class Meta: """Factory configuration.""" + abstract = True sqlalchemy_session = db.session class RotationPolicyFactory(BaseFactory): """Rotation Factory.""" - name = Sequence(lambda n: 'policy{0}'.format(n)) + + name = Sequence(lambda n: "policy{0}".format(n)) days = 30 class Meta: """Factory configuration.""" + model = RotationPolicy class CertificateFactory(BaseFactory): """Certificate factory.""" - name = Sequence(lambda n: 'certificate{0}'.format(n)) + + name = Sequence(lambda n: "certificate{0}".format(n)) chain = INTERMEDIATE_CERT_STR body = SAN_CERT_STR private_key = SAN_CERT_KEY - owner = 'joe@example.com' - status = FuzzyChoice(['valid', 'revoked', 'unknown']) + owner = "joe@example.com" + status = FuzzyChoice(["valid", "revoked", "unknown"]) deleted = False description = FuzzyText(length=128) active = True @@ -58,6 +69,7 @@ class CertificateFactory(BaseFactory): class Meta: """Factory Configuration.""" + model = Certificate @post_generation @@ -137,16 +149,24 @@ class CACertificateFactory(CertificateFactory): private_key = INTERMEDIATE_KEY +class InvalidCertificateFactory(CertificateFactory): + body = INVALID_CERT_STR + private_key = "" + chain = "" + + class AuthorityFactory(BaseFactory): """Authority factory.""" - name = Sequence(lambda n: 'authority{0}'.format(n)) - owner = 'joe@example.com' - plugin = {'slug': 'test-issuer'} + + name = Sequence(lambda n: "authority{0}".format(n)) + owner = "joe@example.com" + plugin = {"slug": "test-issuer"} description = FuzzyText(length=128) authority_certificate = SubFactory(CACertificateFactory) class Meta: """Factory configuration.""" + model = Authority @post_generation @@ -161,49 +181,64 @@ class AuthorityFactory(BaseFactory): class AsyncAuthorityFactory(AuthorityFactory): """Async Authority factory.""" - name = Sequence(lambda n: 'authority{0}'.format(n)) - owner = 'joe@example.com' - plugin = {'slug': 'test-issuer-async'} + + name = Sequence(lambda n: "authority{0}".format(n)) + owner = "joe@example.com" + plugin = {"slug": "test-issuer-async"} description = FuzzyText(length=128) authority_certificate = SubFactory(CertificateFactory) +class CryptoAuthorityFactory(AuthorityFactory): + """Authority factory based on 'cryptography' plugin.""" + + plugin = {"slug": "cryptography-issuer"} + + class DestinationFactory(BaseFactory): """Destination factory.""" - plugin_name = 'test-destination' - label = Sequence(lambda n: 'destination{0}'.format(n)) + + plugin_name = "test-destination" + label = Sequence(lambda n: "destination{0}".format(n)) class Meta: """Factory Configuration.""" + model = Destination class SourceFactory(BaseFactory): """Source factory.""" - plugin_name = 'test-source' - label = Sequence(lambda n: 'source{0}'.format(n)) + + plugin_name = "test-source" + label = Sequence(lambda n: "source{0}".format(n)) class Meta: """Factory Configuration.""" + model = Source class NotificationFactory(BaseFactory): """Notification factory.""" - plugin_name = 'test-notification' - label = Sequence(lambda n: 'notification{0}'.format(n)) + + plugin_name = "test-notification" + label = Sequence(lambda n: "notification{0}".format(n)) class Meta: """Factory Configuration.""" + model = Notification class RoleFactory(BaseFactory): """Role factory.""" - name = Sequence(lambda n: 'role{0}'.format(n)) + + name = Sequence(lambda n: "role{0}".format(n)) class Meta: """Factory Configuration.""" + model = Role @post_generation @@ -218,14 +253,16 @@ class RoleFactory(BaseFactory): class UserFactory(BaseFactory): """User Factory.""" - username = Sequence(lambda n: 'user{0}'.format(n)) - email = Sequence(lambda n: 'user{0}@example.com'.format(n)) + + username = Sequence(lambda n: "user{0}".format(n)) + email = Sequence(lambda n: "user{0}@example.com".format(n)) active = True password = FuzzyText(length=24) certificates = [] class Meta: """Factory Configuration.""" + model = User @post_generation @@ -258,39 +295,45 @@ class UserFactory(BaseFactory): class PolicyFactory(BaseFactory): """Policy Factory.""" - name = Sequence(lambda n: 'endpoint{0}'.format(n)) + + name = Sequence(lambda n: "endpoint{0}".format(n)) class Meta: """Factory Configuration.""" + model = Policy class EndpointFactory(BaseFactory): """Endpoint Factory.""" - owner = 'joe@example.com' - name = Sequence(lambda n: 'endpoint{0}'.format(n)) - type = FuzzyChoice(['elb']) + + owner = "joe@example.com" + name = Sequence(lambda n: "endpoint{0}".format(n)) + type = FuzzyChoice(["elb"]) active = True port = FuzzyInteger(0, high=65535) - dnsname = 'endpoint.example.com' + dnsname = "endpoint.example.com" policy = SubFactory(PolicyFactory) certificate = SubFactory(CertificateFactory) source = SubFactory(SourceFactory) class Meta: """Factory Configuration.""" + model = Endpoint class ApiKeyFactory(BaseFactory): """Api Key Factory.""" - name = Sequence(lambda n: 'api_key_{0}'.format(n)) + + name = Sequence(lambda n: "api_key_{0}".format(n)) revoked = False ttl = -1 issued_at = 1 class Meta: """Factory Configuration.""" + model = ApiKey @post_generation @@ -304,13 +347,14 @@ class ApiKeyFactory(BaseFactory): class PendingCertificateFactory(BaseFactory): """PendingCertificate factory.""" - name = Sequence(lambda n: 'pending_certificate{0}'.format(n)) + + name = Sequence(lambda n: "pending_certificate{0}".format(n)) external_id = 12345 csr = CSR_STR chain = INTERMEDIATE_CERT_STR private_key = WILDCARD_CERT_KEY - owner = 'joe@example.com' - status = FuzzyChoice(['valid', 'revoked', 'unknown']) + owner = "joe@example.com" + status = FuzzyChoice(["valid", "revoked", "unknown"]) deleted = False description = FuzzyText(length=128) date_created = FuzzyDate(date(2016, 1, 1), date(2020, 1, 1)) @@ -319,6 +363,7 @@ class PendingCertificateFactory(BaseFactory): class Meta: """Factory Configuration.""" + model = PendingCertificate @post_generation diff --git a/lemur/tests/plugins/destination_plugin.py b/lemur/tests/plugins/destination_plugin.py index f77085ec..d1eb6711 100644 --- a/lemur/tests/plugins/destination_plugin.py +++ b/lemur/tests/plugins/destination_plugin.py @@ -2,12 +2,12 @@ from lemur.plugins.bases import DestinationPlugin class TestDestinationPlugin(DestinationPlugin): - title = 'Test' - slug = 'test-destination' - description = 'Enables testing' + title = "Test" + slug = "test-destination" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestDestinationPlugin, self).__init__(*args, **kwargs) diff --git a/lemur/tests/plugins/issuer_plugin.py b/lemur/tests/plugins/issuer_plugin.py index 3fda83ae..5f5c732b 100644 --- a/lemur/tests/plugins/issuer_plugin.py +++ b/lemur/tests/plugins/issuer_plugin.py @@ -4,12 +4,12 @@ from lemur.tests.vectors import SAN_CERT_STR, INTERMEDIATE_CERT_STR class TestIssuerPlugin(IssuerPlugin): - title = 'Test' - slug = 'test-issuer' - description = 'Enables testing' + title = "Test" + slug = "test-issuer" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestIssuerPlugin, self).__init__(*args, **kwargs) @@ -20,17 +20,17 @@ class TestIssuerPlugin(IssuerPlugin): @staticmethod def create_authority(options): - role = {'username': '', 'password': '', 'name': 'test'} + role = {"username": "", "password": "", "name": "test"} return SAN_CERT_STR, "", [role] class TestAsyncIssuerPlugin(IssuerPlugin): - title = 'Test Async' - slug = 'test-issuer-async' - description = 'Enables testing with pending certificates' + title = "Test Async" + slug = "test-issuer-async" + description = "Enables testing with pending certificates" - author = 'James Chuong' - author_url = 'https://github.com/jchuong' + author = "James Chuong" + author_url = "https://github.com/jchuong" def __init__(self, *args, **kwargs): super(TestAsyncIssuerPlugin, self).__init__(*args, **kwargs) @@ -43,7 +43,7 @@ class TestAsyncIssuerPlugin(IssuerPlugin): @staticmethod def create_authority(options): - role = {'username': '', 'password': '', 'name': 'test'} + role = {"username": "", "password": "", "name": "test"} return SAN_CERT_STR, "", [role] def cancel_ordered_certificate(self, pending_certificate, **kwargs): diff --git a/lemur/tests/plugins/notification_plugin.py b/lemur/tests/plugins/notification_plugin.py index ad393d60..4ad79704 100644 --- a/lemur/tests/plugins/notification_plugin.py +++ b/lemur/tests/plugins/notification_plugin.py @@ -2,12 +2,12 @@ from lemur.plugins.bases import NotificationPlugin class TestNotificationPlugin(NotificationPlugin): - title = 'Test' - slug = 'test-notification' - description = 'Enables testing' + title = "Test" + slug = "test-notification" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestNotificationPlugin, self).__init__(*args, **kwargs) diff --git a/lemur/tests/plugins/source_plugin.py b/lemur/tests/plugins/source_plugin.py index 10402576..21ce245d 100644 --- a/lemur/tests/plugins/source_plugin.py +++ b/lemur/tests/plugins/source_plugin.py @@ -2,12 +2,12 @@ from lemur.plugins.bases import SourcePlugin class TestSourcePlugin(SourcePlugin): - title = 'Test' - slug = 'test-source' - description = 'Enables testing' + title = "Test" + slug = "test-source" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestSourcePlugin, self).__init__(*args, **kwargs) diff --git a/lemur/tests/test_api_keys.py b/lemur/tests/test_api_keys.py index e60773bf..9e293be2 100644 --- a/lemur/tests/test_api_keys.py +++ b/lemur/tests/test_api_keys.py @@ -4,219 +4,398 @@ import pytest from lemur.api_keys.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_api_key_list_get(client, token, status): assert client.get(api.url_for(ApiKeyList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_api_key_list_post_invalid(client, token, status): - assert client.post(api.url_for(ApiKeyList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(ApiKeyList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,user_id,status", [ - (VALID_USER_HEADER_TOKEN, 1, 200), - (VALID_ADMIN_HEADER_TOKEN, 2, 200), - (VALID_ADMIN_API_TOKEN, 2, 200), - ('', 0, 401) -]) +@pytest.mark.parametrize( + "token,user_id,status", + [ + (VALID_USER_HEADER_TOKEN, 1, 200), + (VALID_ADMIN_HEADER_TOKEN, 2, 200), + (VALID_ADMIN_API_TOKEN, 2, 200), + ("", 0, 401), + ], +) def test_api_key_list_post_valid_self(client, user_id, token, status): - assert client.post(api.url_for(ApiKeyList), data=json.dumps({'name': 'a test token', 'user': {'id': user_id, 'username': 'example', 'email': 'example@test.net'}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyList), + data=json.dumps( + { + "name": "a test token", + "user": { + "id": user_id, + "username": "example", + "email": "example@test.net", + }, + "ttl": -1, + } + ), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_api_key_list_post_valid_no_permission(client, token, status): - assert client.post(api.url_for(ApiKeyList), data=json.dumps({'name': 'a test token', 'user': {'id': 2, 'username': 'example', 'email': 'example@test.net'}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyList), + data=json.dumps( + { + "name": "a test token", + "user": { + "id": 2, + "username": "example", + "email": "example@test.net", + }, + "ttl": -1, + } + ), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_list_patch(client, token, status): - assert client.patch(api.url_for(ApiKeyList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(ApiKeyList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_list_delete(client, token, status): assert client.delete(api.url_for(ApiKeyList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_api_key_list_get(client, token, status): - assert client.get(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_user_api_key_list_post_invalid(client, token, status): - assert client.post(api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,user_id,status", [ - (VALID_USER_HEADER_TOKEN, 1, 200), - (VALID_ADMIN_HEADER_TOKEN, 2, 200), - (VALID_ADMIN_API_TOKEN, 2, 200), - ('', 0, 401) -]) +@pytest.mark.parametrize( + "token,user_id,status", + [ + (VALID_USER_HEADER_TOKEN, 1, 200), + (VALID_ADMIN_HEADER_TOKEN, 2, 200), + (VALID_ADMIN_API_TOKEN, 2, 200), + ("", 0, 401), + ], +) def test_user_api_key_list_post_valid_self(client, user_id, token, status): - assert client.post(api.url_for(ApiKeyUserList, user_id=1), data=json.dumps({'name': 'a test token', 'user': {'id': user_id}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyUserList, user_id=1), + data=json.dumps( + {"name": "a test token", "user": {"id": user_id}, "ttl": -1} + ), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_api_key_list_post_valid_no_permission(client, token, status): - assert client.post(api.url_for(ApiKeyUserList, user_id=2), data=json.dumps({'name': 'a test token', 'user': {'id': 2}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyUserList, user_id=2), + data=json.dumps({"name": "a test token", "user": {"id": 2}, "ttl": -1}), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_list_patch(client, token, status): - assert client.patch(api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_list_delete(client, token, status): - assert client.delete(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_api_key_get(client, token, status): assert client.get(api.url_for(ApiKeys, aid=1), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_post(client, token, status): assert client.post(api.url_for(ApiKeys, aid=1), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_patch(client, token, status): - assert client.patch(api.url_for(ApiKeys, aid=1), headers=token).status_code == status + assert ( + client.patch(api.url_for(ApiKeys, aid=1), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_api_key_put_permssions(client, token, status): - assert client.put(api.url_for(ApiKeys, aid=1), data=json.dumps({'name': 'Test', 'revoked': False, 'ttl': -1}), headers=token).status_code == status + assert ( + client.put( + api.url_for(ApiKeys, aid=1), + data=json.dumps({"name": "Test", "revoked": False, "ttl": -1}), + headers=token, + ).status_code + == status + ) # This test works while the other doesn't because the schema allows user id to be null. -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_api_key_described_get(client, token, status): - assert client.get(api.url_for(ApiKeysDescribed, aid=1), headers=token).status_code == status + assert ( + client.get(api.url_for(ApiKeysDescribed, aid=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_user_api_key_get(client, token, status): - assert client.get(api.url_for(UserApiKeys, uid=1, aid=1), headers=token).status_code == status + assert ( + client.get(api.url_for(UserApiKeys, uid=1, aid=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_post(client, token, status): - assert client.post(api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_patch(client, token, status): - assert client.patch(api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_user_api_key_put_permssions(client, token, status): - assert client.put(api.url_for(UserApiKeys, uid=2, aid=1), data=json.dumps({'name': 'Test', 'revoked': False, 'ttl': -1}), headers=token).status_code == status + assert ( + client.put( + api.url_for(UserApiKeys, uid=2, aid=1), + data=json.dumps({"name": "Test", "revoked": False, "ttl": -1}), + headers=token, + ).status_code + == status + ) diff --git a/lemur/tests/test_authorities.py b/lemur/tests/test_authorities.py index e865ab41..9649e949 100644 --- a/lemur/tests/test_authorities.py +++ b/lemur/tests/test_authorities.py @@ -4,22 +4,29 @@ import pytest from lemur.authorities.views import * # noqa from lemur.tests.factories import AuthorityFactory, RoleFactory -from lemur.tests.vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from lemur.tests.vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_authority_input_schema(client, role, issuer_plugin, logged_in_user): from lemur.authorities.schemas import AuthorityInputSchema input_data = { - 'name': 'Example Authority', - 'owner': 'jim@example.com', - 'description': 'An example authority.', - 'commonName': 'An Example Authority', - 'plugin': {'slug': 'test-issuer', 'plugin_options': [{'name': 'test', 'value': 'blah'}]}, - 'type': 'root', - 'signingAlgorithm': 'sha256WithRSA', - 'keyType': 'RSA2048', - 'sensitivity': 'medium' + "name": "Example Authority", + "owner": "jim@example.com", + "description": "An example authority.", + "commonName": "An Example Authority", + "plugin": { + "slug": "test-issuer", + "plugin_options": [{"name": "test", "value": "blah"}], + }, + "type": "root", + "signingAlgorithm": "sha256WithRSA", + "keyType": "RSA2048", + "sensitivity": "medium", } data, errors = AuthorityInputSchema().load(input_data) @@ -28,179 +35,286 @@ def test_authority_input_schema(client, role, issuer_plugin, logged_in_user): def test_user_authority(session, client, authority, role, user, issuer_plugin): - u = user['user'] + u = user["user"] u.roles.append(role) authority.roles.append(role) session.commit() - assert client.get(api.url_for(AuthoritiesList), headers=user['token']).json['total'] == 1 + assert ( + client.get(api.url_for(AuthoritiesList), headers=user["token"]).json["total"] + == 1 + ) u.roles.remove(role) session.commit() - assert client.get(api.url_for(AuthoritiesList), headers=user['token']).json['total'] == 0 + assert ( + client.get(api.url_for(AuthoritiesList), headers=user["token"]).json["total"] + == 0 + ) def test_create_authority(issuer_plugin, user): from lemur.authorities.service import create - authority = create(plugin={'plugin_object': issuer_plugin, 'slug': issuer_plugin.slug}, owner='jim@example.com', type='root', creator=user['user']) + + authority = create( + plugin={"plugin_object": issuer_plugin, "slug": issuer_plugin.slug}, + owner="jim@example.com", + type="root", + creator=user["user"], + ) assert authority.authority_certificate -@pytest.mark.parametrize("token, count", [ - (VALID_USER_HEADER_TOKEN, 0), - (VALID_ADMIN_HEADER_TOKEN, 3), - (VALID_ADMIN_API_TOKEN, 3), -]) +@pytest.mark.parametrize( + "token, count", + [ + (VALID_USER_HEADER_TOKEN, 0), + (VALID_ADMIN_HEADER_TOKEN, 3), + (VALID_ADMIN_API_TOKEN, 3), + ], +) def test_admin_authority(client, authority, issuer_plugin, token, count): - assert client.get(api.url_for(AuthoritiesList), headers=token).json['total'] == count + assert ( + client.get(api.url_for(AuthoritiesList), headers=token).json["total"] == count + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_authority_get(client, token, status): - assert client.get(api.url_for(Authorities, authority_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Authorities, authority_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authority_post(client, token, status): - assert client.post(api.url_for(Authorities, authority_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Authorities, authority_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_authority_put(client, token, status): - assert client.put(api.url_for(Authorities, authority_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Authorities, authority_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authority_delete(client, token, status): - assert client.delete(api.url_for(Authorities, authority_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Authorities, authority_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authority_patch(client, token, status): - assert client.patch(api.url_for(Authorities, authority_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Authorities, authority_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_authorities_get(client, token, status): assert client.get(api.url_for(AuthoritiesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_authorities_post(client, token, status): - assert client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authorities_put(client, token, status): - assert client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authorities_delete(client, token, status): - assert client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + assert ( + client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authorities_patch(client, token, status): - assert client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificate_authorities_get(client, token, status): assert client.get(api.url_for(AuthoritiesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificate_authorities_post(client, token, status): - assert client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_authorities_put(client, token, status): - assert client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_authorities_delete(client, token, status): - assert client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + assert ( + client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_authorities_patch(client, token, status): - assert client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) def test_authority_roles(client, session, issuer_plugin): @@ -209,23 +323,29 @@ def test_authority_roles(client, session, issuer_plugin): session.flush() data = { - 'owner': auth.owner, - 'name': auth.name, - 'description': auth.description, - 'active': True, - 'roles': [ - {'id': role.id}, - ], + "owner": auth.owner, + "name": auth.name, + "description": auth.description, + "active": True, + "roles": [{"id": role.id}], } # Add role - resp = client.put(api.url_for(Authorities, authority_id=auth.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.put( + api.url_for(Authorities, authority_id=auth.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 1 + assert len(resp.json["roles"]) == 1 assert set(auth.roles) == {role} # Remove role - del data['roles'][0] - resp = client.put(api.url_for(Authorities, authority_id=auth.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + del data["roles"][0] + resp = client.put( + api.url_for(Authorities, authority_id=auth.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 0 + assert len(resp.json["roles"]) == 0 diff --git a/lemur/tests/test_certificates.py b/lemur/tests/test_certificates.py index 87416a7a..adafa605 100644 --- a/lemur/tests/test_certificates.py +++ b/lemur/tests/test_certificates.py @@ -17,28 +17,133 @@ from lemur.common import utils from lemur.domains.models import Domain -from lemur.tests.vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN, CSR_STR, \ - INTERMEDIATE_CERT_STR, SAN_CERT_STR, SAN_CERT_KEY +from lemur.tests.vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, + CSR_STR, + INTERMEDIATE_CERT_STR, + SAN_CERT_STR, + SAN_CERT_CSR, + SAN_CERT_KEY, + ROOTCA_KEY, + ROOTCA_CERT_STR, +) def test_get_or_increase_name(session, certificate): from lemur.certificates.models import get_or_increase_name from lemur.tests.factories import CertificateFactory - serial = 'AFF2DB4F8D2D4D8E80FA382AE27C2333' + serial = "AFF2DB4F8D2D4D8E80FA382AE27C2333" - assert get_or_increase_name(certificate.name, certificate.serial) == '{0}-{1}'.format(certificate.name, serial) + assert get_or_increase_name( + certificate.name, certificate.serial + ) == "{0}-{1}".format(certificate.name, serial) - certificate.name = 'test-cert-11111111' - assert get_or_increase_name(certificate.name, certificate.serial) == 'test-cert-11111111-' + serial + certificate.name = "test-cert-11111111" + assert ( + get_or_increase_name(certificate.name, certificate.serial) + == "test-cert-11111111-" + serial + ) - certificate.name = 'test-cert-11111111-1' - assert get_or_increase_name('test-cert-11111111-1', certificate.serial) == 'test-cert-11111111-1-' + serial + certificate.name = "test-cert-11111111-1" + assert ( + get_or_increase_name("test-cert-11111111-1", certificate.serial) + == "test-cert-11111111-1-" + serial + ) - cert2 = CertificateFactory(name='certificate1-' + serial) + CertificateFactory(name="certificate1") + CertificateFactory(name="certificate1-" + serial) session.commit() - assert get_or_increase_name('certificate1', int(serial, 16)) == 'certificate1-{}-1'.format(serial) + assert get_or_increase_name( + "certificate1", int(serial, 16) + ) == "certificate1-{}-1".format(serial) + + +def test_get_all_certs(session, certificate): + from lemur.certificates.service import get_all_certs + + assert len(get_all_certs()) > 1 + + +def test_get_by_name(session, certificate): + from lemur.certificates.service import get_by_name + + found = get_by_name(certificate.name) + + assert found + + +def test_get_by_serial(session, certificate): + from lemur.certificates.service import get_by_serial + + found = get_by_serial(certificate.serial) + + assert found + + +def test_delete_cert(session): + from lemur.certificates.service import delete, get + from lemur.tests.factories import CertificateFactory + + delete_this = CertificateFactory(name="DELETEME") + session.commit() + + cert_exists = get(delete_this.id) + + # it needs to exist first + assert cert_exists + + delete(delete_this.id) + cert_exists = get(delete_this.id) + + # then not exist after delete + assert not cert_exists + + +def test_get_by_attributes(session, certificate): + from lemur.certificates.service import get_by_attributes + + # Should get one cert + certificate1 = get_by_attributes( + { + "name": "SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231" + } + ) + + # Should get one cert using multiple attrs + certificate2 = get_by_attributes( + {"name": "test-cert-11111111-1", "cn": "san.example.org"} + ) + + # Should get multiple certs + multiple = get_by_attributes( + { + "cn": "LemurTrust Unittests Class 1 CA 2018", + "issuer": "LemurTrustUnittestsRootCA2018", + } + ) + + assert len(certificate1) == 1 + assert len(certificate2) == 1 + assert len(multiple) > 1 + + +def test_find_duplicates(session): + from lemur.certificates.service import find_duplicates + + cert = {"body": SAN_CERT_STR, "chain": INTERMEDIATE_CERT_STR} + + dups1 = find_duplicates(cert) + + cert["chain"] = "" + + dups2 = find_duplicates(cert) + + assert len(dups1) > 0 + assert len(dups2) > 0 def test_get_certificate_primitives(certificate): @@ -55,13 +160,15 @@ def test_certificate_output_schema(session, certificate, issuer_plugin): from lemur.certificates.schemas import CertificateOutputSchema # Clear the cached attribute first - if 'parsed_cert' in certificate.__dict__: - del certificate.__dict__['parsed_cert'] + if "parsed_cert" in certificate.__dict__: + del certificate.__dict__["parsed_cert"] # Make sure serialization parses the cert only once (uses cached 'parsed_cert' attribute) - with patch('lemur.common.utils.parse_certificate', side_effect=utils.parse_certificate) as wrapper: + with patch( + "lemur.common.utils.parse_certificate", side_effect=utils.parse_certificate + ) as wrapper: data, errors = CertificateOutputSchema().dump(certificate) - assert data['issuer'] == 'LemurTrustUnittestsClass1CA2018' + assert data["issuer"] == "LemurTrustUnittestsClass1CA2018" assert wrapper.call_count == 1 @@ -69,24 +176,21 @@ def test_certificate_output_schema(session, certificate, issuer_plugin): def test_certificate_edit_schema(session): from lemur.certificates.schemas import CertificateEditInputSchema - input_data = {'owner': 'bob@example.com'} + input_data = {"owner": "bob@example.com"} data, errors = CertificateEditInputSchema().load(input_data) - assert len(data['notifications']) == 3 + assert len(data["notifications"]) == 3 def test_authority_key_identifier_schema(): from lemur.schemas import AuthorityKeyIdentifierSchema - input_data = { - 'useKeyIdentifier': True, - 'useAuthorityCert': True - } + + input_data = {"useKeyIdentifier": True, "useAuthorityCert": True} data, errors = AuthorityKeyIdentifierSchema().load(input_data) - assert sorted(data) == sorted({ - 'use_key_identifier': True, - 'use_authority_cert': True - }) + assert sorted(data) == sorted( + {"use_key_identifier": True, "use_authority_cert": True} + ) assert not errors data, errors = AuthorityKeyIdentifierSchema().dumps(data) @@ -96,11 +200,12 @@ def test_authority_key_identifier_schema(): def test_certificate_info_access_schema(): from lemur.schemas import CertificateInfoAccessSchema - input_data = {'includeAIA': True} + + input_data = {"includeAIA": True} data, errors = CertificateInfoAccessSchema().load(input_data) assert not errors - assert data == {'include_aia': True} + assert data == {"include_aia": True} data, errors = CertificateInfoAccessSchema().dump(data) assert not errors @@ -110,11 +215,11 @@ def test_certificate_info_access_schema(): def test_subject_key_identifier_schema(): from lemur.schemas import SubjectKeyIdentifierSchema - input_data = {'includeSKI': True} + input_data = {"includeSKI": True} data, errors = SubjectKeyIdentifierSchema().load(input_data) assert not errors - assert data == {'include_ski': True} + assert data == {"include_ski": True} data, errors = SubjectKeyIdentifierSchema().dump(data) assert not errors assert data == input_data @@ -124,16 +229,9 @@ def test_extension_schema(client): from lemur.certificates.schemas import ExtensionSchema input_data = { - 'keyUsage': { - 'useKeyEncipherment': True, - 'useDigitalSignature': True - }, - 'extendedKeyUsage': { - 'useServerAuthentication': True - }, - 'subjectKeyIdentifier': { - 'includeSKI': True - } + "keyUsage": {"useKeyEncipherment": True, "useDigitalSignature": True}, + "extendedKeyUsage": {"useServerAuthentication": True}, + "subjectKeyIdentifier": {"includeSKI": True}, } data, errors = ExtensionSchema().load(input_data) @@ -147,24 +245,24 @@ def test_certificate_input_schema(client, authority): from lemur.certificates.schemas import CertificateInputSchema input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': arrow.get(2018, 11, 9).isoformat(), - 'validityEnd': arrow.get(2019, 11, 9).isoformat(), - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": arrow.get(2018, 11, 9).isoformat(), + "validityEnd": arrow.get(2019, 11, 9).isoformat(), + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) assert not errors - assert data['authority'].id == authority.id + assert data["authority"].id == authority.id # make sure the defaults got set - assert data['common_name'] == 'test.example.com' - assert data['country'] == 'US' - assert data['location'] == 'Los Gatos' + assert data["common_name"] == "test.example.com" + assert data["country"] == "US" + assert data["location"] == "Los Gatos" assert len(data.keys()) == 19 @@ -173,54 +271,86 @@ def test_certificate_input_with_extensions(client, authority): from lemur.certificates.schemas import CertificateInputSchema input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'extensions': { - 'keyUsage': { - 'digital_signature': True + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "extensions": { + "keyUsage": {"digital_signature": True}, + "extendedKeyUsage": { + "useClientAuthentication": True, + "useServerAuthentication": True, }, - 'extendedKeyUsage': { - 'useClientAuthentication': True, - 'useServerAuthentication': True + "subjectKeyIdentifier": {"includeSKI": True}, + "subAltNames": { + "names": [{"nameType": "DNSName", "value": "test.example.com"}] }, - 'subjectKeyIdentifier': { - 'includeSKI': True - }, - 'subAltNames': { - 'names': [ - {'nameType': 'DNSName', 'value': 'test.example.com'} - ] - } }, - 'dnsProvider': None, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) assert not errors +def test_certificate_input_schema_parse_csr(authority): + from lemur.certificates.schemas import CertificateInputSchema + + test_san_dns = "foobar.com" + extensions = { + "sub_alt_names": { + "names": x509.SubjectAlternativeName([x509.DNSName(test_san_dns)]) + } + } + csr, private_key = create_csr( + owner="joe@example.com", + common_name="ACommonName", + organization="test", + organizational_unit="Meters", + country="NL", + state="Noord-Holland", + location="Amsterdam", + key_type="RSA2048", + extensions=extensions, + ) + + input_data = { + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "csr": csr, + "dnsProvider": None, + } + + data, errors = CertificateInputSchema().load(input_data) + + for san in data["extensions"]["sub_alt_names"]["names"]: + assert san.value == test_san_dns + assert not errors + + def test_certificate_out_of_range_date(client, authority): from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityYears': 100, - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityYears": 100, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) assert errors - input_data['validityStart'] = '2017-04-30T00:12:34.513631' + input_data["validityStart"] = "2017-04-30T00:12:34.513631" data, errors = CertificateInputSchema().load(input_data) assert errors - input_data['validityEnd'] = '2018-04-30T00:12:34.513631' + input_data["validityEnd"] = "2018-04-30T00:12:34.513631" data, errors = CertificateInputSchema().load(input_data) assert errors @@ -228,13 +358,14 @@ def test_certificate_out_of_range_date(client, authority): def test_certificate_valid_years(client, authority): from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityYears': 1, - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityYears": 1, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -243,14 +374,15 @@ def test_certificate_valid_years(client, authority): def test_certificate_valid_dates(client, authority): from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -260,14 +392,15 @@ def test_certificate_valid_dates(client, authority): def test_certificate_cn_admin(client, authority, logged_in_admin): """Admin is exempt from CN/SAN domain restrictions.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': '*.admin-overrides-whitelist.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "*.admin-overrides-whitelist.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -277,22 +410,23 @@ def test_certificate_cn_admin(client, authority, logged_in_admin): def test_certificate_allowed_names(client, authority, session, logged_in_user): """Test for allowed CN and SAN values.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'Names with spaces are not checked', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'extensions': { - 'subAltNames': { - 'names': [ - {'nameType': 'DNSName', 'value': 'allowed.example.com'}, - {'nameType': 'IPAddress', 'value': '127.0.0.1'}, + "commonName": "Names with spaces are not checked", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "extensions": { + "subAltNames": { + "names": [ + {"nameType": "DNSName", "value": "allowed.example.com"}, + {"nameType": "IPAddress", "value": "127.0.0.1"}, ] } }, - 'dnsProvider': None, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -307,79 +441,209 @@ def test_certificate_incative_authority(client, authority, session, logged_in_us session.add(authority) input_data = { - 'commonName': 'foo.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "foo.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) - assert errors['authority'][0] == "The authority is inactive." + assert errors["authority"][0] == "The authority is inactive." def test_certificate_disallowed_names(client, authority, session, logged_in_user): """The CN and SAN are disallowed by LEMUR_WHITELISTED_DOMAINS.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': '*.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'extensions': { - 'subAltNames': { - 'names': [ - {'nameType': 'DNSName', 'value': 'allowed.example.com'}, - {'nameType': 'DNSName', 'value': 'evilhacker.org'}, + "commonName": "*.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "extensions": { + "subAltNames": { + "names": [ + {"nameType": "DNSName", "value": "allowed.example.com"}, + {"nameType": "DNSName", "value": "evilhacker.org"}, ] } }, - 'dnsProvider': None, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) - assert errors['common_name'][0].startswith("Domain *.example.com does not match whitelisted domain patterns") - assert (errors['extensions']['sub_alt_names']['names'][0] - .startswith("Domain evilhacker.org does not match whitelisted domain patterns")) + assert errors["common_name"][0].startswith( + "Domain *.example.com does not match whitelisted domain patterns" + ) + assert errors["extensions"]["sub_alt_names"]["names"][0].startswith( + "Domain evilhacker.org does not match whitelisted domain patterns" + ) def test_certificate_sensitive_name(client, authority, session, logged_in_user): """The CN is disallowed by 'sensitive' flag on Domain model.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'sensitive.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "sensitive.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } - session.add(Domain(name='sensitive.example.com', sensitive=True)) + session.add(Domain(name="sensitive.example.com", sensitive=True)) data, errors = CertificateInputSchema().load(input_data) - assert errors['common_name'][0].startswith("Domain sensitive.example.com has been marked as sensitive") + assert errors["common_name"][0].startswith( + "Domain sensitive.example.com has been marked as sensitive" + ) + + +def test_certificate_upload_schema_ok(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = { + "name": "Jane", + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "privateKey": SAN_CERT_KEY, + "chain": INTERMEDIATE_CERT_STR, + "csr": SAN_CERT_CSR, + "external_id": "1234", + } + data, errors = CertificateUploadInputSchema().load(data) + assert not errors + + +def test_certificate_upload_schema_minimal(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = {"owner": "pwner@example.com", "body": SAN_CERT_STR} + data, errors = CertificateUploadInputSchema().load(data) + assert not errors + + +def test_certificate_upload_schema_long_chain(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = { + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "chain": INTERMEDIATE_CERT_STR + "\n" + ROOTCA_CERT_STR, + } + data, errors = CertificateUploadInputSchema().load(data) + assert not errors + + +def test_certificate_upload_schema_invalid_body(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = { + "owner": "pwner@example.com", + "body": "Hereby I certify that this is a valid body", + } + data, errors = CertificateUploadInputSchema().load(data) + assert errors == {"body": ["Public certificate presented is not valid."]} + + +def test_certificate_upload_schema_invalid_pkey(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = { + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "privateKey": "Look at me Im a private key!!111", + } + data, errors = CertificateUploadInputSchema().load(data) + assert errors == {"private_key": ["Private key presented is not valid."]} + + +def test_certificate_upload_schema_invalid_chain(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = {"body": SAN_CERT_STR, "chain": "CHAINSAW", "owner": "pwner@example.com"} + data, errors = CertificateUploadInputSchema().load(data) + assert errors == {"chain": ["Invalid certificate in certificate chain."]} + + +def test_certificate_upload_schema_wrong_pkey(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = { + "body": SAN_CERT_STR, + "privateKey": ROOTCA_KEY, + "chain": INTERMEDIATE_CERT_STR, + "owner": "pwner@example.com", + } + data, errors = CertificateUploadInputSchema().load(data) + assert errors == {"_schema": ["Private key does not match certificate."]} + + +def test_certificate_upload_schema_wrong_chain(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = { + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "chain": ROOTCA_CERT_STR, + } + data, errors = CertificateUploadInputSchema().load(data) + assert errors == { + "_schema": [ + "Incorrect chain certificate(s) provided: 'san.example.org' is not signed by " + "'LemurTrust Unittests Root CA 2018'" + ] + } + + +def test_certificate_upload_schema_wrong_chain_2nd(client): + from lemur.certificates.schemas import CertificateUploadInputSchema + + data = { + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "chain": INTERMEDIATE_CERT_STR + "\n" + SAN_CERT_STR, + } + data, errors = CertificateUploadInputSchema().load(data) + assert errors == { + "_schema": [ + "Incorrect chain certificate(s) provided: 'LemurTrust Unittests Class 1 CA 2018' is " + "not signed by 'san.example.org'" + ] + } def test_create_basic_csr(client): csr_config = dict( - common_name='example.com', - organization='Example, Inc.', - organizational_unit='Operations', - country='US', - state='CA', - location='A place', - owner='joe@example.com', - key_type='RSA2048', - extensions=dict(names=dict(sub_alt_names=x509.SubjectAlternativeName([x509.DNSName('test.example.com'), x509.DNSName('test2.example.com')]))) + common_name="example.com", + organization="Example, Inc.", + organizational_unit="Operations", + country="US", + state="CA", + location="A place", + owner="joe@example.com", + key_type="RSA2048", + extensions=dict( + names=dict( + sub_alt_names=x509.SubjectAlternativeName( + [ + x509.DNSName("test.example.com"), + x509.DNSName("test2.example.com"), + ] + ) + ) + ), ) csr, pem = create_csr(**csr_config) - csr = x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend()) + csr = x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend()) for name in csr.subject: assert name.value in csr_config.values() @@ -391,13 +655,13 @@ def test_csr_empty_san(client): """ csr_text, pkey = create_csr( - common_name='daniel-san.example.com', - owner='daniel-san@example.com', - key_type='RSA2048', - extensions={'sub_alt_names': {'names': x509.SubjectAlternativeName([])}} + common_name="daniel-san.example.com", + owner="daniel-san@example.com", + key_type="RSA2048", + extensions={"sub_alt_names": {"names": x509.SubjectAlternativeName([])}}, ) - csr = x509.load_pem_x509_csr(csr_text.encode('utf-8'), default_backend()) + csr = x509.load_pem_x509_csr(csr_text.encode("utf-8"), default_backend()) with pytest.raises(x509.ExtensionNotFound): csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) @@ -408,13 +672,13 @@ def test_csr_disallowed_cn(client, logged_in_user): from lemur.common import validators request, pkey = create_csr( - common_name='evilhacker.org', - owner='joe@example.com', - key_type='RSA2048', + common_name="evilhacker.org", owner="joe@example.com", key_type="RSA2048" ) with pytest.raises(ValidationError) as err: validators.csr(request) - assert str(err.value).startswith('Domain evilhacker.org does not match whitelisted domain patterns') + assert str(err.value).startswith( + "Domain evilhacker.org does not match whitelisted domain patterns" + ) def test_csr_disallowed_san(client, logged_in_user): @@ -423,315 +687,585 @@ def test_csr_disallowed_san(client, logged_in_user): request, pkey = create_csr( common_name="CN with spaces isn't a domain and is thus allowed", - owner='joe@example.com', - key_type='RSA2048', - extensions={'sub_alt_names': {'names': x509.SubjectAlternativeName([x509.DNSName('evilhacker.org')])}} + owner="joe@example.com", + key_type="RSA2048", + extensions={ + "sub_alt_names": { + "names": x509.SubjectAlternativeName([x509.DNSName("evilhacker.org")]) + } + }, ) with pytest.raises(ValidationError) as err: validators.csr(request) - assert str(err.value).startswith('Domain evilhacker.org does not match whitelisted domain patterns') + assert str(err.value).startswith( + "Domain evilhacker.org does not match whitelisted domain patterns" + ) def test_get_name_from_arn(client): from lemur.certificates.service import get_name_from_arn - arn = 'arn:aws:iam::11111111:server-certificate/mycertificate' - assert get_name_from_arn(arn) == 'mycertificate' + + arn = "arn:aws:iam::11111111:server-certificate/mycertificate" + assert get_name_from_arn(arn) == "mycertificate" def test_get_account_number(client): from lemur.certificates.service import get_account_number - arn = 'arn:aws:iam::11111111:server-certificate/mycertificate' - assert get_account_number(arn) == '11111111' + + arn = "arn:aws:iam::11111111:server-certificate/mycertificate" + assert get_account_number(arn) == "11111111" def test_mint_certificate(issuer_plugin, authority): from lemur.certificates.service import mint - cert_body, private_key, chain, external_id, csr = mint(authority=authority, csr=CSR_STR) + + cert_body, private_key, chain, external_id, csr = mint( + authority=authority, csr=CSR_STR + ) assert cert_body == SAN_CERT_STR def test_create_certificate(issuer_plugin, authority, user): from lemur.certificates.service import create - cert = create(authority=authority, csr=CSR_STR, owner='joe@example.com', creator=user['user']) - assert str(cert.not_after) == '2047-12-31T22:00:00+00:00' - assert str(cert.not_before) == '2017-12-31T22:00:00+00:00' - assert cert.issuer == 'LemurTrustUnittestsClass1CA2018' - assert cert.name == 'SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231-AFF2DB4F8D2D4D8E80FA382AE27C2333' - cert = create(authority=authority, csr=CSR_STR, owner='joe@example.com', name='ACustomName1', creator=user['user']) - assert cert.name == 'ACustomName1' + cert = create( + authority=authority, csr=CSR_STR, owner="joe@example.com", creator=user["user"] + ) + assert str(cert.not_after) == "2047-12-31T22:00:00+00:00" + assert str(cert.not_before) == "2017-12-31T22:00:00+00:00" + assert cert.issuer == "LemurTrustUnittestsClass1CA2018" + assert ( + cert.name + == "SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231-AFF2DB4F8D2D4D8E80FA382AE27C2333" + ) + + cert = create( + authority=authority, + csr=CSR_STR, + owner="joe@example.com", + name="ACustomName1", + creator=user["user"], + ) + assert cert.name == "ACustomName1" -def test_reissue_certificate(issuer_plugin, authority, certificate): +def test_reissue_certificate( + issuer_plugin, crypto_authority, certificate, logged_in_user +): from lemur.certificates.service import reissue_certificate + + # test-authority would return a mismatching private key, so use 'cryptography-issuer' plugin instead. + certificate.authority = crypto_authority new_cert = reissue_certificate(certificate) assert new_cert def test_create_csr(): - csr, private_key = create_csr(owner='joe@example.com', common_name='ACommonName', organization='test', organizational_unit='Meters', country='US', - state='CA', location='Here', key_type='RSA2048') + csr, private_key = create_csr( + owner="joe@example.com", + common_name="ACommonName", + organization="test", + organizational_unit="Meters", + country="US", + state="CA", + location="Here", + key_type="RSA2048", + ) assert csr assert private_key - extensions = {'sub_alt_names': {'names': x509.SubjectAlternativeName([x509.DNSName('AnotherCommonName')])}} - csr, private_key = create_csr(owner='joe@example.com', common_name='ACommonName', organization='test', organizational_unit='Meters', country='US', - state='CA', location='Here', extensions=extensions, key_type='RSA2048') + extensions = { + "sub_alt_names": { + "names": x509.SubjectAlternativeName([x509.DNSName("AnotherCommonName")]) + } + } + csr, private_key = create_csr( + owner="joe@example.com", + common_name="ACommonName", + organization="test", + organizational_unit="Meters", + country="US", + state="CA", + location="Here", + extensions=extensions, + key_type="RSA2048", + ) assert csr assert private_key def test_import(user): from lemur.certificates.service import import_certificate - cert = import_certificate(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, creator=user['user']) - assert str(cert.not_after) == '2047-12-31T22:00:00+00:00' - assert str(cert.not_before) == '2017-12-31T22:00:00+00:00' - assert cert.issuer == 'LemurTrustUnittestsClass1CA2018' - assert cert.name == 'SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231-AFF2DB4F8D2D4D8E80FA382AE27C2333-2' - cert = import_certificate(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', name='ACustomName2', creator=user['user']) - assert cert.name == 'ACustomName2' + cert = import_certificate( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + creator=user["user"], + ) + assert str(cert.not_after) == "2047-12-31T22:00:00+00:00" + assert str(cert.not_before) == "2017-12-31T22:00:00+00:00" + assert cert.issuer == "LemurTrustUnittestsClass1CA2018" + assert cert.name.startswith( + "SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231" + ) + + cert = import_certificate( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + name="ACustomName2", + creator=user["user"], + ) + assert cert.name == "ACustomName2" @pytest.mark.skip def test_upload(user): from lemur.certificates.service import upload - cert = upload(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', creator=user['user']) - assert str(cert.not_after) == '2040-01-01T20:30:52+00:00' - assert str(cert.not_before) == '2015-06-26T20:30:52+00:00' - assert cert.issuer == 'Example' - assert cert.name == 'long.lived.com-Example-20150626-20400101-3' - cert = upload(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', name='ACustomName', creator=user['user']) - assert 'ACustomName' in cert.name + cert = upload( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + creator=user["user"], + ) + assert str(cert.not_after) == "2040-01-01T20:30:52+00:00" + assert str(cert.not_before) == "2015-06-26T20:30:52+00:00" + assert cert.issuer == "Example" + assert cert.name == "long.lived.com-Example-20150626-20400101-3" + + cert = upload( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + name="ACustomName", + creator=user["user"], + ) + assert "ACustomName" in cert.name # verify upload with a private key as a str def test_upload_private_key_str(user): from lemur.certificates.service import upload - cert = upload(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', name='ACustomName', creator=user['user']) + + cert = upload( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + name="ACustomName", + creator=user["user"], + ) assert cert -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificate_get_private_key(client, token, status): - assert client.get(api.url_for(Certificates, certificate_id=1), headers=token).status_code == status + assert ( + client.get( + api.url_for(Certificates, certificate_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificate_get(client, token, status): - assert client.get(api.url_for(Certificates, certificate_id=1), headers=token).status_code == status + assert ( + client.get( + api.url_for(Certificates, certificate_id=1), headers=token + ).status_code + == status + ) def test_certificate_get_body(client): - response_body = client.get(api.url_for(Certificates, certificate_id=1), headers=VALID_USER_HEADER_TOKEN).json - assert response_body['serial'] == '211983098819107449768450703123665283596' - assert response_body['serialHex'] == '9F7A75B39DAE4C3F9524C68B06DA6A0C' + response_body = client.get( + api.url_for(Certificates, certificate_id=1), headers=VALID_USER_HEADER_TOKEN + ).json + assert response_body["serial"] == "211983098819107449768450703123665283596" + assert response_body["serialHex"] == "9F7A75B39DAE4C3F9524C68B06DA6A0C" + assert response_body["distinguishedName"] == ( + "CN=LemurTrust Unittests Class 1 CA 2018," + "O=LemurTrust Enterprises Ltd," + "OU=Unittesting Operations Center," + "C=EE," + "ST=N/A," + "L=Earth" + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_post(client, token, status): - assert client.post(api.url_for(Certificates, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Certificates, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificate_put(client, token, status): - assert client.put(api.url_for(Certificates, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Certificates, certificate_id=1), data={}, headers=token + ).status_code + == status + ) def test_certificate_put_with_data(client, certificate, issuer_plugin): - resp = client.put(api.url_for(Certificates, certificate_id=certificate.id), data=json.dumps({'owner': 'bob@example.com', 'description': 'test', 'notify': True}), headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.put( + api.url_for(Certificates, certificate_id=certificate.id), + data=json.dumps( + {"owner": "bob@example.com", "description": "test", "notify": True} + ), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 204), + (VALID_ADMIN_API_TOKEN, 412), + ("", 401), + ], +) def test_certificate_delete(client, token, status): - assert client.delete(api.url_for(Certificates, certificate_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Certificates, certificate_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 204), + (VALID_ADMIN_API_TOKEN, 204), + ("", 401), + ], +) +def test_invalid_certificate_delete(client, invalid_certificate, token, status): + assert ( + client.delete( + api.url_for(Certificates, certificate_id=invalid_certificate.id), + headers=token, + ).status_code + == status + ) + + +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_patch(client, token, status): - assert client.patch(api.url_for(Certificates, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Certificates, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificates_get(client, token, status): - assert client.get(api.url_for(CertificatesList), headers=token).status_code == status + assert ( + client.get(api.url_for(CertificatesList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificates_post(client, token, status): - assert client.post(api.url_for(CertificatesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(CertificatesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_put(client, token, status): - assert client.put(api.url_for(CertificatesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(CertificatesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_delete(client, token, status): - assert client.delete(api.url_for(CertificatesList), headers=token).status_code == status + assert ( + client.delete(api.url_for(CertificatesList), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_patch(client, token, status): - assert client.patch(api.url_for(CertificatesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(CertificatesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_post(client, token, status): - assert client.post(api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_put(client, token, status): - assert client.put(api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_delete(client, token, status): - assert client.delete(api.url_for(CertificatePrivateKey, certificate_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(CertificatePrivateKey, certificate_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_patch(client, token, status): - assert client.patch(api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_get(client, token, status): - assert client.get(api.url_for(CertificatesUpload), headers=token).status_code == status + assert ( + client.get(api.url_for(CertificatesUpload), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificates_upload_post(client, token, status): - assert client.post(api.url_for(CertificatesUpload), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(CertificatesUpload), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_put(client, token, status): - assert client.put(api.url_for(CertificatesUpload), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(CertificatesUpload), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_delete(client, token, status): - assert client.delete(api.url_for(CertificatesUpload), headers=token).status_code == status + assert ( + client.delete(api.url_for(CertificatesUpload), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_patch(client, token, status): - assert client.patch(api.url_for(CertificatesUpload), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(CertificatesUpload), data={}, headers=token + ).status_code + == status + ) def test_sensitive_sort(client): - resp = client.get(api.url_for(CertificatesList) + '?sortBy=private_key&sortDir=asc', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'private_key' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(CertificatesList) + "?sortBy=private_key&sortDir=asc", + headers=VALID_ADMIN_HEADER_TOKEN, + ) + assert "'private_key' is not sortable or filterable" in resp.json["message"] def test_boolean_filter(client): - resp = client.get(api.url_for(CertificatesList) + '?filter=notify;true', headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.get( + api.url_for(CertificatesList) + "?filter=notify;true", + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 # Also don't crash with invalid input (we currently treat that as false) - resp = client.get(api.url_for(CertificatesList) + '?filter=notify;whatisthis', headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.get( + api.url_for(CertificatesList) + "?filter=notify;whatisthis", + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 diff --git a/lemur/tests/test_defaults.py b/lemur/tests/test_defaults.py index 918e1ab8..b8daa575 100644 --- a/lemur/tests/test_defaults.py +++ b/lemur/tests/test_defaults.py @@ -1,17 +1,25 @@ +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes + from .vectors import SAN_CERT, WILDCARD_CERT, INTERMEDIATE_CERT def test_cert_get_cn(client): from lemur.common.defaults import common_name - assert common_name(SAN_CERT) == 'san.example.org' + assert common_name(SAN_CERT) == "san.example.org" def test_cert_sub_alt_domains(client): from lemur.common.defaults import domains assert domains(INTERMEDIATE_CERT) == [] - assert domains(SAN_CERT) == ['san.example.org', 'san2.example.org', 'daniel-san.example.org'] + assert domains(SAN_CERT) == [ + "san.example.org", + "san2.example.org", + "daniel-san.example.org", + ] def test_cert_is_san(client): @@ -24,54 +32,119 @@ def test_cert_is_san(client): def test_cert_is_wildcard(client): from lemur.common.defaults import is_wildcard + assert is_wildcard(WILDCARD_CERT) assert not is_wildcard(INTERMEDIATE_CERT) def test_cert_bitstrength(client): from lemur.common.defaults import bitstrength + assert bitstrength(INTERMEDIATE_CERT) == 2048 def test_cert_issuer(client): from lemur.common.defaults import issuer - assert issuer(INTERMEDIATE_CERT) == 'LemurTrustUnittestsRootCA2018' + + assert issuer(INTERMEDIATE_CERT) == "LemurTrustUnittestsRootCA2018" def test_text_to_slug(client): from lemur.common.defaults import text_to_slug - assert text_to_slug('test - string') == 'test-string' + + assert text_to_slug("test - string") == "test-string" + assert text_to_slug("test - string", "") == "teststring" # Accented characters are decomposed - assert text_to_slug('föö bär') == 'foo-bar' + assert text_to_slug("föö bär") == "foo-bar" # Melt away the Unicode Snowman - assert text_to_slug('\u2603') == '' - assert text_to_slug('\u2603test\u2603') == 'test' - assert text_to_slug('snow\u2603man') == 'snow-man' + assert text_to_slug("\u2603") == "" + assert text_to_slug("\u2603test\u2603") == "test" + assert text_to_slug("snow\u2603man") == "snow-man" + assert text_to_slug("snow\u2603man", "") == "snowman" # IDNA-encoded domain names should be kept as-is - assert text_to_slug('xn--i1b6eqas.xn--xmpl-loa9b3671b.com') == 'xn--i1b6eqas.xn--xmpl-loa9b3671b.com' + assert ( + text_to_slug("xn--i1b6eqas.xn--xmpl-loa9b3671b.com") + == "xn--i1b6eqas.xn--xmpl-loa9b3671b.com" + ) def test_create_name(client): from lemur.common.defaults import certificate_name from datetime import datetime - assert certificate_name( - 'example.com', - 'Example Inc,', - datetime(2015, 5, 7, 0, 0, 0), - datetime(2015, 5, 12, 0, 0, 0), - False - ) == 'example.com-ExampleInc-20150507-20150512' - assert certificate_name( - 'example.com', - 'Example Inc,', - datetime(2015, 5, 7, 0, 0, 0), - datetime(2015, 5, 12, 0, 0, 0), - True - ) == 'SAN-example.com-ExampleInc-20150507-20150512' - assert certificate_name( - 'xn--mnchen-3ya.de', - 'Vertrauenswürdig Autorität', - datetime(2015, 5, 7, 0, 0, 0), - datetime(2015, 5, 12, 0, 0, 0), - False - ) == 'xn--mnchen-3ya.de-VertrauenswurdigAutoritat-20150507-20150512' + + assert ( + certificate_name( + "example.com", + "Example Inc,", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2015, 5, 12, 0, 0, 0), + False, + ) + == "example.com-ExampleInc-20150507-20150512" + ) + assert ( + certificate_name( + "example.com", + "Example Inc,", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2015, 5, 12, 0, 0, 0), + True, + ) + == "SAN-example.com-ExampleInc-20150507-20150512" + ) + assert ( + certificate_name( + "xn--mnchen-3ya.de", + "Vertrauenswürdig Autorität", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2015, 5, 12, 0, 0, 0), + False, + ) + == "xn--mnchen-3ya.de-VertrauenswurdigAutoritat-20150507-20150512" + ) + assert ( + certificate_name( + "selfie.example.org", + "", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2025, 5, 12, 13, 37, 0), + False, + ) + == "selfie.example.org-selfsigned-20150507-20250512" + ) + + +def test_issuer(client, cert_builder, issuer_private_key): + from lemur.common.defaults import issuer + + assert issuer(INTERMEDIATE_CERT) == "LemurTrustUnittestsRootCA2018" + + # We need to override builder's issuer name + cert_builder._issuer_name = None + # Unicode issuer name + cert = cert_builder.issuer_name( + x509.Name( + [x509.NameAttribute(x509.NameOID.COMMON_NAME, "Vertrauenswürdig Autorität")] + ) + ).sign(issuer_private_key, hashes.SHA256(), default_backend()) + assert issuer(cert) == "VertrauenswurdigAutoritat" + + # Fallback to 'Organization' field when issuer CN is missing + cert = cert_builder.issuer_name( + x509.Name( + [x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "No Such Organization")] + ) + ).sign(issuer_private_key, hashes.SHA256(), default_backend()) + assert issuer(cert) == "NoSuchOrganization" + + # Missing issuer name + cert = cert_builder.issuer_name(x509.Name([])).sign( + issuer_private_key, hashes.SHA256(), default_backend() + ) + assert issuer(cert) == "" + + +def test_issuer_selfsigned(selfsigned_cert): + from lemur.common.defaults import issuer + + assert issuer(selfsigned_cert) == "" diff --git a/lemur/tests/test_destinations.py b/lemur/tests/test_destinations.py index 11f03d9e..d17c703b 100644 --- a/lemur/tests/test_destinations.py +++ b/lemur/tests/test_destinations.py @@ -3,20 +3,22 @@ import pytest from lemur.destinations.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_destination_input_schema(client, destination_plugin, destination): from lemur.destinations.schemas import DestinationInputSchema input_data = { - 'label': 'destination1', - 'options': {}, - 'description': 'my destination', - 'active': True, - 'plugin': { - 'slug': 'test-destination' - } + "label": "destination1", + "options": {}, + "description": "my destination", + "active": True, + "plugin": {"slug": "test-destination"}, } data, errors = DestinationInputSchema().load(input_data) @@ -24,91 +26,154 @@ def test_destination_input_schema(client, destination_plugin, destination): assert not errors -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 404), - (VALID_ADMIN_HEADER_TOKEN, 404), - (VALID_ADMIN_API_TOKEN, 404), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 404), + (VALID_ADMIN_HEADER_TOKEN, 404), + (VALID_ADMIN_API_TOKEN, 404), + ("", 401), + ], +) def test_destination_get(client, token, status): - assert client.get(api.url_for(Destinations, destination_id=1), headers=token).status_code == status + assert ( + client.get( + api.url_for(Destinations, destination_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_post_(client, token, status): - assert client.post(api.url_for(Destinations, destination_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Destinations, destination_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_destination_put(client, token, status): - assert client.put(api.url_for(Destinations, destination_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Destinations, destination_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_destination_delete(client, token, status): - assert client.delete(api.url_for(Destinations, destination_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Destinations, destination_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_patch(client, token, status): - assert client.patch(api.url_for(Destinations, destination_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Destinations, destination_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_destination_list_post_(client, token, status): - assert client.post(api.url_for(DestinationsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(DestinationsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_destination_list_get(client, token, status): - assert client.get(api.url_for(DestinationsList), headers=token).status_code == status + assert ( + client.get(api.url_for(DestinationsList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_list_delete(client, token, status): - assert client.delete(api.url_for(DestinationsList), headers=token).status_code == status + assert ( + client.delete(api.url_for(DestinationsList), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_list_patch(client, token, status): - assert client.patch(api.url_for(DestinationsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(DestinationsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_domains.py b/lemur/tests/test_domains.py index 873412b2..47023f8c 100644 --- a/lemur/tests/test_domains.py +++ b/lemur/tests/test_domains.py @@ -3,94 +3,152 @@ import pytest from lemur.domains.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_domain_get(client, token, status): - assert client.get(api.url_for(Domains, domain_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Domains, domain_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_post_(client, token, status): - assert client.post(api.url_for(Domains, domain_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Domains, domain_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_domain_put(client, token, status): - assert client.put(api.url_for(Domains, domain_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Domains, domain_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_delete(client, token, status): - assert client.delete(api.url_for(Domains, domain_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Domains, domain_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_patch(client, token, status): - assert client.patch(api.url_for(Domains, domain_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Domains, domain_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_domain_list_post_(client, token, status): - assert client.post(api.url_for(DomainsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(DomainsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_domain_list_get(client, token, status): assert client.get(api.url_for(DomainsList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_list_delete(client, token, status): assert client.delete(api.url_for(DomainsList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_list_patch(client, token, status): - assert client.patch(api.url_for(DomainsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(DomainsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_endpoints.py b/lemur/tests/test_endpoints.py index 4ea0a4aa..af073e53 100644 --- a/lemur/tests/test_endpoints.py +++ b/lemur/tests/test_endpoints.py @@ -4,11 +4,16 @@ from lemur.endpoints.views import * # noqa from lemur.tests.factories import EndpointFactory, CertificateFactory -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_rotate_certificate(client, source_plugin): from lemur.deployment.service import rotate_certificate + new_certificate = CertificateFactory() endpoint = EndpointFactory() @@ -16,91 +21,147 @@ def test_rotate_certificate(client, source_plugin): assert endpoint.certificate == new_certificate -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 404), - (VALID_ADMIN_HEADER_TOKEN, 404), - (VALID_ADMIN_API_TOKEN, 404), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 404), + (VALID_ADMIN_HEADER_TOKEN, 404), + (VALID_ADMIN_API_TOKEN, 404), + ("", 401), + ], +) def test_endpoint_get(client, token, status): - assert client.get(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_post_(client, token, status): - assert client.post(api.url_for(Endpoints, endpoint_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Endpoints, endpoint_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_put(client, token, status): - assert client.put(api.url_for(Endpoints, endpoint_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Endpoints, endpoint_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_delete(client, token, status): - assert client.delete(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_patch(client, token, status): - assert client.patch(api.url_for(Endpoints, endpoint_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Endpoints, endpoint_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_list_post_(client, token, status): - assert client.post(api.url_for(EndpointsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(EndpointsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_endpoint_list_get(client, token, status): assert client.get(api.url_for(EndpointsList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_list_delete(client, token, status): - assert client.delete(api.url_for(EndpointsList), headers=token).status_code == status + assert ( + client.delete(api.url_for(EndpointsList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_list_patch(client, token, status): - assert client.patch(api.url_for(EndpointsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(EndpointsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_ldap.py b/lemur/tests/test_ldap.py index a636afdc..8e4027a9 100644 --- a/lemur/tests/test_ldap.py +++ b/lemur/tests/test_ldap.py @@ -1,51 +1,69 @@ import pytest -from lemur.auth.ldap import * # noqa +from lemur.auth.ldap import * # noqa from mock import patch, MagicMock class LdapPrincipalTester(LdapPrincipal): - def __init__(self, args): super().__init__(args) - self.ldap_server = 'ldap://localhost' + self.ldap_server = "ldap://localhost" def bind_test(self): - groups = [('user', {'memberOf': ['CN=Lemur Access,OU=Groups,DC=example,DC=com'.encode('utf-8'), - 'CN=Pen Pushers,OU=Groups,DC=example,DC=com'.encode('utf-8')]})] + groups = [ + ( + "user", + { + "memberOf": [ + "CN=Lemur Access,OU=Groups,DC=example,DC=com".encode("utf-8"), + "CN=Pen Pushers,OU=Groups,DC=example,DC=com".encode("utf-8"), + ] + }, + ) + ] self.ldap_client = MagicMock() self.ldap_client.search_s.return_value = groups self._bind() def authorize_test_groups_to_roles_admin(self): - self.ldap_groups = ''.join(['CN=Pen Pushers,OU=Groups,DC=example,DC=com', - 'CN=Lemur Admins,OU=Groups,DC=example,DC=com', - 'CN=Lemur Read Only,OU=Groups,DC=example,DC=com']) + self.ldap_groups = "".join( + [ + "CN=Pen Pushers,OU=Groups,DC=example,DC=com", + "CN=Lemur Admins,OU=Groups,DC=example,DC=com", + "CN=Lemur Read Only,OU=Groups,DC=example,DC=com", + ] + ) self.ldap_required_group = None - self.ldap_groups_to_roles = {'Lemur Admins': 'admin', 'Lemur Read Only': 'read-only'} + self.ldap_groups_to_roles = { + "Lemur Admins": "admin", + "Lemur Read Only": "read-only", + } return self._authorize() def authorize_test_required_group(self, group): - self.ldap_groups = ''.join(['CN=Lemur Access,OU=Groups,DC=example,DC=com', - 'CN=Pen Pushers,OU=Groups,DC=example,DC=com']) + self.ldap_groups = "".join( + [ + "CN=Lemur Access,OU=Groups,DC=example,DC=com", + "CN=Pen Pushers,OU=Groups,DC=example,DC=com", + ] + ) self.ldap_required_group = group return self._authorize() @pytest.fixture() def principal(session): - args = {'username': 'user', 'password': 'p4ssw0rd'} + args = {"username": "user", "password": "p4ssw0rd"} yield LdapPrincipalTester(args) class TestLdapPrincipal: - - @patch('ldap.initialize') + @patch("ldap.initialize") def test_bind(self, app, principal): self.test_ldap_user = principal self.test_ldap_user.bind_test() - group = 'Pen Pushers' + group = "Pen Pushers" assert group in self.test_ldap_user.ldap_groups - assert self.test_ldap_user.ldap_principal == 'user@example.com' + assert self.test_ldap_user.ldap_principal == "user@example.com" def test_authorize_groups_to_roles_admin(self, app, principal): self.test_ldap_user = principal @@ -54,11 +72,11 @@ class TestLdapPrincipal: def test_authorize_required_group_missing(self, app, principal): self.test_ldap_user = principal - roles = self.test_ldap_user.authorize_test_required_group('Not Allowed') + roles = self.test_ldap_user.authorize_test_required_group("Not Allowed") assert not roles def test_authorize_required_group_access(self, session, principal): self.test_ldap_user = principal - roles = self.test_ldap_user.authorize_test_required_group('Lemur Access') + roles = self.test_ldap_user.authorize_test_required_group("Lemur Access") assert len(roles) >= 1 assert any(x.name == "user@example.com" for x in roles) diff --git a/lemur/tests/test_logs.py b/lemur/tests/test_logs.py index 516f5bb7..6705ffca 100644 --- a/lemur/tests/test_logs.py +++ b/lemur/tests/test_logs.py @@ -1,21 +1,32 @@ import pytest -from lemur.tests.vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from lemur.tests.vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) from lemur.logs.views import * # noqa def test_private_key_audit(client, certificate): from lemur.certificates.views import CertificatePrivateKey, api + assert len(certificate.logs) == 0 - client.get(api.url_for(CertificatePrivateKey, certificate_id=certificate.id), headers=VALID_ADMIN_HEADER_TOKEN) + client.get( + api.url_for(CertificatePrivateKey, certificate_id=certificate.id), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert len(certificate.logs) == 1 -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_get_logs(client, token, status): assert client.get(api.url_for(LogsList), headers=token).status_code == status diff --git a/lemur/tests/test_messaging.py b/lemur/tests/test_messaging.py index fc0e62da..98e9ebf3 100644 --- a/lemur/tests/test_messaging.py +++ b/lemur/tests/test_messaging.py @@ -8,14 +8,21 @@ from moto import mock_ses def test_needs_notification(app, certificate, notification): from lemur.notifications.messaging import needs_notification + assert not needs_notification(certificate) with pytest.raises(Exception): - notification.options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'min'}] + notification.options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "min"}, + ] certificate.notifications.append(notification) needs_notification(certificate) - certificate.notifications[0].options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + certificate.notifications[0].options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ] assert not needs_notification(certificate) delta = certificate.not_after - timedelta(days=10) @@ -30,7 +37,8 @@ def test_get_certificates(app, certificate, notification): delta = certificate.not_after - timedelta(days=2) notification.options = [ - {'name': 'interval', 'value': 2}, {'name': 'unit', 'value': 'days'} + {"name": "interval", "value": 2}, + {"name": "unit", "value": "days"}, ] with freeze_time(delta.datetime): @@ -55,11 +63,16 @@ def test_get_eligible_certificates(app, certificate, notification): from lemur.notifications.messaging import get_eligible_certificates certificate.notifications.append(notification) - certificate.notifications[0].options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + certificate.notifications[0].options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ] delta = certificate.not_after - timedelta(days=10) with freeze_time(delta.datetime): - assert get_eligible_certificates() == {certificate.owner: {notification.label: [(notification, certificate)]}} + assert get_eligible_certificates() == { + certificate.owner: {notification.label: [(notification, certificate)]} + } @mock_ses @@ -67,7 +80,10 @@ def test_send_expiration_notification(certificate, notification, notification_pl from lemur.notifications.messaging import send_expiration_notifications certificate.notifications.append(notification) - certificate.notifications[0].options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + certificate.notifications[0].options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ] delta = certificate.not_after - timedelta(days=10) with freeze_time(delta.datetime): @@ -75,7 +91,9 @@ def test_send_expiration_notification(certificate, notification, notification_pl @mock_ses -def test_send_expiration_notification_with_no_notifications(certificate, notification, notification_plugin): +def test_send_expiration_notification_with_no_notifications( + certificate, notification, notification_plugin +): from lemur.notifications.messaging import send_expiration_notifications delta = certificate.not_after - timedelta(days=10) @@ -86,4 +104,5 @@ def test_send_expiration_notification_with_no_notifications(certificate, notific @mock_ses def test_send_rotation_notification(notification_plugin, certificate): from lemur.notifications.messaging import send_rotation_notification + send_rotation_notification(certificate, notification_plugin=notification_plugin) diff --git a/lemur/tests/test_missing.py b/lemur/tests/test_missing.py index 4f2c20c6..59bac2d6 100644 --- a/lemur/tests/test_missing.py +++ b/lemur/tests/test_missing.py @@ -9,9 +9,12 @@ def test_convert_validity_years(session): with freeze_time("2016-01-01"): data = convert_validity_years(dict(validity_years=2)) - assert data['validity_start'] == arrow.utcnow().isoformat() - assert data['validity_end'] == arrow.utcnow().replace(years=+2).isoformat() + assert data["validity_start"] == arrow.utcnow().isoformat() + assert data["validity_end"] == arrow.utcnow().shift(years=+2).isoformat() with freeze_time("2015-01-10"): data = convert_validity_years(dict(validity_years=1)) - assert data['validity_end'] == arrow.utcnow().replace(years=+1, days=-2).isoformat() + assert ( + data["validity_end"] + == arrow.utcnow().shift(years=+1, days=-2).isoformat() + ) diff --git a/lemur/tests/test_notifications.py b/lemur/tests/test_notifications.py index 6daee0a8..20079f97 100644 --- a/lemur/tests/test_notifications.py +++ b/lemur/tests/test_notifications.py @@ -3,20 +3,22 @@ import pytest from lemur.notifications.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_notification_input_schema(client, notification_plugin, notification): from lemur.notifications.schemas import NotificationInputSchema input_data = { - 'label': 'notification1', - 'options': {}, - 'description': 'my notification', - 'active': True, - 'plugin': { - 'slug': 'test-notification' - } + "label": "notification1", + "options": {}, + "description": "my notification", + "active": True, + "plugin": {"slug": "test-notification"}, } data, errors = NotificationInputSchema().load(input_data) @@ -24,91 +26,156 @@ def test_notification_input_schema(client, notification_plugin, notification): assert not errors -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_notification_get(client, notification_plugin, notification, token, status): - assert client.get(api.url_for(Notifications, notification_id=notification.id), headers=token).status_code == status + assert ( + client.get( + api.url_for(Notifications, notification_id=notification.id), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_post_(client, token, status): - assert client.post(api.url_for(Notifications, notification_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Notifications, notification_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_notification_put(client, token, status): - assert client.put(api.url_for(Notifications, notification_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Notifications, notification_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_notification_delete(client, token, status): - assert client.delete(api.url_for(Notifications, notification_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Notifications, notification_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_patch(client, token, status): - assert client.patch(api.url_for(Notifications, notification_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Notifications, notification_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_notification_list_post_(client, token, status): - assert client.post(api.url_for(NotificationsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(NotificationsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -def test_notification_list_get(client, notification_plugin, notification, token, status): - assert client.get(api.url_for(NotificationsList), headers=token).status_code == status +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +def test_notification_list_get( + client, notification_plugin, notification, token, status +): + assert ( + client.get(api.url_for(NotificationsList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_list_delete(client, token, status): - assert client.delete(api.url_for(NotificationsList), headers=token).status_code == status + assert ( + client.delete(api.url_for(NotificationsList), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_list_patch(client, token, status): - assert client.patch(api.url_for(NotificationsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(NotificationsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_pending_certificates.py b/lemur/tests/test_pending_certificates.py index 7accf7d9..3e755574 100644 --- a/lemur/tests/test_pending_certificates.py +++ b/lemur/tests/test_pending_certificates.py @@ -2,13 +2,21 @@ import json import pytest +from marshmallow import ValidationError from lemur.pending_certificates.views import * # noqa -from .vectors import CSR_STR, INTERMEDIATE_CERT_STR, VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, \ - VALID_USER_HEADER_TOKEN, WILDCARD_CERT_STR +from .vectors import ( + CSR_STR, + INTERMEDIATE_CERT_STR, + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, + WILDCARD_CERT_STR, +) def test_increment_attempt(pending_certificate): from lemur.pending_certificates.service import increment_attempt + initial_attempt = pending_certificate.number_attempts attempts = increment_attempt(pending_certificate) assert attempts == initial_attempt + 1 @@ -16,37 +24,93 @@ def test_increment_attempt(pending_certificate): def test_create_pending_certificate(async_issuer_plugin, async_authority, user): from lemur.certificates.service import create - pending_cert = create(authority=async_authority, csr=CSR_STR, owner='joe@example.com', creator=user['user'], - common_name='ACommonName') - assert pending_cert.external_id == '12345' + + pending_cert = create( + authority=async_authority, + csr=CSR_STR, + owner="joe@example.com", + creator=user["user"], + common_name="ACommonName", + ) + assert pending_cert.external_id == "12345" def test_create_pending(pending_certificate, user, session): import copy from lemur.pending_certificates.service import create_certificate, get - cert = {'body': WILDCARD_CERT_STR, - 'chain': INTERMEDIATE_CERT_STR, - 'external_id': '54321'} + + cert = { + "body": WILDCARD_CERT_STR, + "chain": INTERMEDIATE_CERT_STR, + "external_id": "54321", + } # Weird copy because the session behavior. pending_certificate is a valid object but the # return of vars(pending_certificate) is a sessionobject, and so nothing from the pending_cert # is used to create the certificate. Maybe a bug due to using vars(), and should copy every # field explicitly. pending_certificate = copy.copy(get(pending_certificate.id)) - real_cert = create_certificate(pending_certificate, cert, user['user']) + real_cert = create_certificate(pending_certificate, cert, user["user"]) assert real_cert.owner == pending_certificate.owner assert real_cert.notify == pending_certificate.notify assert real_cert.private_key == pending_certificate.private_key - assert real_cert.external_id == '54321' + assert real_cert.external_id == "54321" -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 204), - (VALID_ADMIN_API_TOKEN, 204), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 204), + (VALID_ADMIN_API_TOKEN, 204), + ("", 401), + ], +) def test_pending_cancel(client, pending_certificate, token, status): - assert client.delete(api.url_for(PendingCertificates, pending_certificate_id=pending_certificate.id), - data=json.dumps({'note': "unit test", 'send_email': False}), - headers=token).status_code == status + assert ( + client.delete( + api.url_for( + PendingCertificates, pending_certificate_id=pending_certificate.id + ), + data=json.dumps({"note": "unit test", "send_email": False}), + headers=token, + ).status_code + == status + ) + + +def test_pending_upload(pending_certificate_from_full_chain_ca): + from lemur.pending_certificates.service import upload + from lemur.certificates.service import get + + cert = {"body": WILDCARD_CERT_STR, "chain": None, "external_id": None} + + pending_cert = upload(pending_certificate_from_full_chain_ca.id, **cert) + assert pending_cert.resolved + assert get(pending_cert.resolved_cert_id) + + +def test_pending_upload_with_chain(pending_certificate_from_partial_chain_ca): + from lemur.pending_certificates.service import upload + from lemur.certificates.service import get + + cert = { + "body": WILDCARD_CERT_STR, + "chain": INTERMEDIATE_CERT_STR, + "external_id": None, + } + + pending_cert = upload(pending_certificate_from_partial_chain_ca.id, **cert) + assert pending_cert.resolved + assert get(pending_cert.resolved_cert_id) + + +def test_invalid_pending_upload_with_chain(pending_certificate_from_partial_chain_ca): + from lemur.pending_certificates.service import upload + + cert = {"body": WILDCARD_CERT_STR, "chain": None, "external_id": None} + with pytest.raises(ValidationError) as err: + upload(pending_certificate_from_partial_chain_ca.id, **cert) + assert str(err.value).startswith( + "Incorrect chain certificate(s) provided: '*.wild.example.org' is not signed by 'LemurTrust Unittests Root CA 2018" + ) diff --git a/lemur/tests/test_redis.py b/lemur/tests/test_redis.py new file mode 100644 index 00000000..aab2e397 --- /dev/null +++ b/lemur/tests/test_redis.py @@ -0,0 +1,13 @@ +import fakeredis +import time +import sys + + +def test_write_and_read_from_redis(): + function = f"{__name__}.{sys._getframe().f_code.co_name}" + + red = fakeredis.FakeStrictRedis() + key = f"{function}.last_success" + value = int(time.time()) + assert red.set(key, value) is True + assert (int(red.get(key)) == value) is True diff --git a/lemur/tests/test_roles.py b/lemur/tests/test_roles.py index e5483e00..6e612062 100644 --- a/lemur/tests/test_roles.py +++ b/lemur/tests/test_roles.py @@ -3,16 +3,23 @@ import json import pytest from lemur.roles.views import * # noqa -from lemur.tests.factories import RoleFactory, AuthorityFactory, CertificateFactory, UserFactory -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from lemur.tests.factories import ( + RoleFactory, + AuthorityFactory, + CertificateFactory, + UserFactory, +) +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_role_input_schema(client): from lemur.roles.schemas import RoleInputSchema - input_data = { - 'name': 'myRole' - } + input_data = {"name": "myRole"} data, errors = RoleInputSchema().load(input_data) @@ -38,60 +45,80 @@ def test_multiple_authority_certificate_association(session, client): assert role.certificates[1].name == certificate1.name -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_get(client, token, status): - assert client.get(api.url_for(Roles, role_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Roles, role_id=1), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_post_(client, token, status): - assert client.post(api.url_for(Roles, role_id=1), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(Roles, role_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_role_put(client, token, status): - assert client.put(api.url_for(Roles, role_id=1), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(Roles, role_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_put_with_data(client, session, token, status): user = UserFactory() role = RoleFactory() session.commit() - data = { - 'users': [ - {'id': user.id} - ], - 'id': role.id, - 'name': role.name - } + data = {"users": [{"id": user.id}], "id": role.id, "name": role.name} - assert client.put(api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=token).status_code == status + assert ( + client.put( + api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=token + ).status_code + == status + ) def test_role_put_with_data_and_user(client, session): from lemur.auth.service import create_token + user = UserFactory() role = RoleFactory(users=[user]) role1 = RoleFactory() @@ -99,83 +126,119 @@ def test_role_put_with_data_and_user(client, session): session.commit() headers = { - 'Authorization': 'Basic ' + create_token(user), - 'Content-Type': 'application/json' + "Authorization": "Basic " + create_token(user), + "Content-Type": "application/json", } data = { - 'users': [ - {'id': user1.id}, - {'id': user.id} - ], - 'id': role.id, - 'name': role.name + "users": [{"id": user1.id}, {"id": user.id}], + "id": role.id, + "name": role.name, } - assert client.put(api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=headers).status_code == 200 - assert client.get(api.url_for(RolesList), data={}, headers=headers).json['total'] > 1 + assert ( + client.put( + api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=headers + ).status_code + == 200 + ) + assert ( + client.get(api.url_for(RolesList), data={}, headers=headers).json["total"] > 1 + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_delete(client, token, status, role): - assert client.delete(api.url_for(Roles, role_id=role.id), headers=token).status_code == status + assert ( + client.delete(api.url_for(Roles, role_id=role.id), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_patch(client, token, status): - assert client.patch(api.url_for(Roles, role_id=1), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(Roles, role_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_role_list_post_(client, token, status): - assert client.post(api.url_for(RolesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(RolesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_list_get(client, token, status): assert client.get(api.url_for(RolesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_list_delete(client, token, status): assert client.delete(api.url_for(RolesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_list_patch(client, token, status): - assert client.patch(api.url_for(RolesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(RolesList), data={}, headers=token).status_code + == status + ) def test_sensitive_filter(client): - resp = client.get(api.url_for(RolesList) + '?filter=password;a', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'password' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(RolesList) + "?filter=password;a", headers=VALID_ADMIN_HEADER_TOKEN + ) + assert "'password' is not sortable or filterable" in resp.json["message"] diff --git a/lemur/tests/test_schemas.py b/lemur/tests/test_schemas.py index e2a05213..2c085849 100644 --- a/lemur/tests/test_schemas.py +++ b/lemur/tests/test_schemas.py @@ -14,15 +14,15 @@ def test_get_object_attribute(): get_object_attribute([{}], many=True) with pytest.raises(ValidationError): - get_object_attribute([{}, {'id': 1}], many=True) + get_object_attribute([{}, {"id": 1}], many=True) with pytest.raises(ValidationError): - get_object_attribute([{}, {'name': 'test'}], many=True) + get_object_attribute([{}, {"name": "test"}], many=True) - assert get_object_attribute({'name': 'test'}) == 'name' - assert get_object_attribute({'id': 1}) == 'id' - assert get_object_attribute([{'name': 'test'}], many=True) == 'name' - assert get_object_attribute([{'id': 1}], many=True) == 'id' + assert get_object_attribute({"name": "test"}) == "name" + assert get_object_attribute({"id": 1}) == "id" + assert get_object_attribute([{"name": "test"}], many=True) == "name" + assert get_object_attribute([{"id": 1}], many=True) == "id" def test_fetch_objects(session): @@ -33,26 +33,26 @@ def test_fetch_objects(session): role1 = RoleFactory() session.commit() - data = {'id': role.id} + data = {"id": role.id} found_role = fetch_objects(Role, data) assert found_role == role - data = {'name': role.name} + data = {"name": role.name} found_role = fetch_objects(Role, data) assert found_role == role - data = [{'id': role.id}, {'id': role1.id}] + data = [{"id": role.id}, {"id": role1.id}] found_roles = fetch_objects(Role, data, many=True) assert found_roles == [role, role1] - data = [{'name': role.name}, {'name': role1.name}] + data = [{"name": role.name}, {"name": role1.name}] found_roles = fetch_objects(Role, data, many=True) assert found_roles == [role, role1] with pytest.raises(ValidationError): - data = [{'name': 'blah'}, {'name': role1.name}] + data = [{"name": "blah"}, {"name": role1.name}] fetch_objects(Role, data, many=True) with pytest.raises(ValidationError): - data = {'name': 'nah'} + data = {"name": "nah"} fetch_objects(Role, data) diff --git a/lemur/tests/test_sources.py b/lemur/tests/test_sources.py index 1ce0d9ba..312c008f 100644 --- a/lemur/tests/test_sources.py +++ b/lemur/tests/test_sources.py @@ -2,17 +2,22 @@ import pytest from lemur.sources.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN, WILDCARD_CERT_STR, \ - WILDCARD_CERT_KEY +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, + WILDCARD_CERT_STR, + WILDCARD_CERT_KEY, +) def validate_source_schema(client): from lemur.sources.schemas import SourceInputSchema input_data = { - 'label': 'exampleSource', - 'options': {}, - 'plugin': {'slug': 'aws-source'} + "label": "exampleSource", + "options": {}, + "plugin": {"slug": "aws-source"}, } data, errors = SourceInputSchema().load(input_data) @@ -26,111 +31,171 @@ def test_create_certificate(user, source): certificate_create({}, source) data = { - 'body': WILDCARD_CERT_STR, - 'private_key': WILDCARD_CERT_KEY, - 'owner': 'bob@example.com', - 'creator': user['user'] + "body": WILDCARD_CERT_STR, + "private_key": WILDCARD_CERT_KEY, + "owner": "bob@example.com", + "creator": user["user"], } cert = certificate_create(data, source) assert cert.notifications -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 404), - (VALID_ADMIN_HEADER_TOKEN, 404), - (VALID_ADMIN_API_TOKEN, 404), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 404), + (VALID_ADMIN_HEADER_TOKEN, 404), + (VALID_ADMIN_API_TOKEN, 404), + ("", 401), + ], +) def test_source_get(client, source_plugin, token, status): - assert client.get(api.url_for(Sources, source_id=43543), headers=token).status_code == status + assert ( + client.get(api.url_for(Sources, source_id=43543), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_source_post_(client, token, status): - assert client.post(api.url_for(Sources, source_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Sources, source_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_source_put(client, token, status): - assert client.put(api.url_for(Sources, source_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Sources, source_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_source_delete(client, token, status): - assert client.delete(api.url_for(Sources, source_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Sources, source_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_source_patch(client, token, status): - assert client.patch(api.url_for(Sources, source_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Sources, source_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_sources_list_get(client, source_plugin, token, status): assert client.get(api.url_for(SourcesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_sources_list_post(client, token, status): - assert client.post(api.url_for(SourcesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(SourcesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_sources_list_put(client, token, status): - assert client.put(api.url_for(SourcesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(SourcesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_sources_list_delete(client, token, status): assert client.delete(api.url_for(SourcesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_sources_list_patch(client, token, status): - assert client.patch(api.url_for(SourcesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(SourcesList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_users.py b/lemur/tests/test_users.py index 61db93bf..9e67f868 100644 --- a/lemur/tests/test_users.py +++ b/lemur/tests/test_users.py @@ -4,16 +4,20 @@ import pytest from lemur.tests.factories import UserFactory, RoleFactory from lemur.users.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_user_input_schema(client): from lemur.users.schemas import UserInputSchema input_data = { - 'username': 'example', - 'password': '1233432', - 'email': 'example@example.com' + "username": "example", + "password": "1233432", + "email": "example@example.com", } data, errors = UserInputSchema().load(input_data) @@ -21,104 +25,156 @@ def test_user_input_schema(client): assert not errors -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_get(client, token, status): - assert client.get(api.url_for(Users, user_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Users, user_id=1), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_post_(client, token, status): - assert client.post(api.url_for(Users, user_id=1), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(Users, user_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_user_put(client, token, status): - assert client.put(api.url_for(Users, user_id=1), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(Users, user_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_delete(client, token, status): - assert client.delete(api.url_for(Users, user_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Users, user_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_patch(client, token, status): - assert client.patch(api.url_for(Users, user_id=1), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(Users, user_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_user_list_post_(client, token, status): - assert client.post(api.url_for(UsersList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(UsersList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_list_get(client, token, status): assert client.get(api.url_for(UsersList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_list_delete(client, token, status): assert client.delete(api.url_for(UsersList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_list_patch(client, token, status): - assert client.patch(api.url_for(UsersList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(UsersList), data={}, headers=token).status_code + == status + ) def test_sensitive_filter(client): - resp = client.get(api.url_for(UsersList) + '?filter=password;a', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'password' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(UsersList) + "?filter=password;a", headers=VALID_ADMIN_HEADER_TOKEN + ) + assert "'password' is not sortable or filterable" in resp.json["message"] def test_sensitive_sort(client): - resp = client.get(api.url_for(UsersList) + '?sortBy=password&sortDir=asc', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'password' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(UsersList) + "?sortBy=password&sortDir=asc", + headers=VALID_ADMIN_HEADER_TOKEN, + ) + assert "'password' is not sortable or filterable" in resp.json["message"] def test_user_role_changes(client, session): @@ -128,25 +184,30 @@ def test_user_role_changes(client, session): session.flush() data = { - 'active': True, - 'id': user.id, - 'username': user.username, - 'email': user.email, - 'roles': [ - {'id': role1.id}, - {'id': role2.id}, - ], + "active": True, + "id": user.id, + "username": user.username, + "email": user.email, + "roles": [{"id": role1.id}, {"id": role2.id}], } # PUT two roles - resp = client.put(api.url_for(Users, user_id=user.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.put( + api.url_for(Users, user_id=user.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 2 + assert len(resp.json["roles"]) == 2 assert set(user.roles) == {role1, role2} # Remove one role and PUT again - del data['roles'][1] - resp = client.put(api.url_for(Users, user_id=user.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + del data["roles"][1] + resp = client.put( + api.url_for(Users, user_id=user.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 1 + assert len(resp.json["roles"]) == 1 assert set(user.roles) == {role1} diff --git a/lemur/tests/test_utils.py b/lemur/tests/test_utils.py index 62d021a4..2e117d25 100644 --- a/lemur/tests/test_utils.py +++ b/lemur/tests/test_utils.py @@ -1,38 +1,49 @@ import pytest +from lemur.tests.vectors import ( + SAN_CERT, + INTERMEDIATE_CERT, + ROOTCA_CERT, + EC_CERT_EXAMPLE, + ECDSA_PRIME256V1_CERT, + ECDSA_SECP384r1_CERT, + DSA_CERT, +) + def test_generate_private_key(): from lemur.common.utils import generate_private_key - assert generate_private_key('RSA2048') - assert generate_private_key('RSA4096') - assert generate_private_key('ECCPRIME192V1') - assert generate_private_key('ECCPRIME256V1') - assert generate_private_key('ECCSECP192R1') - assert generate_private_key('ECCSECP224R1') - assert generate_private_key('ECCSECP256R1') - assert generate_private_key('ECCSECP384R1') - assert generate_private_key('ECCSECP521R1') - assert generate_private_key('ECCSECP256K1') - assert generate_private_key('ECCSECT163K1') - assert generate_private_key('ECCSECT233K1') - assert generate_private_key('ECCSECT283K1') - assert generate_private_key('ECCSECT409K1') - assert generate_private_key('ECCSECT571K1') - assert generate_private_key('ECCSECT163R2') - assert generate_private_key('ECCSECT233R1') - assert generate_private_key('ECCSECT283R1') - assert generate_private_key('ECCSECT409R1') - assert generate_private_key('ECCSECT571R2') + assert generate_private_key("RSA2048") + assert generate_private_key("RSA4096") + assert generate_private_key("ECCPRIME192V1") + assert generate_private_key("ECCPRIME256V1") + assert generate_private_key("ECCSECP192R1") + assert generate_private_key("ECCSECP224R1") + assert generate_private_key("ECCSECP256R1") + assert generate_private_key("ECCSECP384R1") + assert generate_private_key("ECCSECP521R1") + assert generate_private_key("ECCSECP256K1") + assert generate_private_key("ECCSECT163K1") + assert generate_private_key("ECCSECT233K1") + assert generate_private_key("ECCSECT283K1") + assert generate_private_key("ECCSECT409K1") + assert generate_private_key("ECCSECT571K1") + assert generate_private_key("ECCSECT163R2") + assert generate_private_key("ECCSECT233R1") + assert generate_private_key("ECCSECT283R1") + assert generate_private_key("ECCSECT409R1") + assert generate_private_key("ECCSECT571R2") with pytest.raises(Exception): - generate_private_key('LEMUR') + generate_private_key("LEMUR") def test_get_authority_key(): - '''test get authority key function''' + """test get authority key function""" from lemur.common.utils import get_authority_key - test_cert = '''-----BEGIN CERTIFICATE----- + + test_cert = """-----BEGIN CERTIFICATE----- MIIGYjCCBEqgAwIBAgIUVS7mn6LR5XlQyEGxQ4w9YAWL/XIwDQYJKoZIhvcNAQEN BQAweTELMAkGA1UEBhMCREUxDTALBgNVBAgTBEJvbm4xEDAOBgNVBAcTB0dlcm1h bnkxITAfBgNVBAoTGFRlbGVrb20gRGV1dHNjaGxhbmQgR21iSDELMAkGA1UECxMC @@ -68,6 +79,24 @@ zc75IDsn5wP6A3KflduWW7ri0bYUiKe5higMcbUM0aXzTEAVxsxPk8aEsR9dazF7 y4L/msew3UjFE3ovDHgStjWM1NBMxuIvJEbWOsiB2WA2l3FiT8HvFi0eX/0hbkGi 5LL+oz7nvm9Of7te/BV6Rq0rXWN4d6asO+QlLkTqbmAH6rwunmPCY7MbLXXtP/qM KFfxwrO1 ------END CERTIFICATE-----''' +-----END CERTIFICATE-----""" authority_key = get_authority_key(test_cert) - assert authority_key == 'feacb541be81771293affa412d8dc9f66a3ebb80' + assert authority_key == "feacb541be81771293affa412d8dc9f66a3ebb80" + + +def test_is_selfsigned(selfsigned_cert): + from lemur.common.utils import is_selfsigned + + assert is_selfsigned(selfsigned_cert) is True + assert is_selfsigned(SAN_CERT) is False + assert is_selfsigned(INTERMEDIATE_CERT) is False + # Root CA certificates are also technically self-signed + assert is_selfsigned(ROOTCA_CERT) is True + assert is_selfsigned(EC_CERT_EXAMPLE) is False + + # selfsigned certs + assert is_selfsigned(ECDSA_PRIME256V1_CERT) is True + assert is_selfsigned(ECDSA_SECP384r1_CERT) is True + # unsupported algorithm (DSA) + with pytest.raises(Exception): + is_selfsigned(DSA_CERT) diff --git a/lemur/tests/test_validators.py b/lemur/tests/test_validators.py index 815b7c9d..77148079 100644 --- a/lemur/tests/test_validators.py +++ b/lemur/tests/test_validators.py @@ -1,23 +1,35 @@ -import pytest from datetime import datetime -from .vectors import SAN_CERT_KEY + +import pytest from marshmallow.exceptions import ValidationError +from lemur.common.utils import parse_private_key +from lemur.common.validators import verify_private_key_match +from lemur.tests.vectors import INTERMEDIATE_CERT, SAN_CERT, SAN_CERT_KEY + def test_private_key(session): - from lemur.common.validators import private_key + parse_private_key(SAN_CERT_KEY) - private_key(SAN_CERT_KEY) + with pytest.raises(ValueError): + parse_private_key("invalid_private_key") + + +def test_validate_private_key(session): + key = parse_private_key(SAN_CERT_KEY) + + verify_private_key_match(key, SAN_CERT) with pytest.raises(ValidationError): - private_key('invalid_private_key') + # Wrong key for certificate + verify_private_key_match(key, INTERMEDIATE_CERT) def test_sub_alt_type(session): from lemur.common.validators import sub_alt_type with pytest.raises(ValidationError): - sub_alt_type('CNAME') + sub_alt_type("CNAME") def test_dates(session): @@ -32,7 +44,13 @@ def test_dates(session): dates(dict(validity_end=datetime(2016, 1, 1))) with pytest.raises(ValidationError): - dates(dict(validity_start=datetime(2016, 1, 5), validity_end=datetime(2016, 1, 1))) + dates( + dict(validity_start=datetime(2016, 1, 5), validity_end=datetime(2016, 1, 1)) + ) with pytest.raises(ValidationError): - dates(dict(validity_start=datetime(2016, 1, 1), validity_end=datetime(2016, 1, 10))) + dates( + dict( + validity_start=datetime(2016, 1, 1), validity_end=datetime(2016, 1, 10) + ) + ) diff --git a/lemur/tests/test_verify.py b/lemur/tests/test_verify.py index a1f0f5eb..348f6559 100644 --- a/lemur/tests/test_verify.py +++ b/lemur/tests/test_verify.py @@ -13,20 +13,24 @@ from .vectors import INTERMEDIATE_CERT_STR def test_verify_simple_cert(): """Simple certificate without CRL or OCSP.""" # Verification returns None if there are no means to verify a cert - assert verify_string(INTERMEDIATE_CERT_STR, '') is None + assert verify_string(INTERMEDIATE_CERT_STR, "") is None def test_verify_crl_unknown_scheme(cert_builder, private_key): """Unknown distribution point URI schemes should be ignored.""" - ldap_uri = 'ldap://ldap.example.org/cn=Example%20Certificate%20Authority?certificateRevocationList;binary' - crl_dp = x509.DistributionPoint([UniformResourceIdentifier(ldap_uri)], - relative_name=None, reasons=None, crl_issuer=None) - cert = (cert_builder - .add_extension(x509.CRLDistributionPoints([crl_dp]), critical=False) - .sign(private_key, hashes.SHA256(), default_backend())) + ldap_uri = "ldap://ldap.example.org/cn=Example%20Certificate%20Authority?certificateRevocationList;binary" + crl_dp = x509.DistributionPoint( + [UniformResourceIdentifier(ldap_uri)], + relative_name=None, + reasons=None, + crl_issuer=None, + ) + cert = cert_builder.add_extension( + x509.CRLDistributionPoints([crl_dp]), critical=False + ).sign(private_key, hashes.SHA256(), default_backend()) with mktempfile() as cert_tmp: - with open(cert_tmp, 'wb') as f: + with open(cert_tmp, "wb") as f: f.write(cert.public_bytes(serialization.Encoding.PEM)) # Must not raise exception @@ -35,15 +39,19 @@ def test_verify_crl_unknown_scheme(cert_builder, private_key): def test_verify_crl_unreachable(cert_builder, private_key): """Unreachable CRL distribution point results in error.""" - ldap_uri = 'http://invalid.example.org/crl/foobar.crl' - crl_dp = x509.DistributionPoint([UniformResourceIdentifier(ldap_uri)], - relative_name=None, reasons=None, crl_issuer=None) - cert = (cert_builder - .add_extension(x509.CRLDistributionPoints([crl_dp]), critical=False) - .sign(private_key, hashes.SHA256(), default_backend())) + ldap_uri = "http://invalid.example.org/crl/foobar.crl" + crl_dp = x509.DistributionPoint( + [UniformResourceIdentifier(ldap_uri)], + relative_name=None, + reasons=None, + crl_issuer=None, + ) + cert = cert_builder.add_extension( + x509.CRLDistributionPoints([crl_dp]), critical=False + ).sign(private_key, hashes.SHA256(), default_backend()) with mktempfile() as cert_tmp: - with open(cert_tmp, 'wb') as f: + with open(cert_tmp, "wb") as f: f.write(cert.public_bytes(serialization.Encoding.PEM)) with pytest.raises(Exception, match="Unable to retrieve CRL:"): diff --git a/lemur/tests/vectors.py b/lemur/tests/vectors.py index 6a836b30..0768cdac 100644 --- a/lemur/tests/vectors.py +++ b/lemur/tests/vectors.py @@ -1,20 +1,23 @@ from lemur.common.utils import parse_certificate VALID_USER_HEADER_TOKEN = { - 'Authorization': 'Basic ' + 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE1MjE2NTIwMjIsImV4cCI6MjM4NTY1MjAyMiwic3ViIjoxfQ.uK4PZjVAs0gt6_9h2EkYkKd64nFXdOq-rHsJZzeQicc', - 'Content-Type': 'application/json' + "Authorization": "Basic " + + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE1MjE2NTIwMjIsImV4cCI6MjM4NTY1MjAyMiwic3ViIjoxfQ.uK4PZjVAs0gt6_9h2EkYkKd64nFXdOq-rHsJZzeQicc", + "Content-Type": "application/json", } VALID_ADMIN_HEADER_TOKEN = { - 'Authorization': 'Basic ' + 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE1MjE2NTE2NjMsInN1YiI6MiwiYWlkIjoxfQ.wyf5PkQNcggLrMFqxDfzjY-GWPw_XsuWvU2GmQaC5sg', - 'Content-Type': 'application/json' + "Authorization": "Basic " + + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE1MjE2NTE2NjMsInN1YiI6MiwiYWlkIjoxfQ.wyf5PkQNcggLrMFqxDfzjY-GWPw_XsuWvU2GmQaC5sg", + "Content-Type": "application/json", } VALID_ADMIN_API_TOKEN = { - 'Authorization': 'Basic ' + 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjIsImFpZCI6MSwiaWF0IjoxNDM1MjMzMzY5fQ.umW0I_oh4MVZ2qrClzj9SfYnQl6cd0HGzh9EwkDW60I', - 'Content-Type': 'application/json' + "Authorization": "Basic " + + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjIsImFpZCI6MSwiaWF0IjoxNDM1MjMzMzY5fQ.umW0I_oh4MVZ2qrClzj9SfYnQl6cd0HGzh9EwkDW60I", + "Content-Type": "application/json", } @@ -45,6 +48,7 @@ ssvobJ6Xe2D4cCVjUmsqtFEztMgdqgmlcWyGdUKeXdi7CMoeTb4uO+9qRQq46wYW n7K1z+W0Kp5yhnnPAoOioAP4vjASDx3z3RnLaZvMmcO7YdCIwhE5oGV0 -----END CERTIFICATE----- """ +ROOTCA_CERT = parse_certificate(ROOTCA_CERT_STR) ROOTCA_KEY = """\ -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAvyVpe0tfIzri3l3PYH2r7hW86wKF58GLY+Ua52rEO5E3eXQq @@ -136,6 +140,26 @@ eMVHHbWm1CpGO294R+vMBv4jcuhIBOx63KZE4VaoJuaazF6TE5czDw== #: CN=san.example.org, issued by LemurTrust Unittests Class 1 CA 2018 +SAN_CERT_CSR = """\ +-----BEGIN CERTIFICATE REQUEST----- +MIICvTCCAaUCAQAweDELMAkGA1UEBhMCRUUxDDAKBgNVBAgMA04vQTEOMAwGA1UE +BwwFRWFydGgxGDAWBgNVBAoMD0RhbmllbCBTYW4gJiBjbzEXMBUGA1UECwwOS2Fy +YXRlIExlc3NvbnMxGDAWBgNVBAMMD3Nhbi5leGFtcGxlLm9yZzCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAMia9BcpypZUU9xJoknzdEp+AevQE93XSAyl +IlXji80ZlYS/T/mVWtu6hNwz2IJDBFh6nPaHT1Ud/AI4YanDMa+fF4KJxzlkKPbY +quWx4EOjTZ2sFBBCivwxlo1So8r5Hf4NZ9Ewu4AIma3zmk+dzxJTpnWbTIFJGsDG +LwJO9iu6uqf79VdYkGELCusq3dyF2j2DNDiGHoRcQYFMMhDKR6uYmCTYvwjf0+sf +6k1zk2EK1X+ZWUyjP+Nl2NB6bpL0TydF75fuplWROczceiO6BKO4YT2uNPdF4BAH +p/kQCkqnjw5FCX7PONRT4wTW/AjDkt5WOgY+AB90zQBPxvXWbUMCAwEAAaAAMA0G +CSqGSIb3DQEBCwUAA4IBAQAFYgEafwRmsqdK1i1xrLFYbNNLkzmAZyL+6gXUBVIJ +TbGVVWSNNIcEmHIX8O9X4lN52qDYWOsxH/OKPVxpXqoHm/ztczFlte76wOYg+VAS +yK8DwQRP/+n+j6J40o1cZwnilPWqHgee5zbIL7lpCVxuFDofWpskwP5PLbxibFq8 +4TWynhjKKUw4+q4h4iCHG3PQhbV0ExWOyqX05QyDtJdkEwgJUWz1m9caHU2Jl7kX +5bWKOtXORpCYA7ed3WqktKQIxBD6vCVbQ+LuLZPYeWzGHYjfOejL6usD32KmNa2E +ZhDsC0fjqSX0FJKz6gOhP88bkbbapyHuGB71o2dwhCKV +-----END CERTIFICATE REQUEST----- +""" + SAN_CERT_STR = """\ -----BEGIN CERTIFICATE----- MIIESjCCAzKgAwIBAgIRAK/y20+NLU2OgPo4KuJ8IzMwDQYJKoZIhvcNAQELBQAw @@ -393,3 +417,98 @@ zm3Cn4Ul8DO26w9QS4fmZjmnPOZFXYMWoOR6osHzb62PWQ8FBMqXcdToBV2Q9Iw4 PiFAxlc0tVjlLqQ= -----END CERTIFICATE REQUEST----- """ + + +EC_CERT_STR = """ +-----BEGIN CERTIFICATE----- +MIIDxzCCAq+gAwIBAgIIHsJeci1JWAkwDQYJKoZIhvcNAQELBQAwVDELMAkGA1UE +BhMCVVMxHjAcBgNVBAoTFUdvb2dsZSBUcnVzdCBTZXJ2aWNlczElMCMGA1UEAxMc +R29vZ2xlIEludGVybmV0IEF1dGhvcml0eSBHMzAeFw0xOTAyMTMxNTM1NTdaFw0x +OTA1MDgxNTM1MDBaMGgxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh +MRYwFAYDVQQHDA1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKDApHb29nbGUgTExDMRcw +FQYDVQQDDA53d3cuZ29vZ2xlLmNvbTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA +BKwMlIbd4rAwf6eWoa6RrR2w0s5k1M40XOORPf96PByPmld+qhjRMLvA/xcAxdCR +XdcMfaX6EUr0Zw8CepitMB2jggFSMIIBTjATBgNVHSUEDDAKBggrBgEFBQcDATAO +BgNVHQ8BAf8EBAMCB4AwGQYDVR0RBBIwEIIOd3d3Lmdvb2dsZS5jb20waAYIKwYB +BQUHAQEEXDBaMC0GCCsGAQUFBzAChiFodHRwOi8vcGtpLmdvb2cvZ3NyMi9HVFNH +SUFHMy5jcnQwKQYIKwYBBQUHMAGGHWh0dHA6Ly9vY3NwLnBraS5nb29nL0dUU0dJ +QUczMB0GA1UdDgQWBBQLovm8GG0oG91gOGCL58YPNoAlejAMBgNVHRMBAf8EAjAA +MB8GA1UdIwQYMBaAFHfCuFCaZ3Z2sS3ChtCDoH6mfrpLMCEGA1UdIAQaMBgwDAYK +KwYBBAHWeQIFAzAIBgZngQwBAgIwMQYDVR0fBCowKDAmoCSgIoYgaHR0cDovL2Ny +bC5wa2kuZ29vZy9HVFNHSUFHMy5jcmwwDQYJKoZIhvcNAQELBQADggEBAKFbmNOA +e3pJ7UVI5EmkAMZgSDRdrsLHV6F7WluuyYCyE/HFpZjBd6y8xgGtYWcask6edwrq +zrcXNEN/GY34AYre0M+p0xAs+lKSwkrJd2sCgygmzsBFtGwjW6lhjm+rg83zPHhH +mQZ0ShUR1Kp4TvzXgxj44RXOsS5ZyDe3slGiG4aw/hl+igO8Y8JMvcv/Tpzo+V75 +BkDAFmLRi08NayfeyCqK/TcRpzxKMKhS7jEHK8Pzu5P+FyFHKqIsobi+BA+psOix +5nZLhrweLdKNz387mE2lSSKzr7qeLGHSOMt+ajQtZio4YVyZqJvg4Y++J0n5+Rjw +MXp8GrvTfn1DQ+o= +-----END CERTIFICATE----- +""" +EC_CERT_EXAMPLE = parse_certificate(EC_CERT_STR) + + +ECDSA_PRIME256V1_CERT_STR = """ +-----BEGIN CERTIFICATE----- +MIICUTCCAfYCCQCvH7H/e2nuiDAKBggqhkjOPQQDAjCBrzELMAkGA1UEBhMCVVMx +EzARBgNVBAgMCkNhbGlmb3JuaWExEjAQBgNVBAcMCUxvcyBHYXRvczEjMCEGA1UE +CgwaTGVtdXJUcnVzdCBFbnRlcnByaXNlcyBMdGQxJjAkBgNVBAsMHVVuaXR0ZXN0 +aW5nIE9wZXJhdGlvbnMgQ2VudGVyMSowKAYDVQQDDCFMZW11clRydXN0IFVuaXR0 +ZXN0cyBSb290IENBIDIwMTkwHhcNMTkwMjI2MTgxMTUyWhcNMjkwMjIzMTgxMTUy +WjCBrzELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEjAQBgNVBAcM +CUxvcyBHYXRvczEjMCEGA1UECgwaTGVtdXJUcnVzdCBFbnRlcnByaXNlcyBMdGQx +JjAkBgNVBAsMHVVuaXR0ZXN0aW5nIE9wZXJhdGlvbnMgQ2VudGVyMSowKAYDVQQD +DCFMZW11clRydXN0IFVuaXR0ZXN0cyBSb290IENBIDIwMTkwWTATBgcqhkjOPQIB +BggqhkjOPQMBBwNCAAQsnAVUtpDCFMK/k9Chynu8BWRVUBUYbGQ9Q9xeLR60J4fD +uBt48YpTqg5RMZEclVknMReXqTmqphOBo37/YVdlMAoGCCqGSM49BAMCA0kAMEYC +IQDQZ6xfBiCTHxY4GM4+zLeG1iPBUSfIJOjkFNViFZY/XAIhAJYmrkVQb/YjWCdd +Vl89McYhmV4IV7WDgUmUhkUSFXgy +-----END CERTIFICATE----- +""" +ECDSA_PRIME256V1_CERT = parse_certificate(ECDSA_PRIME256V1_CERT_STR) + + +ECDSA_SECP384r1_CERT_STR = """ +-----BEGIN CERTIFICATE----- +MIICjjCCAhMCCQD2UadeQ7ub1jAKBggqhkjOPQQDAjCBrzELMAkGA1UEBhMCVVMx +EzARBgNVBAgMCkNhbGlmb3JuaWExEjAQBgNVBAcMCUxvcyBHYXRvczEjMCEGA1UE +CgwaTGVtdXJUcnVzdCBFbnRlcnByaXNlcyBMdGQxJjAkBgNVBAsMHVVuaXR0ZXN0 +aW5nIE9wZXJhdGlvbnMgQ2VudGVyMSowKAYDVQQDDCFMZW11clRydXN0IFVuaXR0 +ZXN0cyBSb290IENBIDIwMTgwHhcNMTkwMjI2MTgxODU2WhcNMjkwMjIzMTgxODU2 +WjCBrzELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEjAQBgNVBAcM +CUxvcyBHYXRvczEjMCEGA1UECgwaTGVtdXJUcnVzdCBFbnRlcnByaXNlcyBMdGQx +JjAkBgNVBAsMHVVuaXR0ZXN0aW5nIE9wZXJhdGlvbnMgQ2VudGVyMSowKAYDVQQD +DCFMZW11clRydXN0IFVuaXR0ZXN0cyBSb290IENBIDIwMTgwdjAQBgcqhkjOPQIB +BgUrgQQAIgNiAARuKyHIRp2e6PB5UcY8L/bUdavkL5Zf3IegNKvaAsvkDenhDGAI +zwWgsk3rOo7jmpMibn7yJQn404uZovwyeKcApn8uVv8ltheeYAx+ySzzn/APxNGy +cye/nv1D9cDW628wCgYIKoZIzj0EAwIDaQAwZgIxANl1ljDH4ykNK2OaRqKOkBOW +cKk1SvtiEZDS/wytiZGCeaxYteSYF+3GE8V2W1geWAIxAI8D7DY0HU5zw+oxAlTD +Uw/TeHA6q0QV4otPvrINW3V09iXDwFSPe265fTkHSfT6hQ== +-----END CERTIFICATE----- +""" +ECDSA_SECP384r1_CERT = parse_certificate(ECDSA_SECP384r1_CERT_STR) + +DSA_CERT_STR = """ +-----BEGIN CERTIFICATE----- +MIIDmTCCA1YCCQD5h/cM7xYO9jALBglghkgBZQMEAwIwga8xCzAJBgNVBAYTAlVT +MRMwEQYDVQQIDApDYWxpZm9ybmlhMRIwEAYDVQQHDAlMb3MgR2F0b3MxIzAhBgNV +BAoMGkxlbXVyVHJ1c3QgRW50ZXJwcmlzZXMgTHRkMSYwJAYDVQQLDB1Vbml0dGVz +dGluZyBPcGVyYXRpb25zIENlbnRlcjEqMCgGA1UEAwwhTGVtdXJUcnVzdCBVbml0 +dGVzdHMgUm9vdCBDQSAyMDE4MB4XDTE5MDIyNjE4MjUyMloXDTI5MDIyMzE4MjUy +Mlowga8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRIwEAYDVQQH +DAlMb3MgR2F0b3MxIzAhBgNVBAoMGkxlbXVyVHJ1c3QgRW50ZXJwcmlzZXMgTHRk +MSYwJAYDVQQLDB1Vbml0dGVzdGluZyBPcGVyYXRpb25zIENlbnRlcjEqMCgGA1UE +AwwhTGVtdXJUcnVzdCBVbml0dGVzdHMgUm9vdCBDQSAyMDE4MIIBtjCCASsGByqG +SM44BAEwggEeAoGBAO2+6wO20rn9K7RtXJ7/kCSVFzYZsY1RKvmJ6BBkMFIepBkz +2pk62tRhJgNH07GKF7pyTPRRKqt38CaPK4ERUpavx3Ok6vZ3PKq8tMac/PMKBmT1 +Xfpch54KDlCdreEMJqYiCwbIyiSCR4+PCH+7xC5Uh0PIZo6otNWe3Wkk53CfAhUA +8d4YAtto6D30f7qkEa7DMAccUS8CgYAiv8r0k0aUEaeioblcCAjmhvE0v8/tD5u1 +anHO4jZIIv7uOrNFIGfqcNEOBs5AQkt5Bxn6x0b/VvtZ0FSrD0j4f36pTgro6noG +/0oRt0JngxsMSfo0LV4+bY62v21A0SneNgTgY+ugdfgGWvb0+9tpsIhiY69T+7c8 +Oa0S6OWSPAOBhAACgYB5wa+nJJNZPoTWFum27JlWGYLO2flg5EpWlOvcEE0o5RfB +FPnMM033kKQQEI0YpCAq9fIMKhhUMk1X4mKUBUTt+Nrn1pY2l/wt5G6AQdHI8QXz +P1ecBbHPNZtWe3iVnfOgz/Pd8tU9slcXP9z5XbZ7R/oGcF/TPRTtbLEkYZNaDDAL +BglghkgBZQMEAwIDMAAwLQIVANubSNMSLt8plN9ZV3cp4pe3lMYCAhQPLLE7rTgm +92X+hWfyz000QEpYEQ== +-----END CERTIFICATE----- +""" +DSA_CERT = parse_certificate(DSA_CERT_STR) diff --git a/lemur/users/models.py b/lemur/users/models.py index 79125b9c..d7b900dc 100644 --- a/lemur/users/models.py +++ b/lemur/users/models.py @@ -33,7 +33,7 @@ def hash_password(mapper, connect, target): class User(db.Model): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) password = Column(String(128)) active = Column(Boolean()) @@ -41,14 +41,24 @@ class User(db.Model): username = Column(String(255), nullable=False, unique=True) email = Column(String(128), unique=True) profile_picture = Column(String(255)) - roles = relationship('Role', secondary=roles_users, passive_deletes=True, backref=db.backref('user'), lazy='dynamic') - certificates = relationship('Certificate', backref=db.backref('user'), lazy='dynamic') - pending_certificates = relationship('PendingCertificate', backref=db.backref('user'), lazy='dynamic') - authorities = relationship('Authority', backref=db.backref('user'), lazy='dynamic') - keys = relationship('ApiKey', backref=db.backref('user'), lazy='dynamic') - logs = relationship('Log', backref=db.backref('user'), lazy='dynamic') + roles = relationship( + "Role", + secondary=roles_users, + passive_deletes=True, + backref=db.backref("user"), + lazy="dynamic", + ) + certificates = relationship( + "Certificate", backref=db.backref("user"), lazy="dynamic" + ) + pending_certificates = relationship( + "PendingCertificate", backref=db.backref("user"), lazy="dynamic" + ) + authorities = relationship("Authority", backref=db.backref("user"), lazy="dynamic") + keys = relationship("ApiKey", backref=db.backref("user"), lazy="dynamic") + logs = relationship("Log", backref=db.backref("user"), lazy="dynamic") - sensitive_fields = ('password',) + sensitive_fields = ("password",) def check_password(self, password): """ @@ -68,7 +78,7 @@ class User(db.Model): :return: """ if self.password: - self.password = bcrypt.generate_password_hash(self.password).decode('utf-8') + self.password = bcrypt.generate_password_hash(self.password).decode("utf-8") @property def is_admin(self): @@ -79,11 +89,11 @@ class User(db.Model): :return: """ for role in self.roles: - if role.name == 'admin': + if role.name == "admin": return True def __repr__(self): return "User(username={username})".format(username=self.username) -listen(User, 'before_insert', hash_password) +listen(User, "before_insert", hash_password) diff --git a/lemur/users/schemas.py b/lemur/users/schemas.py index b5a21127..74bd93e9 100644 --- a/lemur/users/schemas.py +++ b/lemur/users/schemas.py @@ -8,7 +8,11 @@ from marshmallow import fields from lemur.common.schema import LemurInputSchema, LemurOutputSchema -from lemur.schemas import AssociatedRoleSchema, AssociatedCertificateSchema, AssociatedAuthoritySchema +from lemur.schemas import ( + AssociatedRoleSchema, + AssociatedCertificateSchema, + AssociatedAuthoritySchema, +) class UserInputSchema(LemurInputSchema): diff --git a/lemur/users/service.py b/lemur/users/service.py index c6557cb9..8fb91aa3 100644 --- a/lemur/users/service.py +++ b/lemur/users/service.py @@ -96,7 +96,7 @@ def get_by_email(email): :param email: :return: """ - return database.get(User, email, field='email') + return database.get(User, email, field="email") def get_by_username(username): @@ -106,7 +106,7 @@ def get_by_username(username): :param username: :return: """ - return database.get(User, username, field='username') + return database.get(User, username, field="username") def get_all(): @@ -129,10 +129,10 @@ def render(args): """ query = database.session_query(User) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, User, terms) return database.sort_and_page(query, User, args) diff --git a/lemur/users/views.py b/lemur/users/views.py index eb67f014..06729177 100644 --- a/lemur/users/views.py +++ b/lemur/users/views.py @@ -18,15 +18,20 @@ from lemur.users import service from lemur.certificates import service as certificate_service from lemur.roles import service as role_service -from lemur.users.schemas import user_input_schema, user_output_schema, users_output_schema +from lemur.users.schemas import ( + user_input_schema, + user_output_schema, + users_output_schema, +) -mod = Blueprint('users', __name__) +mod = Blueprint("users", __name__) api = Api(mod) class UsersList(AuthenticatedResource): """ Defines the 'users' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(UsersList, self).__init__() @@ -83,8 +88,8 @@ class UsersList(AuthenticatedResource): :statuscode 200: no error """ parser = paginated_parser.copy() - parser.add_argument('owner', type=str, location='args') - parser.add_argument('id', type=str, location='args') + parser.add_argument("owner", type=str, location="args") + parser.add_argument("id", type=str, location="args") args = parser.parse_args() return service.render(args) @@ -137,7 +142,14 @@ class UsersList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.create(data['username'], data['password'], data['email'], data['active'], None, data['roles']) + return service.create( + data["username"], + data["password"], + data["email"], + data["active"], + None, + data["roles"], + ) class Users(AuthenticatedResource): @@ -225,7 +237,14 @@ class Users(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.update(user_id, data['username'], data['email'], data['active'], None, data['roles']) + return service.update( + user_id, + data["username"], + data["email"], + data["active"], + None, + data["roles"], + ) class CertificateUsers(AuthenticatedResource): @@ -365,8 +384,12 @@ class Me(AuthenticatedResource): return g.current_user -api.add_resource(Me, '/auth/me', endpoint='me') -api.add_resource(UsersList, '/users', endpoint='users') -api.add_resource(Users, '/users/', endpoint='user') -api.add_resource(CertificateUsers, '/certificates//creator', endpoint='certificateCreator') -api.add_resource(RoleUsers, '/roles//users', endpoint='roleUsers') +api.add_resource(Me, "/auth/me", endpoint="me") +api.add_resource(UsersList, "/users", endpoint="users") +api.add_resource(Users, "/users/", endpoint="user") +api.add_resource( + CertificateUsers, + "/certificates//creator", + endpoint="certificateCreator", +) +api.add_resource(RoleUsers, "/roles//users", endpoint="roleUsers") diff --git a/lemur/utils.py b/lemur/utils.py index 1661e3f7..909d959a 100644 --- a/lemur/utils.py +++ b/lemur/utils.py @@ -31,7 +31,9 @@ def mktempfile(): @contextmanager def mktemppath(): try: - path = os.path.join(tempfile._get_default_tempdir(), next(tempfile._get_candidate_names())) + path = os.path.join( + tempfile._get_default_tempdir(), next(tempfile._get_candidate_names()) + ) yield path finally: try: @@ -53,7 +55,7 @@ def get_keys(): # when running lemur create_config, this code needs to work despite # the fact that there is not a current_app with a config at that point - keys = current_app.config.get('LEMUR_ENCRYPTION_KEYS', []) + keys = current_app.config.get("LEMUR_ENCRYPTION_KEYS", []) # this function is expected to return a list of keys, but we want # to let people just specify a single key @@ -97,7 +99,7 @@ class Vault(types.TypeDecorator): # ensure bytes for fernet if isinstance(value, str): - value = value.encode('utf-8') + value = value.encode("utf-8") return MultiFernet(self.keys).encrypt(value) @@ -117,4 +119,4 @@ class Vault(types.TypeDecorator): if not value: return - return MultiFernet(self.keys).decrypt(value).decode('utf8') + return MultiFernet(self.keys).decrypt(value).decode("utf8") diff --git a/package.json b/package.json index f47978db..9b899176 100644 --- a/package.json +++ b/package.json @@ -7,9 +7,8 @@ }, "dependencies": { "bower": "^1.8.2", - "browser-sync": "^2.3.1", + "browser-sync": "^2.26.7", "del": "^2.2.2", - "gulp": "^3.8.11", "gulp-autoprefixer": "^3.1.1", "gulp-cache": "^0.4.5", "gulp-concat": "^2.4.1", @@ -26,7 +25,7 @@ "gulp-minify-css": "^1.2.4", "gulp-minify-html": "~1.0.6", "gulp-ng-annotate": "~2.0.0", - "gulp-ng-html2js": "~0.2.2", + "gulp-ng-html2js": "^0.2.3", "gulp-notify": "^2.2.0", "gulp-plumber": "^1.1.0", "gulp-print": "^2.0.1", @@ -60,6 +59,7 @@ "test": "gulp test" }, "devDependencies": { + "gulp": "^3.9.1", "jshint": "^2.8.0", "karma-chrome-launcher": "^2.0.0" } diff --git a/requirements-dev.in b/requirements-dev.in index 84104679..2ffc5488 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -4,4 +4,5 @@ flake8==3.5.0 # flake8 3.6.0 is giving erroneous "W605 invalid escape sequence" pre-commit invoke twine -nodeenv \ No newline at end of file +nodeenv +pyyaml>=4.2b1 \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 7b427b20..d1423888 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,35 +2,40 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --no-index --output-file requirements-dev.txt requirements-dev.in +# pip-compile --no-index --output-file=requirements-dev.txt requirements-dev.in # -aspy.yaml==1.1.1 # via pre-commit -bleach==3.0.2 # via readme-renderer -cached-property==1.5.1 # via pre-commit -certifi==2018.11.29 # via requests -cfgv==1.1.0 # via pre-commit +aspy.yaml==1.3.0 # via pre-commit +bleach==3.1.0 # via readme-renderer +certifi==2019.11.28 # via requests +cfgv==2.0.1 # via pre-commit chardet==3.0.4 # via requests -docutils==0.14 # via readme-renderer +docutils==0.15.2 # via readme-renderer flake8==3.5.0 -identify==1.1.7 # via pre-commit +identify==1.4.9 # via pre-commit idna==2.8 # via requests -importlib-metadata==0.7 # via pre-commit -invoke==1.2.0 +importlib-metadata==1.3.0 # via keyring, pre-commit, twine +invoke==1.3.0 +keyring==21.0.0 # via twine mccabe==0.6.1 # via flake8 +more-itertools==8.0.2 # via zipp nodeenv==1.3.3 -pkginfo==1.4.2 # via twine -pre-commit==1.12.0 +pkginfo==1.5.0.1 # via twine +pre-commit==1.21.0 pycodestyle==2.3.1 # via flake8 pyflakes==1.6.0 # via flake8 -pygments==2.3.1 # via readme-renderer -pyyaml==3.13 # via aspy.yaml, pre-commit +pygments==2.5.2 # via readme-renderer +pyyaml==5.2 readme-renderer==24.0 # via twine -requests-toolbelt==0.8.0 # via twine -requests==2.21.0 # via requests-toolbelt, twine -six==1.12.0 # via bleach, cfgv, pre-commit, readme-renderer +requests-toolbelt==0.9.1 # via twine +requests==2.22.0 # via requests-toolbelt, twine +six==1.13.0 # via bleach, cfgv, pre-commit, readme-renderer toml==0.10.0 # via pre-commit -tqdm==4.28.1 # via twine -twine==1.12.1 -urllib3==1.24.1 # via requests -virtualenv==16.1.0 # via pre-commit +tqdm==4.41.1 # via twine +twine==3.1.1 +urllib3==1.25.7 # via requests +virtualenv==16.7.9 # via pre-commit webencodings==0.5.1 # via bleach +zipp==0.6.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements-docs.txt b/requirements-docs.txt index 3f036915..893965ca 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -2,95 +2,113 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --no-index --output-file requirements-docs.txt requirements-docs.in +# pip-compile --no-index --output-file=requirements-docs.txt requirements-docs.in # -acme==0.29.1 +acme==1.0.0 alabaster==0.7.12 # via sphinx alembic-autogenerate-enums==0.0.2 -alembic==1.0.5 -amqp==2.3.2 -aniso8601==4.0.1 -arrow==0.12.1 -asn1crypto==0.24.0 +alembic==1.3.2 +amqp==2.5.2 +aniso8601==8.0.0 +arrow==0.15.5 asyncpool==1.0 -babel==2.6.0 # via sphinx -bcrypt==3.1.4 -billiard==3.5.0.5 +babel==2.8.0 # via sphinx +bcrypt==3.1.7 +billiard==3.6.1.0 blinker==1.4 -boto3==1.9.60 -botocore==1.12.60 -celery[redis]==4.2.1 -certifi==2018.11.29 -cffi==1.11.5 +boto3==1.10.46 +botocore==1.13.46 +celery[redis]==4.4.0 +certifi==2019.11.28 +certsrv==2.1.1 +cffi==1.13.2 chardet==3.0.4 click==7.0 -cloudflare==2.1.0 -cryptography==2.4.2 +cloudflare==2.3.1 +cryptography==2.8 dnspython3==1.15.0 dnspython==1.15.0 -docutils==0.14 +docutils==0.15.2 dyn==1.8.1 flask-bcrypt==0.7.1 -flask-cors==3.0.7 +flask-cors==3.0.8 flask-mail==0.9.1 -flask-migrate==2.3.1 +flask-migrate==2.5.2 flask-principal==0.4.0 -flask-restful==0.3.6 +flask-replicated==1.3 +flask-restful==0.3.7 flask-script==2.0.6 -flask-sqlalchemy==2.3.2 -flask==1.0.2 -future==0.17.1 -gunicorn==19.9.0 -idna==2.7 -imagesize==1.1.0 # via sphinx +flask-sqlalchemy==2.4.1 +flask==1.1.1 +future==0.18.2 +gunicorn==20.0.4 +hvac==0.9.6 +idna==2.8 +imagesize==1.2.0 # via sphinx +importlib-metadata==1.3.0 inflection==0.3.1 itsdangerous==1.1.0 -jinja2==2.10 -jmespath==0.9.3 -josepy==1.1.0 +javaobj-py3==0.4.0.1 +jinja2==2.10.3 +jmespath==0.9.4 +josepy==1.2.0 jsonlines==1.2.0 -kombu==4.2.2 +kombu==4.6.7 lockfile==0.12.2 -mako==1.0.7 -markupsafe==1.1.0 -marshmallow-sqlalchemy==0.15.0 -marshmallow==2.16.3 -mock==2.0.0 +logmatic-python==0.1.7 +mako==1.1.0 +markupsafe==1.1.1 +marshmallow-sqlalchemy==0.21.0 +marshmallow==2.20.4 +mock==3.0.5 +more-itertools==8.0.2 ndg-httpsclient==0.5.1 -packaging==18.0 # via sphinx -paramiko==2.4.2 -pbr==5.1.1 -pem==18.2.0 -psycopg2==2.7.6.1 -pyasn1-modules==0.2.2 -pyasn1==0.4.4 +packaging==19.2 # via sphinx +paramiko==2.7.1 +pem==19.3.0 +psycopg2==2.8.4 +pyasn1-modules==0.2.7 +pyasn1==0.4.8 pycparser==2.19 -pygments==2.3.1 # via sphinx -pyjwt==1.7.0 +pycryptodomex==3.9.4 +pygments==2.5.2 # via sphinx +pyjks==19.0.0 +pyjwt==1.7.1 pynacl==1.3.0 -pyopenssl==18.0.0 -pyparsing==2.3.0 # via packaging +pyopenssl==19.1.0 +pyparsing==2.4.6 # via packaging pyrfc3339==1.1 -python-dateutil==2.7.5 -python-editor==1.0.3 -pytz==2018.7 -pyyaml==3.13 -raven[flask]==6.9.0 -redis==2.10.6 -requests-toolbelt==0.8.0 -requests[security]==2.20.1 +python-dateutil==2.8.1 +python-editor==1.0.4 +python-json-logger==0.1.11 +pytz==2019.3 +pyyaml==5.2 +raven[flask]==6.10.0 +redis==3.3.11 +requests-toolbelt==0.9.1 +requests[security]==2.22.0 retrying==1.3.3 -s3transfer==0.1.13 -six==1.11.0 -snowballstemmer==1.2.1 # via sphinx -sphinx-rtd-theme==0.4.2 -sphinx==1.8.2 +s3transfer==0.2.1 +six==1.13.0 +snowballstemmer==2.0.0 # via sphinx +sphinx-rtd-theme==0.4.3 +sphinx==2.3.1 +sphinxcontrib-applehelp==1.0.1 # via sphinx +sphinxcontrib-devhelp==1.0.1 # via sphinx +sphinxcontrib-htmlhelp==1.0.2 # via sphinx sphinxcontrib-httpdomain==1.7.0 -sphinxcontrib-websupport==1.1.0 # via sphinx -sqlalchemy-utils==0.33.9 -sqlalchemy==1.2.14 -tabulate==0.8.2 -urllib3==1.24.1 -vine==1.1.4 -werkzeug==0.14.1 -xmltodict==0.11.0 +sphinxcontrib-jsmath==1.0.1 # via sphinx +sphinxcontrib-qthelp==1.0.2 # via sphinx +sphinxcontrib-serializinghtml==1.1.3 # via sphinx +sqlalchemy-utils==0.36.1 +sqlalchemy==1.3.12 +tabulate==0.8.6 +twofish==0.3.0 +urllib3==1.25.7 +vine==1.3.0 +werkzeug==0.16.0 +xmltodict==0.12.0 +zipp==0.6.0 + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements-tests.in b/requirements-tests.in index 02a2b0ae..610f26f9 100644 --- a/requirements-tests.in +++ b/requirements-tests.in @@ -1,8 +1,11 @@ # Run `make up-reqs` to update pinned dependencies in requirement text files +bandit +black coverage factory-boy Faker +fakeredis freezegun moto nose @@ -11,3 +14,4 @@ pytest pytest-flask pytest-mock requests-mock +pyyaml>=4.2b1 \ No newline at end of file diff --git a/requirements-tests.txt b/requirements-tests.txt index 59c626f7..293bd350 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -2,63 +2,89 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --no-index --output-file requirements-tests.txt requirements-tests.in +# pip-compile --no-index --output-file=requirements-tests.txt requirements-tests.in # -asn1crypto==0.24.0 # via cryptography -atomicwrites==1.2.1 # via pytest -attrs==18.2.0 # via pytest -aws-xray-sdk==0.95 # via moto -boto3==1.9.67 # via moto +appdirs==1.4.3 # via black +attrs==19.3.0 # via black, jsonschema, pytest +aws-sam-translator==1.19.1 # via cfn-lint +aws-xray-sdk==2.4.3 # via moto +bandit==1.6.2 +black==19.10b0 +boto3==1.10.46 # via aws-sam-translator, moto boto==2.49.0 # via moto -botocore==1.12.67 # via boto3, moto, s3transfer -certifi==2018.11.29 # via requests -cffi==1.11.5 # via cryptography +botocore==1.13.46 # via aws-xray-sdk, boto3, moto, s3transfer +certifi==2019.11.28 # via requests +cffi==1.13.2 # via cryptography +cfn-lint==0.26.2 # via moto chardet==3.0.4 # via requests -click==7.0 # via flask -coverage==4.5.2 -cryptography==2.4.2 # via moto -docker-pycreds==0.4.0 # via docker -docker==3.6.0 # via moto -docutils==0.14 # via botocore -ecdsa==0.13 # via python-jose -factory-boy==2.11.1 -faker==1.0.1 -flask==1.0.2 # via pytest-flask -freezegun==0.3.11 -future==0.17.1 # via python-jose -idna==2.8 # via cryptography, requests +click==7.0 # via black, flask +coverage==5.0.1 +cryptography==2.8 # via moto, sshpubkeys +docker==4.1.0 # via moto +docutils==0.15.2 # via botocore +ecdsa==0.15 # via python-jose, sshpubkeys +factory-boy==2.12.0 +faker==3.0.0 +fakeredis==1.1.0 +flask==1.1.1 # via pytest-flask +freezegun==0.3.12 +future==0.18.2 # via aws-xray-sdk +gitdb2==2.0.6 # via gitpython +gitpython==3.0.5 # via bandit +idna==2.8 # via moto, requests +importlib-metadata==1.3.0 # via jsonschema, pluggy, pytest itsdangerous==1.1.0 # via flask -jinja2==2.10 # via flask, moto -jmespath==0.9.3 # via boto3, botocore -jsondiff==1.1.1 # via moto -jsonpickle==1.0 # via aws-xray-sdk -markupsafe==1.1.0 # via jinja2 -mock==2.0.0 # via moto -more-itertools==4.3.0 # via pytest -moto==1.3.7 +jinja2==2.10.3 # via flask, moto +jmespath==0.9.4 # via boto3, botocore +jsondiff==1.1.2 # via moto +jsonpatch==1.24 # via cfn-lint +jsonpickle==1.2 # via aws-xray-sdk +jsonpointer==2.0 # via jsonpatch +jsonschema==3.2.0 # via aws-sam-translator, cfn-lint +markupsafe==1.1.1 # via jinja2 +mock==3.0.5 # via moto +more-itertools==8.0.2 # via pytest, zipp +moto==1.3.14 nose==1.3.7 -pbr==5.1.1 # via mock -pluggy==0.8.0 # via pytest -py==1.7.0 # via pytest -pyaml==18.11.0 # via moto +packaging==19.2 # via pytest +pathspec==0.7.0 # via black +pbr==5.4.4 # via stevedore +pluggy==0.13.1 # via pytest +py==1.8.1 # via pytest +pyasn1==0.4.8 # via python-jose, rsa pycparser==2.19 # via cffi -pycryptodome==3.7.2 # via python-jose -pyflakes==2.0.0 -pytest-flask==0.14.0 -pytest-mock==1.10.0 -pytest==4.0.2 -python-dateutil==2.7.5 # via botocore, faker, freezegun, moto -python-jose==2.0.2 # via moto -pytz==2018.7 # via moto -pyyaml==3.13 # via pyaml -requests-mock==1.5.2 -requests==2.21.0 # via aws-xray-sdk, docker, moto, requests-mock, responses -responses==0.10.5 # via moto -s3transfer==0.1.13 # via boto3 -six==1.12.0 # via cryptography, docker, docker-pycreds, faker, freezegun, mock, more-itertools, moto, pytest, python-dateutil, python-jose, requests-mock, responses, websocket-client -text-unidecode==1.2 # via faker -urllib3==1.24.1 # via botocore, requests -websocket-client==0.54.0 # via docker -werkzeug==0.14.1 # via flask, moto, pytest-flask -wrapt==1.10.11 # via aws-xray-sdk -xmltodict==0.11.0 # via moto +pyflakes==2.1.1 +pyparsing==2.4.6 # via packaging +pyrsistent==0.15.6 # via jsonschema +pytest-flask==0.15.0 +pytest-mock==1.13.0 +pytest==5.3.2 +python-dateutil==2.8.1 # via botocore, faker, freezegun, moto +python-jose==3.1.0 # via moto +pytz==2019.3 # via moto +pyyaml==5.2 +redis==3.3.11 # via fakeredis +regex==2019.12.20 # via black +requests-mock==1.7.0 +requests==2.22.0 # via docker, moto, requests-mock, responses +responses==0.10.9 # via moto +rsa==4.0 # via python-jose +s3transfer==0.2.1 # via boto3 +six==1.13.0 # via aws-sam-translator, bandit, cfn-lint, cryptography, docker, ecdsa, faker, fakeredis, freezegun, jsonschema, mock, moto, packaging, pyrsistent, python-dateutil, python-jose, requests-mock, responses, stevedore, websocket-client +smmap2==2.0.5 # via gitdb2 +sortedcontainers==2.1.0 # via fakeredis +sshpubkeys==3.1.0 # via moto +stevedore==1.31.0 # via bandit +text-unidecode==1.3 # via faker +toml==0.10.0 # via black +typed-ast==1.4.0 # via black +urllib3==1.25.7 # via botocore, requests +wcwidth==0.1.8 # via pytest +websocket-client==0.57.0 # via docker +werkzeug==0.16.0 # via flask, moto, pytest-flask +wrapt==1.11.2 # via aws-xray-sdk +xmltodict==0.12.0 # via moto +zipp==0.6.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements.in b/requirements.in index 9824650b..ed2093c9 100644 --- a/requirements.in +++ b/requirements.in @@ -8,6 +8,7 @@ boto3 botocore celery[redis] certifi +certsrv CloudFlare cryptography dnspython3 @@ -21,25 +22,30 @@ Flask-Script Flask-SQLAlchemy Flask Flask-Cors +flask_replicated future gunicorn +hvac # required for the vault destination plugin inflection jinja2 lockfile +logmatic-python marshmallow-sqlalchemy -marshmallow +marshmallow<2.20.5 #schema duplicate issues https://github.com/marshmallow-code/marshmallow-sqlalchemy/issues/121 ndg-httpsclient paramiko # required for the SFTP destination plugin pem psycopg2 +pyjks >= 19 # pyjks < 19 depends on pycryptodome, which conflicts with dyn's usage of pycrypto pyjwt pyOpenSSL +pyyaml>=4.2b1 #high severity alert python_ldap raven[flask] -redis<3 # redis>=3 is not compatible with celery +redis requests retrying six SQLAlchemy-Utils tabulate -xmltodict \ No newline at end of file +xmltodict diff --git a/requirements.txt b/requirements.txt index 7ee9a167..639c9377 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,85 +2,98 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --no-index --output-file requirements.txt requirements.in +# pip-compile --no-index --output-file=requirements.txt requirements.in # -acme==0.29.1 +acme==1.0.0 alembic-autogenerate-enums==0.0.2 -alembic==1.0.5 # via flask-migrate -amqp==2.3.2 # via kombu -aniso8601==4.0.1 # via flask-restful -arrow==0.12.1 -asn1crypto==0.24.0 # via cryptography +alembic==1.3.2 # via flask-migrate +amqp==2.5.2 # via kombu +aniso8601==8.0.0 # via flask-restful +arrow==0.15.5 asyncpool==1.0 -bcrypt==3.1.5 # via flask-bcrypt, paramiko -billiard==3.5.0.5 # via celery +bcrypt==3.1.7 # via flask-bcrypt, paramiko +billiard==3.6.1.0 # via celery blinker==1.4 # via flask-mail, flask-principal, raven -boto3==1.9.67 -botocore==1.12.67 -celery[redis]==4.2.1 -certifi==2018.11.29 -cffi==1.11.5 # via bcrypt, cryptography, pynacl +boto3==1.10.46 +botocore==1.13.46 +celery[redis]==4.4.0 +certifi==2019.11.28 +certsrv==2.1.1 +cffi==1.13.2 # via bcrypt, cryptography, pynacl chardet==3.0.4 # via requests click==7.0 # via flask -cloudflare==2.1.0 -cryptography==2.4.2 +cloudflare==2.3.1 +cryptography==2.8 dnspython3==1.15.0 dnspython==1.15.0 # via dnspython3 -docutils==0.14 # via botocore +docutils==0.15.2 # via botocore dyn==1.8.1 flask-bcrypt==0.7.1 -flask-cors==3.0.7 +flask-cors==3.0.8 flask-mail==0.9.1 -flask-migrate==2.3.1 +flask-migrate==2.5.2 flask-principal==0.4.0 +flask-replicated==1.3 flask-restful==0.3.7 flask-script==2.0.6 -flask-sqlalchemy==2.3.2 -flask==1.0.2 -future==0.17.1 -gunicorn==19.9.0 -idna==2.8 # via cryptography, requests +flask-sqlalchemy==2.4.1 +flask==1.1.1 +future==0.18.2 +gunicorn==20.0.4 +hvac==0.9.6 +idna==2.8 # via requests +importlib-metadata==1.3.0 # via kombu inflection==0.3.1 itsdangerous==1.1.0 # via flask -jinja2==2.10 -jmespath==0.9.3 # via boto3, botocore -josepy==1.1.0 # via acme +javaobj-py3==0.4.0.1 # via pyjks +jinja2==2.10.3 +jmespath==0.9.4 # via boto3, botocore +josepy==1.2.0 # via acme jsonlines==1.2.0 # via cloudflare -kombu==4.2.2 # via celery +kombu==4.6.7 # via celery lockfile==0.12.2 -mako==1.0.7 # via alembic -markupsafe==1.1.0 # via jinja2, mako -marshmallow-sqlalchemy==0.15.0 -marshmallow==2.16.3 -mock==2.0.0 # via acme +logmatic-python==0.1.7 +mako==1.1.0 # via alembic +markupsafe==1.1.1 # via jinja2, mako +marshmallow-sqlalchemy==0.21.0 +marshmallow==2.20.4 +mock==3.0.5 # via acme +more-itertools==8.0.2 # via zipp ndg-httpsclient==0.5.1 -paramiko==2.4.2 -pbr==5.1.1 # via mock -pem==18.2.0 -psycopg2==2.7.6.1 -pyasn1-modules==0.2.2 # via python-ldap -pyasn1==0.4.4 # via ndg-httpsclient, paramiko, pyasn1-modules, python-ldap +paramiko==2.7.1 +pem==19.3.0 +psycopg2==2.8.4 +pyasn1-modules==0.2.7 # via pyjks, python-ldap +pyasn1==0.4.8 # via ndg-httpsclient, pyasn1-modules, pyjks, python-ldap pycparser==2.19 # via cffi +pycryptodomex==3.9.4 # via pyjks +pyjks==19.0.0 pyjwt==1.7.1 pynacl==1.3.0 # via paramiko -pyopenssl==18.0.0 +pyopenssl==19.1.0 pyrfc3339==1.1 # via acme -python-dateutil==2.7.5 # via alembic, arrow, botocore -python-editor==1.0.3 # via alembic -python-ldap==3.1.0 -pytz==2018.7 # via acme, celery, flask-restful, pyrfc3339 -pyyaml==3.13 # via cloudflare -raven[flask]==6.9.0 -redis==2.10.6 -requests-toolbelt==0.8.0 # via acme -requests[security]==2.21.0 +python-dateutil==2.8.1 # via alembic, arrow, botocore +python-editor==1.0.4 # via alembic +python-json-logger==0.1.11 # via logmatic-python +python-ldap==3.2.0 +pytz==2019.3 # via acme, celery, flask-restful, pyrfc3339 +pyyaml==5.2 +raven[flask]==6.10.0 +redis==3.3.11 +requests-toolbelt==0.9.1 # via acme +requests[security]==2.22.0 retrying==1.3.3 -s3transfer==0.1.13 # via boto3 -six==1.12.0 -sqlalchemy-utils==0.33.9 -sqlalchemy==1.2.15 # via alembic, flask-sqlalchemy, marshmallow-sqlalchemy, sqlalchemy-utils -tabulate==0.8.2 -urllib3==1.24.1 # via botocore, requests -vine==1.1.4 # via amqp -werkzeug==0.14.1 # via flask -xmltodict==0.11.0 +s3transfer==0.2.1 # via boto3 +six==1.13.0 +sqlalchemy-utils==0.36.1 +sqlalchemy==1.3.12 # via alembic, flask-sqlalchemy, marshmallow-sqlalchemy, sqlalchemy-utils +tabulate==0.8.6 +twofish==0.3.0 # via pyjks +urllib3==1.25.7 # via botocore, requests +vine==1.3.0 # via amqp, celery +werkzeug==0.16.0 # via flask +xmltodict==0.12.0 +zipp==0.6.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/setup.py b/setup.py index 1511b013..90c0b2f8 100644 --- a/setup.py +++ b/setup.py @@ -143,10 +143,11 @@ setup( 'aws_s3 = lemur.plugins.lemur_aws.plugin:S3DestinationPlugin', 'email_notification = lemur.plugins.lemur_email.plugin:EmailNotificationPlugin', 'slack_notification = lemur.plugins.lemur_slack.plugin:SlackNotificationPlugin', - 'java_truststore_export = lemur.plugins.lemur_java.plugin:JavaTruststoreExportPlugin', - 'java_keystore_export = lemur.plugins.lemur_java.plugin:JavaKeystoreExportPlugin', + 'java_truststore_export = lemur.plugins.lemur_jks.plugin:JavaTruststoreExportPlugin', + 'java_keystore_export = lemur.plugins.lemur_jks.plugin:JavaKeystoreExportPlugin', 'openssl_export = lemur.plugins.lemur_openssl.plugin:OpenSSLExportPlugin', 'atlas_metric = lemur.plugins.lemur_atlas.plugin:AtlasMetricPlugin', + 'atlas_metric_redis = lemur.plugins.lemur_atlas_redis.plugin:AtlasMetricRedisPlugin', 'kubernetes_destination = lemur.plugins.lemur_kubernetes.plugin:KubernetesDestinationPlugin', 'cryptography_issuer = lemur.plugins.lemur_cryptography.plugin:CryptographyIssuerPlugin', 'cfssl_issuer = lemur.plugins.lemur_cfssl.plugin:CfsslIssuerPlugin', @@ -154,7 +155,11 @@ setup( 'digicert_cis_issuer = lemur.plugins.lemur_digicert.plugin:DigiCertCISIssuerPlugin', 'digicert_cis_source = lemur.plugins.lemur_digicert.plugin:DigiCertCISSourcePlugin', 'csr_export = lemur.plugins.lemur_csr.plugin:CSRExportPlugin', - 'sftp_destination = lemur.plugins.lemur_sftp.plugin:SFTPDestinationPlugin' + 'sftp_destination = lemur.plugins.lemur_sftp.plugin:SFTPDestinationPlugin', + 'vault_source = lemur.plugins.lemur_vault_dest.plugin:VaultSourcePlugin', + 'vault_desination = lemur.plugins.lemur_vault_dest.plugin:VaultDestinationPlugin', + 'adcs_issuer = lemur.plugins.lemur_adcs.plugin:ADCSIssuerPlugin', + 'adcs_source = lemur.plugins.lemur_adcs.plugin:ADCSSourcePlugin' ], }, classifiers=[ diff --git a/tox.ini b/tox.ini index fdd2585b..d3ad8944 100644 --- a/tox.ini +++ b/tox.ini @@ -1,2 +1,2 @@ [tox] -envlist = py35 +envlist = py37