diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3d19151..995a8508 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,3 +8,8 @@ sha: v2.9.5 hooks: - id: jshint +- repo: https://github.com/ambv/black + rev: stable + hooks: + - id: black + language_version: python3.7 \ No newline at end of file diff --git a/lemur/__about__.py b/lemur/__about__.py index d15b7dea..766d3668 100644 --- a/lemur/__about__.py +++ b/lemur/__about__.py @@ -1,12 +1,18 @@ from __future__ import absolute_import, division, print_function __all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", ] __title__ = "lemur" -__summary__ = ("Certificate management and orchestration service") +__summary__ = "Certificate management and orchestration service" __uri__ = "https://github.com/Netflix/lemur" __version__ = "0.7.0" diff --git a/lemur/__init__.py b/lemur/__init__.py index 769e0cec..6229a3d1 100644 --- a/lemur/__init__.py +++ b/lemur/__init__.py @@ -32,14 +32,26 @@ from lemur.pending_certificates.views import mod as pending_certificates_bp from lemur.dns_providers.views import mod as dns_providers_bp from lemur.__about__ import ( - __author__, __copyright__, __email__, __license__, __summary__, __title__, - __uri__, __version__ + __author__, + __copyright__, + __email__, + __license__, + __summary__, + __title__, + __uri__, + __version__, ) __all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", ] LEMUR_BLUEPRINTS = ( @@ -63,7 +75,9 @@ LEMUR_BLUEPRINTS = ( def create_app(config_path=None): - app = factory.create_app(app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config_path) + app = factory.create_app( + app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config_path + ) configure_hook(app) return app @@ -93,7 +107,7 @@ def configure_hook(app): @app.after_request def after_request(response): # Return early if we don't have the start time - if not hasattr(g, 'request_start_time'): + if not hasattr(g, "request_start_time"): return response # Get elapsed time in milliseconds @@ -102,12 +116,12 @@ def configure_hook(app): # Collect request/response tags tags = { - 'endpoint': request.endpoint, - 'request_method': request.method.lower(), - 'status_code': response.status_code + "endpoint": request.endpoint, + "request_method": request.method.lower(), + "status_code": response.status_code, } # Record our response time metric - metrics.send('response_time', 'TIMER', elapsed, metric_tags=tags) - metrics.send('status_code_{}'.format(response.status_code), 'counter', 1) + metrics.send("response_time", "TIMER", elapsed, metric_tags=tags) + metrics.send("status_code_{}".format(response.status_code), "counter", 1) return response diff --git a/lemur/api_keys/cli.py b/lemur/api_keys/cli.py index 2259d774..8aed0497 100644 --- a/lemur/api_keys/cli.py +++ b/lemur/api_keys/cli.py @@ -14,23 +14,32 @@ from datetime import datetime manager = Manager(usage="Handles all api key related tasks.") -@manager.option('-u', '--user-id', dest='uid', help='The User ID this access key belongs too.') -@manager.option('-n', '--name', dest='name', help='The name of this API Key.') -@manager.option('-t', '--ttl', dest='ttl', help='The TTL of this API Key. -1 for forever.') +@manager.option( + "-u", "--user-id", dest="uid", help="The User ID this access key belongs too." +) +@manager.option("-n", "--name", dest="name", help="The name of this API Key.") +@manager.option( + "-t", "--ttl", dest="ttl", help="The TTL of this API Key. -1 for forever." +) def create(uid, name, ttl): """ Create a new api key for a user. :return: """ print("[+] Creating a new api key.") - key = api_key_service.create(user_id=uid, name=name, - ttl=ttl, issued_at=int(datetime.utcnow().timestamp()), revoked=False) + key = api_key_service.create( + user_id=uid, + name=name, + ttl=ttl, + issued_at=int(datetime.utcnow().timestamp()), + revoked=False, + ) print("[+] Successfully created a new api key. Generating a JWT...") jwt = create_token(uid, key.id, key.ttl) print("[+] Your JWT is: {jwt}".format(jwt=jwt)) -@manager.option('-a', '--api-key-id', dest='aid', help='The API Key ID to revoke.') +@manager.option("-a", "--api-key-id", dest="aid", help="The API Key ID to revoke.") def revoke(aid): """ Revokes an api key for a user. diff --git a/lemur/api_keys/models.py b/lemur/api_keys/models.py index df77edb1..fbcc3e44 100644 --- a/lemur/api_keys/models.py +++ b/lemur/api_keys/models.py @@ -12,14 +12,19 @@ from lemur.database import db class ApiKey(db.Model): - __tablename__ = 'api_keys' + __tablename__ = "api_keys" id = Column(Integer, primary_key=True) name = Column(String) - user_id = Column(Integer, ForeignKey('users.id')) + user_id = Column(Integer, ForeignKey("users.id")) ttl = Column(BigInteger) issued_at = Column(BigInteger) revoked = Column(Boolean) def __repr__(self): return "ApiKey(name={name}, user_id={user_id}, ttl={ttl}, issued_at={iat}, revoked={revoked})".format( - user_id=self.user_id, name=self.name, ttl=self.ttl, iat=self.issued_at, revoked=self.revoked) + user_id=self.user_id, + name=self.name, + ttl=self.ttl, + iat=self.issued_at, + revoked=self.revoked, + ) diff --git a/lemur/api_keys/schemas.py b/lemur/api_keys/schemas.py index a3c11417..e690b859 100644 --- a/lemur/api_keys/schemas.py +++ b/lemur/api_keys/schemas.py @@ -13,12 +13,18 @@ from lemur.users.schemas import UserNestedOutputSchema, UserInputSchema def current_user_id(): - return {'id': g.current_user.id, 'email': g.current_user.email, 'username': g.current_user.username} + return { + "id": g.current_user.id, + "email": g.current_user.email, + "username": g.current_user.username, + } class ApiKeyInputSchema(LemurInputSchema): name = fields.String(required=False) - user = fields.Nested(UserInputSchema, missing=current_user_id, default=current_user_id) + user = fields.Nested( + UserInputSchema, missing=current_user_id, default=current_user_id + ) ttl = fields.Integer() diff --git a/lemur/api_keys/service.py b/lemur/api_keys/service.py index 5ddb8a3a..ea681a62 100644 --- a/lemur/api_keys/service.py +++ b/lemur/api_keys/service.py @@ -34,7 +34,7 @@ def revoke(aid): :return: """ api_key = get(aid) - setattr(api_key, 'revoked', False) + setattr(api_key, "revoked", False) return database.update(api_key) @@ -80,10 +80,10 @@ def render(args): :return: """ query = database.session_query(ApiKey) - user_id = args.pop('user_id', None) - aid = args.pop('id', None) - has_permission = args.pop('has_permission', False) - requesting_user_id = args.pop('requesting_user_id') + user_id = args.pop("user_id", None) + aid = args.pop("id", None) + has_permission = args.pop("has_permission", False) + requesting_user_id = args.pop("requesting_user_id") if user_id: query = query.filter(ApiKey.user_id == user_id) diff --git a/lemur/api_keys/views.py b/lemur/api_keys/views.py index b7af2944..ee09d3f7 100644 --- a/lemur/api_keys/views.py +++ b/lemur/api_keys/views.py @@ -19,10 +19,16 @@ from lemur.auth.permissions import ApiKeyCreatorPermission from lemur.common.schema import validate_schema from lemur.common.utils import paginated_parser -from lemur.api_keys.schemas import api_key_input_schema, api_key_revoke_schema, api_key_output_schema, \ - api_keys_output_schema, api_key_described_output_schema, user_api_key_input_schema +from lemur.api_keys.schemas import ( + api_key_input_schema, + api_key_revoke_schema, + api_key_output_schema, + api_keys_output_schema, + api_key_described_output_schema, + user_api_key_input_schema, +) -mod = Blueprint('api_keys', __name__) +mod = Blueprint("api_keys", __name__) api = Api(mod) @@ -81,8 +87,8 @@ class ApiKeyList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['has_permission'] = ApiKeyCreatorPermission().can() - args['requesting_user_id'] = g.current_user.id + args["has_permission"] = ApiKeyCreatorPermission().can() + args["requesting_user_id"] = g.current_user.id return service.render(args) @validate_schema(api_key_input_schema, api_key_output_schema) @@ -124,12 +130,26 @@ class ApiKeyList(AuthenticatedResource): :statuscode 403: unauthenticated """ if not ApiKeyCreatorPermission().can(): - if data['user']['id'] != g.current_user.id: - return dict(message="You are not authorized to create tokens for: {0}".format(data['user']['username'])), 403 + if data["user"]["id"] != g.current_user.id: + return ( + dict( + message="You are not authorized to create tokens for: {0}".format( + data["user"]["username"] + ) + ), + 403, + ) - access_token = service.create(name=data['name'], user_id=data['user']['id'], ttl=data['ttl'], - revoked=False, issued_at=int(datetime.utcnow().timestamp())) - return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)) + access_token = service.create( + name=data["name"], + user_id=data["user"]["id"], + ttl=data["ttl"], + revoked=False, + issued_at=int(datetime.utcnow().timestamp()), + ) + return dict( + jwt=create_token(access_token.user_id, access_token.id, access_token.ttl) + ) class ApiKeyUserList(AuthenticatedResource): @@ -186,9 +206,9 @@ class ApiKeyUserList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['has_permission'] = ApiKeyCreatorPermission().can() - args['requesting_user_id'] = g.current_user.id - args['user_id'] = user_id + args["has_permission"] = ApiKeyCreatorPermission().can() + args["requesting_user_id"] = g.current_user.id + args["user_id"] = user_id return service.render(args) @validate_schema(user_api_key_input_schema, api_key_output_schema) @@ -230,11 +250,25 @@ class ApiKeyUserList(AuthenticatedResource): """ if not ApiKeyCreatorPermission().can(): if user_id != g.current_user.id: - return dict(message="You are not authorized to create tokens for: {0}".format(user_id)), 403 + return ( + dict( + message="You are not authorized to create tokens for: {0}".format( + user_id + ) + ), + 403, + ) - access_token = service.create(name=data['name'], user_id=user_id, ttl=data['ttl'], - revoked=False, issued_at=int(datetime.utcnow().timestamp())) - return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)) + access_token = service.create( + name=data["name"], + user_id=user_id, + ttl=data["ttl"], + revoked=False, + issued_at=int(datetime.utcnow().timestamp()), + ) + return dict( + jwt=create_token(access_token.user_id, access_token.id, access_token.ttl) + ) class ApiKeys(AuthenticatedResource): @@ -329,7 +363,9 @@ class ApiKeys(AuthenticatedResource): if not ApiKeyCreatorPermission().can(): return dict(message="You are not authorized to update this token!"), 403 - service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl']) + service.update( + access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"] + ) return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl)) def delete(self, aid): @@ -371,7 +407,7 @@ class ApiKeys(AuthenticatedResource): return dict(message="You are not authorized to delete this token!"), 403 service.delete(access_key) - return {'result': True} + return {"result": True} class UserApiKeys(AuthenticatedResource): @@ -472,7 +508,9 @@ class UserApiKeys(AuthenticatedResource): if access_key.user_id != uid: return dict(message="You are not authorized to update this token!"), 403 - service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl']) + service.update( + access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"] + ) return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl)) def delete(self, uid, aid): @@ -517,7 +555,7 @@ class UserApiKeys(AuthenticatedResource): return dict(message="You are not authorized to delete this token!"), 403 service.delete(access_key) - return {'result': True} + return {"result": True} class ApiKeysDescribed(AuthenticatedResource): @@ -572,8 +610,12 @@ class ApiKeysDescribed(AuthenticatedResource): return access_key -api.add_resource(ApiKeyList, '/keys', endpoint='api_keys') -api.add_resource(ApiKeys, '/keys/', endpoint='api_key') -api.add_resource(ApiKeysDescribed, '/keys//described', endpoint='api_key_described') -api.add_resource(ApiKeyUserList, '/users//keys', endpoint='user_api_keys') -api.add_resource(UserApiKeys, '/users//keys/', endpoint='user_api_key') +api.add_resource(ApiKeyList, "/keys", endpoint="api_keys") +api.add_resource(ApiKeys, "/keys/", endpoint="api_key") +api.add_resource( + ApiKeysDescribed, "/keys//described", endpoint="api_key_described" +) +api.add_resource(ApiKeyUserList, "/users//keys", endpoint="user_api_keys") +api.add_resource( + UserApiKeys, "/users//keys/", endpoint="user_api_key" +) diff --git a/lemur/auth/ldap.py b/lemur/auth/ldap.py index 7eded060..f4ceab03 100644 --- a/lemur/auth/ldap.py +++ b/lemur/auth/ldap.py @@ -14,35 +14,41 @@ from lemur.roles import service as role_service from lemur.common.utils import validate_conf, get_psuedo_random_string -class LdapPrincipal(): +class LdapPrincipal: """ Provides methods for authenticating against an LDAP server. """ + def __init__(self, args): self._ldap_validate_conf() # setup ldap config - if not args['username']: + if not args["username"]: raise Exception("missing ldap username") - if not args['password']: + if not args["password"]: self.error_message = "missing ldap password" raise Exception("missing ldap password") - self.ldap_principal = args['username'] + self.ldap_principal = args["username"] self.ldap_email_domain = current_app.config.get("LDAP_EMAIL_DOMAIN", None) - if '@' not in self.ldap_principal: - self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain) - self.ldap_username = args['username'] - if '@' in self.ldap_username: - self.ldap_username = args['username'].split("@")[0] - self.ldap_password = args['password'] - self.ldap_server = current_app.config.get('LDAP_BIND_URI', None) + if "@" not in self.ldap_principal: + self.ldap_principal = "%s@%s" % ( + self.ldap_principal, + self.ldap_email_domain, + ) + self.ldap_username = args["username"] + if "@" in self.ldap_username: + self.ldap_username = args["username"].split("@")[0] + self.ldap_password = args["password"] + self.ldap_server = current_app.config.get("LDAP_BIND_URI", None) self.ldap_base_dn = current_app.config.get("LDAP_BASE_DN", None) self.ldap_use_tls = current_app.config.get("LDAP_USE_TLS", False) self.ldap_cacert_file = current_app.config.get("LDAP_CACERT_FILE", None) self.ldap_default_role = current_app.config.get("LEMUR_DEFAULT_ROLE", None) self.ldap_required_group = current_app.config.get("LDAP_REQUIRED_GROUP", None) self.ldap_groups_to_roles = current_app.config.get("LDAP_GROUPS_TO_ROLES", None) - self.ldap_is_active_directory = current_app.config.get("LDAP_IS_ACTIVE_DIRECTORY", False) - self.ldap_attrs = ['memberOf'] + self.ldap_is_active_directory = current_app.config.get( + "LDAP_IS_ACTIVE_DIRECTORY", False + ) + self.ldap_attrs = ["memberOf"] self.ldap_client = None self.ldap_groups = None @@ -60,8 +66,8 @@ class LdapPrincipal(): get_psuedo_random_string(), self.ldap_principal, True, - '', # thumbnailPhotoUrl - list(roles) + "", # thumbnailPhotoUrl + list(roles), ) else: # we add 'lemur' specific roles, so they do not get marked as removed @@ -76,7 +82,7 @@ class LdapPrincipal(): self.ldap_principal, user.active, user.profile_picture, - list(roles) + list(roles), ) return user @@ -105,9 +111,12 @@ class LdapPrincipal(): # update their 'roles' role = role_service.get_by_name(self.ldap_principal) if not role: - description = "auto generated role based on owner: {0}".format(self.ldap_principal) - role = role_service.create(self.ldap_principal, description=description, - third_party=True) + description = "auto generated role based on owner: {0}".format( + self.ldap_principal + ) + role = role_service.create( + self.ldap_principal, description=description, third_party=True + ) if not role.third_party: role = role_service.set_third_party(role.id, third_party_status=True) roles.add(role) @@ -118,9 +127,15 @@ class LdapPrincipal(): role = role_service.get_by_name(role_name) if role: if ldap_group_name in self.ldap_groups: - current_app.logger.debug("assigning role {0} to ldap user {1}".format(self.ldap_principal, role)) + current_app.logger.debug( + "assigning role {0} to ldap user {1}".format( + self.ldap_principal, role + ) + ) if not role.third_party: - role = role_service.set_third_party(role.id, third_party_status=True) + role = role_service.set_third_party( + role.id, third_party_status=True + ) roles.add(role) return roles @@ -132,7 +147,7 @@ class LdapPrincipal(): self._bind() roles = self._authorize() if not roles: - raise Exception('ldap authorization failed') + raise Exception("ldap authorization failed") return self._update_user(roles) def _bind(self): @@ -141,9 +156,12 @@ class LdapPrincipal(): list groups for a user. raise an exception on error. """ - if '@' not in self.ldap_principal: - self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain) - ldap_filter = 'userPrincipalName=%s' % self.ldap_principal + if "@" not in self.ldap_principal: + self.ldap_principal = "%s@%s" % ( + self.ldap_principal, + self.ldap_email_domain, + ) + ldap_filter = "userPrincipalName=%s" % self.ldap_principal # query ldap for auth try: @@ -159,37 +177,47 @@ class LdapPrincipal(): self.ldap_client.set_option(ldap.OPT_X_TLS_DEMAND, True) self.ldap_client.set_option(ldap.OPT_DEBUG_LEVEL, 255) if self.ldap_cacert_file: - self.ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file) + self.ldap_client.set_option( + ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file + ) self.ldap_client.simple_bind_s(self.ldap_principal, self.ldap_password) except ldap.INVALID_CREDENTIALS: self.ldap_client.unbind() - raise Exception('The supplied ldap credentials are invalid') + raise Exception("The supplied ldap credentials are invalid") except ldap.SERVER_DOWN: - raise Exception('ldap server unavailable') + raise Exception("ldap server unavailable") except ldap.LDAPError as e: raise Exception("ldap error: {0}".format(e)) if self.ldap_is_active_directory: # Lookup user DN, needed to search for group membership - userdn = self.ldap_client.search_s(self.ldap_base_dn, - ldap.SCOPE_SUBTREE, ldap_filter, - ['distinguishedName'])[0][1]['distinguishedName'][0] - userdn = userdn.decode('utf-8') + userdn = self.ldap_client.search_s( + self.ldap_base_dn, + ldap.SCOPE_SUBTREE, + ldap_filter, + ["distinguishedName"], + )[0][1]["distinguishedName"][0] + userdn = userdn.decode("utf-8") # Search all groups that have the userDN as a member - groupfilter = '(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))'.format(userdn) - lgroups = self.ldap_client.search_s(self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ['cn']) + groupfilter = "(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))".format( + userdn + ) + lgroups = self.ldap_client.search_s( + self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ["cn"] + ) # Create a list of group CN's from the result self.ldap_groups = [] for group in lgroups: (dn, values) = group - self.ldap_groups.append(values['cn'][0].decode('ascii')) + self.ldap_groups.append(values["cn"][0].decode("ascii")) else: - lgroups = self.ldap_client.search_s(self.ldap_base_dn, - ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs)[0][1]['memberOf'] + lgroups = self.ldap_client.search_s( + self.ldap_base_dn, ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs + )[0][1]["memberOf"] # lgroups is a list of utf-8 encoded strings # convert to a single string of groups to allow matching - self.ldap_groups = b''.join(lgroups).decode('ascii') + self.ldap_groups = b"".join(lgroups).decode("ascii") self.ldap_client.unbind() @@ -197,9 +225,5 @@ class LdapPrincipal(): """ Confirms required ldap config settings exist. """ - required_vars = [ - 'LDAP_BIND_URI', - 'LDAP_BASE_DN', - 'LDAP_EMAIL_DOMAIN', - ] + required_vars = ["LDAP_BIND_URI", "LDAP_BASE_DN", "LDAP_EMAIL_DOMAIN"] validate_conf(current_app, required_vars) diff --git a/lemur/auth/permissions.py b/lemur/auth/permissions.py index 68c48773..c3c57356 100644 --- a/lemur/auth/permissions.py +++ b/lemur/auth/permissions.py @@ -12,21 +12,21 @@ from collections import namedtuple from flask_principal import Permission, RoleNeed # Permissions -operator_permission = Permission(RoleNeed('operator')) -admin_permission = Permission(RoleNeed('admin')) +operator_permission = Permission(RoleNeed("operator")) +admin_permission = Permission(RoleNeed("admin")) -CertificateOwner = namedtuple('certificate', ['method', 'value']) -CertificateOwnerNeed = partial(CertificateOwner, 'role') +CertificateOwner = namedtuple("certificate", ["method", "value"]) +CertificateOwnerNeed = partial(CertificateOwner, "role") class SensitiveDomainPermission(Permission): def __init__(self): - super(SensitiveDomainPermission, self).__init__(RoleNeed('admin')) + super(SensitiveDomainPermission, self).__init__(RoleNeed("admin")) class CertificatePermission(Permission): def __init__(self, owner, roles): - needs = [RoleNeed('admin'), RoleNeed(owner), RoleNeed('creator')] + needs = [RoleNeed("admin"), RoleNeed(owner), RoleNeed("creator")] for r in roles: needs.append(CertificateOwnerNeed(str(r))) # Backwards compatibility with mixed-case role names @@ -38,29 +38,29 @@ class CertificatePermission(Permission): class ApiKeyCreatorPermission(Permission): def __init__(self): - super(ApiKeyCreatorPermission, self).__init__(RoleNeed('admin')) + super(ApiKeyCreatorPermission, self).__init__(RoleNeed("admin")) -RoleMember = namedtuple('role', ['method', 'value']) -RoleMemberNeed = partial(RoleMember, 'member') +RoleMember = namedtuple("role", ["method", "value"]) +RoleMemberNeed = partial(RoleMember, "member") class RoleMemberPermission(Permission): def __init__(self, role_id): - needs = [RoleNeed('admin'), RoleMemberNeed(role_id)] + needs = [RoleNeed("admin"), RoleMemberNeed(role_id)] super(RoleMemberPermission, self).__init__(*needs) -AuthorityCreator = namedtuple('authority', ['method', 'value']) -AuthorityCreatorNeed = partial(AuthorityCreator, 'authorityUse') +AuthorityCreator = namedtuple("authority", ["method", "value"]) +AuthorityCreatorNeed = partial(AuthorityCreator, "authorityUse") -AuthorityOwner = namedtuple('authority', ['method', 'value']) -AuthorityOwnerNeed = partial(AuthorityOwner, 'role') +AuthorityOwner = namedtuple("authority", ["method", "value"]) +AuthorityOwnerNeed = partial(AuthorityOwner, "role") class AuthorityPermission(Permission): def __init__(self, authority_id, roles): - needs = [RoleNeed('admin'), AuthorityCreatorNeed(str(authority_id))] + needs = [RoleNeed("admin"), AuthorityCreatorNeed(str(authority_id))] for r in roles: needs.append(AuthorityOwnerNeed(str(r))) diff --git a/lemur/auth/service.py b/lemur/auth/service.py index c862aa2e..0e1521b3 100644 --- a/lemur/auth/service.py +++ b/lemur/auth/service.py @@ -39,13 +39,13 @@ def get_rsa_public_key(n, e): :param e: :return: a RSA Public Key in PEM format """ - n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, 'utf-8'))), 16) - e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, 'utf-8'))), 16) + n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, "utf-8"))), 16) + e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, "utf-8"))), 16) pub = RSAPublicNumbers(e, n).public_key(default_backend()) return pub.public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) @@ -57,28 +57,27 @@ def create_token(user, aid=None, ttl=None): :param user: :return: """ - expiration_delta = timedelta(days=int(current_app.config.get('LEMUR_TOKEN_EXPIRATION', 1))) - payload = { - 'iat': datetime.utcnow(), - 'exp': datetime.utcnow() + expiration_delta - } + expiration_delta = timedelta( + days=int(current_app.config.get("LEMUR_TOKEN_EXPIRATION", 1)) + ) + payload = {"iat": datetime.utcnow(), "exp": datetime.utcnow() + expiration_delta} # Handle Just a User ID & User Object. if isinstance(user, int): - payload['sub'] = user + payload["sub"] = user else: - payload['sub'] = user.id + payload["sub"] = user.id if aid is not None: - payload['aid'] = aid + payload["aid"] = aid # Custom TTLs are only supported on Access Keys. if ttl is not None and aid is not None: # Tokens that are forever until revoked. if ttl == -1: - del payload['exp'] + del payload["exp"] else: - payload['exp'] = ttl - token = jwt.encode(payload, current_app.config['LEMUR_TOKEN_SECRET']) - return token.decode('unicode_escape') + payload["exp"] = ttl + token = jwt.encode(payload, current_app.config["LEMUR_TOKEN_SECRET"]) + return token.decode("unicode_escape") def login_required(f): @@ -88,49 +87,54 @@ def login_required(f): :param f: :return: """ + @wraps(f) def decorated_function(*args, **kwargs): - if not request.headers.get('Authorization'): - response = jsonify(message='Missing authorization header') + if not request.headers.get("Authorization"): + response = jsonify(message="Missing authorization header") response.status_code = 401 return response try: - token = request.headers.get('Authorization').split()[1] + token = request.headers.get("Authorization").split()[1] except Exception as e: - return dict(message='Token is invalid'), 403 + return dict(message="Token is invalid"), 403 try: - payload = jwt.decode(token, current_app.config['LEMUR_TOKEN_SECRET']) + payload = jwt.decode(token, current_app.config["LEMUR_TOKEN_SECRET"]) except jwt.DecodeError: - return dict(message='Token is invalid'), 403 + return dict(message="Token is invalid"), 403 except jwt.ExpiredSignatureError: - return dict(message='Token has expired'), 403 + return dict(message="Token has expired"), 403 except jwt.InvalidTokenError: - return dict(message='Token is invalid'), 403 + return dict(message="Token is invalid"), 403 - if 'aid' in payload: - access_key = api_key_service.get(payload['aid']) + if "aid" in payload: + access_key = api_key_service.get(payload["aid"]) if access_key.revoked: - return dict(message='Token has been revoked'), 403 + return dict(message="Token has been revoked"), 403 if access_key.ttl != -1: current_time = datetime.utcnow() - expired_time = datetime.fromtimestamp(access_key.issued_at + access_key.ttl) + expired_time = datetime.fromtimestamp( + access_key.issued_at + access_key.ttl + ) if current_time >= expired_time: - return dict(message='Token has expired'), 403 + return dict(message="Token has expired"), 403 - user = user_service.get(payload['sub']) + user = user_service.get(payload["sub"]) if not user.active: - return dict(message='User is not currently active'), 403 + return dict(message="User is not currently active"), 403 g.current_user = user if not g.current_user: - return dict(message='You are not logged in'), 403 + return dict(message="You are not logged in"), 403 # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), identity=Identity(g.current_user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(g.current_user.id) + ) return f(*args, **kwargs) @@ -144,18 +148,18 @@ def fetch_token_header(token): :param token: :return: :raise jwt.DecodeError: """ - token = token.encode('utf-8') + token = token.encode("utf-8") try: - signing_input, crypto_segment = token.rsplit(b'.', 1) - header_segment, payload_segment = signing_input.split(b'.', 1) + signing_input, crypto_segment = token.rsplit(b".", 1) + header_segment, payload_segment = signing_input.split(b".", 1) except ValueError: - raise jwt.DecodeError('Not enough segments') + raise jwt.DecodeError("Not enough segments") try: - return json.loads(jwt.utils.base64url_decode(header_segment).decode('utf-8')) + return json.loads(jwt.utils.base64url_decode(header_segment).decode("utf-8")) except TypeError as e: current_app.logger.exception(e) - raise jwt.DecodeError('Invalid header padding') + raise jwt.DecodeError("Invalid header padding") @identity_loaded.connect @@ -174,13 +178,13 @@ def on_identity_loaded(sender, identity): identity.provides.add(UserNeed(identity.id)) # identity with the roles that the user provides - if hasattr(user, 'roles'): + if hasattr(user, "roles"): for role in user.roles: identity.provides.add(RoleNeed(role.name)) identity.provides.add(RoleMemberNeed(role.id)) # apply ownership for authorities - if hasattr(user, 'authorities'): + if hasattr(user, "authorities"): for authority in user.authorities: identity.provides.add(AuthorityCreatorNeed(authority.id)) @@ -191,6 +195,7 @@ class AuthenticatedResource(Resource): """ Inherited by all resources that need to be protected by authentication. """ + method_decorators = [login_required] def __init__(self): diff --git a/lemur/auth/views.py b/lemur/auth/views.py index 0c319b5b..e7f87356 100644 --- a/lemur/auth/views.py +++ b/lemur/auth/views.py @@ -24,11 +24,13 @@ from lemur.auth.service import create_token, fetch_token_header, get_rsa_public_ from lemur.auth import ldap -mod = Blueprint('auth', __name__) +mod = Blueprint("auth", __name__) api = Api(mod) -def exchange_for_access_token(code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True): +def exchange_for_access_token( + code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True +): """ Exchanges authorization code for access token. @@ -43,28 +45,32 @@ def exchange_for_access_token(code, redirect_uri, client_id, secret, access_toke """ # take the information we have received from the provider to create a new request params = { - 'grant_type': 'authorization_code', - 'scope': 'openid email profile address', - 'code': code, - 'redirect_uri': redirect_uri, - 'client_id': client_id + "grant_type": "authorization_code", + "scope": "openid email profile address", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, } # the secret and cliendId will be given to you when you signup for the provider - token = '{0}:{1}'.format(client_id, secret) + token = "{0}:{1}".format(client_id, secret) - basic = base64.b64encode(bytes(token, 'utf-8')) + basic = base64.b64encode(bytes(token, "utf-8")) headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'authorization': 'basic {0}'.format(basic.decode('utf-8')) + "Content-Type": "application/x-www-form-urlencoded", + "authorization": "basic {0}".format(basic.decode("utf-8")), } # exchange authorization code for access token. - r = requests.post(access_token_url, headers=headers, params=params, verify=verify_cert) + r = requests.post( + access_token_url, headers=headers, params=params, verify=verify_cert + ) if r.status_code == 400: - r = requests.post(access_token_url, headers=headers, data=params, verify=verify_cert) - id_token = r.json()['id_token'] - access_token = r.json()['access_token'] + r = requests.post( + access_token_url, headers=headers, data=params, verify=verify_cert + ) + id_token = r.json()["id_token"] + access_token = r.json()["access_token"] return id_token, access_token @@ -83,23 +89,25 @@ def validate_id_token(id_token, client_id, jwks_url): # retrieve the key material as specified by the token header r = requests.get(jwks_url) - for key in r.json()['keys']: - if key['kid'] == header_data['kid']: - secret = get_rsa_public_key(key['n'], key['e']) - algo = header_data['alg'] + for key in r.json()["keys"]: + if key["kid"] == header_data["kid"]: + secret = get_rsa_public_key(key["n"], key["e"]) + algo = header_data["alg"] break else: - return dict(message='Key not found'), 401 + return dict(message="Key not found"), 401 # validate your token based on the key it was signed with try: - jwt.decode(id_token, secret.decode('utf-8'), algorithms=[algo], audience=client_id) + jwt.decode( + id_token, secret.decode("utf-8"), algorithms=[algo], audience=client_id + ) except jwt.DecodeError: - return dict(message='Token is invalid'), 401 + return dict(message="Token is invalid"), 401 except jwt.ExpiredSignatureError: - return dict(message='Token has expired'), 401 + return dict(message="Token has expired"), 401 except jwt.InvalidTokenError: - return dict(message='Token is invalid'), 401 + return dict(message="Token is invalid"), 401 def retrieve_user(user_api_url, access_token): @@ -110,22 +118,18 @@ def retrieve_user(user_api_url, access_token): :param access_token: :return: """ - user_params = dict(access_token=access_token, schema='profile') + user_params = dict(access_token=access_token, schema="profile") headers = {} - if current_app.config.get('PING_INCLUDE_BEARER_TOKEN'): - headers = {'Authorization': f'Bearer {access_token}'} + if current_app.config.get("PING_INCLUDE_BEARER_TOKEN"): + headers = {"Authorization": f"Bearer {access_token}"} # retrieve information about the current user. - r = requests.get( - user_api_url, - params=user_params, - headers=headers, - ) + r = requests.get(user_api_url, params=user_params, headers=headers) profile = r.json() - user = user_service.get_by_email(profile['email']) + user = user_service.get_by_email(profile["email"]) return user, profile @@ -138,31 +142,44 @@ def create_user_roles(profile): roles = [] # update their google 'roles' - if 'googleGroups' in profile: - for group in profile['googleGroups']: + if "googleGroups" in profile: + for group in profile["googleGroups"]: role = role_service.get_by_name(group) if not role: - role = role_service.create(group, description='This is a google group based role created by Lemur', third_party=True) + role = role_service.create( + group, + description="This is a google group based role created by Lemur", + third_party=True, + ) if not role.third_party: role = role_service.set_third_party(role.id, third_party_status=True) roles.append(role) else: - current_app.logger.warning("'googleGroups' not sent by identity provider, no specific roles will assigned to the user.") + current_app.logger.warning( + "'googleGroups' not sent by identity provider, no specific roles will assigned to the user." + ) - role = role_service.get_by_name(profile['email']) + role = role_service.get_by_name(profile["email"]) if not role: - role = role_service.create(profile['email'], description='This is a user specific role', third_party=True) + role = role_service.create( + profile["email"], + description="This is a user specific role", + third_party=True, + ) if not role.third_party: role = role_service.set_third_party(role.id, third_party_status=True) roles.append(role) # every user is an operator (tied to a default role) - if current_app.config.get('LEMUR_DEFAULT_ROLE'): - default = role_service.get_by_name(current_app.config['LEMUR_DEFAULT_ROLE']) + if current_app.config.get("LEMUR_DEFAULT_ROLE"): + default = role_service.get_by_name(current_app.config["LEMUR_DEFAULT_ROLE"]) if not default: - default = role_service.create(current_app.config['LEMUR_DEFAULT_ROLE'], description='This is the default Lemur role.') + default = role_service.create( + current_app.config["LEMUR_DEFAULT_ROLE"], + description="This is the default Lemur role.", + ) if not default.third_party: role_service.set_third_party(default.id, third_party_status=True) roles.append(default) @@ -181,12 +198,12 @@ def update_user(user, profile, roles): # if we get an sso user create them an account if not user: user = user_service.create( - profile['email'], + profile["email"], get_psuedo_random_string(), - profile['email'], + profile["email"], True, - profile.get('thumbnailPhotoUrl'), - roles + profile.get("thumbnailPhotoUrl"), + roles, ) else: @@ -198,11 +215,11 @@ def update_user(user, profile, roles): # update any changes to the user user_service.update( user.id, - profile['email'], - profile['email'], + profile["email"], + profile["email"], True, - profile.get('thumbnailPhotoUrl'), # profile isn't google+ enabled - roles + profile.get("thumbnailPhotoUrl"), # profile isn't google+ enabled + roles, ) @@ -223,6 +240,7 @@ class Login(Resource): on your uses cases but. It is important to not that there is currently no build in method to revoke a users token \ and force re-authentication. """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(Login, self).__init__() @@ -263,23 +281,26 @@ class Login(Resource): :statuscode 401: invalid credentials :statuscode 200: no error """ - self.reqparse.add_argument('username', type=str, required=True, location='json') - self.reqparse.add_argument('password', type=str, required=True, location='json') + self.reqparse.add_argument("username", type=str, required=True, location="json") + self.reqparse.add_argument("password", type=str, required=True, location="json") args = self.reqparse.parse_args() - if '@' in args['username']: - user = user_service.get_by_email(args['username']) + if "@" in args["username"]: + user = user_service.get_by_email(args["username"]) else: - user = user_service.get_by_username(args['username']) + user = user_service.get_by_username(args["username"]) # default to local authentication - if user and user.check_password(args['password']) and user.active: + if user and user.check_password(args["password"]) and user.active: # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), - identity=Identity(user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) # try ldap login @@ -289,19 +310,29 @@ class Login(Resource): user = ldap_principal.authenticate() if user and user.active: # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), - identity=Identity(user.id)) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) + metrics.send( + "login", + "counter", + 1, + metric_tags={"status": SUCCESS_METRIC_STATUS}, + ) return dict(token=create_token(user)) except Exception as e: - current_app.logger.error("ldap error: {0}".format(e)) - ldap_message = 'ldap error: %s' % e - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message=ldap_message), 403 + current_app.logger.error("ldap error: {0}".format(e)) + ldap_message = "ldap error: %s" % e + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message=ldap_message), 403 # if not valid user - no certificates for you - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid"), 403 class Ping(Resource): @@ -314,36 +345,39 @@ class Ping(Resource): provider uses for its callbacks. 2. Add or change the Lemur AngularJS Configuration to point to your new provider """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(Ping, self).__init__() def get(self): - return 'Redirecting...' + return "Redirecting..." def post(self): - self.reqparse.add_argument('clientId', type=str, required=True, location='json') - self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') - self.reqparse.add_argument('code', type=str, required=True, location='json') + self.reqparse.add_argument("clientId", type=str, required=True, location="json") + self.reqparse.add_argument( + "redirectUri", type=str, required=True, location="json" + ) + self.reqparse.add_argument("code", type=str, required=True, location="json") args = self.reqparse.parse_args() # you can either discover these dynamically or simply configure them - access_token_url = current_app.config.get('PING_ACCESS_TOKEN_URL') - user_api_url = current_app.config.get('PING_USER_API_URL') + access_token_url = current_app.config.get("PING_ACCESS_TOKEN_URL") + user_api_url = current_app.config.get("PING_USER_API_URL") - secret = current_app.config.get('PING_SECRET') + secret = current_app.config.get("PING_SECRET") id_token, access_token = exchange_for_access_token( - args['code'], - args['redirectUri'], - args['clientId'], + args["code"], + args["redirectUri"], + args["clientId"], secret, - access_token_url=access_token_url + access_token_url=access_token_url, ) - jwks_url = current_app.config.get('PING_JWKS_URL') - error_code = validate_id_token(id_token, args['clientId'], jwks_url) + jwks_url = current_app.config.get("PING_JWKS_URL") + error_code = validate_id_token(id_token, args["clientId"], jwks_url) if error_code: return error_code user, profile = retrieve_user(user_api_url, access_token) @@ -351,13 +385,19 @@ class Ping(Resource): update_user(user, profile, roles) if not user or not user.active: - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid"), 403 # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), identity=Identity(user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) @@ -367,33 +407,35 @@ class OAuth2(Resource): super(OAuth2, self).__init__() def get(self): - return 'Redirecting...' + return "Redirecting..." def post(self): - self.reqparse.add_argument('clientId', type=str, required=True, location='json') - self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') - self.reqparse.add_argument('code', type=str, required=True, location='json') + self.reqparse.add_argument("clientId", type=str, required=True, location="json") + self.reqparse.add_argument( + "redirectUri", type=str, required=True, location="json" + ) + self.reqparse.add_argument("code", type=str, required=True, location="json") args = self.reqparse.parse_args() # you can either discover these dynamically or simply configure them - access_token_url = current_app.config.get('OAUTH2_ACCESS_TOKEN_URL') - user_api_url = current_app.config.get('OAUTH2_USER_API_URL') - verify_cert = current_app.config.get('OAUTH2_VERIFY_CERT') + access_token_url = current_app.config.get("OAUTH2_ACCESS_TOKEN_URL") + user_api_url = current_app.config.get("OAUTH2_USER_API_URL") + verify_cert = current_app.config.get("OAUTH2_VERIFY_CERT") - secret = current_app.config.get('OAUTH2_SECRET') + secret = current_app.config.get("OAUTH2_SECRET") id_token, access_token = exchange_for_access_token( - args['code'], - args['redirectUri'], - args['clientId'], + args["code"], + args["redirectUri"], + args["clientId"], secret, access_token_url=access_token_url, - verify_cert=verify_cert + verify_cert=verify_cert, ) - jwks_url = current_app.config.get('PING_JWKS_URL') - error_code = validate_id_token(id_token, args['clientId'], jwks_url) + jwks_url = current_app.config.get("PING_JWKS_URL") + error_code = validate_id_token(id_token, args["clientId"], jwks_url) if error_code: return error_code @@ -402,13 +444,19 @@ class OAuth2(Resource): update_user(user, profile, roles) if not user.active: - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid"), 403 # Tell Flask-Principal the identity changed - identity_changed.send(current_app._get_current_object(), identity=Identity(user.id)) + identity_changed.send( + current_app._get_current_object(), identity=Identity(user.id) + ) - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) @@ -419,44 +467,52 @@ class Google(Resource): super(Google, self).__init__() def post(self): - access_token_url = 'https://accounts.google.com/o/oauth2/token' - people_api_url = 'https://www.googleapis.com/plus/v1/people/me/openIdConnect' + access_token_url = "https://accounts.google.com/o/oauth2/token" + people_api_url = "https://www.googleapis.com/plus/v1/people/me/openIdConnect" - self.reqparse.add_argument('clientId', type=str, required=True, location='json') - self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') - self.reqparse.add_argument('code', type=str, required=True, location='json') + self.reqparse.add_argument("clientId", type=str, required=True, location="json") + self.reqparse.add_argument( + "redirectUri", type=str, required=True, location="json" + ) + self.reqparse.add_argument("code", type=str, required=True, location="json") args = self.reqparse.parse_args() # Step 1. Exchange authorization code for access token payload = { - 'client_id': args['clientId'], - 'grant_type': 'authorization_code', - 'redirect_uri': args['redirectUri'], - 'code': args['code'], - 'client_secret': current_app.config.get('GOOGLE_SECRET') + "client_id": args["clientId"], + "grant_type": "authorization_code", + "redirect_uri": args["redirectUri"], + "code": args["code"], + "client_secret": current_app.config.get("GOOGLE_SECRET"), } r = requests.post(access_token_url, data=payload) token = r.json() # Step 2. Retrieve information about the current user - headers = {'Authorization': 'Bearer {0}'.format(token['access_token'])} + headers = {"Authorization": "Bearer {0}".format(token["access_token"])} r = requests.get(people_api_url, headers=headers) profile = r.json() - user = user_service.get_by_email(profile['email']) + user = user_service.get_by_email(profile["email"]) if not (user and user.active): - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - return dict(message='The supplied credentials are invalid.'), 403 + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) + return dict(message="The supplied credentials are invalid."), 403 if user: - metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS} + ) return dict(token=create_token(user)) - metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) + metrics.send( + "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS} + ) class Providers(Resource): @@ -467,47 +523,57 @@ class Providers(Resource): provider = provider.lower() if provider == "google": - active_providers.append({ - 'name': 'google', - 'clientId': current_app.config.get("GOOGLE_CLIENT_ID"), - 'url': api.url_for(Google) - }) + active_providers.append( + { + "name": "google", + "clientId": current_app.config.get("GOOGLE_CLIENT_ID"), + "url": api.url_for(Google), + } + ) elif provider == "ping": - active_providers.append({ - 'name': current_app.config.get("PING_NAME"), - 'url': current_app.config.get('PING_REDIRECT_URI'), - 'redirectUri': current_app.config.get("PING_REDIRECT_URI"), - 'clientId': current_app.config.get("PING_CLIENT_ID"), - 'responseType': 'code', - 'scope': ['openid', 'email', 'profile', 'address'], - 'scopeDelimiter': ' ', - 'authorizationEndpoint': current_app.config.get("PING_AUTH_ENDPOINT"), - 'requiredUrlParams': ['scope'], - 'type': '2.0' - }) + active_providers.append( + { + "name": current_app.config.get("PING_NAME"), + "url": current_app.config.get("PING_REDIRECT_URI"), + "redirectUri": current_app.config.get("PING_REDIRECT_URI"), + "clientId": current_app.config.get("PING_CLIENT_ID"), + "responseType": "code", + "scope": ["openid", "email", "profile", "address"], + "scopeDelimiter": " ", + "authorizationEndpoint": current_app.config.get( + "PING_AUTH_ENDPOINT" + ), + "requiredUrlParams": ["scope"], + "type": "2.0", + } + ) elif provider == "oauth2": - active_providers.append({ - 'name': current_app.config.get("OAUTH2_NAME"), - 'url': current_app.config.get('OAUTH2_REDIRECT_URI'), - 'redirectUri': current_app.config.get("OAUTH2_REDIRECT_URI"), - 'clientId': current_app.config.get("OAUTH2_CLIENT_ID"), - 'responseType': 'code', - 'scope': ['openid', 'email', 'profile', 'groups'], - 'scopeDelimiter': ' ', - 'authorizationEndpoint': current_app.config.get("OAUTH2_AUTH_ENDPOINT"), - 'requiredUrlParams': ['scope', 'state', 'nonce'], - 'state': 'STATE', - 'nonce': get_psuedo_random_string(), - 'type': '2.0' - }) + active_providers.append( + { + "name": current_app.config.get("OAUTH2_NAME"), + "url": current_app.config.get("OAUTH2_REDIRECT_URI"), + "redirectUri": current_app.config.get("OAUTH2_REDIRECT_URI"), + "clientId": current_app.config.get("OAUTH2_CLIENT_ID"), + "responseType": "code", + "scope": ["openid", "email", "profile", "groups"], + "scopeDelimiter": " ", + "authorizationEndpoint": current_app.config.get( + "OAUTH2_AUTH_ENDPOINT" + ), + "requiredUrlParams": ["scope", "state", "nonce"], + "state": "STATE", + "nonce": get_psuedo_random_string(), + "type": "2.0", + } + ) return active_providers -api.add_resource(Login, '/auth/login', endpoint='login') -api.add_resource(Ping, '/auth/ping', endpoint='ping') -api.add_resource(Google, '/auth/google', endpoint='google') -api.add_resource(OAuth2, '/auth/oauth2', endpoint='oauth2') -api.add_resource(Providers, '/auth/providers', endpoint='providers') +api.add_resource(Login, "/auth/login", endpoint="login") +api.add_resource(Ping, "/auth/ping", endpoint="ping") +api.add_resource(Google, "/auth/google", endpoint="google") +api.add_resource(OAuth2, "/auth/oauth2", endpoint="oauth2") +api.add_resource(Providers, "/auth/providers", endpoint="providers") diff --git a/lemur/authorities/models.py b/lemur/authorities/models.py index 6c5f790b..ccd1fab8 100644 --- a/lemur/authorities/models.py +++ b/lemur/authorities/models.py @@ -7,7 +7,17 @@ .. moduleauthor:: Kevin Glisson """ from sqlalchemy.orm import relationship -from sqlalchemy import Column, Integer, String, Text, func, ForeignKey, DateTime, PassiveDefault, Boolean +from sqlalchemy import ( + Column, + Integer, + String, + Text, + func, + ForeignKey, + DateTime, + PassiveDefault, + Boolean, +) from sqlalchemy.dialects.postgresql import JSON from lemur.database import db @@ -16,7 +26,7 @@ from lemur.models import roles_authorities class Authority(db.Model): - __tablename__ = 'authorities' + __tablename__ = "authorities" id = Column(Integer, primary_key=True) owner = Column(String(128), nullable=False) name = Column(String(128), unique=True) @@ -27,22 +37,44 @@ class Authority(db.Model): description = Column(Text) options = Column(JSON) date_created = Column(DateTime, PassiveDefault(func.now()), nullable=False) - roles = relationship('Role', secondary=roles_authorities, passive_deletes=True, backref=db.backref('authority'), lazy='dynamic') - user_id = Column(Integer, ForeignKey('users.id')) - authority_certificate = relationship("Certificate", backref='root_authority', uselist=False, foreign_keys='Certificate.root_authority_id') - certificates = relationship("Certificate", backref='authority', foreign_keys='Certificate.authority_id') + roles = relationship( + "Role", + secondary=roles_authorities, + passive_deletes=True, + backref=db.backref("authority"), + lazy="dynamic", + ) + user_id = Column(Integer, ForeignKey("users.id")) + authority_certificate = relationship( + "Certificate", + backref="root_authority", + uselist=False, + foreign_keys="Certificate.root_authority_id", + ) + certificates = relationship( + "Certificate", backref="authority", foreign_keys="Certificate.authority_id" + ) - authority_pending_certificate = relationship("PendingCertificate", backref='root_authority', uselist=False, foreign_keys='PendingCertificate.root_authority_id') - pending_certificates = relationship('PendingCertificate', backref='authority', foreign_keys='PendingCertificate.authority_id') + authority_pending_certificate = relationship( + "PendingCertificate", + backref="root_authority", + uselist=False, + foreign_keys="PendingCertificate.root_authority_id", + ) + pending_certificates = relationship( + "PendingCertificate", + backref="authority", + foreign_keys="PendingCertificate.authority_id", + ) def __init__(self, **kwargs): - self.owner = kwargs['owner'] - self.roles = kwargs.get('roles', []) - self.name = kwargs.get('name') - self.description = kwargs.get('description') - self.authority_certificate = kwargs['authority_certificate'] - self.plugin_name = kwargs['plugin']['slug'] - self.options = kwargs.get('options') + self.owner = kwargs["owner"] + self.roles = kwargs.get("roles", []) + self.name = kwargs.get("name") + self.description = kwargs.get("description") + self.authority_certificate = kwargs["authority_certificate"] + self.plugin_name = kwargs["plugin"]["slug"] + self.options = kwargs.get("options") @property def plugin(self): diff --git a/lemur/authorities/schemas.py b/lemur/authorities/schemas.py index d1f0adfc..c78aec94 100644 --- a/lemur/authorities/schemas.py +++ b/lemur/authorities/schemas.py @@ -11,7 +11,13 @@ from marshmallow import fields, validates_schema, pre_load from marshmallow import validate from marshmallow.exceptions import ValidationError -from lemur.schemas import PluginInputSchema, PluginOutputSchema, ExtensionSchema, AssociatedAuthoritySchema, AssociatedRoleSchema +from lemur.schemas import ( + PluginInputSchema, + PluginOutputSchema, + ExtensionSchema, + AssociatedAuthoritySchema, + AssociatedRoleSchema, +) from lemur.users.schemas import UserNestedOutputSchema from lemur.common.schema import LemurInputSchema, LemurOutputSchema from lemur.common import validators, missing @@ -30,21 +36,36 @@ class AuthorityInputSchema(LemurInputSchema): validity_years = fields.Integer() # certificate body fields - organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT')) - organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION')) - location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION')) - country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY')) - state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE')) + organizational_unit = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT") + ) + organization = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION") + ) + location = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION") + ) + country = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY") + ) + state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE")) plugin = fields.Nested(PluginInputSchema) # signing related options - type = fields.String(validate=validate.OneOf(['root', 'subca']), missing='root') + type = fields.String(validate=validate.OneOf(["root", "subca"]), missing="root") parent = fields.Nested(AssociatedAuthoritySchema) - signing_algorithm = fields.String(validate=validate.OneOf(['sha256WithRSA', 'sha1WithRSA']), missing='sha256WithRSA') - key_type = fields.String(validate=validate.OneOf(['RSA2048', 'RSA4096']), missing='RSA2048') + signing_algorithm = fields.String( + validate=validate.OneOf(["sha256WithRSA", "sha1WithRSA"]), + missing="sha256WithRSA", + ) + key_type = fields.String( + validate=validate.OneOf(["RSA2048", "RSA4096"]), missing="RSA2048" + ) key_name = fields.String() - sensitivity = fields.String(validate=validate.OneOf(['medium', 'high']), missing='medium') + sensitivity = fields.String( + validate=validate.OneOf(["medium", "high"]), missing="medium" + ) serial_number = fields.Integer() first_serial = fields.Integer(missing=1) @@ -58,9 +79,11 @@ class AuthorityInputSchema(LemurInputSchema): @validates_schema def validate_subca(self, data): - if data['type'] == 'subca': - if not data.get('parent'): - raise ValidationError("If generating a subca, parent 'authority' must be specified.") + if data["type"] == "subca": + if not data.get("parent"): + raise ValidationError( + "If generating a subca, parent 'authority' must be specified." + ) @pre_load def ensure_dates(self, data): diff --git a/lemur/authorities/service.py b/lemur/authorities/service.py index 41c381e3..c70c6fc5 100644 --- a/lemur/authorities/service.py +++ b/lemur/authorities/service.py @@ -43,7 +43,7 @@ def mint(**kwargs): """ Creates the authority based on the plugin provided. """ - issuer = kwargs['plugin']['plugin_object'] + issuer = kwargs["plugin"]["plugin_object"] values = issuer.create_authority(kwargs) # support older plugins @@ -53,7 +53,12 @@ def mint(**kwargs): elif len(values) == 4: body, private_key, chain, roles = values - roles = create_authority_roles(roles, kwargs['owner'], kwargs['plugin']['plugin_object'].title, kwargs['creator']) + roles = create_authority_roles( + roles, + kwargs["owner"], + kwargs["plugin"]["plugin_object"].title, + kwargs["creator"], + ) return body, private_key, chain, roles @@ -66,16 +71,17 @@ def create_authority_roles(roles, owner, plugin_title, creator): """ role_objs = [] for r in roles: - role = role_service.get_by_name(r['name']) + role = role_service.get_by_name(r["name"]) if not role: role = role_service.create( - r['name'], - password=r['password'], + r["name"], + password=r["password"], description="Auto generated role for {0}".format(plugin_title), - username=r['username']) + username=r["username"], + ) # the user creating the authority should be able to administer it - if role.username == 'admin': + if role.username == "admin": creator.roles.append(role) role_objs.append(role) @@ -84,8 +90,7 @@ def create_authority_roles(roles, owner, plugin_title, creator): owner_role = role_service.get_by_name(owner) if not owner_role: owner_role = role_service.create( - owner, - description="Auto generated role based on owner: {0}".format(owner) + owner, description="Auto generated role based on owner: {0}".format(owner) ) role_objs.append(owner_role) @@ -98,27 +103,29 @@ def create(**kwargs): """ body, private_key, chain, roles = mint(**kwargs) - kwargs['creator'].roles = list(set(list(kwargs['creator'].roles) + roles)) + kwargs["creator"].roles = list(set(list(kwargs["creator"].roles) + roles)) - kwargs['body'] = body - kwargs['private_key'] = private_key - kwargs['chain'] = chain + kwargs["body"] = body + kwargs["private_key"] = private_key + kwargs["chain"] = chain - if kwargs.get('roles'): - kwargs['roles'] += roles + if kwargs.get("roles"): + kwargs["roles"] += roles else: - kwargs['roles'] = roles + kwargs["roles"] = roles cert = upload(**kwargs) - kwargs['authority_certificate'] = cert - if kwargs.get('plugin', {}).get('plugin_options', []): - kwargs['options'] = json.dumps(kwargs['plugin']['plugin_options']) + kwargs["authority_certificate"] = cert + if kwargs.get("plugin", {}).get("plugin_options", []): + kwargs["options"] = json.dumps(kwargs["plugin"]["plugin_options"]) authority = Authority(**kwargs) authority = database.create(authority) - kwargs['creator'].authorities.append(authority) + kwargs["creator"].authorities.append(authority) - metrics.send('authority_created', 'counter', 1, metric_tags=dict(owner=authority.owner)) + metrics.send( + "authority_created", "counter", 1, metric_tags=dict(owner=authority.owner) + ) return authority @@ -150,7 +157,7 @@ def get_by_name(authority_name): :param authority_name: :return: """ - return database.get(Authority, authority_name, field='name') + return database.get(Authority, authority_name, field="name") def get_authority_role(ca_name, creator=None): @@ -173,29 +180,31 @@ def render(args): :return: """ query = database.session_query(Authority) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') - if 'active' in filt: + terms = filt.split(";") + if "active" in filt: query = query.filter(Authority.active == truthiness(terms[1])) - elif 'cn' in filt: - term = '%{0}%'.format(terms[1]) - sub_query = database.session_query(Certificate.root_authority_id) \ - .filter(Certificate.cn.ilike(term)) \ + elif "cn" in filt: + term = "%{0}%".format(terms[1]) + sub_query = ( + database.session_query(Certificate.root_authority_id) + .filter(Certificate.cn.ilike(term)) .subquery() + ) query = query.filter(Authority.id.in_(sub_query)) else: query = database.filter(query, Authority, terms) # we make sure that a user can only use an authority they either own are a member of - admins can see all - if not args['user'].is_admin: + if not args["user"].is_admin: authority_ids = [] - for authority in args['user'].authorities: + for authority in args["user"].authorities: authority_ids.append(authority.id) - for role in args['user'].roles: + for role in args["user"].roles: for authority in role.authorities: authority_ids.append(authority.id) query = query.filter(Authority.id.in_(authority_ids)) diff --git a/lemur/authorities/views.py b/lemur/authorities/views.py index b85c9b70..49bce63e 100644 --- a/lemur/authorities/views.py +++ b/lemur/authorities/views.py @@ -16,15 +16,21 @@ from lemur.auth.permissions import AuthorityPermission from lemur.certificates import service as certificate_service from lemur.authorities import service -from lemur.authorities.schemas import authority_input_schema, authority_output_schema, authorities_output_schema, authority_update_schema +from lemur.authorities.schemas import ( + authority_input_schema, + authority_output_schema, + authorities_output_schema, + authority_update_schema, +) -mod = Blueprint('authorities', __name__) +mod = Blueprint("authorities", __name__) api = Api(mod) class AuthoritiesList(AuthenticatedResource): """ Defines the 'authorities' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(AuthoritiesList, self).__init__() @@ -107,7 +113,7 @@ class AuthoritiesList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['user'] = g.current_user + args["user"] = g.current_user return service.render(args) @validate_schema(authority_input_schema, authority_output_schema) @@ -220,7 +226,7 @@ class AuthoritiesList(AuthenticatedResource): :statuscode 403: unauthenticated :statuscode 200: no error """ - data['creator'] = g.current_user + data["creator"] = g.current_user return service.create(**data) @@ -388,7 +394,7 @@ class Authorities(AuthenticatedResource): authority = service.get(authority_id) if not authority: - return dict(message='Not Found'), 404 + return dict(message="Not Found"), 404 # all the authority role members should be allowed roles = [x.name for x in authority.roles] @@ -397,10 +403,10 @@ class Authorities(AuthenticatedResource): if permission.can(): return service.update( authority_id, - owner=data['owner'], - description=data['description'], - active=data['active'], - roles=data['roles'] + owner=data["owner"], + description=data["description"], + active=data["active"], + roles=data["roles"], ) return dict(message="You are not authorized to update this authority."), 403 @@ -505,10 +511,21 @@ class AuthorityVisualizations(AuthenticatedResource): ]} """ authority = service.get(authority_id) - return dict(name=authority.name, children=[{"name": c.name} for c in authority.certificates]) + return dict( + name=authority.name, + children=[{"name": c.name} for c in authority.certificates], + ) -api.add_resource(AuthoritiesList, '/authorities', endpoint='authorities') -api.add_resource(Authorities, '/authorities/', endpoint='authority') -api.add_resource(AuthorityVisualizations, '/authorities//visualize', endpoint='authority_visualizations') -api.add_resource(CertificateAuthority, '/certificates//authority', endpoint='certificateAuthority') +api.add_resource(AuthoritiesList, "/authorities", endpoint="authorities") +api.add_resource(Authorities, "/authorities/", endpoint="authority") +api.add_resource( + AuthorityVisualizations, + "/authorities//visualize", + endpoint="authority_visualizations", +) +api.add_resource( + CertificateAuthority, + "/certificates//authority", + endpoint="certificateAuthority", +) diff --git a/lemur/authorizations/models.py b/lemur/authorizations/models.py index d30de7ed..04ac0508 100644 --- a/lemur/authorizations/models.py +++ b/lemur/authorizations/models.py @@ -13,7 +13,7 @@ from lemur.plugins.base import plugins class Authorization(db.Model): - __tablename__ = 'pending_dns_authorizations' + __tablename__ = "pending_dns_authorizations" id = Column(Integer, primary_key=True, autoincrement=True) account_number = Column(String(128)) domains = Column(JSONType) diff --git a/lemur/certificates/cli.py b/lemur/certificates/cli.py index 04b8ec9a..b57ff175 100644 --- a/lemur/certificates/cli.py +++ b/lemur/certificates/cli.py @@ -34,7 +34,7 @@ from lemur.certificates.service import ( get_all_pending_reissue, get_by_name, get_all_certs, - get + get, ) from lemur.certificates.verify import verify_string @@ -56,11 +56,14 @@ def print_certificate_details(details): "\t[+] Authority: {authority_name}\n" "\t[+] Validity Start: {validity_start}\n" "\t[+] Validity End: {validity_end}\n".format( - common_name=details['commonName'], - sans=",".join(x['value'] for x in details['extensions']['subAltNames']['names']) or None, - authority_name=details['authority']['name'], - validity_start=details['validityStart'], - validity_end=details['validityEnd'] + common_name=details["commonName"], + sans=",".join( + x["value"] for x in details["extensions"]["subAltNames"]["names"] + ) + or None, + authority_name=details["authority"]["name"], + validity_start=details["validityStart"], + validity_end=details["validityEnd"], ) ) @@ -120,13 +123,11 @@ def request_rotation(endpoint, certificate, message, commit): except Exception as e: print( "[!] Failed to rotate endpoint {0} to certificate {1} reason: {2}".format( - endpoint.name, - certificate.name, - e + endpoint.name, certificate.name, e ) ) - metrics.send('endpoint_rotation', 'counter', 1, metric_tags={'status': status}) + metrics.send("endpoint_rotation", "counter", 1, metric_tags={"status": status}) def request_reissue(certificate, commit): @@ -154,17 +155,52 @@ def request_reissue(certificate, commit): except Exception as e: sentry.captureException(extra={"certificate_name": str(certificate.name)}) - current_app.logger.exception(f"Error reissuing certificate: {certificate.name}", exc_info=True) + current_app.logger.exception( + f"Error reissuing certificate: {certificate.name}", exc_info=True + ) print(f"[!] Failed to reissue certificate: {certificate.name}. Reason: {e}") - metrics.send('certificate_reissue', 'counter', 1, metric_tags={'status': status, 'certificate': certificate.name}) + metrics.send( + "certificate_reissue", + "counter", + 1, + metric_tags={"status": status, "certificate": certificate.name}, + ) -@manager.option('-e', '--endpoint', dest='endpoint_name', help='Name of the endpoint you wish to rotate.') -@manager.option('-n', '--new-certificate', dest='new_certificate_name', help='Name of the certificate you wish to rotate to.') -@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to rotate.') -@manager.option('-a', '--notify', dest='message', action='store_true', help='Send a rotation notification to the certificates owner.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-e", + "--endpoint", + dest="endpoint_name", + help="Name of the endpoint you wish to rotate.", +) +@manager.option( + "-n", + "--new-certificate", + dest="new_certificate_name", + help="Name of the certificate you wish to rotate to.", +) +@manager.option( + "-o", + "--old-certificate", + dest="old_certificate_name", + help="Name of the certificate you wish to rotate.", +) +@manager.option( + "-a", + "--notify", + dest="message", + action="store_true", + help="Send a rotation notification to the certificates owner.", +) +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, commit): """ Rotates an endpoint and reissues it if it has not already been replaced. If it has @@ -183,7 +219,9 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c endpoint = validate_endpoint(endpoint_name) if endpoint and new_cert: - print(f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}") + print( + f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}" + ) request_rotation(endpoint, new_cert, message, commit) elif old_cert and new_cert: @@ -197,16 +235,27 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c print("[+] Rotating all endpoints that have new certificates available") for endpoint in endpoint_service.get_all_pending_rotation(): if len(endpoint.certificate.replaced) == 1: - print(f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}") - request_rotation(endpoint, endpoint.certificate.replaced[0], message, commit) + print( + f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}" + ) + request_rotation( + endpoint, endpoint.certificate.replaced[0], message, commit + ) else: - metrics.send('endpoint_rotation', 'counter', 1, metric_tags={ - 'status': FAILURE_METRIC_STATUS, - "old_certificate_name": str(old_cert), - "new_certificate_name": str(endpoint.certificate.replaced[0].name), - "endpoint_name": str(endpoint.name), - "message": str(message), - }) + metrics.send( + "endpoint_rotation", + "counter", + 1, + metric_tags={ + "status": FAILURE_METRIC_STATUS, + "old_certificate_name": str(old_cert), + "new_certificate_name": str( + endpoint.certificate.replaced[0].name + ), + "endpoint_name": str(endpoint.name), + "message": str(message), + }, + ) print( f"[!] Failed to rotate endpoint {endpoint.name} reason: " "Multiple replacement certificates found." @@ -222,20 +271,38 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c "new_certificate_name": str(new_certificate_name), "endpoint_name": str(endpoint_name), "message": str(message), - }) + } + ) - metrics.send('endpoint_rotation_job', 'counter', 1, metric_tags={ - "status": status, - "old_certificate_name": str(old_certificate_name), - "new_certificate_name": str(new_certificate_name), - "endpoint_name": str(endpoint_name), - "message": str(message), - "endpoint": str(globals().get("endpoint")) - }) + metrics.send( + "endpoint_rotation_job", + "counter", + 1, + metric_tags={ + "status": status, + "old_certificate_name": str(old_certificate_name), + "new_certificate_name": str(new_certificate_name), + "endpoint_name": str(endpoint_name), + "message": str(message), + "endpoint": str(globals().get("endpoint")), + }, + ) -@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to reissue.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-o", + "--old-certificate", + dest="old_certificate_name", + help="Name of the certificate you wish to reissue.", +) +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def reissue(old_certificate_name, commit): """ Reissues certificate with the same parameters as it was originally issued with. @@ -263,76 +330,94 @@ def reissue(old_certificate_name, commit): except Exception as e: sentry.captureException() current_app.logger.exception("Error reissuing certificate.", exc_info=True) - print( - "[!] Failed to reissue certificates. Reason: {}".format( - e - ) - ) + print("[!] Failed to reissue certificates. Reason: {}".format(e)) - metrics.send('certificate_reissue_job', 'counter', 1, metric_tags={'status': status}) + metrics.send( + "certificate_reissue_job", "counter", 1, metric_tags={"status": status} + ) -@manager.option('-f', '--fqdns', dest='fqdns', help='FQDNs to query. Multiple fqdns specified via comma.') -@manager.option('-i', '--issuer', dest='issuer', help='Issuer to query for.') -@manager.option('-o', '--owner', dest='owner', help='Owner to query for.') -@manager.option('-e', '--expired', dest='expired', type=bool, default=False, help='Include expired certificates.') +@manager.option( + "-f", + "--fqdns", + dest="fqdns", + help="FQDNs to query. Multiple fqdns specified via comma.", +) +@manager.option("-i", "--issuer", dest="issuer", help="Issuer to query for.") +@manager.option("-o", "--owner", dest="owner", help="Owner to query for.") +@manager.option( + "-e", + "--expired", + dest="expired", + type=bool, + default=False, + help="Include expired certificates.", +) def query(fqdns, issuer, owner, expired): """Prints certificates that match the query params.""" table = [] q = database.session_query(Certificate) if issuer: - sub_query = database.session_query(Authority.id) \ - .filter(Authority.name.ilike('%{0}%'.format(issuer))) \ + sub_query = ( + database.session_query(Authority.id) + .filter(Authority.name.ilike("%{0}%".format(issuer))) .subquery() + ) q = q.filter( or_( - Certificate.issuer.ilike('%{0}%'.format(issuer)), - Certificate.authority_id.in_(sub_query) + Certificate.issuer.ilike("%{0}%".format(issuer)), + Certificate.authority_id.in_(sub_query), ) ) if owner: - q = q.filter(Certificate.owner.ilike('%{0}%'.format(owner))) + q = q.filter(Certificate.owner.ilike("%{0}%".format(owner))) if not expired: q = q.filter(Certificate.expired == False) # noqa if fqdns: - for f in fqdns.split(','): + for f in fqdns.split(","): q = q.filter( or_( - Certificate.cn.ilike('%{0}%'.format(f)), - Certificate.domains.any(Domain.name.ilike('%{0}%'.format(f))) + Certificate.cn.ilike("%{0}%".format(f)), + Certificate.domains.any(Domain.name.ilike("%{0}%".format(f))), ) ) for c in q.all(): table.append([c.id, c.name, c.owner, c.issuer]) - print(tabulate(table, headers=['Id', 'Name', 'Owner', 'Issuer'], tablefmt='csv')) + print(tabulate(table, headers=["Id", "Name", "Owner", "Issuer"], tablefmt="csv")) def worker(data, commit, reason): - parts = [x for x in data.split(' ') if x] + parts = [x for x in data.split(" ") if x] try: cert = get(int(parts[0].strip())) plugin = plugins.get(cert.authority.plugin_name) - print('[+] Revoking certificate. Id: {0} Name: {1}'.format(cert.id, cert.name)) + print("[+] Revoking certificate. Id: {0} Name: {1}".format(cert.id, cert.name)) if commit: plugin.revoke_certificate(cert, reason) - metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) + metrics.send( + "certificate_revoke", + "counter", + 1, + metric_tags={"status": SUCCESS_METRIC_STATUS}, + ) except Exception as e: sentry.captureException() - metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) - print( - "[!] Failed to revoke certificates. Reason: {}".format( - e - ) + metrics.send( + "certificate_revoke", + "counter", + 1, + metric_tags={"status": FAILURE_METRIC_STATUS}, ) + print("[!] Failed to revoke certificates. Reason: {}".format(e)) @manager.command @@ -341,13 +426,22 @@ def clear_pending(): Function clears all pending certificates. :return: """ - v = plugins.get('verisign-issuer') + v = plugins.get("verisign-issuer") v.clear_pending_certificates() -@manager.option('-p', '--path', dest='path', help='Absolute file path to a Lemur query csv.') -@manager.option('-r', '--reason', dest='reason', help='Reason to revoke certificate.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-p", "--path", dest="path", help="Absolute file path to a Lemur query csv." +) +@manager.option("-r", "--reason", dest="reason", help="Reason to revoke certificate.") +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def revoke(path, reason, commit): """ Revokes given certificate. @@ -357,7 +451,7 @@ def revoke(path, reason, commit): print("[+] Starting certificate revocation.") - with open(path, 'r') as f: + with open(path, "r") as f: args = [[x, commit, reason] for x in f.readlines()[2:]] with multiprocessing.Pool(processes=3) as pool: @@ -380,11 +474,11 @@ def check_revoked(): else: status = verify_string(cert.body, "") - cert.status = 'valid' if status else 'revoked' + cert.status = "valid" if status else "revoked" except Exception as e: sentry.captureException() current_app.logger.exception(e) - cert.status = 'unknown' + cert.status = "unknown" database.update(cert) diff --git a/lemur/certificates/hooks.py b/lemur/certificates/hooks.py index 16f6c3b0..93409bb4 100644 --- a/lemur/certificates/hooks.py +++ b/lemur/certificates/hooks.py @@ -12,21 +12,30 @@ import subprocess from flask import current_app -from lemur.certificates.service import csr_created, csr_imported, certificate_issued, certificate_imported +from lemur.certificates.service import ( + csr_created, + csr_imported, + certificate_issued, + certificate_imported, +) def csr_dump_handler(sender, csr, **kwargs): try: - subprocess.run(['openssl', 'req', '-text', '-noout', '-reqopt', 'no_sigdump,no_pubkey'], - input=csr.encode('utf8')) + subprocess.run( + ["openssl", "req", "-text", "-noout", "-reqopt", "no_sigdump,no_pubkey"], + input=csr.encode("utf8"), + ) except Exception as err: current_app.logger.warning("Error inspecting CSR: %s", err) def cert_dump_handler(sender, certificate, **kwargs): try: - subprocess.run(['openssl', 'x509', '-text', '-noout', '-certopt', 'no_sigdump,no_pubkey'], - input=certificate.body.encode('utf8')) + subprocess.run( + ["openssl", "x509", "-text", "-noout", "-certopt", "no_sigdump,no_pubkey"], + input=certificate.body.encode("utf8"), + ) except Exception as err: current_app.logger.warning("Error inspecting certificate: %s", err) diff --git a/lemur/certificates/models.py b/lemur/certificates/models.py index bd6e8b5e..965f79d1 100644 --- a/lemur/certificates/models.py +++ b/lemur/certificates/models.py @@ -12,7 +12,18 @@ from cryptography import x509 from cryptography.hazmat.primitives.asymmetric import rsa from flask import current_app from idna.core import InvalidCodepoint -from sqlalchemy import event, Integer, ForeignKey, String, PassiveDefault, func, Column, Text, Boolean, Index +from sqlalchemy import ( + event, + Integer, + ForeignKey, + String, + PassiveDefault, + func, + Column, + Text, + Boolean, + Index, +) from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship from sqlalchemy.sql.expression import case, extract @@ -25,19 +36,25 @@ from lemur.database import db from lemur.domains.models import Domain from lemur.extensions import metrics from lemur.extensions import sentry -from lemur.models import certificate_associations, certificate_source_associations, \ - certificate_destination_associations, certificate_notification_associations, \ - certificate_replacement_associations, roles_certificates, pending_cert_replacement_associations +from lemur.models import ( + certificate_associations, + certificate_source_associations, + certificate_destination_associations, + certificate_notification_associations, + certificate_replacement_associations, + roles_certificates, + pending_cert_replacement_associations, +) from lemur.plugins.base import plugins from lemur.policies.models import RotationPolicy from lemur.utils import Vault def get_sequence(name): - if '-' not in name: + if "-" not in name: return name, None - parts = name.split('-') + parts = name.split("-") # see if we have an int at the end of our name try: @@ -49,18 +66,22 @@ def get_sequence(name): if len(parts[-1]) == 8: return name, None - root = '-'.join(parts[:-1]) + root = "-".join(parts[:-1]) return root, seq def get_or_increase_name(name, serial): - certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(name))).all() + certificates = Certificate.query.filter( + Certificate.name.ilike("{0}%".format(name)) + ).all() if not certificates: return name - serial_name = '{0}-{1}'.format(name, hex(int(serial))[2:].upper()) - certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(serial_name))).all() + serial_name = "{0}-{1}".format(name, hex(int(serial))[2:].upper()) + certificates = Certificate.query.filter( + Certificate.name.ilike("{0}%".format(serial_name)) + ).all() if not certificates: return serial_name @@ -72,21 +93,29 @@ def get_or_increase_name(name, serial): if end: ends.append(end) - return '{0}-{1}'.format(root, max(ends) + 1) + return "{0}-{1}".format(root, max(ends) + 1) class Certificate(db.Model): - __tablename__ = 'certificates' + __tablename__ = "certificates" __table_args__ = ( - Index('ix_certificates_cn', "cn", - postgresql_ops={"cn": "gin_trgm_ops"}, - postgresql_using='gin'), - Index('ix_certificates_name', "name", - postgresql_ops={"name": "gin_trgm_ops"}, - postgresql_using='gin'), + Index( + "ix_certificates_cn", + "cn", + postgresql_ops={"cn": "gin_trgm_ops"}, + postgresql_using="gin", + ), + Index( + "ix_certificates_name", + "name", + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ), ) id = Column(Integer, primary_key=True) - ix = Index('ix_certificates_id_desc', id.desc(), postgresql_using='btree', unique=True) + ix = Index( + "ix_certificates_id_desc", id.desc(), postgresql_using="btree", unique=True + ) external_id = Column(String(128)) owner = Column(String(128), nullable=False) name = Column(String(256), unique=True) @@ -102,7 +131,9 @@ class Certificate(db.Model): serial = Column(String(128)) cn = Column(String(128)) deleted = Column(Boolean, index=True, default=False) - dns_provider_id = Column(Integer(), ForeignKey('dns_providers.id', ondelete='CASCADE'), nullable=True) + dns_provider_id = Column( + Integer(), ForeignKey("dns_providers.id", ondelete="CASCADE"), nullable=True + ) not_before = Column(ArrowType) not_after = Column(ArrowType) @@ -114,34 +145,53 @@ class Certificate(db.Model): san = Column(String(1024)) # TODO this should be migrated to boolean rotation = Column(Boolean, default=False) - user_id = Column(Integer, ForeignKey('users.id')) - authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - root_authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - rotation_policy_id = Column(Integer, ForeignKey('rotation_policies.id')) + user_id = Column(Integer, ForeignKey("users.id")) + authority_id = Column(Integer, ForeignKey("authorities.id", ondelete="CASCADE")) + root_authority_id = Column( + Integer, ForeignKey("authorities.id", ondelete="CASCADE") + ) + rotation_policy_id = Column(Integer, ForeignKey("rotation_policies.id")) - notifications = relationship('Notification', secondary=certificate_notification_associations, backref='certificate') - destinations = relationship('Destination', secondary=certificate_destination_associations, backref='certificate') - sources = relationship('Source', secondary=certificate_source_associations, backref='certificate') - domains = relationship('Domain', secondary=certificate_associations, backref='certificate') - roles = relationship('Role', secondary=roles_certificates, backref='certificate') - replaces = relationship('Certificate', - secondary=certificate_replacement_associations, - primaryjoin=id == certificate_replacement_associations.c.certificate_id, # noqa - secondaryjoin=id == certificate_replacement_associations.c.replaced_certificate_id, # noqa - backref='replaced') + notifications = relationship( + "Notification", + secondary=certificate_notification_associations, + backref="certificate", + ) + destinations = relationship( + "Destination", + secondary=certificate_destination_associations, + backref="certificate", + ) + sources = relationship( + "Source", secondary=certificate_source_associations, backref="certificate" + ) + domains = relationship( + "Domain", secondary=certificate_associations, backref="certificate" + ) + roles = relationship("Role", secondary=roles_certificates, backref="certificate") + replaces = relationship( + "Certificate", + secondary=certificate_replacement_associations, + primaryjoin=id == certificate_replacement_associations.c.certificate_id, # noqa + secondaryjoin=id + == certificate_replacement_associations.c.replaced_certificate_id, # noqa + backref="replaced", + ) - replaced_by_pending = relationship('PendingCertificate', - secondary=pending_cert_replacement_associations, - backref='pending_replace', - viewonly=True) + replaced_by_pending = relationship( + "PendingCertificate", + secondary=pending_cert_replacement_associations, + backref="pending_replace", + viewonly=True, + ) - logs = relationship('Log', backref='certificate') - endpoints = relationship('Endpoint', backref='certificate') + logs = relationship("Log", backref="certificate") + endpoints = relationship("Endpoint", backref="certificate") rotation_policy = relationship("RotationPolicy") - sensitive_fields = ('private_key',) + sensitive_fields = ("private_key",) def __init__(self, **kwargs): - self.body = kwargs['body'].strip() + self.body = kwargs["body"].strip() cert = self.parsed_cert self.issuer = defaults.issuer(cert) @@ -152,36 +202,42 @@ class Certificate(db.Model): self.serial = defaults.serial(cert) # when destinations are appended they require a valid name. - if kwargs.get('name'): - self.name = get_or_increase_name(defaults.text_to_slug(kwargs['name']), self.serial) + if kwargs.get("name"): + self.name = get_or_increase_name( + defaults.text_to_slug(kwargs["name"]), self.serial + ) else: self.name = get_or_increase_name( - defaults.certificate_name(self.cn, self.issuer, self.not_before, self.not_after, self.san), self.serial) + defaults.certificate_name( + self.cn, self.issuer, self.not_before, self.not_after, self.san + ), + self.serial, + ) - self.owner = kwargs['owner'] + self.owner = kwargs["owner"] - if kwargs.get('private_key'): - self.private_key = kwargs['private_key'].strip() + if kwargs.get("private_key"): + self.private_key = kwargs["private_key"].strip() - if kwargs.get('chain'): - self.chain = kwargs['chain'].strip() + if kwargs.get("chain"): + self.chain = kwargs["chain"].strip() - if kwargs.get('csr'): - self.csr = kwargs['csr'].strip() + if kwargs.get("csr"): + self.csr = kwargs["csr"].strip() - self.notify = kwargs.get('notify', True) - self.destinations = kwargs.get('destinations', []) - self.notifications = kwargs.get('notifications', []) - self.description = kwargs.get('description') - self.roles = list(set(kwargs.get('roles', []))) - self.replaces = kwargs.get('replaces', []) - self.rotation = kwargs.get('rotation') - self.rotation_policy = kwargs.get('rotation_policy') + self.notify = kwargs.get("notify", True) + self.destinations = kwargs.get("destinations", []) + self.notifications = kwargs.get("notifications", []) + self.description = kwargs.get("description") + self.roles = list(set(kwargs.get("roles", []))) + self.replaces = kwargs.get("replaces", []) + self.rotation = kwargs.get("rotation") + self.rotation_policy = kwargs.get("rotation_policy") self.signing_algorithm = defaults.signing_algorithm(cert) self.bits = defaults.bitstrength(cert) - self.external_id = kwargs.get('external_id') - self.authority_id = kwargs.get('authority_id') - self.dns_provider_id = kwargs.get('dns_provider_id') + self.external_id = kwargs.get("external_id") + self.authority_id = kwargs.get("authority_id") + self.dns_provider_id = kwargs.get("dns_provider_id") for domain in defaults.domains(cert): self.domains.append(Domain(name=domain)) @@ -195,8 +251,11 @@ class Certificate(db.Model): Integrity checks: Does the cert have a valid chain and matching private key? """ if self.private_key: - validators.verify_private_key_match(utils.parse_private_key(self.private_key), self.parsed_cert, - error_class=AssertionError) + validators.verify_private_key_match( + utils.parse_private_key(self.private_key), + self.parsed_cert, + error_class=AssertionError, + ) if self.chain: chain = [self.parsed_cert] + utils.parse_cert_chain(self.chain) @@ -238,7 +297,9 @@ class Certificate(db.Model): @property def key_type(self): if isinstance(self.parsed_cert.public_key(), rsa.RSAPublicKey): - return 'RSA{key_size}'.format(key_size=self.parsed_cert.public_key().key_size) + return "RSA{key_size}".format( + key_size=self.parsed_cert.public_key().key_size + ) @property def validity_remaining(self): @@ -263,26 +324,16 @@ class Certificate(db.Model): @expired.expression def expired(cls): - return case( - [ - (cls.not_after <= arrow.utcnow(), True) - ], - else_=False - ) + return case([(cls.not_after <= arrow.utcnow(), True)], else_=False) @hybrid_property def revoked(self): - if 'revoked' == self.status: + if "revoked" == self.status: return True @revoked.expression def revoked(cls): - return case( - [ - (cls.status == 'revoked', True) - ], - else_=False - ) + return case([(cls.status == "revoked", True)], else_=False) @hybrid_property def in_rotation_window(self): @@ -305,66 +356,65 @@ class Certificate(db.Model): :return: """ return case( - [ - (extract('day', cls.not_after - func.now()) <= RotationPolicy.days, True) - ], - else_=False + [(extract("day", cls.not_after - func.now()) <= RotationPolicy.days, True)], + else_=False, ) @property def extensions(self): # setup default values - return_extensions = { - 'sub_alt_names': {'names': []} - } + return_extensions = {"sub_alt_names": {"names": []}} try: for extension in self.parsed_cert.extensions: value = extension.value if isinstance(value, x509.BasicConstraints): - return_extensions['basic_constraints'] = value + return_extensions["basic_constraints"] = value elif isinstance(value, x509.SubjectAlternativeName): - return_extensions['sub_alt_names']['names'] = value + return_extensions["sub_alt_names"]["names"] = value elif isinstance(value, x509.ExtendedKeyUsage): - return_extensions['extended_key_usage'] = value + return_extensions["extended_key_usage"] = value elif isinstance(value, x509.KeyUsage): - return_extensions['key_usage'] = value + return_extensions["key_usage"] = value elif isinstance(value, x509.SubjectKeyIdentifier): - return_extensions['subject_key_identifier'] = {'include_ski': True} + return_extensions["subject_key_identifier"] = {"include_ski": True} elif isinstance(value, x509.AuthorityInformationAccess): - return_extensions['certificate_info_access'] = {'include_aia': True} + return_extensions["certificate_info_access"] = {"include_aia": True} elif isinstance(value, x509.AuthorityKeyIdentifier): - aki = { - 'use_key_identifier': False, - 'use_authority_cert': False - } + aki = {"use_key_identifier": False, "use_authority_cert": False} if value.key_identifier: - aki['use_key_identifier'] = True + aki["use_key_identifier"] = True if value.authority_cert_issuer: - aki['use_authority_cert'] = True + aki["use_authority_cert"] = True - return_extensions['authority_key_identifier'] = aki + return_extensions["authority_key_identifier"] = aki elif isinstance(value, x509.CRLDistributionPoints): - return_extensions['crl_distribution_points'] = {'include_crl_dp': value} + return_extensions["crl_distribution_points"] = { + "include_crl_dp": value + } # TODO: Not supporting custom OIDs yet. https://github.com/Netflix/lemur/issues/665 else: - current_app.logger.warning('Custom OIDs not yet supported for clone operation.') + current_app.logger.warning( + "Custom OIDs not yet supported for clone operation." + ) except InvalidCodepoint as e: sentry.captureException() - current_app.logger.warning('Unable to parse extensions due to underscore in dns name') + current_app.logger.warning( + "Unable to parse extensions due to underscore in dns name" + ) except ValueError as e: sentry.captureException() - current_app.logger.warning('Unable to parse') + current_app.logger.warning("Unable to parse") current_app.logger.exception(e) return return_extensions @@ -373,7 +423,7 @@ class Certificate(db.Model): return "Certificate(name={name})".format(name=self.name) -@event.listens_for(Certificate.destinations, 'append') +@event.listens_for(Certificate.destinations, "append") def update_destinations(target, value, initiator): """ Attempt to upload certificate to the new destination @@ -387,17 +437,31 @@ def update_destinations(target, value, initiator): status = FAILURE_METRIC_STATUS try: if target.private_key or not destination_plugin.requires_key: - destination_plugin.upload(target.name, target.body, target.private_key, target.chain, value.options) + destination_plugin.upload( + target.name, + target.body, + target.private_key, + target.chain, + value.options, + ) status = SUCCESS_METRIC_STATUS except Exception as e: sentry.captureException() raise - metrics.send('destination_upload', 'counter', 1, - metric_tags={'status': status, 'certificate': target.name, 'destination': value.label}) + metrics.send( + "destination_upload", + "counter", + 1, + metric_tags={ + "status": status, + "certificate": target.name, + "destination": value.label, + }, + ) -@event.listens_for(Certificate.replaces, 'append') +@event.listens_for(Certificate.replaces, "append") def update_replacement(target, value, initiator): """ When a certificate is marked as 'replaced' we should not notify. diff --git a/lemur/certificates/schemas.py b/lemur/certificates/schemas.py index f4a6fa9a..bf950e70 100644 --- a/lemur/certificates/schemas.py +++ b/lemur/certificates/schemas.py @@ -39,22 +39,26 @@ from lemur.users.schemas import UserNestedOutputSchema class CertificateSchema(LemurInputSchema): owner = fields.Email(required=True) - description = fields.String(missing='', allow_none=True) + description = fields.String(missing="", allow_none=True) class CertificateCreationSchema(CertificateSchema): @post_load def default_notification(self, data): - if not data['notifications']: - data['notifications'] += notification_service.create_default_expiration_notifications( - "DEFAULT_{0}".format(data['owner'].split('@')[0].upper()), - [data['owner']], + if not data["notifications"]: + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + "DEFAULT_{0}".format(data["owner"].split("@")[0].upper()), + [data["owner"]], ) - data['notifications'] += notification_service.create_default_expiration_notifications( - 'DEFAULT_SECURITY', - current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL'), - current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL_INTERVALS', None) + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + "DEFAULT_SECURITY", + current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL"), + current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL_INTERVALS", None), ) return data @@ -71,37 +75,53 @@ class CertificateInputSchema(CertificateCreationSchema): destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True) notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True) replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) - replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated + replacements = fields.Nested( + AssociatedCertificateSchema, missing=[], many=True + ) # deprecated roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True) - dns_provider = fields.Nested(AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False) + dns_provider = fields.Nested( + AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False + ) csr = fields.String(allow_none=True, validate=validators.csr) key_type = fields.String( - validate=validate.OneOf(CERTIFICATE_KEY_TYPES), - missing='RSA2048') + validate=validate.OneOf(CERTIFICATE_KEY_TYPES), missing="RSA2048" + ) notify = fields.Boolean(default=True) rotation = fields.Boolean() - rotation_policy = fields.Nested(AssociatedRotationPolicySchema, missing={'name': 'default'}, allow_none=True, - default={'name': 'default'}) + rotation_policy = fields.Nested( + AssociatedRotationPolicySchema, + missing={"name": "default"}, + allow_none=True, + default={"name": "default"}, + ) # certificate body fields - organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT')) - organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION')) - location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION')) - country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY')) - state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE')) + organizational_unit = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT") + ) + organization = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION") + ) + location = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION") + ) + country = fields.String( + missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY") + ) + state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE")) extensions = fields.Nested(ExtensionSchema) @validates_schema def validate_authority(self, data): - if isinstance(data['authority'], str): + if isinstance(data["authority"], str): raise ValidationError("Authority not found.") - if not data['authority'].active: - raise ValidationError("The authority is inactive.", ['authority']) + if not data["authority"].active: + raise ValidationError("The authority is inactive.", ["authority"]) @validates_schema def validate_dates(self, data): @@ -109,23 +129,19 @@ class CertificateInputSchema(CertificateCreationSchema): @pre_load def load_data(self, data): - if data.get('replacements'): - data['replaces'] = data['replacements'] # TODO remove when field is deprecated - if data.get('csr'): - csr_sans = cert_utils.get_sans_from_csr(data['csr']) - if not data.get('extensions'): - data['extensions'] = { - 'subAltNames': { - 'names': [] - } - } - elif not data['extensions'].get('subAltNames'): - data['extensions']['subAltNames'] = { - 'names': [] - } - elif not data['extensions']['subAltNames'].get('names'): - data['extensions']['subAltNames']['names'] = [] - data['extensions']['subAltNames']['names'] += csr_sans + if data.get("replacements"): + data["replaces"] = data[ + "replacements" + ] # TODO remove when field is deprecated + if data.get("csr"): + csr_sans = cert_utils.get_sans_from_csr(data["csr"]) + if not data.get("extensions"): + data["extensions"] = {"subAltNames": {"names": []}} + elif not data["extensions"].get("subAltNames"): + data["extensions"]["subAltNames"] = {"names": []} + elif not data["extensions"]["subAltNames"].get("names"): + data["extensions"]["subAltNames"]["names"] = [] + data["extensions"]["subAltNames"]["names"] += csr_sans return missing.convert_validity_years(data) @@ -138,13 +154,17 @@ class CertificateEditInputSchema(CertificateSchema): destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True) notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True) replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) - replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated + replacements = fields.Nested( + AssociatedCertificateSchema, missing=[], many=True + ) # deprecated roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True) @pre_load def load_data(self, data): - if data.get('replacements'): - data['replaces'] = data['replacements'] # TODO remove when field is deprecated + if data.get("replacements"): + data["replaces"] = data[ + "replacements" + ] # TODO remove when field is deprecated return data @post_load @@ -155,10 +175,15 @@ class CertificateEditInputSchema(CertificateSchema): :param data: :return: """ - if data['owner']: - notification_name = "DEFAULT_{0}".format(data['owner'].split('@')[0].upper()) - data['notifications'] += notification_service.create_default_expiration_notifications(notification_name, - [data['owner']]) + if data["owner"]: + notification_name = "DEFAULT_{0}".format( + data["owner"].split("@")[0].upper() + ) + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + notification_name, [data["owner"]] + ) return data @@ -184,13 +209,13 @@ class CertificateNestedOutputSchema(LemurOutputSchema): # Note aliasing is the first step in deprecating these fields. cn = fields.String() # deprecated - common_name = fields.String(attribute='cn') + common_name = fields.String(attribute="cn") not_after = fields.DateTime() # deprecated - validity_end = ArrowDateTime(attribute='not_after') + validity_end = ArrowDateTime(attribute="not_after") not_before = fields.DateTime() # deprecated - validity_start = ArrowDateTime(attribute='not_before') + validity_start = ArrowDateTime(attribute="not_before") issuer = fields.Nested(AuthorityNestedOutputSchema) @@ -221,22 +246,22 @@ class CertificateOutputSchema(LemurOutputSchema): # Note aliasing is the first step in deprecating these fields. notify = fields.Boolean() - active = fields.Boolean(attribute='notify') + active = fields.Boolean(attribute="notify") cn = fields.String() - common_name = fields.String(attribute='cn') + common_name = fields.String(attribute="cn") distinguished_name = fields.String() not_after = fields.DateTime() - validity_end = ArrowDateTime(attribute='not_after') + validity_end = ArrowDateTime(attribute="not_after") not_before = fields.DateTime() - validity_start = ArrowDateTime(attribute='not_before') + validity_start = ArrowDateTime(attribute="not_before") owner = fields.Email() san = fields.Boolean() serial = fields.String() - serial_hex = Hex(attribute='serial') + serial_hex = Hex(attribute="serial") signing_algorithm = fields.String() status = fields.String() @@ -253,7 +278,9 @@ class CertificateOutputSchema(LemurOutputSchema): dns_provider = fields.Nested(DnsProvidersNestedOutputSchema) roles = fields.Nested(RoleNestedOutputSchema, many=True) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) - replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') + replaced_by = fields.Nested( + CertificateNestedOutputSchema, many=True, attribute="replaced" + ) rotation_policy = fields.Nested(RotationPolicyNestedOutputSchema) @@ -274,35 +301,41 @@ class CertificateUploadInputSchema(CertificateCreationSchema): @validates_schema def keys(self, data): - if data.get('destinations'): - if not data.get('private_key'): - raise ValidationError('Destinations require private key.') + if data.get("destinations"): + if not data.get("private_key"): + raise ValidationError("Destinations require private key.") @validates_schema def validate_cert_private_key_chain(self, data): cert = None key = None - if data.get('body'): + if data.get("body"): try: - cert = utils.parse_certificate(data['body']) + cert = utils.parse_certificate(data["body"]) except ValueError: - raise ValidationError("Public certificate presented is not valid.", field_names=['body']) + raise ValidationError( + "Public certificate presented is not valid.", field_names=["body"] + ) - if data.get('private_key'): + if data.get("private_key"): try: - key = utils.parse_private_key(data['private_key']) + key = utils.parse_private_key(data["private_key"]) except ValueError: - raise ValidationError("Private key presented is not valid.", field_names=['private_key']) + raise ValidationError( + "Private key presented is not valid.", field_names=["private_key"] + ) if cert and key: # Throws ValidationError validators.verify_private_key_match(key, cert) - if data.get('chain'): + if data.get("chain"): try: - chain = utils.parse_cert_chain(data['chain']) + chain = utils.parse_cert_chain(data["chain"]) except ValueError: - raise ValidationError("Invalid certificate in certificate chain.", field_names=['chain']) + raise ValidationError( + "Invalid certificate in certificate chain.", field_names=["chain"] + ) # Throws ValidationError validators.verify_cert_chain([cert] + chain) @@ -318,8 +351,10 @@ class CertificateNotificationOutputSchema(LemurOutputSchema): name = fields.String() owner = fields.Email() user = fields.Nested(UserNestedOutputSchema) - validity_end = ArrowDateTime(attribute='not_after') - replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') + validity_end = ArrowDateTime(attribute="not_after") + replaced_by = fields.Nested( + CertificateNestedOutputSchema, many=True, attribute="replaced" + ) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) diff --git a/lemur/certificates/service.py b/lemur/certificates/service.py index 815349ff..51fede4f 100644 --- a/lemur/certificates/service.py +++ b/lemur/certificates/service.py @@ -26,10 +26,14 @@ from lemur.plugins.base import plugins from lemur.roles import service as role_service from lemur.roles.models import Role -csr_created = signals.signal('csr_created', "CSR generated") -csr_imported = signals.signal('csr_imported', "CSR imported from external source") -certificate_issued = signals.signal('certificate_issued', "Authority issued a certificate") -certificate_imported = signals.signal('certificate_imported', "Certificate imported from external source") +csr_created = signals.signal("csr_created", "CSR generated") +csr_imported = signals.signal("csr_imported", "CSR imported from external source") +certificate_issued = signals.signal( + "certificate_issued", "Authority issued a certificate" +) +certificate_imported = signals.signal( + "certificate_imported", "Certificate imported from external source" +) def get(cert_id): @@ -49,7 +53,7 @@ def get_by_name(name): :param name: :return: """ - return database.get(Certificate, name, field='name') + return database.get(Certificate, name, field="name") def get_by_serial(serial): @@ -105,8 +109,12 @@ def get_all_pending_cleaning(source): :param source: :return: """ - return Certificate.query.filter(Certificate.sources.any(id=source.id)) \ - .filter(not_(Certificate.endpoints.any())).filter(Certificate.expired).all() + return ( + Certificate.query.filter(Certificate.sources.any(id=source.id)) + .filter(not_(Certificate.endpoints.any())) + .filter(Certificate.expired) + .all() + ) def get_all_pending_reissue(): @@ -119,9 +127,12 @@ def get_all_pending_reissue(): :return: """ - return Certificate.query.filter(Certificate.rotation == True) \ - .filter(not_(Certificate.replaced.any())) \ - .filter(Certificate.in_rotation_window == True).all() # noqa + return ( + Certificate.query.filter(Certificate.rotation == True) + .filter(not_(Certificate.replaced.any())) + .filter(Certificate.in_rotation_window == True) + .all() + ) # noqa def find_duplicates(cert): @@ -133,10 +144,12 @@ def find_duplicates(cert): :param cert: :return: """ - if cert['chain']: - return Certificate.query.filter_by(body=cert['body'].strip(), chain=cert['chain'].strip()).all() + if cert["chain"]: + return Certificate.query.filter_by( + body=cert["body"].strip(), chain=cert["chain"].strip() + ).all() else: - return Certificate.query.filter_by(body=cert['body'].strip(), chain=None).all() + return Certificate.query.filter_by(body=cert["body"].strip(), chain=None).all() def export(cert, export_plugin): @@ -148,8 +161,10 @@ def export(cert, export_plugin): :param cert: :return: """ - plugin = plugins.get(export_plugin['slug']) - return plugin.export(cert.body, cert.chain, cert.private_key, export_plugin['pluginOptions']) + plugin = plugins.get(export_plugin["slug"]) + return plugin.export( + cert.body, cert.chain, cert.private_key, export_plugin["pluginOptions"] + ) def update(cert_id, **kwargs): @@ -168,17 +183,19 @@ def update(cert_id, **kwargs): def create_certificate_roles(**kwargs): # create an role for the owner and assign it - owner_role = role_service.get_by_name(kwargs['owner']) + owner_role = role_service.get_by_name(kwargs["owner"]) if not owner_role: owner_role = role_service.create( - kwargs['owner'], - description="Auto generated role based on owner: {0}".format(kwargs['owner']) + kwargs["owner"], + description="Auto generated role based on owner: {0}".format( + kwargs["owner"] + ), ) # ensure that the authority's owner is also associated with the certificate - if kwargs.get('authority'): - authority_owner_role = role_service.get_by_name(kwargs['authority'].owner) + if kwargs.get("authority"): + authority_owner_role = role_service.get_by_name(kwargs["authority"].owner) return [owner_role, authority_owner_role] return [owner_role] @@ -190,16 +207,16 @@ def mint(**kwargs): Support for multiple authorities is handled by individual plugins. """ - authority = kwargs['authority'] + authority = kwargs["authority"] issuer = plugins.get(authority.plugin_name) # allow the CSR to be specified by the user - if not kwargs.get('csr'): + if not kwargs.get("csr"): csr, private_key = create_csr(**kwargs) csr_created.send(authority=authority, csr=csr) else: - csr = str(kwargs.get('csr')) + csr = str(kwargs.get("csr")) private_key = None csr_imported.send(authority=authority, csr=csr) @@ -220,8 +237,8 @@ def import_certificate(**kwargs): :param kwargs: """ - if not kwargs.get('owner'): - kwargs['owner'] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL')[0] + if not kwargs.get("owner"): + kwargs["owner"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")[0] return upload(**kwargs) @@ -232,16 +249,16 @@ def upload(**kwargs): """ roles = create_certificate_roles(**kwargs) - if kwargs.get('roles'): - kwargs['roles'] += roles + if kwargs.get("roles"): + kwargs["roles"] += roles else: - kwargs['roles'] = roles + kwargs["roles"] = roles cert = Certificate(**kwargs) - cert.authority = kwargs.get('authority') + cert.authority = kwargs.get("authority") cert = database.create(cert) - kwargs['creator'].certificates.append(cert) + kwargs["creator"].certificates.append(cert) cert = database.update(cert) certificate_imported.send(certificate=cert, authority=cert.authority) @@ -258,39 +275,45 @@ def create(**kwargs): current_app.logger.error("Exception minting certificate", exc_info=True) sentry.captureException() raise - kwargs['body'] = cert_body - kwargs['private_key'] = private_key - kwargs['chain'] = cert_chain - kwargs['external_id'] = external_id - kwargs['csr'] = csr + kwargs["body"] = cert_body + kwargs["private_key"] = private_key + kwargs["chain"] = cert_chain + kwargs["external_id"] = external_id + kwargs["csr"] = csr roles = create_certificate_roles(**kwargs) - if kwargs.get('roles'): - kwargs['roles'] += roles + if kwargs.get("roles"): + kwargs["roles"] += roles else: - kwargs['roles'] = roles + kwargs["roles"] = roles if cert_body: cert = Certificate(**kwargs) - kwargs['creator'].certificates.append(cert) + kwargs["creator"].certificates.append(cert) else: cert = PendingCertificate(**kwargs) - kwargs['creator'].pending_certificates.append(cert) + kwargs["creator"].pending_certificates.append(cert) - cert.authority = kwargs['authority'] + cert.authority = kwargs["authority"] database.commit() if isinstance(cert, Certificate): certificate_issued.send(certificate=cert, authority=cert.authority) - metrics.send('certificate_issued', 'counter', 1, metric_tags=dict(owner=cert.owner, issuer=cert.issuer)) + metrics.send( + "certificate_issued", + "counter", + 1, + metric_tags=dict(owner=cert.owner, issuer=cert.issuer), + ) if isinstance(cert, PendingCertificate): # We need to refresh the pending certificate to avoid "Instance is not bound to a Session; " # "attribute refresh operation cannot proceed" pending_cert = database.session_query(PendingCertificate).get(cert.id) from lemur.common.celery import fetch_acme_cert + if not current_app.config.get("ACME_DISABLE_AUTORESOLVE", False): fetch_acme_cert.apply_async((pending_cert.id,), countdown=5) @@ -306,51 +329,55 @@ def render(args): """ query = database.session_query(Certificate) - time_range = args.pop('time_range') - destination_id = args.pop('destination_id') - notification_id = args.pop('notification_id', None) - show = args.pop('show') + time_range = args.pop("time_range") + destination_id = args.pop("destination_id") + notification_id = args.pop("notification_id", None) + show = args.pop("show") # owner = args.pop('owner') # creator = args.pop('creator') # TODO we should enabling filtering by owner - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') - term = '%{0}%'.format(terms[1]) + terms = filt.split(";") + term = "%{0}%".format(terms[1]) # Exact matches for quotes. Only applies to name, issuer, and cn if terms[1].startswith('"') and terms[1].endswith('"'): term = terms[1][1:-1] - if 'issuer' in terms: + if "issuer" in terms: # we can't rely on issuer being correct in the cert directly so we combine queries - sub_query = database.session_query(Authority.id) \ - .filter(Authority.name.ilike(term)) \ + sub_query = ( + database.session_query(Authority.id) + .filter(Authority.name.ilike(term)) .subquery() + ) query = query.filter( or_( Certificate.issuer.ilike(term), - Certificate.authority_id.in_(sub_query) + Certificate.authority_id.in_(sub_query), ) ) - elif 'destination' in terms: - query = query.filter(Certificate.destinations.any(Destination.id == terms[1])) - elif 'notify' in filt: + elif "destination" in terms: + query = query.filter( + Certificate.destinations.any(Destination.id == terms[1]) + ) + elif "notify" in filt: query = query.filter(Certificate.notify == truthiness(terms[1])) - elif 'active' in filt: + elif "active" in filt: query = query.filter(Certificate.active == truthiness(terms[1])) - elif 'cn' in terms: + elif "cn" in terms: query = query.filter( or_( Certificate.cn.ilike(term), - Certificate.domains.any(Domain.name.ilike(term)) + Certificate.domains.any(Domain.name.ilike(term)), ) ) - elif 'id' in terms: + elif "id" in terms: query = query.filter(Certificate.id == cast(terms[1], Integer)) - elif 'name' in terms: + elif "name" in terms: query = query.filter( or_( Certificate.name.ilike(term), @@ -362,26 +389,35 @@ def render(args): query = database.filter(query, Certificate, terms) if show: - sub_query = database.session_query(Role.name).filter(Role.user_id == args['user'].id).subquery() + sub_query = ( + database.session_query(Role.name) + .filter(Role.user_id == args["user"].id) + .subquery() + ) query = query.filter( or_( - Certificate.user_id == args['user'].id, - Certificate.owner.in_(sub_query) + Certificate.user_id == args["user"].id, Certificate.owner.in_(sub_query) ) ) if destination_id: - query = query.filter(Certificate.destinations.any(Destination.id == destination_id)) + query = query.filter( + Certificate.destinations.any(Destination.id == destination_id) + ) if notification_id: - query = query.filter(Certificate.notifications.any(Notification.id == notification_id)) + query = query.filter( + Certificate.notifications.any(Notification.id == notification_id) + ) if time_range: - to = arrow.now().replace(weeks=+time_range).format('YYYY-MM-DD') - now = arrow.now().format('YYYY-MM-DD') - query = query.filter(Certificate.not_after <= to).filter(Certificate.not_after >= now) + to = arrow.now().replace(weeks=+time_range).format("YYYY-MM-DD") + now = arrow.now().format("YYYY-MM-DD") + query = query.filter(Certificate.not_after <= to).filter( + Certificate.not_after >= now + ) - if current_app.config.get('ALLOW_CERT_DELETION', False): + if current_app.config.get("ALLOW_CERT_DELETION", False): query = query.filter(Certificate.deleted == False) # noqa result = database.sort_and_page(query, Certificate, args) @@ -409,18 +445,20 @@ def query_common_name(common_name, args): :param args: :return: """ - owner = args.pop('owner') + owner = args.pop("owner") if not owner: - owner = '%' + owner = "%" # only not expired certificates current_time = arrow.utcnow() - result = Certificate.query.filter(Certificate.cn.ilike(common_name)) \ - .filter(Certificate.owner.ilike(owner))\ - .filter(Certificate.not_after >= current_time.format('YYYY-MM-DD')) \ - .filter(Certificate.rotation.is_(True))\ + result = ( + Certificate.query.filter(Certificate.cn.ilike(common_name)) + .filter(Certificate.owner.ilike(owner)) + .filter(Certificate.not_after >= current_time.format("YYYY-MM-DD")) + .filter(Certificate.rotation.is_(True)) .all() + ) return result @@ -432,62 +470,77 @@ def create_csr(**csr_config): :param csr_config: """ - private_key = generate_private_key(csr_config.get('key_type')) + private_key = generate_private_key(csr_config.get("key_type")) builder = x509.CertificateSigningRequestBuilder() - name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config['common_name'])] - if current_app.config.get('LEMUR_OWNER_EMAIL_IN_SUBJECT', True): - name_list.append(x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config['owner'])) - if 'organization' in csr_config and csr_config['organization'].strip(): - name_list.append(x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config['organization'])) - if 'organizational_unit' in csr_config and csr_config['organizational_unit'].strip(): - name_list.append(x509.NameAttribute(x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config['organizational_unit'])) - if 'country' in csr_config and csr_config['country'].strip(): - name_list.append(x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config['country'])) - if 'state' in csr_config and csr_config['state'].strip(): - name_list.append(x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config['state'])) - if 'location' in csr_config and csr_config['location'].strip(): - name_list.append(x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config['location'])) + name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config["common_name"])] + if current_app.config.get("LEMUR_OWNER_EMAIL_IN_SUBJECT", True): + name_list.append( + x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config["owner"]) + ) + if "organization" in csr_config and csr_config["organization"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config["organization"]) + ) + if ( + "organizational_unit" in csr_config + and csr_config["organizational_unit"].strip() + ): + name_list.append( + x509.NameAttribute( + x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config["organizational_unit"] + ) + ) + if "country" in csr_config and csr_config["country"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config["country"]) + ) + if "state" in csr_config and csr_config["state"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config["state"]) + ) + if "location" in csr_config and csr_config["location"].strip(): + name_list.append( + x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config["location"]) + ) builder = builder.subject_name(x509.Name(name_list)) - extensions = csr_config.get('extensions', {}) - critical_extensions = ['basic_constraints', 'sub_alt_names', 'key_usage'] - noncritical_extensions = ['extended_key_usage'] + extensions = csr_config.get("extensions", {}) + critical_extensions = ["basic_constraints", "sub_alt_names", "key_usage"] + noncritical_extensions = ["extended_key_usage"] for k, v in extensions.items(): if v: if k in critical_extensions: - current_app.logger.debug('Adding Critical Extension: {0} {1}'.format(k, v)) - if k == 'sub_alt_names': - if v['names']: - builder = builder.add_extension(v['names'], critical=True) + current_app.logger.debug( + "Adding Critical Extension: {0} {1}".format(k, v) + ) + if k == "sub_alt_names": + if v["names"]: + builder = builder.add_extension(v["names"], critical=True) else: builder = builder.add_extension(v, critical=True) if k in noncritical_extensions: - current_app.logger.debug('Adding Extension: {0} {1}'.format(k, v)) + current_app.logger.debug("Adding Extension: {0} {1}".format(k, v)) builder = builder.add_extension(v, critical=False) - ski = extensions.get('subject_key_identifier', {}) - if ski.get('include_ski', False): + ski = extensions.get("subject_key_identifier", {}) + if ski.get("include_ski", False): builder = builder.add_extension( x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), - critical=False + critical=False, ) - request = builder.sign( - private_key, hashes.SHA256(), default_backend() - ) + request = builder.sign(private_key, hashes.SHA256(), default_backend()) # serialize our private key and CSR private_key = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, # would like to use PKCS8 but AWS ELBs don't like it - encryption_algorithm=serialization.NoEncryption() - ).decode('utf-8') + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") - csr = request.public_bytes( - encoding=serialization.Encoding.PEM - ).decode('utf-8') + csr = request.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") return csr, private_key @@ -499,16 +552,19 @@ def stats(**kwargs): :param kwargs: :return: """ - if kwargs.get('metric') == 'not_after': + if kwargs.get("metric") == "not_after": start = arrow.utcnow() end = start.replace(weeks=+32) - items = database.db.session.query(Certificate.issuer, func.count(Certificate.id)) \ - .group_by(Certificate.issuer) \ - .filter(Certificate.not_after <= end.format('YYYY-MM-DD')) \ - .filter(Certificate.not_after >= start.format('YYYY-MM-DD')).all() + items = ( + database.db.session.query(Certificate.issuer, func.count(Certificate.id)) + .group_by(Certificate.issuer) + .filter(Certificate.not_after <= end.format("YYYY-MM-DD")) + .filter(Certificate.not_after >= start.format("YYYY-MM-DD")) + .all() + ) else: - attr = getattr(Certificate, kwargs.get('metric')) + attr = getattr(Certificate, kwargs.get("metric")) query = database.db.session.query(attr, func.count(attr)) items = query.group_by(attr).all() @@ -519,7 +575,7 @@ def stats(**kwargs): keys.append(key) values.append(count) - return {'labels': keys, 'values': values} + return {"labels": keys, "values": values} def get_account_number(arn): @@ -566,22 +622,24 @@ def get_certificate_primitives(certificate): certificate via `create`. """ start, end = calculate_reissue_range(certificate.not_before, certificate.not_after) - ser = CertificateInputSchema().load(CertificateOutputSchema().dump(certificate).data) + ser = CertificateInputSchema().load( + CertificateOutputSchema().dump(certificate).data + ) assert not ser.errors, "Error re-serializing certificate: %s" % ser.errors data = ser.data # we can't quite tell if we are using a custom name, as this is an automated process (typically) # we will rely on the Lemur generated name - data.pop('name', None) + data.pop("name", None) # TODO this can be removed once we migrate away from cn - data['cn'] = data['common_name'] + data["cn"] = data["common_name"] # needed until we move off not_* - data['not_before'] = start - data['not_after'] = end - data['validity_start'] = start - data['validity_end'] = end + data["not_before"] = start + data["not_after"] = end + data["validity_start"] = start + data["validity_end"] = end return data @@ -599,13 +657,13 @@ def reissue_certificate(certificate, replace=None, user=None): # We do not want to re-use the CSR when creating a certificate because this defeats the purpose of rotation. del primitives["csr"] if not user: - primitives['creator'] = certificate.user + primitives["creator"] = certificate.user else: - primitives['creator'] = user + primitives["creator"] = user if replace: - primitives['replaces'] = [certificate] + primitives["replaces"] = [certificate] new_cert = create(**primitives) diff --git a/lemur/certificates/utils.py b/lemur/certificates/utils.py index 800e1201..4e6cc4f1 100644 --- a/lemur/certificates/utils.py +++ b/lemur/certificates/utils.py @@ -23,17 +23,18 @@ def get_sans_from_csr(data): """ sub_alt_names = [] try: - request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend()) + request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend()) except Exception: - raise ValidationError('CSR presented is not valid.') + raise ValidationError("CSR presented is not valid.") try: - alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName) + alt_names = request.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) for alt_name in alt_names.value: - sub_alt_names.append({ - 'nameType': type(alt_name).__name__, - 'value': alt_name.value - }) + sub_alt_names.append( + {"nameType": type(alt_name).__name__, "value": alt_name.value} + ) except x509.ExtensionNotFound: pass diff --git a/lemur/certificates/verify.py b/lemur/certificates/verify.py index d42e306c..76c6b521 100644 --- a/lemur/certificates/verify.py +++ b/lemur/certificates/verify.py @@ -29,31 +29,45 @@ def ocsp_verify(cert, cert_path, issuer_chain_path): :param issuer_chain_path: :return bool: True if certificate is valid, False otherwise """ - command = ['openssl', 'x509', '-noout', '-ocsp_uri', '-in', cert_path] + command = ["openssl", "x509", "-noout", "-ocsp_uri", "-in", cert_path] p1 = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) url, err = p1.communicate() if not url: - current_app.logger.debug("No OCSP URL in certificate {}".format(cert.serial_number)) + current_app.logger.debug( + "No OCSP URL in certificate {}".format(cert.serial_number) + ) return None - p2 = subprocess.Popen(['openssl', 'ocsp', '-issuer', issuer_chain_path, - '-cert', cert_path, "-url", url.strip()], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + p2 = subprocess.Popen( + [ + "openssl", + "ocsp", + "-issuer", + issuer_chain_path, + "-cert", + cert_path, + "-url", + url.strip(), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) message, err = p2.communicate() - p_message = message.decode('utf-8') + p_message = message.decode("utf-8") - if 'error' in p_message or 'Error' in p_message: + if "error" in p_message or "Error" in p_message: raise Exception("Got error when parsing OCSP url") - elif 'revoked' in p_message: - current_app.logger.debug("OCSP reports certificate revoked: {}".format(cert.serial_number)) + elif "revoked" in p_message: + current_app.logger.debug( + "OCSP reports certificate revoked: {}".format(cert.serial_number) + ) return False - elif 'good' not in p_message: + elif "good" not in p_message: raise Exception("Did not receive a valid response") return True @@ -73,7 +87,9 @@ def crl_verify(cert, cert_path): x509.OID_CRL_DISTRIBUTION_POINTS ).value except x509.ExtensionNotFound: - current_app.logger.debug("No CRLDP extension in certificate {}".format(cert.serial_number)) + current_app.logger.debug( + "No CRLDP extension in certificate {}".format(cert.serial_number) + ) return None for p in distribution_points: @@ -92,8 +108,9 @@ def crl_verify(cert, cert_path): except ConnectionError: raise Exception("Unable to retrieve CRL: {0}".format(point)) - crl_cache[point] = x509.load_der_x509_crl(response.content, - backend=default_backend()) + crl_cache[point] = x509.load_der_x509_crl( + response.content, backend=default_backend() + ) else: current_app.logger.debug("CRL point is cached {}".format(point)) @@ -110,8 +127,9 @@ def crl_verify(cert, cert_path): except x509.ExtensionNotFound: pass - current_app.logger.debug("CRL reports certificate " - "revoked: {}".format(cert.serial_number)) + current_app.logger.debug( + "CRL reports certificate " "revoked: {}".format(cert.serial_number) + ) return False return True @@ -125,7 +143,7 @@ def verify(cert_path, issuer_chain_path): :param issuer_chain_path: :return: True if valid, False otherwise """ - with open(cert_path, 'rt') as c: + with open(cert_path, "rt") as c: try: cert = parse_certificate(c.read()) except ValueError as e: @@ -154,10 +172,10 @@ def verify_string(cert_string, issuer_string): :return: True if valid, False otherwise """ with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: + with open(cert_tmp, "w") as f: f.write(cert_string) with mktempfile() as issuer_tmp: - with open(issuer_tmp, 'w') as f: + with open(issuer_tmp, "w") as f: f.write(issuer_string) status = verify(cert_tmp, issuer_tmp) return status diff --git a/lemur/certificates/views.py b/lemur/certificates/views.py index 48f6d672..61a74a59 100644 --- a/lemur/certificates/views.py +++ b/lemur/certificates/views.py @@ -26,14 +26,14 @@ from lemur.certificates.schemas import ( certificate_upload_input_schema, certificates_output_schema, certificate_export_input_schema, - certificate_edit_input_schema + certificate_edit_input_schema, ) from lemur.roles import service as role_service from lemur.logs import service as log_service -mod = Blueprint('certificates', __name__) +mod = Blueprint("certificates", __name__) api = Api(mod) @@ -128,8 +128,8 @@ class CertificatesListValid(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['user'] = g.user - common_name = args['filter'].split(';')[1] + args["user"] = g.user + common_name = args["filter"].split(";")[1] return service.query_common_name(common_name, args) @@ -228,16 +228,18 @@ class CertificatesNameQuery(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('timeRange', type=int, dest='time_range', location='args') - parser.add_argument('owner', type=inputs.boolean, location='args') - parser.add_argument('id', type=str, location='args') - parser.add_argument('active', type=inputs.boolean, location='args') - parser.add_argument('destinationId', type=int, dest="destination_id", location='args') - parser.add_argument('creator', type=str, location='args') - parser.add_argument('show', type=str, location='args') + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") args = parser.parse_args() - args['user'] = g.user + args["user"] = g.user return service.query_name(certificate_name, args) @@ -336,16 +338,18 @@ class CertificatesList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('timeRange', type=int, dest='time_range', location='args') - parser.add_argument('owner', type=inputs.boolean, location='args') - parser.add_argument('id', type=str, location='args') - parser.add_argument('active', type=inputs.boolean, location='args') - parser.add_argument('destinationId', type=int, dest="destination_id", location='args') - parser.add_argument('creator', type=str, location='args') - parser.add_argument('show', type=str, location='args') + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") args = parser.parse_args() - args['user'] = g.user + args["user"] = g.user return service.render(args) @validate_schema(certificate_input_schema, certificate_output_schema) @@ -463,24 +467,31 @@ class CertificatesList(AuthenticatedResource): :statuscode 403: unauthenticated """ - role = role_service.get_by_name(data['authority'].owner) + role = role_service.get_by_name(data["authority"].owner) # all the authority role members should be allowed - roles = [x.name for x in data['authority'].roles] + roles = [x.name for x in data["authority"].roles] # allow "owner" roles by team DL roles.append(role) - authority_permission = AuthorityPermission(data['authority'].id, roles) + authority_permission = AuthorityPermission(data["authority"].id, roles) if authority_permission.can(): - data['creator'] = g.user + data["creator"] = g.user cert = service.create(**data) if isinstance(cert, Certificate): # only log if created, not pending - log_service.create(g.user, 'create_cert', certificate=cert) + log_service.create(g.user, "create_cert", certificate=cert) return cert - return dict(message="You are not authorized to use the authority: {0}".format(data['authority'].name)), 403 + return ( + dict( + message="You are not authorized to use the authority: {0}".format( + data["authority"].name + ) + ), + 403, + ) class CertificatesUpload(AuthenticatedResource): @@ -583,12 +594,14 @@ class CertificatesUpload(AuthenticatedResource): :statuscode 200: no error """ - data['creator'] = g.user - if data.get('destinations'): - if data.get('private_key'): + data["creator"] = g.user + if data.get("destinations"): + if data.get("private_key"): return service.upload(**data) else: - raise Exception("Private key must be provided in order to upload certificate to AWS") + raise Exception( + "Private key must be provided in order to upload certificate to AWS" + ) return service.upload(**data) @@ -600,10 +613,12 @@ class CertificatesStats(AuthenticatedResource): super(CertificatesStats, self).__init__() def get(self): - self.reqparse.add_argument('metric', type=str, location='args') - self.reqparse.add_argument('range', default=32, type=int, location='args') - self.reqparse.add_argument('destinationId', dest='destination_id', location='args') - self.reqparse.add_argument('active', type=str, default='true', location='args') + self.reqparse.add_argument("metric", type=str, location="args") + self.reqparse.add_argument("range", default=32, type=int, location="args") + self.reqparse.add_argument( + "destinationId", dest="destination_id", location="args" + ) + self.reqparse.add_argument("active", type=str, default="true", location="args") args = self.reqparse.parse_args() @@ -655,12 +670,12 @@ class CertificatePrivateKey(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to view this key'), 403 + return dict(message="You are not authorized to view this key"), 403 - log_service.create(g.current_user, 'key_view', certificate=cert) + log_service.create(g.current_user, "key_view", certificate=cert) response = make_response(jsonify(key=cert.private_key), 200) - response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store' - response.headers['pragma'] = 'no-cache' + response.headers["cache-control"] = "private, max-age=0, no-cache, no-store" + response.headers["pragma"] = "no-cache" return response @@ -850,19 +865,25 @@ class Certificates(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to update this certificate'), 403 + return ( + dict(message="You are not authorized to update this certificate"), + 403, + ) - for destination in data['destinations']: + for destination in data["destinations"]: if destination.plugin.requires_key: if not cert.private_key: - return dict( - message='Unable to add destination: {0}. Certificate does not have required private key.'.format( - destination.label - ) - ), 400 + return ( + dict( + message="Unable to add destination: {0}. Certificate does not have required private key.".format( + destination.label + ) + ), + 400, + ) cert = service.update(certificate_id, **data) - log_service.create(g.current_user, 'update_cert', certificate=cert) + log_service.create(g.current_user, "update_cert", certificate=cert) return cert def delete(self, certificate_id, data=None): @@ -891,7 +912,7 @@ class Certificates(AuthenticatedResource): :statuscode 405: certificate deletion is disabled """ - if not current_app.config.get('ALLOW_CERT_DELETION', False): + if not current_app.config.get("ALLOW_CERT_DELETION", False): return dict(message="Certificate deletion is disabled"), 405 cert = service.get(certificate_id) @@ -908,11 +929,14 @@ class Certificates(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to delete this certificate'), 403 + return ( + dict(message="You are not authorized to delete this certificate"), + 403, + ) service.update(certificate_id, deleted=True) - log_service.create(g.current_user, 'delete_cert', certificate=cert) - return 'Certificate deleted', 204 + log_service.create(g.current_user, "delete_cert", certificate=cert) + return "Certificate deleted", 204 class NotificationCertificatesList(AuthenticatedResource): @@ -1012,17 +1036,19 @@ class NotificationCertificatesList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('timeRange', type=int, dest='time_range', location='args') - parser.add_argument('owner', type=inputs.boolean, location='args') - parser.add_argument('id', type=str, location='args') - parser.add_argument('active', type=inputs.boolean, location='args') - parser.add_argument('destinationId', type=int, dest="destination_id", location='args') - parser.add_argument('creator', type=str, location='args') - parser.add_argument('show', type=str, location='args') + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") args = parser.parse_args() - args['notification_id'] = notification_id - args['user'] = g.current_user + args["notification_id"] = notification_id + args["user"] = g.current_user return service.render(args) @@ -1195,30 +1221,48 @@ class CertificateExport(AuthenticatedResource): if not cert: return dict(message="Cannot find specified certificate"), 404 - plugin = data['plugin']['plugin_object'] + plugin = data["plugin"]["plugin_object"] if plugin.requires_key: if not cert.private_key: - return dict( - message='Unable to export certificate, plugin: {0} requires a private key but no key was found.'.format( - plugin.slug)), 400 + return ( + dict( + message="Unable to export certificate, plugin: {0} requires a private key but no key was found.".format( + plugin.slug + ) + ), + 400, + ) else: # allow creators if g.current_user != cert.user: owner_role = role_service.get_by_name(cert.owner) - permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) + permission = CertificatePermission( + owner_role, [x.name for x in cert.roles] + ) if not permission.can(): - return dict(message='You are not authorized to export this certificate.'), 403 + return ( + dict( + message="You are not authorized to export this certificate." + ), + 403, + ) - options = data['plugin']['plugin_options'] + options = data["plugin"]["plugin_options"] - log_service.create(g.current_user, 'key_view', certificate=cert) - extension, passphrase, data = plugin.export(cert.body, cert.chain, cert.private_key, options) + log_service.create(g.current_user, "key_view", certificate=cert) + extension, passphrase, data = plugin.export( + cert.body, cert.chain, cert.private_key, options + ) # we take a hit in message size when b64 encoding - return dict(extension=extension, passphrase=passphrase, data=base64.b64encode(data).decode('utf-8')) + return dict( + extension=extension, + passphrase=passphrase, + data=base64.b64encode(data).decode("utf-8"), + ) class CertificateRevoke(AuthenticatedResource): @@ -1269,30 +1313,66 @@ class CertificateRevoke(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to revoke this certificate.'), 403 + return ( + dict(message="You are not authorized to revoke this certificate."), + 403, + ) if not cert.external_id: - return dict(message='Cannot revoke certificate. No external id found.'), 400 + return dict(message="Cannot revoke certificate. No external id found."), 400 if cert.endpoints: - return dict(message='Cannot revoke certificate. Endpoints are deployed with the given certificate.'), 403 + return ( + dict( + message="Cannot revoke certificate. Endpoints are deployed with the given certificate." + ), + 403, + ) plugin = plugins.get(cert.authority.plugin_name) plugin.revoke_certificate(cert, data) - log_service.create(g.current_user, 'revoke_cert', certificate=cert) + log_service.create(g.current_user, "revoke_cert", certificate=cert) return dict(id=cert.id) -api.add_resource(CertificateRevoke, '/certificates//revoke', endpoint='revokeCertificate') -api.add_resource(CertificatesNameQuery, '/certificates/name/', endpoint='certificatesNameQuery') -api.add_resource(CertificatesList, '/certificates', endpoint='certificates') -api.add_resource(CertificatesListValid, '/certificates/valid', endpoint='certificatesListValid') -api.add_resource(Certificates, '/certificates/', endpoint='certificate') -api.add_resource(CertificatesStats, '/certificates/stats', endpoint='certificateStats') -api.add_resource(CertificatesUpload, '/certificates/upload', endpoint='certificateUpload') -api.add_resource(CertificatePrivateKey, '/certificates//key', endpoint='privateKeyCertificates') -api.add_resource(CertificateExport, '/certificates//export', endpoint='exportCertificate') -api.add_resource(NotificationCertificatesList, '/notifications//certificates', - endpoint='notificationCertificates') -api.add_resource(CertificatesReplacementsList, '/certificates//replacements', - endpoint='replacements') +api.add_resource( + CertificateRevoke, + "/certificates//revoke", + endpoint="revokeCertificate", +) +api.add_resource( + CertificatesNameQuery, + "/certificates/name/", + endpoint="certificatesNameQuery", +) +api.add_resource(CertificatesList, "/certificates", endpoint="certificates") +api.add_resource( + CertificatesListValid, "/certificates/valid", endpoint="certificatesListValid" +) +api.add_resource( + Certificates, "/certificates/", endpoint="certificate" +) +api.add_resource(CertificatesStats, "/certificates/stats", endpoint="certificateStats") +api.add_resource( + CertificatesUpload, "/certificates/upload", endpoint="certificateUpload" +) +api.add_resource( + CertificatePrivateKey, + "/certificates//key", + endpoint="privateKeyCertificates", +) +api.add_resource( + CertificateExport, + "/certificates//export", + endpoint="exportCertificate", +) +api.add_resource( + NotificationCertificatesList, + "/notifications//certificates", + endpoint="notificationCertificates", +) +api.add_resource( + CertificatesReplacementsList, + "/certificates//replacements", + endpoint="replacements", +) diff --git a/lemur/common/celery.py b/lemur/common/celery.py index 23eabddb..7eb1bb0d 100644 --- a/lemur/common/celery.py +++ b/lemur/common/celery.py @@ -32,8 +32,11 @@ else: def make_celery(app): - celery = Celery(app.import_name, backend=app.config.get('CELERY_RESULT_BACKEND'), - broker=app.config.get('CELERY_BROKER_URL')) + celery = Celery( + app.import_name, + backend=app.config.get("CELERY_RESULT_BACKEND"), + broker=app.config.get("CELERY_BROKER_URL"), + ) celery.conf.update(app.config) TaskBase = celery.Task @@ -53,6 +56,7 @@ celery = make_celery(flask_app) def is_task_active(fun, task_id, args): from celery.task.control import inspect + i = inspect() active_tasks = i.active() for _, tasks in active_tasks.items(): @@ -99,7 +103,7 @@ def fetch_acme_cert(id): # We only care about certs using the acme-issuer plugin for cert in pending_certs: cert_authority = get_authority(cert.authority_id) - if cert_authority.plugin_name == 'acme-issuer': + if cert_authority.plugin_name == "acme-issuer": acme_certs.append(cert) else: wrong_issuer += 1 @@ -112,20 +116,22 @@ def fetch_acme_cert(id): # It's necessary to reload the pending cert due to detached instance: http://sqlalche.me/e/bhk3 pending_cert = pending_certificate_service.get(cert.get("pending_cert").id) if not pending_cert: - log_data["message"] = "Pending certificate doesn't exist anymore. Was it resolved by another process?" + log_data[ + "message" + ] = "Pending certificate doesn't exist anymore. Was it resolved by another process?" current_app.logger.error(log_data) continue if real_cert: # If a real certificate was returned from issuer, then create it in Lemur and mark # the pending certificate as resolved - final_cert = pending_certificate_service.create_certificate(pending_cert, real_cert, pending_cert.user) - pending_certificate_service.update( - cert.get("pending_cert").id, - resolved_cert_id=final_cert.id + final_cert = pending_certificate_service.create_certificate( + pending_cert, real_cert, pending_cert.user ) pending_certificate_service.update( - cert.get("pending_cert").id, - resolved=True + cert.get("pending_cert").id, resolved_cert_id=final_cert.id + ) + pending_certificate_service.update( + cert.get("pending_cert").id, resolved=True ) # add metrics to metrics extension new += 1 @@ -139,17 +145,17 @@ def fetch_acme_cert(id): if pending_cert.number_attempts > 4: error_log["message"] = "Deleting pending certificate" - send_pending_failure_notification(pending_cert, notify_owner=pending_cert.notify) + send_pending_failure_notification( + pending_cert, notify_owner=pending_cert.notify + ) # Mark the pending cert as resolved pending_certificate_service.update( - cert.get("pending_cert").id, - resolved=True + cert.get("pending_cert").id, resolved=True ) else: pending_certificate_service.increment_attempt(pending_cert) pending_certificate_service.update( - cert.get("pending_cert").id, - status=str(cert.get("last_error")) + cert.get("pending_cert").id, status=str(cert.get("last_error")) ) # Add failed pending cert task back to queue fetch_acme_cert.delay(id) @@ -161,9 +167,7 @@ def fetch_acme_cert(id): current_app.logger.debug(log_data) print( "[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format( - new=new, - failed=failed, - wrong_issuer=wrong_issuer + new=new, failed=failed, wrong_issuer=wrong_issuer ) ) @@ -175,7 +179,7 @@ def fetch_all_pending_acme_certs(): log_data = { "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name), - "message": "Starting job." + "message": "Starting job.", } current_app.logger.debug(log_data) @@ -183,7 +187,7 @@ def fetch_all_pending_acme_certs(): # We only care about certs using the acme-issuer plugin for cert in pending_certs: cert_authority = get_authority(cert.authority_id) - if cert_authority.plugin_name == 'acme-issuer': + if cert_authority.plugin_name == "acme-issuer": if datetime.now(timezone.utc) - cert.last_updated > timedelta(minutes=5): log_data["message"] = "Triggering job for cert {}".format(cert.name) log_data["cert_name"] = cert.name @@ -195,17 +199,15 @@ def fetch_all_pending_acme_certs(): @celery.task() def remove_old_acme_certs(): """Prune old pending acme certificates from the database""" - log_data = { - "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name) - } - pending_certs = pending_certificate_service.get_pending_certs('all') + log_data = {"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)} + pending_certs = pending_certificate_service.get_pending_certs("all") # Delete pending certs more than a week old for cert in pending_certs: if datetime.now(timezone.utc) - cert.last_updated > timedelta(days=7): - log_data['pending_cert_id'] = cert.id - log_data['pending_cert_name'] = cert.name - log_data['message'] = "Deleting pending certificate" + log_data["pending_cert_id"] = cert.id + log_data["pending_cert_name"] = cert.name + log_data["message"] = "Deleting pending certificate" current_app.logger.debug(log_data) pending_certificate_service.delete(cert) @@ -218,7 +220,9 @@ def clean_all_sources(): """ sources = validate_sources("all") for source in sources: - current_app.logger.debug("Creating celery task to clean source {}".format(source.label)) + current_app.logger.debug( + "Creating celery task to clean source {}".format(source.label) + ) clean_source.delay(source.label) @@ -242,7 +246,9 @@ def sync_all_sources(): """ sources = validate_sources("all") for source in sources: - current_app.logger.debug("Creating celery task to sync source {}".format(source.label)) + current_app.logger.debug( + "Creating celery task to sync source {}".format(source.label) + ) sync_source.delay(source.label) @@ -277,7 +283,9 @@ def sync_source(source): log_data["message"] = "Error syncing source: Time limit exceeded." current_app.logger.error(log_data) sentry.captureException() - metrics.send('sync_source_timeout', 'counter', 1, metric_tags={'source': source}) + metrics.send( + "sync_source_timeout", "counter", 1, metric_tags={"source": source} + ) return log_data["message"] = "Done syncing source" diff --git a/lemur/common/defaults.py b/lemur/common/defaults.py index 6b259f6b..d563dbd0 100644 --- a/lemur/common/defaults.py +++ b/lemur/common/defaults.py @@ -9,18 +9,20 @@ from lemur.extensions import sentry from lemur.constants import SAN_NAMING_TEMPLATE, DEFAULT_NAMING_TEMPLATE -def text_to_slug(value, joiner='-'): +def text_to_slug(value, joiner="-"): """ Normalize a string to a "slug" value, stripping character accents and removing non-alphanum characters. A series of non-alphanumeric characters is replaced with the joiner character. """ # Strip all character accents: decompose Unicode characters and then drop combining chars. - value = ''.join(c for c in unicodedata.normalize('NFKD', value) if not unicodedata.combining(c)) + value = "".join( + c for c in unicodedata.normalize("NFKD", value) if not unicodedata.combining(c) + ) # Replace all remaining non-alphanumeric characters with joiner string. Multiple characters get collapsed into a # single joiner. Except, keep 'xn--' used in IDNA domain names as is. - value = re.sub(r'[^A-Za-z0-9.]+(?' + return "" # Try Common Name or fall back to Organization name - attrs = (cert.issuer.get_attributes_for_oid(x509.OID_COMMON_NAME) or - cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)) + attrs = cert.issuer.get_attributes_for_oid( + x509.OID_COMMON_NAME + ) or cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME) if not attrs: - current_app.logger.error("Unable to get issuer! Cert serial {:x}".format(cert.serial_number)) - return '' + current_app.logger.error( + "Unable to get issuer! Cert serial {:x}".format(cert.serial_number) + ) + return "" - return text_to_slug(attrs[0].value, '') + return text_to_slug(attrs[0].value, "") def not_before(cert): diff --git a/lemur/common/fields.py b/lemur/common/fields.py index 5ab0c6f0..15631832 100644 --- a/lemur/common/fields.py +++ b/lemur/common/fields.py @@ -25,6 +25,7 @@ class Hex(Field): """ A hex formatted string. """ + def _serialize(self, value, attr, obj): if value: value = hex(int(value))[2:].upper() @@ -48,25 +49,25 @@ class ArrowDateTime(Field): """ DATEFORMAT_SERIALIZATION_FUNCS = { - 'iso': utils.isoformat, - 'iso8601': utils.isoformat, - 'rfc': utils.rfcformat, - 'rfc822': utils.rfcformat, + "iso": utils.isoformat, + "iso8601": utils.isoformat, + "rfc": utils.rfcformat, + "rfc822": utils.rfcformat, } DATEFORMAT_DESERIALIZATION_FUNCS = { - 'iso': utils.from_iso, - 'iso8601': utils.from_iso, - 'rfc': utils.from_rfc, - 'rfc822': utils.from_rfc, + "iso": utils.from_iso, + "iso8601": utils.from_iso, + "rfc": utils.from_rfc, + "rfc822": utils.from_rfc, } - DEFAULT_FORMAT = 'iso' + DEFAULT_FORMAT = "iso" localtime = False default_error_messages = { - 'invalid': 'Not a valid datetime.', - 'format': '"{input}" cannot be formatted as a datetime.', + "invalid": "Not a valid datetime.", + "format": '"{input}" cannot be formatted as a datetime.', } def __init__(self, format=None, **kwargs): @@ -89,34 +90,36 @@ class ArrowDateTime(Field): try: return format_func(value, localtime=self.localtime) except (AttributeError, ValueError) as err: - self.fail('format', input=value) + self.fail("format", input=value) else: return value.strftime(self.dateformat) def _deserialize(self, value, attr, data): if not value: # Falsy values, e.g. '', None, [] are not valid - raise self.fail('invalid') + raise self.fail("invalid") self.dateformat = self.dateformat or self.DEFAULT_FORMAT func = self.DATEFORMAT_DESERIALIZATION_FUNCS.get(self.dateformat) if func: try: return arrow.get(func(value)) except (TypeError, AttributeError, ValueError): - raise self.fail('invalid') + raise self.fail("invalid") elif self.dateformat: try: return dt.datetime.strptime(value, self.dateformat) except (TypeError, AttributeError, ValueError): - raise self.fail('invalid') + raise self.fail("invalid") elif utils.dateutil_available: try: return arrow.get(utils.from_datestring(value)) except TypeError: - raise self.fail('invalid') + raise self.fail("invalid") else: - warnings.warn('It is recommended that you install python-dateutil ' - 'for improved datetime deserialization.') - raise self.fail('invalid') + warnings.warn( + "It is recommended that you install python-dateutil " + "for improved datetime deserialization." + ) + raise self.fail("invalid") class KeyUsageExtension(Field): @@ -131,73 +134,75 @@ class KeyUsageExtension(Field): def _serialize(self, value, attr, obj): return { - 'useDigitalSignature': value.digital_signature, - 'useNonRepudiation': value.content_commitment, - 'useKeyEncipherment': value.key_encipherment, - 'useDataEncipherment': value.data_encipherment, - 'useKeyAgreement': value.key_agreement, - 'useKeyCertSign': value.key_cert_sign, - 'useCRLSign': value.crl_sign, - 'useEncipherOnly': value._encipher_only, - 'useDecipherOnly': value._decipher_only + "useDigitalSignature": value.digital_signature, + "useNonRepudiation": value.content_commitment, + "useKeyEncipherment": value.key_encipherment, + "useDataEncipherment": value.data_encipherment, + "useKeyAgreement": value.key_agreement, + "useKeyCertSign": value.key_cert_sign, + "useCRLSign": value.crl_sign, + "useEncipherOnly": value._encipher_only, + "useDecipherOnly": value._decipher_only, } def _deserialize(self, value, attr, data): keyusages = { - 'digital_signature': False, - 'content_commitment': False, - 'key_encipherment': False, - 'data_encipherment': False, - 'key_agreement': False, - 'key_cert_sign': False, - 'crl_sign': False, - 'encipher_only': False, - 'decipher_only': False + "digital_signature": False, + "content_commitment": False, + "key_encipherment": False, + "data_encipherment": False, + "key_agreement": False, + "key_cert_sign": False, + "crl_sign": False, + "encipher_only": False, + "decipher_only": False, } for k, v in value.items(): - if k == 'useDigitalSignature': - keyusages['digital_signature'] = v + if k == "useDigitalSignature": + keyusages["digital_signature"] = v - elif k == 'useNonRepudiation': - keyusages['content_commitment'] = v + elif k == "useNonRepudiation": + keyusages["content_commitment"] = v - elif k == 'useKeyEncipherment': - keyusages['key_encipherment'] = v + elif k == "useKeyEncipherment": + keyusages["key_encipherment"] = v - elif k == 'useDataEncipherment': - keyusages['data_encipherment'] = v + elif k == "useDataEncipherment": + keyusages["data_encipherment"] = v - elif k == 'useKeyCertSign': - keyusages['key_cert_sign'] = v + elif k == "useKeyCertSign": + keyusages["key_cert_sign"] = v - elif k == 'useCRLSign': - keyusages['crl_sign'] = v + elif k == "useCRLSign": + keyusages["crl_sign"] = v - elif k == 'useKeyAgreement': - keyusages['key_agreement'] = v + elif k == "useKeyAgreement": + keyusages["key_agreement"] = v - elif k == 'useEncipherOnly' and v: - keyusages['encipher_only'] = True - keyusages['key_agreement'] = True + elif k == "useEncipherOnly" and v: + keyusages["encipher_only"] = True + keyusages["key_agreement"] = True - elif k == 'useDecipherOnly' and v: - keyusages['decipher_only'] = True - keyusages['key_agreement'] = True + elif k == "useDecipherOnly" and v: + keyusages["decipher_only"] = True + keyusages["key_agreement"] = True - if keyusages['encipher_only'] and keyusages['decipher_only']: - raise ValidationError('A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages.') + if keyusages["encipher_only"] and keyusages["decipher_only"]: + raise ValidationError( + "A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages." + ) return x509.KeyUsage( - digital_signature=keyusages['digital_signature'], - content_commitment=keyusages['content_commitment'], - key_encipherment=keyusages['key_encipherment'], - data_encipherment=keyusages['data_encipherment'], - key_agreement=keyusages['key_agreement'], - key_cert_sign=keyusages['key_cert_sign'], - crl_sign=keyusages['crl_sign'], - encipher_only=keyusages['encipher_only'], - decipher_only=keyusages['decipher_only'] + digital_signature=keyusages["digital_signature"], + content_commitment=keyusages["content_commitment"], + key_encipherment=keyusages["key_encipherment"], + data_encipherment=keyusages["data_encipherment"], + key_agreement=keyusages["key_agreement"], + key_cert_sign=keyusages["key_cert_sign"], + crl_sign=keyusages["crl_sign"], + encipher_only=keyusages["encipher_only"], + decipher_only=keyusages["decipher_only"], ) @@ -216,69 +221,77 @@ class ExtendedKeyUsageExtension(Field): usage_list = {} for usage in usages: if usage == x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH: - usage_list['useClientAuthentication'] = True + usage_list["useClientAuthentication"] = True elif usage == x509.oid.ExtendedKeyUsageOID.SERVER_AUTH: - usage_list['useServerAuthentication'] = True + usage_list["useServerAuthentication"] = True elif usage == x509.oid.ExtendedKeyUsageOID.CODE_SIGNING: - usage_list['useCodeSigning'] = True + usage_list["useCodeSigning"] = True elif usage == x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION: - usage_list['useEmailProtection'] = True + usage_list["useEmailProtection"] = True elif usage == x509.oid.ExtendedKeyUsageOID.TIME_STAMPING: - usage_list['useTimestamping'] = True + usage_list["useTimestamping"] = True elif usage == x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING: - usage_list['useOCSPSigning'] = True + usage_list["useOCSPSigning"] = True - elif usage.dotted_string == '1.3.6.1.5.5.7.3.14': - usage_list['useEapOverLAN'] = True + elif usage.dotted_string == "1.3.6.1.5.5.7.3.14": + usage_list["useEapOverLAN"] = True - elif usage.dotted_string == '1.3.6.1.5.5.7.3.13': - usage_list['useEapOverPPP'] = True + elif usage.dotted_string == "1.3.6.1.5.5.7.3.13": + usage_list["useEapOverPPP"] = True - elif usage.dotted_string == '1.3.6.1.4.1.311.20.2.2': - usage_list['useSmartCardLogon'] = True + elif usage.dotted_string == "1.3.6.1.4.1.311.20.2.2": + usage_list["useSmartCardLogon"] = True else: - current_app.logger.warning('Unable to serialize ExtendedKeyUsage with OID: {usage}'.format(usage=usage.dotted_string)) + current_app.logger.warning( + "Unable to serialize ExtendedKeyUsage with OID: {usage}".format( + usage=usage.dotted_string + ) + ) return usage_list def _deserialize(self, value, attr, data): usage_oids = [] for k, v in value.items(): - if k == 'useClientAuthentication' and v: + if k == "useClientAuthentication" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH) - elif k == 'useServerAuthentication' and v: + elif k == "useServerAuthentication" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.SERVER_AUTH) - elif k == 'useCodeSigning' and v: + elif k == "useCodeSigning" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.CODE_SIGNING) - elif k == 'useEmailProtection' and v: + elif k == "useEmailProtection" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION) - elif k == 'useTimestamping' and v: + elif k == "useTimestamping" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.TIME_STAMPING) - elif k == 'useOCSPSigning' and v: + elif k == "useOCSPSigning" and v: usage_oids.append(x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING) - elif k == 'useEapOverLAN' and v: + elif k == "useEapOverLAN" and v: usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.14")) - elif k == 'useEapOverPPP' and v: + elif k == "useEapOverPPP" and v: usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.13")) - elif k == 'useSmartCardLogon' and v: + elif k == "useSmartCardLogon" and v: usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.4.1.311.20.2.2")) else: - current_app.logger.warning('Unable to deserialize ExtendedKeyUsage with name: {key}'.format(key=k)) + current_app.logger.warning( + "Unable to deserialize ExtendedKeyUsage with name: {key}".format( + key=k + ) + ) return x509.ExtendedKeyUsage(usage_oids) @@ -294,15 +307,17 @@ class BasicConstraintsExtension(Field): """ def _serialize(self, value, attr, obj): - return {'ca': value.ca, 'path_length': value.path_length} + return {"ca": value.ca, "path_length": value.path_length} def _deserialize(self, value, attr, data): - ca = value.get('ca', False) - path_length = value.get('path_length', None) + ca = value.get("ca", False) + path_length = value.get("path_length", None) if ca: if not isinstance(path_length, (type(None), int)): - raise ValidationError('A CA certificate path_length (for BasicConstraints) must be None or an integer.') + raise ValidationError( + "A CA certificate path_length (for BasicConstraints) must be None or an integer." + ) return x509.BasicConstraints(ca=True, path_length=path_length) else: return x509.BasicConstraints(ca=False, path_length=None) @@ -317,6 +332,7 @@ class SubjectAlternativeNameExtension(Field): :param kwargs: The same keyword arguments that :class:`Field` receives. """ + def _serialize(self, value, attr, obj): general_names = [] name_type = None @@ -326,53 +342,59 @@ class SubjectAlternativeNameExtension(Field): value = name.value if isinstance(name, x509.DNSName): - name_type = 'DNSName' + name_type = "DNSName" elif isinstance(name, x509.IPAddress): if isinstance(value, ipaddress.IPv4Network): - name_type = 'IPNetwork' + name_type = "IPNetwork" else: - name_type = 'IPAddress' + name_type = "IPAddress" value = str(value) elif isinstance(name, x509.UniformResourceIdentifier): - name_type = 'uniformResourceIdentifier' + name_type = "uniformResourceIdentifier" elif isinstance(name, x509.DirectoryName): - name_type = 'directoryName' + name_type = "directoryName" elif isinstance(name, x509.RFC822Name): - name_type = 'rfc822Name' + name_type = "rfc822Name" elif isinstance(name, x509.RegisteredID): - name_type = 'registeredID' + name_type = "registeredID" value = value.dotted_string else: - current_app.logger.warning('Unknown SubAltName type: {name}'.format(name=name)) + current_app.logger.warning( + "Unknown SubAltName type: {name}".format(name=name) + ) continue - general_names.append({'nameType': name_type, 'value': value}) + general_names.append({"nameType": name_type, "value": value}) return general_names def _deserialize(self, value, attr, data): general_names = [] for name in value: - if name['nameType'] == 'DNSName': - validators.sensitive_domain(name['value']) - general_names.append(x509.DNSName(name['value'])) + if name["nameType"] == "DNSName": + validators.sensitive_domain(name["value"]) + general_names.append(x509.DNSName(name["value"])) - elif name['nameType'] == 'IPAddress': - general_names.append(x509.IPAddress(ipaddress.ip_address(name['value']))) + elif name["nameType"] == "IPAddress": + general_names.append( + x509.IPAddress(ipaddress.ip_address(name["value"])) + ) - elif name['nameType'] == 'IPNetwork': - general_names.append(x509.IPAddress(ipaddress.ip_network(name['value']))) + elif name["nameType"] == "IPNetwork": + general_names.append( + x509.IPAddress(ipaddress.ip_network(name["value"])) + ) - elif name['nameType'] == 'uniformResourceIdentifier': - general_names.append(x509.UniformResourceIdentifier(name['value'])) + elif name["nameType"] == "uniformResourceIdentifier": + general_names.append(x509.UniformResourceIdentifier(name["value"])) - elif name['nameType'] == 'directoryName': + elif name["nameType"] == "directoryName": # TODO: Need to parse a string in name['value'] like: # 'CN=Common Name, O=Org Name, OU=OrgUnit Name, C=US, ST=ST, L=City/emailAddress=person@example.com' # or @@ -390,26 +412,32 @@ class SubjectAlternativeNameExtension(Field): # general_names.append(x509.DirectoryName(x509.Name(BLAH)))) pass - elif name['nameType'] == 'rfc822Name': - general_names.append(x509.RFC822Name(name['value'])) + elif name["nameType"] == "rfc822Name": + general_names.append(x509.RFC822Name(name["value"])) - elif name['nameType'] == 'registeredID': - general_names.append(x509.RegisteredID(x509.ObjectIdentifier(name['value']))) + elif name["nameType"] == "registeredID": + general_names.append( + x509.RegisteredID(x509.ObjectIdentifier(name["value"])) + ) - elif name['nameType'] == 'otherName': + elif name["nameType"] == "otherName": # This has two inputs (type and value), so it doesn't fit the mold of the rest of these GeneralName entities. # general_names.append(x509.OtherName(name['type'], bytes(name['value']), 'utf-8')) pass - elif name['nameType'] == 'x400Address': + elif name["nameType"] == "x400Address": # The Python Cryptography library doesn't support x400Address types (yet?) pass - elif name['nameType'] == 'EDIPartyName': + elif name["nameType"] == "EDIPartyName": # The Python Cryptography library doesn't support EDIPartyName types (yet?) pass else: - current_app.logger.warning('Unable to deserialize SubAltName with type: {name_type}'.format(name_type=name['nameType'])) + current_app.logger.warning( + "Unable to deserialize SubAltName with type: {name_type}".format( + name_type=name["nameType"] + ) + ) return x509.SubjectAlternativeName(general_names) diff --git a/lemur/common/health.py b/lemur/common/health.py index 69df3f0c..7e0a17ff 100644 --- a/lemur/common/health.py +++ b/lemur/common/health.py @@ -10,20 +10,20 @@ from flask import Blueprint from lemur.database import db from lemur.extensions import sentry -mod = Blueprint('healthCheck', __name__) +mod = Blueprint("healthCheck", __name__) -@mod.route('/healthcheck') +@mod.route("/healthcheck") def health(): try: if healthcheck(db): - return 'ok' + return "ok" except Exception: sentry.captureException() - return 'db check failed' + return "db check failed" def healthcheck(db): with db.engine.connect() as connection: - connection.execute('SELECT 1;') + connection.execute("SELECT 1;") return True diff --git a/lemur/common/managers.py b/lemur/common/managers.py index 9f30f216..6ce2608f 100644 --- a/lemur/common/managers.py +++ b/lemur/common/managers.py @@ -52,7 +52,7 @@ class InstanceManager(object): results = [] for cls_path in class_list: - module_name, class_name = cls_path.rsplit('.', 1) + module_name, class_name = cls_path.rsplit(".", 1) try: module = __import__(module_name, {}, {}, class_name) cls = getattr(module, class_name) @@ -62,10 +62,14 @@ class InstanceManager(object): results.append(cls) except InvalidConfiguration as e: - current_app.logger.warning("Plugin '{0}' may not work correctly. {1}".format(class_name, e)) + current_app.logger.warning( + "Plugin '{0}' may not work correctly. {1}".format(class_name, e) + ) except Exception as e: - current_app.logger.exception("Unable to import {0}. Reason: {1}".format(cls_path, e)) + current_app.logger.exception( + "Unable to import {0}. Reason: {1}".format(cls_path, e) + ) continue self.cache = results diff --git a/lemur/common/missing.py b/lemur/common/missing.py index 5c7dffac..2f5156df 100644 --- a/lemur/common/missing.py +++ b/lemur/common/missing.py @@ -11,15 +11,15 @@ def convert_validity_years(data): :param data: :return: """ - if data.get('validity_years'): + if data.get("validity_years"): now = arrow.utcnow() - data['validity_start'] = now.isoformat() + data["validity_start"] = now.isoformat() - end = now.replace(years=+int(data['validity_years'])) + end = now.replace(years=+int(data["validity_years"])) - if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True): + if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True): if is_weekend(end): end = end.replace(days=-2) - data['validity_end'] = end.isoformat() + data["validity_end"] = end.isoformat() return data diff --git a/lemur/common/schema.py b/lemur/common/schema.py index ee765dc4..bfa0a091 100644 --- a/lemur/common/schema.py +++ b/lemur/common/schema.py @@ -22,27 +22,26 @@ class LemurSchema(Schema): """ Base schema from which all grouper schema's inherit """ + __envelope__ = True def under(self, data, many=None): items = [] if many: for i in data: - items.append( - {underscore(key): value for key, value in i.items()} - ) + items.append({underscore(key): value for key, value in i.items()}) return items - return { - underscore(key): value - for key, value in data.items() - } + return {underscore(key): value for key, value in data.items()} def camel(self, data, many=None): items = [] if many: for i in data: items.append( - {camelize(key, uppercase_first_letter=False): value for key, value in i.items()} + { + camelize(key, uppercase_first_letter=False): value + for key, value in i.items() + } ) return items return { @@ -52,16 +51,16 @@ class LemurSchema(Schema): def wrap_with_envelope(self, data, many): if many: - if 'total' in self.context.keys(): - return dict(total=self.context['total'], items=data) + if "total" in self.context.keys(): + return dict(total=self.context["total"], items=data) return data class LemurInputSchema(LemurSchema): @pre_load(pass_many=True) def preprocess(self, data, many): - if isinstance(data, dict) and data.get('owner'): - data['owner'] = data['owner'].lower() + if isinstance(data, dict) and data.get("owner"): + data["owner"] = data["owner"].lower() return self.under(data, many=many) @@ -74,17 +73,17 @@ class LemurOutputSchema(LemurSchema): def unwrap_envelope(self, data, many): if many: - if data['items']: + if data["items"]: if isinstance(data, InstrumentedList) or isinstance(data, list): - self.context['total'] = len(data) + self.context["total"] = len(data) return data else: - self.context['total'] = data['total'] + self.context["total"] = data["total"] else: - self.context['total'] = 0 - data = {'items': []} + self.context["total"] = 0 + data = {"items": []} - return data['items'] + return data["items"] return data @@ -110,11 +109,11 @@ def format_errors(messages): def wrap_errors(messages): - errors = dict(message='Validation Error.') - if messages.get('_schema'): - errors['reasons'] = {'Schema': {'rule': messages['_schema']}} + errors = dict(message="Validation Error.") + if messages.get("_schema"): + errors["reasons"] = {"Schema": {"rule": messages["_schema"]}} else: - errors['reasons'] = format_errors(messages) + errors["reasons"] = format_errors(messages) return errors @@ -123,19 +122,19 @@ def unwrap_pagination(data, output_schema): return data if isinstance(data, dict): - if 'total' in data.keys(): - if data.get('total') == 0: + if "total" in data.keys(): + if data.get("total") == 0: return data - marshaled_data = {'total': data['total']} - marshaled_data['items'] = output_schema.dump(data['items'], many=True).data + marshaled_data = {"total": data["total"]} + marshaled_data["items"] = output_schema.dump(data["items"], many=True).data return marshaled_data return output_schema.dump(data).data elif isinstance(data, list): - marshaled_data = {'total': len(data)} - marshaled_data['items'] = output_schema.dump(data, many=True).data + marshaled_data = {"total": len(data)} + marshaled_data["items"] = output_schema.dump(data, many=True).data return marshaled_data return output_schema.dump(data).data @@ -155,7 +154,7 @@ def validate_schema(input_schema, output_schema): if errors: return wrap_errors(errors), 400 - kwargs['data'] = data + kwargs["data"] = data try: resp = f(*args, **kwargs) @@ -173,4 +172,5 @@ def validate_schema(input_schema, output_schema): return unwrap_pagination(resp, output_schema), 200 return decorated_function + return decorator diff --git a/lemur/common/utils.py b/lemur/common/utils.py index 40f828f3..c33722b2 100644 --- a/lemur/common/utils.py +++ b/lemur/common/utils.py @@ -25,22 +25,22 @@ from lemur.exceptions import InvalidConfiguration paginated_parser = RequestParser() -paginated_parser.add_argument('count', type=int, default=10, location='args') -paginated_parser.add_argument('page', type=int, default=1, location='args') -paginated_parser.add_argument('sortDir', type=str, dest='sort_dir', location='args') -paginated_parser.add_argument('sortBy', type=str, dest='sort_by', location='args') -paginated_parser.add_argument('filter', type=str, location='args') -paginated_parser.add_argument('owner', type=str, location='args') +paginated_parser.add_argument("count", type=int, default=10, location="args") +paginated_parser.add_argument("page", type=int, default=1, location="args") +paginated_parser.add_argument("sortDir", type=str, dest="sort_dir", location="args") +paginated_parser.add_argument("sortBy", type=str, dest="sort_by", location="args") +paginated_parser.add_argument("filter", type=str, location="args") +paginated_parser.add_argument("owner", type=str, location="args") def get_psuedo_random_string(): """ Create a random and strongish challenge. """ - challenge = ''.join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa - challenge += ''.join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa - challenge += ''.join(random.choice(string.ascii_lowercase) for x in range(6)) - challenge += ''.join(random.choice(string.digits) for x in range(6)) # noqa + challenge = "".join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa + challenge += "".join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa + challenge += "".join(random.choice(string.ascii_lowercase) for x in range(6)) + challenge += "".join(random.choice(string.digits) for x in range(6)) # noqa return challenge @@ -53,7 +53,7 @@ def parse_certificate(body): """ assert isinstance(body, str) - return x509.load_pem_x509_certificate(body.encode('utf-8'), default_backend()) + return x509.load_pem_x509_certificate(body.encode("utf-8"), default_backend()) def parse_private_key(private_key): @@ -66,7 +66,9 @@ def parse_private_key(private_key): """ assert isinstance(private_key, str) - return load_pem_private_key(private_key.encode('utf8'), password=None, backend=default_backend()) + return load_pem_private_key( + private_key.encode("utf8"), password=None, backend=default_backend() + ) def split_pem(data): @@ -100,14 +102,15 @@ def parse_csr(csr): """ assert isinstance(csr, str) - return x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend()) + return x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend()) def get_authority_key(body): """Returns the authority key for a given certificate in hex format""" parsed_cert = parse_certificate(body) authority_key = parsed_cert.extensions.get_extension_for_class( - x509.AuthorityKeyIdentifier).value.key_identifier + x509.AuthorityKeyIdentifier + ).value.key_identifier return authority_key.hex() @@ -127,20 +130,17 @@ def generate_private_key(key_type): _CURVE_TYPES = { "ECCPRIME192V1": ec.SECP192R1(), "ECCPRIME256V1": ec.SECP256R1(), - "ECCSECP192R1": ec.SECP192R1(), "ECCSECP224R1": ec.SECP224R1(), "ECCSECP256R1": ec.SECP256R1(), "ECCSECP384R1": ec.SECP384R1(), "ECCSECP521R1": ec.SECP521R1(), "ECCSECP256K1": ec.SECP256K1(), - "ECCSECT163K1": ec.SECT163K1(), "ECCSECT233K1": ec.SECT233K1(), "ECCSECT283K1": ec.SECT283K1(), "ECCSECT409K1": ec.SECT409K1(), "ECCSECT571K1": ec.SECT571K1(), - "ECCSECT163R2": ec.SECT163R2(), "ECCSECT233R1": ec.SECT233R1(), "ECCSECT283R1": ec.SECT283R1(), @@ -149,22 +149,20 @@ def generate_private_key(key_type): } if key_type not in CERTIFICATE_KEY_TYPES: - raise Exception("Invalid key type: {key_type}. Supported key types: {choices}".format( - key_type=key_type, - choices=",".join(CERTIFICATE_KEY_TYPES) - )) + raise Exception( + "Invalid key type: {key_type}. Supported key types: {choices}".format( + key_type=key_type, choices=",".join(CERTIFICATE_KEY_TYPES) + ) + ) - if 'RSA' in key_type: + if "RSA" in key_type: key_size = int(key_type[3:]) return rsa.generate_private_key( - public_exponent=65537, - key_size=key_size, - backend=default_backend() + public_exponent=65537, key_size=key_size, backend=default_backend() ) - elif 'ECC' in key_type: + elif "ECC" in key_type: return ec.generate_private_key( - curve=_CURVE_TYPES[key_type], - backend=default_backend() + curve=_CURVE_TYPES[key_type], backend=default_backend() ) @@ -184,11 +182,26 @@ def check_cert_signature(cert, issuer_public_key): raise UnsupportedAlgorithm("RSASSA-PSS not supported") else: padder = padding.PKCS1v15() - issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, padder, cert.signature_hash_algorithm) - elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance(ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA): - issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, ec.ECDSA(cert.signature_hash_algorithm)) + issuer_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + padder, + cert.signature_hash_algorithm, + ) + elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance( + ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA + ): + issuer_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + ec.ECDSA(cert.signature_hash_algorithm), + ) else: - raise UnsupportedAlgorithm("Unsupported Algorithm '{var}'.".format(var=cert.signature_algorithm_oid._name)) + raise UnsupportedAlgorithm( + "Unsupported Algorithm '{var}'.".format( + var=cert.signature_algorithm_oid._name + ) + ) def is_selfsigned(cert): @@ -224,7 +237,9 @@ def validate_conf(app, required_vars): """ for var in required_vars: if var not in app.config: - raise InvalidConfiguration("Required variable '{var}' is not set in Lemur's conf.".format(var=var)) + raise InvalidConfiguration( + "Required variable '{var}' is not set in Lemur's conf.".format(var=var) + ) # https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/WindowedRangeQuery @@ -243,18 +258,15 @@ def column_windows(session, column, windowsize): be computed. """ + def int_for_range(start_id, end_id): if end_id: - return and_( - column >= start_id, - column < end_id - ) + return and_(column >= start_id, column < end_id) else: return column >= start_id q = session.query( - column, - func.row_number().over(order_by=column).label('rownum') + column, func.row_number().over(order_by=column).label("rownum") ).from_self(column) if windowsize > 1: @@ -274,9 +286,7 @@ def column_windows(session, column, windowsize): def windowed_query(q, column, windowsize): """"Break a Query into windows on a given column.""" - for whereclause in column_windows( - q.session, - column, windowsize): + for whereclause in column_windows(q.session, column, windowsize): for row in q.filter(whereclause).order_by(column): yield row @@ -284,7 +294,7 @@ def windowed_query(q, column, windowsize): def truthiness(s): """If input string resembles something truthy then return True, else False.""" - return s.lower() in ('true', 'yes', 'on', 't', '1') + return s.lower() in ("true", "yes", "on", "t", "1") def find_matching_certificates_by_hash(cert, matching_certs): @@ -292,6 +302,8 @@ def find_matching_certificates_by_hash(cert, matching_certs): determine if any of the certificate hashes match and return the matches.""" matching = [] for c in matching_certs: - if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint(hashes.SHA256()): + if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint( + hashes.SHA256() + ): matching.append(c) return matching diff --git a/lemur/common/validators.py b/lemur/common/validators.py index 91b831ba..3e6ebcf9 100644 --- a/lemur/common/validators.py +++ b/lemur/common/validators.py @@ -16,7 +16,7 @@ def common_name(value): # Common name could be a domain name, or a human-readable name of the subject (often used in CA names or client # certificates). As a simple heuristic, we assume that human-readable names always include a space. # However, to avoid confusion for humans, we also don't count spaces at the beginning or end of the string. - if ' ' not in value.strip(): + if " " not in value.strip(): return sensitive_domain(value) @@ -30,17 +30,21 @@ def sensitive_domain(domain): # User has permission, no need to check anything return - whitelist = current_app.config.get('LEMUR_WHITELISTED_DOMAINS', []) + whitelist = current_app.config.get("LEMUR_WHITELISTED_DOMAINS", []) if whitelist and not any(re.match(pattern, domain) for pattern in whitelist): - raise ValidationError('Domain {0} does not match whitelisted domain patterns. ' - 'Contact an administrator to issue the certificate.'.format(domain)) + raise ValidationError( + "Domain {0} does not match whitelisted domain patterns. " + "Contact an administrator to issue the certificate.".format(domain) + ) # Avoid circular import. from lemur.domains import service as domain_service if any(d.sensitive for d in domain_service.get_by_name(domain)): - raise ValidationError('Domain {0} has been marked as sensitive. ' - 'Contact an administrator to issue the certificate.'.format(domain)) + raise ValidationError( + "Domain {0} has been marked as sensitive. " + "Contact an administrator to issue the certificate.".format(domain) + ) def encoding(oid_encoding): @@ -49,9 +53,13 @@ def encoding(oid_encoding): :param oid_encoding: :return: """ - valid_types = ['b64asn1', 'string', 'ia5string'] + valid_types = ["b64asn1", "string", "ia5string"] if oid_encoding.lower() not in [o_type.lower() for o_type in valid_types]: - raise ValidationError('Invalid Oid Encoding: {0} choose from {1}'.format(oid_encoding, ",".join(valid_types))) + raise ValidationError( + "Invalid Oid Encoding: {0} choose from {1}".format( + oid_encoding, ",".join(valid_types) + ) + ) def sub_alt_type(alt_type): @@ -60,10 +68,23 @@ def sub_alt_type(alt_type): :param alt_type: :return: """ - valid_types = ['DNSName', 'IPAddress', 'uniFormResourceIdentifier', 'directoryName', 'rfc822Name', 'registrationID', - 'otherName', 'x400Address', 'EDIPartyName'] + valid_types = [ + "DNSName", + "IPAddress", + "uniFormResourceIdentifier", + "directoryName", + "rfc822Name", + "registrationID", + "otherName", + "x400Address", + "EDIPartyName", + ] if alt_type.lower() not in [a_type.lower() for a_type in valid_types]: - raise ValidationError('Invalid SubAltName Type: {0} choose from {1}'.format(type, ",".join(valid_types))) + raise ValidationError( + "Invalid SubAltName Type: {0} choose from {1}".format( + type, ",".join(valid_types) + ) + ) def csr(data): @@ -73,16 +94,18 @@ def csr(data): :return: """ try: - request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend()) + request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend()) except Exception: - raise ValidationError('CSR presented is not valid.') + raise ValidationError("CSR presented is not valid.") # Validate common name and SubjectAltNames for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME): common_name(name.value) try: - alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName) + alt_names = request.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ) for name in alt_names.value.get_values_for_type(x509.DNSName): sensitive_domain(name) @@ -91,26 +114,40 @@ def csr(data): def dates(data): - if not data.get('validity_start') and data.get('validity_end'): - raise ValidationError('If validity start is specified so must validity end.') + if not data.get("validity_start") and data.get("validity_end"): + raise ValidationError("If validity start is specified so must validity end.") - if not data.get('validity_end') and data.get('validity_start'): - raise ValidationError('If validity end is specified so must validity start.') + if not data.get("validity_end") and data.get("validity_start"): + raise ValidationError("If validity end is specified so must validity start.") - if data.get('validity_start') and data.get('validity_end'): - if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True): - if is_weekend(data.get('validity_end')): - raise ValidationError('Validity end must not land on a weekend.') + if data.get("validity_start") and data.get("validity_end"): + if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True): + if is_weekend(data.get("validity_end")): + raise ValidationError("Validity end must not land on a weekend.") - if not data['validity_start'] < data['validity_end']: - raise ValidationError('Validity start must be before validity end.') + if not data["validity_start"] < data["validity_end"]: + raise ValidationError("Validity start must be before validity end.") - if data.get('authority'): - if data.get('validity_start').date() < data['authority'].authority_certificate.not_before.date(): - raise ValidationError('Validity start must not be before {0}'.format(data['authority'].authority_certificate.not_before)) + if data.get("authority"): + if ( + data.get("validity_start").date() + < data["authority"].authority_certificate.not_before.date() + ): + raise ValidationError( + "Validity start must not be before {0}".format( + data["authority"].authority_certificate.not_before + ) + ) - if data.get('validity_end').date() > data['authority'].authority_certificate.not_after.date(): - raise ValidationError('Validity end must not be after {0}'.format(data['authority'].authority_certificate.not_after)) + if ( + data.get("validity_end").date() + > data["authority"].authority_certificate.not_after.date() + ): + raise ValidationError( + "Validity end must not be after {0}".format( + data["authority"].authority_certificate.not_after + ) + ) return data @@ -148,8 +185,13 @@ def verify_cert_chain(certs, error_class=ValidationError): # Avoid circular import. from lemur.common import defaults - raise error_class("Incorrect chain certificate(s) provided: '%s' is not signed by '%s'" - % (defaults.common_name(cert) or 'Unknown', defaults.common_name(issuer))) + raise error_class( + "Incorrect chain certificate(s) provided: '%s' is not signed by '%s'" + % ( + defaults.common_name(cert) or "Unknown", + defaults.common_name(issuer), + ) + ) except UnsupportedAlgorithm as err: current_app.logger.warning("Skipping chain validation: %s", err) diff --git a/lemur/constants.py b/lemur/constants.py index 060ecfed..cc1653cb 100644 --- a/lemur/constants.py +++ b/lemur/constants.py @@ -7,28 +7,28 @@ SAN_NAMING_TEMPLATE = "SAN-{subject}-{issuer}-{not_before}-{not_after}" DEFAULT_NAMING_TEMPLATE = "{subject}-{issuer}-{not_before}-{not_after}" NONSTANDARD_NAMING_TEMPLATE = "{issuer}-{not_before}-{not_after}" -SUCCESS_METRIC_STATUS = 'success' -FAILURE_METRIC_STATUS = 'failure' +SUCCESS_METRIC_STATUS = "success" +FAILURE_METRIC_STATUS = "failure" CERTIFICATE_KEY_TYPES = [ - 'RSA2048', - 'RSA4096', - 'ECCPRIME192V1', - 'ECCPRIME256V1', - 'ECCSECP192R1', - 'ECCSECP224R1', - 'ECCSECP256R1', - 'ECCSECP384R1', - 'ECCSECP521R1', - 'ECCSECP256K1', - 'ECCSECT163K1', - 'ECCSECT233K1', - 'ECCSECT283K1', - 'ECCSECT409K1', - 'ECCSECT571K1', - 'ECCSECT163R2', - 'ECCSECT233R1', - 'ECCSECT283R1', - 'ECCSECT409R1', - 'ECCSECT571R2' + "RSA2048", + "RSA4096", + "ECCPRIME192V1", + "ECCPRIME256V1", + "ECCSECP192R1", + "ECCSECP224R1", + "ECCSECP256R1", + "ECCSECP384R1", + "ECCSECP521R1", + "ECCSECP256K1", + "ECCSECT163K1", + "ECCSECT233K1", + "ECCSECT283K1", + "ECCSECT409K1", + "ECCSECT571K1", + "ECCSECT163R2", + "ECCSECT233R1", + "ECCSECT283R1", + "ECCSECT409R1", + "ECCSECT571R2", ] diff --git a/lemur/database.py b/lemur/database.py index 82fb0423..a9610325 100644 --- a/lemur/database.py +++ b/lemur/database.py @@ -43,7 +43,7 @@ def session_query(model): :param model: sqlalchemy model :return: query object for model """ - return model.query if hasattr(model, 'query') else db.session.query(model) + return model.query if hasattr(model, "query") else db.session.query(model) def create_query(model, kwargs): @@ -77,7 +77,7 @@ def add(model): def get_model_column(model, field): - if field in getattr(model, 'sensitive_fields', ()): + if field in getattr(model, "sensitive_fields", ()): raise AttrNotFound(field) column = model.__table__.columns._data.get(field, None) if column is None: @@ -100,7 +100,7 @@ def find_all(query, model, kwargs): kwargs = filter_none(kwargs) for attr, value in kwargs.items(): if not isinstance(value, list): - value = value.split(',') + value = value.split(",") conditions.append(get_model_column(model, attr).in_(value)) @@ -200,7 +200,7 @@ def filter(query, model, terms): :return: """ column = get_model_column(model, underscore(terms[0])) - return query.filter(column.ilike('%{}%'.format(terms[1]))) + return query.filter(column.ilike("%{}%".format(terms[1]))) def sort(query, model, field, direction): @@ -214,7 +214,7 @@ def sort(query, model, field, direction): :param direction: """ column = get_model_column(model, underscore(field)) - return query.order_by(column.desc() if direction == 'desc' else column.asc()) + return query.order_by(column.desc() if direction == "desc" else column.asc()) def paginate(query, page, count): @@ -247,10 +247,10 @@ def update_list(model, model_attr, item_model, items): for i in items: for item in getattr(model, model_attr): - if item.id == i['id']: + if item.id == i["id"]: break else: - getattr(model, model_attr).append(get(item_model, i['id'])) + getattr(model, model_attr).append(get(item_model, i["id"])) return model @@ -276,9 +276,9 @@ def get_count(q): disable_group_by = False if len(q._entities) > 1: # currently support only one entity - raise Exception('only one entity is supported for get_count, got: %s' % q) + raise Exception("only one entity is supported for get_count, got: %s" % q) entity = q._entities[0] - if hasattr(entity, 'column'): + if hasattr(entity, "column"): # _ColumnEntity has column attr - on case: query(Model.column)... col = entity.column if q._group_by and q._distinct: @@ -295,7 +295,11 @@ def get_count(q): count_func = func.count() if q._group_by and not disable_group_by: count_func = count_func.over(None) - count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None) + count_q = ( + q.options(lazyload("*")) + .statement.with_only_columns([count_func]) + .order_by(None) + ) if disable_group_by: count_q = count_q.group_by(None) count = q.session.execute(count_q).scalar() @@ -311,13 +315,13 @@ def sort_and_page(query, model, args): :param args: :return: """ - sort_by = args.pop('sort_by') - sort_dir = args.pop('sort_dir') - page = args.pop('page') - count = args.pop('count') + sort_by = args.pop("sort_by") + sort_dir = args.pop("sort_dir") + page = args.pop("page") + count = args.pop("count") - if args.get('user'): - user = args.pop('user') + if args.get("user"): + user = args.pop("user") query = find_all(query, model, args) diff --git a/lemur/default.conf.py b/lemur/default.conf.py index 217d8371..bd67bf7a 100644 --- a/lemur/default.conf.py +++ b/lemur/default.conf.py @@ -1,6 +1,7 @@ # This is just Python which means you can inherit and tweak settings import os + _basedir = os.path.abspath(os.path.dirname(__file__)) THREADS_PER_PAGE = 8 diff --git a/lemur/defaults/views.py b/lemur/defaults/views.py index 5a573829..b3741b15 100644 --- a/lemur/defaults/views.py +++ b/lemur/defaults/views.py @@ -13,12 +13,13 @@ from lemur.auth.service import AuthenticatedResource from lemur.defaults.schemas import default_output_schema -mod = Blueprint('default', __name__) +mod = Blueprint("default", __name__) api = Api(mod) class LemurDefaults(AuthenticatedResource): """ Defines the 'defaults' endpoint """ + def __init__(self): super(LemurDefaults) @@ -59,17 +60,21 @@ class LemurDefaults(AuthenticatedResource): :statuscode 403: unauthenticated """ - default_authority = get_by_name(current_app.config.get('LEMUR_DEFAULT_AUTHORITY')) + default_authority = get_by_name( + current_app.config.get("LEMUR_DEFAULT_AUTHORITY") + ) return dict( - country=current_app.config.get('LEMUR_DEFAULT_COUNTRY'), - state=current_app.config.get('LEMUR_DEFAULT_STATE'), - location=current_app.config.get('LEMUR_DEFAULT_LOCATION'), - organization=current_app.config.get('LEMUR_DEFAULT_ORGANIZATION'), - organizational_unit=current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT'), - issuer_plugin=current_app.config.get('LEMUR_DEFAULT_ISSUER_PLUGIN'), + country=current_app.config.get("LEMUR_DEFAULT_COUNTRY"), + state=current_app.config.get("LEMUR_DEFAULT_STATE"), + location=current_app.config.get("LEMUR_DEFAULT_LOCATION"), + organization=current_app.config.get("LEMUR_DEFAULT_ORGANIZATION"), + organizational_unit=current_app.config.get( + "LEMUR_DEFAULT_ORGANIZATIONAL_UNIT" + ), + issuer_plugin=current_app.config.get("LEMUR_DEFAULT_ISSUER_PLUGIN"), authority=default_authority, ) -api.add_resource(LemurDefaults, '/defaults', endpoint='default') +api.add_resource(LemurDefaults, "/defaults", endpoint="default") diff --git a/lemur/destinations/models.py b/lemur/destinations/models.py index 192a5f5d..a2575378 100644 --- a/lemur/destinations/models.py +++ b/lemur/destinations/models.py @@ -13,7 +13,7 @@ from lemur.plugins.base import plugins class Destination(db.Model): - __tablename__ = 'destinations' + __tablename__ = "destinations" id = Column(Integer, primary_key=True) label = Column(String(32)) options = Column(JSONType) diff --git a/lemur/destinations/schemas.py b/lemur/destinations/schemas.py index 279889b4..cc46ecd4 100644 --- a/lemur/destinations/schemas.py +++ b/lemur/destinations/schemas.py @@ -30,7 +30,7 @@ class DestinationOutputSchema(LemurOutputSchema): @post_dump def fill_object(self, data): if data: - data['plugin']['pluginOptions'] = data['options'] + data["plugin"]["pluginOptions"] = data["options"] return data diff --git a/lemur/destinations/service.py b/lemur/destinations/service.py index 8e505fce..92162f4b 100644 --- a/lemur/destinations/service.py +++ b/lemur/destinations/service.py @@ -26,10 +26,12 @@ def create(label, plugin_name, options, description=None): """ # remove any sub-plugin objects before try to save the json options for option in options: - if 'plugin' in option['type']: - del option['value']['plugin_object'] + if "plugin" in option["type"]: + del option["value"]["plugin_object"] - destination = Destination(label=label, options=options, plugin_name=plugin_name, description=description) + destination = Destination( + label=label, options=options, plugin_name=plugin_name, description=description + ) current_app.logger.info("Destination: %s created", label) # add the destination as source, to avoid new destinations that are not in source, as long as an AWS destination @@ -85,7 +87,7 @@ def get_by_label(label): :param label: :return: """ - return database.get(Destination, label, field='label') + return database.get(Destination, label, field="label") def get_all(): @@ -99,17 +101,19 @@ def get_all(): def render(args): - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if certificate_id: - query = database.session_query(Destination).join(Certificate, Destination.certificate) + query = database.session_query(Destination).join( + Certificate, Destination.certificate + ) query = query.filter(Certificate.id == certificate_id) else: query = database.session_query(Destination) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Destination, terms) return database.sort_and_page(query, Destination, args) @@ -122,9 +126,15 @@ def stats(**kwargs): :param kwargs: :return: """ - items = database.db.session.query(Destination.label, func.count(certificate_destination_associations.c.certificate_id))\ - .join(certificate_destination_associations)\ - .group_by(Destination.label).all() + items = ( + database.db.session.query( + Destination.label, + func.count(certificate_destination_associations.c.certificate_id), + ) + .join(certificate_destination_associations) + .group_by(Destination.label) + .all() + ) keys = [] values = [] @@ -132,4 +142,4 @@ def stats(**kwargs): keys.append(key) values.append(count) - return {'labels': keys, 'values': values} + return {"labels": keys, "values": values} diff --git a/lemur/destinations/views.py b/lemur/destinations/views.py index 7084e8e9..0b0559fe 100644 --- a/lemur/destinations/views.py +++ b/lemur/destinations/views.py @@ -15,15 +15,20 @@ from lemur.auth.permissions import admin_permission from lemur.common.utils import paginated_parser from lemur.common.schema import validate_schema -from lemur.destinations.schemas import destinations_output_schema, destination_input_schema, destination_output_schema +from lemur.destinations.schemas import ( + destinations_output_schema, + destination_input_schema, + destination_output_schema, +) -mod = Blueprint('destinations', __name__) +mod = Blueprint("destinations", __name__) api = Api(mod) class DestinationsList(AuthenticatedResource): """ Defines the 'destinations' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(DestinationsList, self).__init__() @@ -176,7 +181,12 @@ class DestinationsList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.create(data['label'], data['plugin']['slug'], data['plugin']['plugin_options'], data['description']) + return service.create( + data["label"], + data["plugin"]["slug"], + data["plugin"]["plugin_options"], + data["description"], + ) class Destinations(AuthenticatedResource): @@ -325,16 +335,22 @@ class Destinations(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.update(destination_id, data['label'], data['plugin']['plugin_options'], data['description']) + return service.update( + destination_id, + data["label"], + data["plugin"]["plugin_options"], + data["description"], + ) @admin_permission.require(http_exception=403) def delete(self, destination_id): service.delete(destination_id) - return {'result': True} + return {"result": True} class CertificateDestinations(AuthenticatedResource): """ Defines the 'certificate/', endpoint='destination') -api.add_resource(CertificateDestinations, '/certificates//destinations', - endpoint='certificateDestinations') -api.add_resource(DestinationsStats, '/destinations/stats', endpoint='destinationStats') +api.add_resource(DestinationsList, "/destinations", endpoint="destinations") +api.add_resource( + Destinations, "/destinations/", endpoint="destination" +) +api.add_resource( + CertificateDestinations, + "/certificates//destinations", + endpoint="certificateDestinations", +) +api.add_resource(DestinationsStats, "/destinations/stats", endpoint="destinationStats") diff --git a/lemur/dns_providers/cli.py b/lemur/dns_providers/cli.py index 159bdaa0..72f9c874 100644 --- a/lemur/dns_providers/cli.py +++ b/lemur/dns_providers/cli.py @@ -5,7 +5,9 @@ from lemur.dns_providers.service import get_all_dns_providers, set_domains from lemur.extensions import metrics from lemur.plugins.base import plugins -manager = Manager(usage="Iterates through all DNS providers and sets DNS zones in the database.") +manager = Manager( + usage="Iterates through all DNS providers and sets DNS zones in the database." +) @manager.command @@ -27,5 +29,5 @@ def get_all_zones(): status = SUCCESS_METRIC_STATUS - metrics.send('get_all_zones', 'counter', 1, metric_tags={'status': status}) + metrics.send("get_all_zones", "counter", 1, metric_tags={"status": status}) print("[+] Done with dns provider zone lookup and configuration.") diff --git a/lemur/dns_providers/models.py b/lemur/dns_providers/models.py index 435a2398..eb8cdff9 100644 --- a/lemur/dns_providers/models.py +++ b/lemur/dns_providers/models.py @@ -9,22 +9,23 @@ from lemur.utils import Vault class DnsProvider(db.Model): - __tablename__ = 'dns_providers' - id = Column( - Integer(), - primary_key=True, - ) + __tablename__ = "dns_providers" + id = Column(Integer(), primary_key=True) name = Column(String(length=256), unique=True, nullable=True) description = Column(Text(), nullable=True) provider_type = Column(String(length=256), nullable=True) credentials = Column(Vault, nullable=True) api_endpoint = Column(String(length=256), nullable=True) - date_created = Column(ArrowType(), server_default=text('now()'), nullable=False) + date_created = Column(ArrowType(), server_default=text("now()"), nullable=False) status = Column(String(length=128), nullable=True) options = Column(JSON, nullable=True) domains = Column(JSON, nullable=True) - certificates = relationship("Certificate", backref='dns_provider', foreign_keys='Certificate.dns_provider_id', - lazy='dynamic') + certificates = relationship( + "Certificate", + backref="dns_provider", + foreign_keys="Certificate.dns_provider_id", + lazy="dynamic", + ) def __init__(self, name, description, provider_type, credentials): self.name = name diff --git a/lemur/dns_providers/service.py b/lemur/dns_providers/service.py index bf50bba1..ec9fa0de 100644 --- a/lemur/dns_providers/service.py +++ b/lemur/dns_providers/service.py @@ -49,7 +49,9 @@ def get_friendly(dns_provider_id): } if dns_provider.provider_type == "route53": - dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get("account_id") + dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get( + "account_id" + ) return dns_provider_friendly @@ -64,40 +66,40 @@ def delete(dns_provider_id): def get_types(): provider_config = current_app.config.get( - 'ACME_DNS_PROVIDER_TYPES', - {"items": [ - { - 'name': 'route53', - 'requirements': [ - { - 'name': 'account_id', - 'type': 'int', - 'required': True, - 'helpMessage': 'AWS Account number' - }, - ] - }, - { - 'name': 'cloudflare', - 'requirements': [ - { - 'name': 'email', - 'type': 'str', - 'required': True, - 'helpMessage': 'Cloudflare Email' - }, - { - 'name': 'key', - 'type': 'str', - 'required': True, - 'helpMessage': 'Cloudflare Key' - }, - ] - }, - { - 'name': 'dyn', - }, - ]} + "ACME_DNS_PROVIDER_TYPES", + { + "items": [ + { + "name": "route53", + "requirements": [ + { + "name": "account_id", + "type": "int", + "required": True, + "helpMessage": "AWS Account number", + } + ], + }, + { + "name": "cloudflare", + "requirements": [ + { + "name": "email", + "type": "str", + "required": True, + "helpMessage": "Cloudflare Email", + }, + { + "name": "key", + "type": "str", + "required": True, + "helpMessage": "Cloudflare Key", + }, + ], + }, + {"name": "dyn"}, + ] + }, ) if not provider_config: raise Exception("No DNS Provider configuration specified.") diff --git a/lemur/dns_providers/views.py b/lemur/dns_providers/views.py index 1f5b3164..d470aa2f 100644 --- a/lemur/dns_providers/views.py +++ b/lemur/dns_providers/views.py @@ -13,9 +13,12 @@ from lemur.auth.service import AuthenticatedResource from lemur.common.schema import validate_schema from lemur.common.utils import paginated_parser from lemur.dns_providers import service -from lemur.dns_providers.schemas import dns_provider_output_schema, dns_provider_input_schema +from lemur.dns_providers.schemas import ( + dns_provider_output_schema, + dns_provider_input_schema, +) -mod = Blueprint('dns_providers', __name__) +mod = Blueprint("dns_providers", __name__) api = Api(mod) @@ -71,12 +74,12 @@ class DnsProvidersList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('dns_provider_id', type=int, location='args') - parser.add_argument('name', type=str, location='args') - parser.add_argument('type', type=str, location='args') + parser.add_argument("dns_provider_id", type=int, location="args") + parser.add_argument("name", type=str, location="args") + parser.add_argument("type", type=str, location="args") args = parser.parse_args() - args['user'] = g.user + args["user"] = g.user return service.render(args) @validate_schema(dns_provider_input_schema, None) @@ -152,7 +155,7 @@ class DnsProviders(AuthenticatedResource): @admin_permission.require(http_exception=403) def delete(self, dns_provider_id): service.delete(dns_provider_id) - return {'result': True} + return {"result": True} class DnsProviderOptions(AuthenticatedResource): @@ -166,6 +169,10 @@ class DnsProviderOptions(AuthenticatedResource): return service.get_types() -api.add_resource(DnsProvidersList, '/dns_providers', endpoint='dns_providers') -api.add_resource(DnsProviders, '/dns_providers/', endpoint='dns_provider') -api.add_resource(DnsProviderOptions, '/dns_provider_options', endpoint='dns_provider_options') +api.add_resource(DnsProvidersList, "/dns_providers", endpoint="dns_providers") +api.add_resource( + DnsProviders, "/dns_providers/", endpoint="dns_provider" +) +api.add_resource( + DnsProviderOptions, "/dns_provider_options", endpoint="dns_provider_options" +) diff --git a/lemur/domains/models.py b/lemur/domains/models.py index 05fccd9c..791e74de 100644 --- a/lemur/domains/models.py +++ b/lemur/domains/models.py @@ -13,11 +13,14 @@ from lemur.database import db class Domain(db.Model): - __tablename__ = 'domains' + __tablename__ = "domains" __table_args__ = ( - Index('ix_domains_name_gin', "name", - postgresql_ops={"name": "gin_trgm_ops"}, - postgresql_using='gin'), + Index( + "ix_domains_name_gin", + "name", + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ), ) id = Column(Integer, primary_key=True) name = Column(String(256), index=True) diff --git a/lemur/domains/service.py b/lemur/domains/service.py index c9b8f759..8a581bfd 100644 --- a/lemur/domains/service.py +++ b/lemur/domains/service.py @@ -77,11 +77,11 @@ def render(args): :return: """ query = database.session_query(Domain) - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Domain, terms) if certificate_id: diff --git a/lemur/domains/views.py b/lemur/domains/views.py index db73f5cd..a3e0cdff 100644 --- a/lemur/domains/views.py +++ b/lemur/domains/views.py @@ -17,14 +17,19 @@ from lemur.auth.permissions import SensitiveDomainPermission from lemur.common.schema import validate_schema from lemur.common.utils import paginated_parser -from lemur.domains.schemas import domain_input_schema, domain_output_schema, domains_output_schema +from lemur.domains.schemas import ( + domain_input_schema, + domain_output_schema, + domains_output_schema, +) -mod = Blueprint('domains', __name__) +mod = Blueprint("domains", __name__) api = Api(mod) class DomainsList(AuthenticatedResource): """ Defines the 'domains' endpoint """ + def __init__(self): super(DomainsList, self).__init__() @@ -123,7 +128,7 @@ class DomainsList(AuthenticatedResource): :statuscode 200: no error :statuscode 403: unauthenticated """ - return service.create(data['name'], data['sensitive']) + return service.create(data["name"], data["sensitive"]) class Domains(AuthenticatedResource): @@ -205,13 +210,14 @@ class Domains(AuthenticatedResource): :statuscode 403: unauthenticated """ if SensitiveDomainPermission().can(): - return service.update(domain_id, data['name'], data['sensitive']) + return service.update(domain_id, data["name"], data["sensitive"]) - return dict(message='You are not authorized to modify this domain'), 403 + return dict(message="You are not authorized to modify this domain"), 403 class CertificateDomains(AuthenticatedResource): """ Defines the 'domains' endpoint """ + def __init__(self): super(CertificateDomains, self).__init__() @@ -265,10 +271,14 @@ class CertificateDomains(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['certificate_id'] = certificate_id + args["certificate_id"] = certificate_id return service.render(args) -api.add_resource(DomainsList, '/domains', endpoint='domains') -api.add_resource(Domains, '/domains/', endpoint='domain') -api.add_resource(CertificateDomains, '/certificates//domains', endpoint='certificateDomains') +api.add_resource(DomainsList, "/domains", endpoint="domains") +api.add_resource(Domains, "/domains/", endpoint="domain") +api.add_resource( + CertificateDomains, + "/certificates//domains", + endpoint="certificateDomains", +) diff --git a/lemur/endpoints/cli.py b/lemur/endpoints/cli.py index 59496930..99f8c342 100644 --- a/lemur/endpoints/cli.py +++ b/lemur/endpoints/cli.py @@ -21,7 +21,14 @@ from lemur.endpoints.models import Endpoint manager = Manager(usage="Handles all endpoint related tasks.") -@manager.option('-ttl', '--time-to-live', type=int, dest='ttl', default=2, help='Time in hours, which endpoint has not been refreshed to remove the endpoint.') +@manager.option( + "-ttl", + "--time-to-live", + type=int, + dest="ttl", + default=2, + help="Time in hours, which endpoint has not been refreshed to remove the endpoint.", +) def expire(ttl): """ Removed all endpoints that have not been recently updated. @@ -31,12 +38,18 @@ def expire(ttl): try: now = arrow.utcnow() expiration = now - timedelta(hours=ttl) - endpoints = database.session_query(Endpoint).filter(cast(Endpoint.last_updated, ArrowType) <= expiration) + endpoints = database.session_query(Endpoint).filter( + cast(Endpoint.last_updated, ArrowType) <= expiration + ) for endpoint in endpoints: - print("[!] Expiring endpoint: {name} Last Updated: {last_updated}".format(name=endpoint.name, last_updated=endpoint.last_updated)) + print( + "[!] Expiring endpoint: {name} Last Updated: {last_updated}".format( + name=endpoint.name, last_updated=endpoint.last_updated + ) + ) database.delete(endpoint) - metrics.send('endpoint_expired', 'counter', 1) + metrics.send("endpoint_expired", "counter", 1) print("[+] Finished expiration.") except Exception as e: diff --git a/lemur/endpoints/models.py b/lemur/endpoints/models.py index b5823327..6e44fe71 100644 --- a/lemur/endpoints/models.py +++ b/lemur/endpoints/models.py @@ -20,15 +20,11 @@ from lemur.database import db from lemur.models import policies_ciphers -BAD_CIPHERS = [ - 'Protocol-SSLv3', - 'Protocol-SSLv2', - 'Protocol-TLSv1' -] +BAD_CIPHERS = ["Protocol-SSLv3", "Protocol-SSLv2", "Protocol-TLSv1"] class Cipher(db.Model): - __tablename__ = 'ciphers' + __tablename__ = "ciphers" id = Column(Integer, primary_key=True) name = Column(String(128), nullable=False) @@ -38,23 +34,18 @@ class Cipher(db.Model): @deprecated.expression def deprecated(cls): - return case( - [ - (cls.name in BAD_CIPHERS, True) - ], - else_=False - ) + return case([(cls.name in BAD_CIPHERS, True)], else_=False) class Policy(db.Model): - ___tablename__ = 'policies' + ___tablename__ = "policies" id = Column(Integer, primary_key=True) name = Column(String(128), nullable=True) - ciphers = relationship('Cipher', secondary=policies_ciphers, backref='policy') + ciphers = relationship("Cipher", secondary=policies_ciphers, backref="policy") class Endpoint(db.Model): - __tablename__ = 'endpoints' + __tablename__ = "endpoints" id = Column(Integer, primary_key=True) owner = Column(String(128)) name = Column(String(128)) @@ -62,16 +53,18 @@ class Endpoint(db.Model): type = Column(String(128)) active = Column(Boolean, default=True) port = Column(Integer) - policy_id = Column(Integer, ForeignKey('policy.id')) - policy = relationship('Policy', backref='endpoint') - certificate_id = Column(Integer, ForeignKey('certificates.id')) - source_id = Column(Integer, ForeignKey('sources.id')) + policy_id = Column(Integer, ForeignKey("policy.id")) + policy = relationship("Policy", backref="endpoint") + certificate_id = Column(Integer, ForeignKey("certificates.id")) + source_id = Column(Integer, ForeignKey("sources.id")) sensitive = Column(Boolean, default=False) - source = relationship('Source', back_populates='endpoints') + source = relationship("Source", back_populates="endpoints") last_updated = Column(ArrowType, default=arrow.utcnow, nullable=False) - date_created = Column(ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False) + date_created = Column( + ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False + ) - replaced = association_proxy('certificate', 'replaced') + replaced = association_proxy("certificate", "replaced") @property def issues(self): @@ -79,13 +72,30 @@ class Endpoint(db.Model): for cipher in self.policy.ciphers: if cipher.deprecated: - issues.append({'name': 'deprecated cipher', 'value': '{0} has been deprecated consider removing it.'.format(cipher.name)}) + issues.append( + { + "name": "deprecated cipher", + "value": "{0} has been deprecated consider removing it.".format( + cipher.name + ), + } + ) if self.certificate.expired: - issues.append({'name': 'expired certificate', 'value': 'There is an expired certificate attached to this endpoint consider replacing it.'}) + issues.append( + { + "name": "expired certificate", + "value": "There is an expired certificate attached to this endpoint consider replacing it.", + } + ) if self.certificate.revoked: - issues.append({'name': 'revoked', 'value': 'There is a revoked certificate attached to this endpoint consider replacing it.'}) + issues.append( + { + "name": "revoked", + "value": "There is a revoked certificate attached to this endpoint consider replacing it.", + } + ) return issues diff --git a/lemur/endpoints/service.py b/lemur/endpoints/service.py index d14174df..2a737858 100644 --- a/lemur/endpoints/service.py +++ b/lemur/endpoints/service.py @@ -46,7 +46,7 @@ def get_by_name(name): :param name: :return: """ - return database.get(Endpoint, name, field='name') + return database.get(Endpoint, name, field="name") def get_by_dnsname(dnsname): @@ -56,7 +56,7 @@ def get_by_dnsname(dnsname): :param dnsname: :return: """ - return database.get(Endpoint, dnsname, field='dnsname') + return database.get(Endpoint, dnsname, field="dnsname") def get_by_dnsname_and_port(dnsname, port): @@ -66,7 +66,11 @@ def get_by_dnsname_and_port(dnsname, port): :param port: :return: """ - return Endpoint.query.filter(Endpoint.dnsname == dnsname).filter(Endpoint.port == port).scalar() + return ( + Endpoint.query.filter(Endpoint.dnsname == dnsname) + .filter(Endpoint.port == port) + .scalar() + ) def get_by_source(source_label): @@ -95,12 +99,14 @@ def create(**kwargs): """ endpoint = Endpoint(**kwargs) database.create(endpoint) - metrics.send('endpoint_added', 'counter', 1, metric_tags={'source': endpoint.source.label}) + metrics.send( + "endpoint_added", "counter", 1, metric_tags={"source": endpoint.source.label} + ) return endpoint def get_or_create_policy(**kwargs): - policy = database.get(Policy, kwargs['name'], field='name') + policy = database.get(Policy, kwargs["name"], field="name") if not policy: policy = Policy(**kwargs) @@ -110,7 +116,7 @@ def get_or_create_policy(**kwargs): def get_or_create_cipher(**kwargs): - cipher = database.get(Cipher, kwargs['name'], field='name') + cipher = database.get(Cipher, kwargs["name"], field="name") if not cipher: cipher = Cipher(**kwargs) @@ -122,11 +128,13 @@ def get_or_create_cipher(**kwargs): def update(endpoint_id, **kwargs): endpoint = database.get(Endpoint, endpoint_id) - endpoint.policy = kwargs['policy'] - endpoint.certificate = kwargs['certificate'] - endpoint.source = kwargs['source'] + endpoint.policy = kwargs["policy"] + endpoint.certificate = kwargs["certificate"] + endpoint.source = kwargs["source"] endpoint.last_updated = arrow.utcnow() - metrics.send('endpoint_updated', 'counter', 1, metric_tags={'source': endpoint.source.label}) + metrics.send( + "endpoint_updated", "counter", 1, metric_tags={"source": endpoint.source.label} + ) database.update(endpoint) return endpoint @@ -138,19 +146,17 @@ def render(args): :return: """ query = database.session_query(Endpoint) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') - if 'active' in filt: # this is really weird but strcmp seems to not work here?? + terms = filt.split(";") + if "active" in filt: # this is really weird but strcmp seems to not work here?? query = query.filter(Endpoint.active == truthiness(terms[1])) - elif 'port' in filt: - if terms[1] != 'null': # ng-table adds 'null' if a number is removed + elif "port" in filt: + if terms[1] != "null": # ng-table adds 'null' if a number is removed query = query.filter(Endpoint.port == terms[1]) - elif 'ciphers' in filt: - query = query.filter( - Cipher.name == terms[1] - ) + elif "ciphers" in filt: + query = query.filter(Cipher.name == terms[1]) else: query = database.filter(query, Endpoint, terms) @@ -164,7 +170,7 @@ def stats(**kwargs): :param kwargs: :return: """ - attr = getattr(Endpoint, kwargs.get('metric')) + attr = getattr(Endpoint, kwargs.get("metric")) query = database.db.session.query(attr, func.count(attr)) items = query.group_by(attr).all() @@ -175,4 +181,4 @@ def stats(**kwargs): keys.append(key) values.append(count) - return {'labels': keys, 'values': values} + return {"labels": keys, "values": values} diff --git a/lemur/endpoints/views.py b/lemur/endpoints/views.py index 6509f056..9f469a6b 100644 --- a/lemur/endpoints/views.py +++ b/lemur/endpoints/views.py @@ -16,12 +16,13 @@ from lemur.endpoints import service from lemur.endpoints.schemas import endpoint_output_schema, endpoints_output_schema -mod = Blueprint('endpoints', __name__) +mod = Blueprint("endpoints", __name__) api = Api(mod) class EndpointsList(AuthenticatedResource): """ Defines the 'endpoints' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(EndpointsList, self).__init__() @@ -63,7 +64,7 @@ class EndpointsList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['user'] = g.current_user + args["user"] = g.current_user return service.render(args) @@ -103,5 +104,5 @@ class Endpoints(AuthenticatedResource): return service.get(endpoint_id) -api.add_resource(EndpointsList, '/endpoints', endpoint='endpoints') -api.add_resource(Endpoints, '/endpoints/', endpoint='endpoint') +api.add_resource(EndpointsList, "/endpoints", endpoint="endpoints") +api.add_resource(Endpoints, "/endpoints/", endpoint="endpoint") diff --git a/lemur/exceptions.py b/lemur/exceptions.py index d392fe5d..98e216bb 100644 --- a/lemur/exceptions.py +++ b/lemur/exceptions.py @@ -21,7 +21,9 @@ class DuplicateError(LemurException): class InvalidListener(LemurException): def __str__(self): - return repr("Invalid listener, ensure you select a certificate if you are using a secure protocol") + return repr( + "Invalid listener, ensure you select a certificate if you are using a secure protocol" + ) class AttrNotFound(LemurException): diff --git a/lemur/extensions.py b/lemur/extensions.py index a54df6c7..24c4c814 100644 --- a/lemur/extensions.py +++ b/lemur/extensions.py @@ -15,25 +15,33 @@ class SQLAlchemy(SA): db = SQLAlchemy() from flask_migrate import Migrate + migrate = Migrate() from flask_bcrypt import Bcrypt + bcrypt = Bcrypt() from flask_principal import Principal + principal = Principal(use_sessions=False) from flask_mail import Mail + smtp_mail = Mail() from lemur.metrics import Metrics + metrics = Metrics() from raven.contrib.flask import Sentry + sentry = Sentry() from blinker import Namespace + signals = Namespace() from flask_cors import CORS + cors = CORS() diff --git a/lemur/factory.py b/lemur/factory.py index c2719e9b..b4066e78 100644 --- a/lemur/factory.py +++ b/lemur/factory.py @@ -24,9 +24,7 @@ from lemur.common.health import mod as health from lemur.extensions import db, migrate, principal, smtp_mail, metrics, sentry, cors -DEFAULT_BLUEPRINTS = ( - health, -) +DEFAULT_BLUEPRINTS = (health,) API_VERSION = 1 @@ -71,16 +69,20 @@ def from_file(file_path, silent=False): :param file_path: :param silent: """ - d = imp.new_module('config') + d = imp.new_module("config") d.__file__ = file_path try: with open(file_path) as config_file: - exec(compile(config_file.read(), # nosec: config file safe - file_path, 'exec'), d.__dict__) + exec( + compile( + config_file.read(), file_path, "exec" # nosec: config file safe + ), + d.__dict__, + ) except IOError as e: if silent and e.errno in (errno.ENOENT, errno.EISDIR): return False - e.strerror = 'Unable to load configuration file (%s)' % e.strerror + e.strerror = "Unable to load configuration file (%s)" % e.strerror raise return d @@ -94,8 +96,8 @@ def configure_app(app, config=None): :return: """ # respect the config first - if config and config != 'None': - app.config['CONFIG_PATH'] = config + if config and config != "None": + app.config["CONFIG_PATH"] = config app.config.from_object(from_file(config)) else: try: @@ -103,12 +105,21 @@ def configure_app(app, config=None): except RuntimeError: # look in default paths if os.path.isfile(os.path.expanduser("~/.lemur/lemur.conf.py")): - app.config.from_object(from_file(os.path.expanduser("~/.lemur/lemur.conf.py"))) + app.config.from_object( + from_file(os.path.expanduser("~/.lemur/lemur.conf.py")) + ) else: - app.config.from_object(from_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'default.conf.py'))) + app.config.from_object( + from_file( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "default.conf.py", + ) + ) + ) # we don't use this - app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False def configure_extensions(app): @@ -125,9 +136,15 @@ def configure_extensions(app): metrics.init_app(app) sentry.init_app(app) - if app.config['CORS']: - app.config['CORS_HEADERS'] = 'Content-Type' - cors.init_app(app, resources=r'/api/*', headers='Content-Type', origin='*', supports_credentials=True) + if app.config["CORS"]: + app.config["CORS_HEADERS"] = "Content-Type" + cors.init_app( + app, + resources=r"/api/*", + headers="Content-Type", + origin="*", + supports_credentials=True, + ) def configure_blueprints(app, blueprints): @@ -148,22 +165,25 @@ def configure_logging(app): :param app: """ - handler = RotatingFileHandler(app.config.get('LOG_FILE', 'lemur.log'), maxBytes=10000000, backupCount=100) + handler = RotatingFileHandler( + app.config.get("LOG_FILE", "lemur.log"), maxBytes=10000000, backupCount=100 + ) - handler.setFormatter(Formatter( - '%(asctime)s %(levelname)s: %(message)s ' - '[in %(pathname)s:%(lineno)d]' - )) + handler.setFormatter( + Formatter( + "%(asctime)s %(levelname)s: %(message)s " "[in %(pathname)s:%(lineno)d]" + ) + ) - handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) - app.logger.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) + handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG")) + app.logger.setLevel(app.config.get("LOG_LEVEL", "DEBUG")) app.logger.addHandler(handler) stream_handler = StreamHandler() - stream_handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) + stream_handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG")) app.logger.addHandler(stream_handler) - if app.config.get('DEBUG_DUMP', False): + if app.config.get("DEBUG_DUMP", False): activate_debug_dump() @@ -176,17 +196,21 @@ def install_plugins(app): """ from lemur.plugins import plugins from lemur.plugins.base import register + # entry_points={ # 'lemur.plugins': [ # 'verisign = lemur_verisign.plugin:VerisignPlugin' # ], # }, - for ep in pkg_resources.iter_entry_points('lemur.plugins'): + for ep in pkg_resources.iter_entry_points("lemur.plugins"): try: plugin = ep.load() except Exception: import traceback - app.logger.error("Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc())) + + app.logger.error( + "Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc()) + ) else: register(plugin) @@ -196,6 +220,9 @@ def install_plugins(app): try: plugins.get(slug) except KeyError: - raise Exception("Unable to location notification plugin: {slug}. Ensure that " - "LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin." - .format(slug=slug)) + raise Exception( + "Unable to location notification plugin: {slug}. Ensure that " + "LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin.".format( + slug=slug + ) + ) diff --git a/lemur/logs/models.py b/lemur/logs/models.py index 9f982c24..07a2ded3 100644 --- a/lemur/logs/models.py +++ b/lemur/logs/models.py @@ -15,9 +15,19 @@ from lemur.database import db class Log(db.Model): - __tablename__ = 'logs' + __tablename__ = "logs" id = Column(Integer, primary_key=True) - certificate_id = Column(Integer, ForeignKey('certificates.id')) - log_type = Column(Enum('key_view', 'create_cert', 'update_cert', 'revoke_cert', 'delete_cert', name='log_type'), nullable=False) + certificate_id = Column(Integer, ForeignKey("certificates.id")) + log_type = Column( + Enum( + "key_view", + "create_cert", + "update_cert", + "revoke_cert", + "delete_cert", + name="log_type", + ), + nullable=False, + ) logged_at = Column(ArrowType(), PassiveDefault(func.now()), nullable=False) - user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) diff --git a/lemur/logs/service.py b/lemur/logs/service.py index 04355938..f4949911 100644 --- a/lemur/logs/service.py +++ b/lemur/logs/service.py @@ -24,7 +24,11 @@ def create(user, type, certificate=None): :param certificate: :return: """ - current_app.logger.info("[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format(type, user.email, certificate.name)) + current_app.logger.info( + "[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format( + type, user.email, certificate.name + ) + ) view = Log(user_id=user.id, log_type=type, certificate_id=certificate.id) database.add(view) database.commit() @@ -50,20 +54,22 @@ def render(args): """ query = database.session_query(Log) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') + terms = filt.split(";") - if 'certificate.name' in terms: - sub_query = database.session_query(Certificate.id)\ - .filter(Certificate.name.ilike('%{0}%'.format(terms[1]))) + if "certificate.name" in terms: + sub_query = database.session_query(Certificate.id).filter( + Certificate.name.ilike("%{0}%".format(terms[1])) + ) query = query.filter(Log.certificate_id.in_(sub_query)) - elif 'user.email' in terms: - sub_query = database.session_query(User.id)\ - .filter(User.email.ilike('%{0}%'.format(terms[1]))) + elif "user.email" in terms: + sub_query = database.session_query(User.id).filter( + User.email.ilike("%{0}%".format(terms[1])) + ) query = query.filter(Log.user_id.in_(sub_query)) diff --git a/lemur/logs/views.py b/lemur/logs/views.py index 1e0bd184..57c588ed 100644 --- a/lemur/logs/views.py +++ b/lemur/logs/views.py @@ -17,12 +17,13 @@ from lemur.logs.schemas import logs_output_schema from lemur.logs import service -mod = Blueprint('logs', __name__) +mod = Blueprint("logs", __name__) api = Api(mod) class LogsList(AuthenticatedResource): """ Defines the 'logs' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(LogsList, self).__init__() @@ -65,10 +66,10 @@ class LogsList(AuthenticatedResource): :statuscode 200: no error """ parser = paginated_parser.copy() - parser.add_argument('owner', type=str, location='args') - parser.add_argument('id', type=str, location='args') + parser.add_argument("owner", type=str, location="args") + parser.add_argument("id", type=str, location="args") args = parser.parse_args() return service.render(args) -api.add_resource(LogsList, '/logs', endpoint='logs') +api.add_resource(LogsList, "/logs", endpoint="logs") diff --git a/lemur/manage.py b/lemur/manage.py index c9ce4240..e6e85a9d 100755 --- a/lemur/manage.py +++ b/lemur/manage.py @@ -1,4 +1,4 @@ -from __future__ import unicode_literals # at top of module +from __future__ import unicode_literals # at top of module import os import sys @@ -52,24 +52,24 @@ from lemur.dns_providers.models import DnsProvider # noqa from sqlalchemy.sql import text manager = Manager(create_app) -manager.add_option('-c', '--config', dest='config_path', required=False) +manager.add_option("-c", "--config", dest="config_path", required=False) migrate = Migrate(create_app) REQUIRED_VARIABLES = [ - 'LEMUR_SECURITY_TEAM_EMAIL', - 'LEMUR_DEFAULT_ORGANIZATIONAL_UNIT', - 'LEMUR_DEFAULT_ORGANIZATION', - 'LEMUR_DEFAULT_LOCATION', - 'LEMUR_DEFAULT_COUNTRY', - 'LEMUR_DEFAULT_STATE', - 'SQLALCHEMY_DATABASE_URI' + "LEMUR_SECURITY_TEAM_EMAIL", + "LEMUR_DEFAULT_ORGANIZATIONAL_UNIT", + "LEMUR_DEFAULT_ORGANIZATION", + "LEMUR_DEFAULT_LOCATION", + "LEMUR_DEFAULT_COUNTRY", + "LEMUR_DEFAULT_STATE", + "SQLALCHEMY_DATABASE_URI", ] KEY_LENGTH = 40 -DEFAULT_CONFIG_PATH = '~/.lemur/lemur.conf.py' -DEFAULT_SETTINGS = 'lemur.conf.server' -SETTINGS_ENVVAR = 'LEMUR_CONF' +DEFAULT_CONFIG_PATH = "~/.lemur/lemur.conf.py" +DEFAULT_SETTINGS = "lemur.conf.server" +SETTINGS_ENVVAR = "LEMUR_CONF" CONFIG_TEMPLATE = """ # This is just Python which means you can inherit and tweak settings @@ -144,9 +144,9 @@ SQLALCHEMY_DATABASE_URI = 'postgresql://lemur:lemur@localhost:5432/lemur' @MigrateCommand.command def create(): - database.db.engine.execute(text('CREATE EXTENSION IF NOT EXISTS pg_trgm')) + database.db.engine.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) database.db.create_all() - stamp(revision='head') + stamp(revision="head") @MigrateCommand.command @@ -174,9 +174,9 @@ def generate_settings(): output = CONFIG_TEMPLATE.format( # we use Fernet.generate_key to make sure that the key length is # compatible with Fernet - encryption_key=Fernet.generate_key().decode('utf-8'), - secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'), - flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'), + encryption_key=Fernet.generate_key().decode("utf-8"), + secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"), + flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"), ) return output @@ -190,39 +190,44 @@ class InitializeApp(Command): Additionally a Lemur user will be created as a default user and be used when certificates are discovered by Lemur. """ - option_list = ( - Option('-p', '--password', dest='password'), - ) + + option_list = (Option("-p", "--password", dest="password"),) def run(self, password): create() user = user_service.get_by_username("lemur") - admin_role = role_service.get_by_name('admin') + admin_role = role_service.get_by_name("admin") if admin_role: sys.stdout.write("[-] Admin role already created, skipping...!\n") else: # we create an admin role - admin_role = role_service.create('admin', description='This is the Lemur administrator role.') + admin_role = role_service.create( + "admin", description="This is the Lemur administrator role." + ) sys.stdout.write("[+] Created 'admin' role\n") - operator_role = role_service.get_by_name('operator') + operator_role = role_service.get_by_name("operator") if operator_role: sys.stdout.write("[-] Operator role already created, skipping...!\n") else: # we create an operator role - operator_role = role_service.create('operator', description='This is the Lemur operator role.') + operator_role = role_service.create( + "operator", description="This is the Lemur operator role." + ) sys.stdout.write("[+] Created 'operator' role\n") - read_only_role = role_service.get_by_name('read-only') + read_only_role = role_service.get_by_name("read-only") if read_only_role: sys.stdout.write("[-] Read only role already created, skipping...!\n") else: # we create an read only role - read_only_role = role_service.create('read-only', description='This is the Lemur read only role.') + read_only_role = role_service.create( + "read-only", description="This is the Lemur read only role." + ) sys.stdout.write("[+] Created 'read-only' role\n") if not user: @@ -235,34 +240,54 @@ class InitializeApp(Command): sys.stderr.write("[!] Passwords do not match!\n") sys.exit(1) - user_service.create("lemur", password, 'lemur@nobody.com', True, None, [admin_role]) - sys.stdout.write("[+] Created the user 'lemur' and granted it the 'admin' role!\n") + user_service.create( + "lemur", password, "lemur@nobody.com", True, None, [admin_role] + ) + sys.stdout.write( + "[+] Created the user 'lemur' and granted it the 'admin' role!\n" + ) else: - sys.stdout.write("[-] Default user has already been created, skipping...!\n") + sys.stdout.write( + "[-] Default user has already been created, skipping...!\n" + ) - intervals = current_app.config.get("LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", []) + intervals = current_app.config.get( + "LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", [] + ) sys.stdout.write( "[!] Creating {num} notifications for {intervals} days as specified by LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS\n".format( - num=len(intervals), - intervals=",".join([str(x) for x in intervals]) + num=len(intervals), intervals=",".join([str(x) for x in intervals]) ) ) - recipients = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') + recipients = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL") sys.stdout.write("[+] Creating expiration email notifications!\n") - sys.stdout.write("[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format(recipients)) - notification_service.create_default_expiration_notifications("DEFAULT_SECURITY", recipients=recipients) + sys.stdout.write( + "[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format( + recipients + ) + ) + notification_service.create_default_expiration_notifications( + "DEFAULT_SECURITY", recipients=recipients + ) - _DEFAULT_ROTATION_INTERVAL = 'default' - default_rotation_interval = policy_service.get_by_name(_DEFAULT_ROTATION_INTERVAL) + _DEFAULT_ROTATION_INTERVAL = "default" + default_rotation_interval = policy_service.get_by_name( + _DEFAULT_ROTATION_INTERVAL + ) if default_rotation_interval: - sys.stdout.write("[-] Default rotation interval policy already created, skipping...!\n") + sys.stdout.write( + "[-] Default rotation interval policy already created, skipping...!\n" + ) else: days = current_app.config.get("LEMUR_DEFAULT_ROTATION_INTERVAL", 30) - sys.stdout.write("[+] Creating default certificate rotation policy of {days} days before issuance.\n".format( - days=days)) + sys.stdout.write( + "[+] Creating default certificate rotation policy of {days} days before issuance.\n".format( + days=days + ) + ) policy_service.create(days=days, name=_DEFAULT_ROTATION_INTERVAL) sys.stdout.write("[/] Done!\n") @@ -272,12 +297,13 @@ class CreateUser(Command): """ This command allows for the creation of a new user within Lemur. """ + option_list = ( - Option('-u', '--username', dest='username', required=True), - Option('-e', '--email', dest='email', required=True), - Option('-a', '--active', dest='active', default=True), - Option('-r', '--roles', dest='roles', action='append', default=[]), - Option('-p', '--password', dest='password', default=None) + Option("-u", "--username", dest="username", required=True), + Option("-e", "--email", dest="email", required=True), + Option("-a", "--active", dest="active", default=True), + Option("-r", "--roles", dest="roles", action="append", default=[]), + Option("-p", "--password", dest="password", default=None), ) def run(self, username, email, active, roles, password): @@ -307,9 +333,8 @@ class ResetPassword(Command): """ This command allows you to reset a user's password. """ - option_list = ( - Option('-u', '--username', dest='username', required=True), - ) + + option_list = (Option("-u", "--username", dest="username", required=True),) def run(self, username): user = user_service.get_by_username(username) @@ -335,10 +360,11 @@ class CreateRole(Command): """ This command allows for the creation of a new role within Lemur """ + option_list = ( - Option('-n', '--name', dest='name', required=True), - Option('-u', '--users', dest='users', default=[]), - Option('-d', '--description', dest='description', required=True) + Option("-n", "--name", dest="name", required=True), + Option("-u", "--users", dest="users", default=[]), + Option("-d", "--description", dest="description", required=True), ) def run(self, name, users, description): @@ -369,7 +395,8 @@ class LemurServer(Command): Will start gunicorn with 4 workers bound to 127.0.0.0:8002 """ - description = 'Run the app within Gunicorn' + + description = "Run the app within Gunicorn" def get_options(self): settings = make_settings() @@ -377,8 +404,10 @@ class LemurServer(Command): for setting, klass in settings.items(): if klass.cli: if klass.action: - if klass.action == 'store_const': - options.append(Option(*klass.cli, const=klass.const, action=klass.action)) + if klass.action == "store_const": + options.append( + Option(*klass.cli, const=klass.const, action=klass.action) + ) else: options.append(Option(*klass.cli, action=klass.action)) else: @@ -394,7 +423,9 @@ class LemurServer(Command): # run startup tasks on an app like object validate_conf(current_app, REQUIRED_VARIABLES) - app.app_uri = 'lemur:create_app(config_path="{0}")'.format(current_app.config.get('CONFIG_PATH')) + app.app_uri = 'lemur:create_app(config_path="{0}")'.format( + current_app.config.get("CONFIG_PATH") + ) return app.run() @@ -414,7 +445,7 @@ def create_config(config_path=None): os.makedirs(dir) config = generate_settings() - with open(config_path, 'w') as f: + with open(config_path, "w") as f: f.write(config) sys.stdout.write("[+] Created a new configuration file {0}\n".format(config_path)) @@ -436,7 +467,7 @@ def lock(path=None): :param: path """ if not path: - path = os.path.expanduser('~/.lemur/keys') + path = os.path.expanduser("~/.lemur/keys") dest_dir = os.path.join(path, "encrypted") sys.stdout.write("[!] Generating a new key...\n") @@ -447,15 +478,17 @@ def lock(path=None): sys.stdout.write("[+] Creating encryption directory: {0}\n".format(dest_dir)) os.makedirs(dest_dir) - for root, dirs, files in os.walk(os.path.join(path, 'decrypted')): + for root, dirs, files in os.walk(os.path.join(path, "decrypted")): for f in files: source = os.path.join(root, f) dest = os.path.join(dest_dir, f + ".enc") - with open(source, 'rb') as in_file, open(dest, 'wb') as out_file: + with open(source, "rb") as in_file, open(dest, "wb") as out_file: f = Fernet(key) data = f.encrypt(in_file.read()) out_file.write(data) - sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source)) + sys.stdout.write( + "[+] Writing file: {0} Source: {1}\n".format(dest, source) + ) sys.stdout.write("[+] Keys have been encrypted with key {0}\n".format(key)) @@ -475,7 +508,7 @@ def unlock(path=None): key = prompt_pass("[!] Please enter the encryption password") if not path: - path = os.path.expanduser('~/.lemur/keys') + path = os.path.expanduser("~/.lemur/keys") dest_dir = os.path.join(path, "decrypted") source_dir = os.path.join(path, "encrypted") @@ -488,11 +521,13 @@ def unlock(path=None): for f in files: source = os.path.join(source_dir, f) dest = os.path.join(dest_dir, ".".join(f.split(".")[:-1])) - with open(source, 'rb') as in_file, open(dest, 'wb') as out_file: + with open(source, "rb") as in_file, open(dest, "wb") as out_file: f = Fernet(key) data = f.decrypt(in_file.read()) out_file.write(data) - sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source)) + sys.stdout.write( + "[+] Writing file: {0} Source: {1}\n".format(dest, source) + ) sys.stdout.write("[+] Keys have been unencrypted!\n") @@ -505,15 +540,16 @@ def publish_verisign_units(): :return: """ from lemur.plugins import plugins - v = plugins.get('verisign-issuer') + + v = plugins.get("verisign-issuer") units = v.get_available_units() metrics = {} for item in units: - if item['@type'] in metrics.keys(): - metrics[item['@type']] += int(item['@remaining']) + if item["@type"] in metrics.keys(): + metrics[item["@type"]] += int(item["@remaining"]) else: - metrics.update({item['@type']: int(item['@remaining'])}) + metrics.update({item["@type"]: int(item["@remaining"])}) for name, value in metrics.items(): metric = [ @@ -522,16 +558,16 @@ def publish_verisign_units(): "type": "GAUGE", "name": "Symantec {0} Unit Count".format(name), "tags": {}, - "value": value + "value": value, } ] - requests.post('http://localhost:8078/metrics', data=json.dumps(metric)) + requests.post("http://localhost:8078/metrics", data=json.dumps(metric)) def main(): manager.add_command("start", LemurServer()) - manager.add_command("runserver", Server(host='127.0.0.1', threaded=True)) + manager.add_command("runserver", Server(host="127.0.0.1", threaded=True)) manager.add_command("clean", Clean()) manager.add_command("show_urls", ShowUrls()) manager.add_command("db", MigrateCommand) diff --git a/lemur/metrics.py b/lemur/metrics.py index 381dc605..52f8c25b 100644 --- a/lemur/metrics.py +++ b/lemur/metrics.py @@ -11,6 +11,7 @@ class Metrics(object): """ :param app: The Flask application object. Defaults to None. """ + _providers = [] def __init__(self, app=None): @@ -22,11 +23,14 @@ class Metrics(object): :param app: The Flask application object. """ - self._providers = app.config.get('METRIC_PROVIDERS', []) + self._providers = app.config.get("METRIC_PROVIDERS", []) def send(self, metric_name, metric_type, metric_value, *args, **kwargs): for provider in self._providers: current_app.logger.debug( - "Sending metric '{metric}' to the {provider} provider.".format(metric=metric_name, provider=provider)) + "Sending metric '{metric}' to the {provider} provider.".format( + metric=metric_name, provider=provider + ) + ) p = plugins.get(provider) p.submit(metric_name, metric_type, metric_value, *args, **kwargs) diff --git a/lemur/migrations/env.py b/lemur/migrations/env.py index 63425041..008a9952 100644 --- a/lemur/migrations/env.py +++ b/lemur/migrations/env.py @@ -19,8 +19,11 @@ fileConfig(config.config_file_name) # from myapp import mymodel # target_metadata = mymodel.Base.metadata from flask import current_app -config.set_main_option('sqlalchemy.url', current_app.config.get('SQLALCHEMY_DATABASE_URI')) -target_metadata = current_app.extensions['migrate'].db.metadata + +config.set_main_option( + "sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI") +) +target_metadata = current_app.extensions["migrate"].db.metadata # other values from the config, defined by the needs of env.py, # can be acquired: @@ -54,14 +57,18 @@ def run_migrations_online(): and associate a connection with the context. """ - engine = engine_from_config(config.get_section(config.config_ini_section), - prefix='sqlalchemy.', - poolclass=pool.NullPool) + engine = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) connection = engine.connect() - context.configure(connection=connection, - target_metadata=target_metadata, - **current_app.extensions['migrate'].configure_args) + context.configure( + connection=connection, + target_metadata=target_metadata, + **current_app.extensions["migrate"].configure_args + ) try: with context.begin_transaction(): @@ -69,8 +76,8 @@ def run_migrations_online(): finally: connection.close() + if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() - diff --git a/lemur/migrations/versions/131ec6accff5_.py b/lemur/migrations/versions/131ec6accff5_.py index bddc5fe2..d5b42462 100644 --- a/lemur/migrations/versions/131ec6accff5_.py +++ b/lemur/migrations/versions/131ec6accff5_.py @@ -7,8 +7,8 @@ Create Date: 2016-12-07 17:29:42.049986 """ # revision identifiers, used by Alembic. -revision = '131ec6accff5' -down_revision = 'e3691fc396e9' +revision = "131ec6accff5" +down_revision = "e3691fc396e9" from alembic import op import sqlalchemy as sa @@ -16,13 +16,24 @@ import sqlalchemy as sa def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('certificates', sa.Column('rotation', sa.Boolean(), nullable=False, server_default=sa.false())) - op.add_column('endpoints', sa.Column('last_updated', sa.DateTime(), server_default=sa.text('now()'), nullable=False)) + op.add_column( + "certificates", + sa.Column("rotation", sa.Boolean(), nullable=False, server_default=sa.false()), + ) + op.add_column( + "endpoints", + sa.Column( + "last_updated", + sa.DateTime(), + server_default=sa.text("now()"), + nullable=False, + ), + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('endpoints', 'last_updated') - op.drop_column('certificates', 'rotation') + op.drop_column("endpoints", "last_updated") + op.drop_column("certificates", "rotation") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/1ae8e3104db8_.py b/lemur/migrations/versions/1ae8e3104db8_.py index 3cb3bb9e..9e19f0e7 100644 --- a/lemur/migrations/versions/1ae8e3104db8_.py +++ b/lemur/migrations/versions/1ae8e3104db8_.py @@ -7,15 +7,19 @@ Create Date: 2017-07-13 12:32:09.162800 """ # revision identifiers, used by Alembic. -revision = '1ae8e3104db8' -down_revision = 'a02a678ddc25' +revision = "1ae8e3104db8" +down_revision = "a02a678ddc25" from alembic import op def upgrade(): - op.sync_enum_values('public', 'log_type', ['key_view'], ['create_cert', 'key_view', 'update_cert']) + op.sync_enum_values( + "public", "log_type", ["key_view"], ["create_cert", "key_view", "update_cert"] + ) def downgrade(): - op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['key_view']) + op.sync_enum_values( + "public", "log_type", ["create_cert", "key_view", "update_cert"], ["key_view"] + ) diff --git a/lemur/migrations/versions/1db4f82bc780_.py b/lemur/migrations/versions/1db4f82bc780_.py index 2d917e2e..e6fb47f0 100644 --- a/lemur/migrations/versions/1db4f82bc780_.py +++ b/lemur/migrations/versions/1db4f82bc780_.py @@ -7,8 +7,8 @@ Create Date: 2018-08-03 12:56:44.565230 """ # revision identifiers, used by Alembic. -revision = '1db4f82bc780' -down_revision = '3adfdd6598df' +revision = "1db4f82bc780" +down_revision = "3adfdd6598df" import logging @@ -20,12 +20,14 @@ log = logging.getLogger(__name__) def upgrade(): connection = op.get_bind() - result = connection.execute("""\ + result = connection.execute( + """\ UPDATE certificates SET rotation_policy_id=(SELECT id FROM rotation_policies WHERE name='default') WHERE rotation_policy_id IS NULL RETURNING id - """) + """ + ) log.info("Filled rotation_policy for %d certificates" % result.rowcount) diff --git a/lemur/migrations/versions/29d8c8455c86_.py b/lemur/migrations/versions/29d8c8455c86_.py index f0b4749f..3a0e8717 100644 --- a/lemur/migrations/versions/29d8c8455c86_.py +++ b/lemur/migrations/versions/29d8c8455c86_.py @@ -7,8 +7,8 @@ Create Date: 2016-06-28 16:05:25.720213 """ # revision identifiers, used by Alembic. -revision = '29d8c8455c86' -down_revision = '3307381f3b88' +revision = "29d8c8455c86" +down_revision = "3307381f3b88" from alembic import op import sqlalchemy as sa @@ -17,46 +17,60 @@ from sqlalchemy.dialects import postgresql def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('ciphers', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=128), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "ciphers", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=128), nullable=False), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('policy', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=128), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "policy", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=128), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('policies_ciphers', - sa.Column('cipher_id', sa.Integer(), nullable=True), - sa.Column('policy_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['cipher_id'], ['ciphers.id'], ), - sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], ) + op.create_table( + "policies_ciphers", + sa.Column("cipher_id", sa.Integer(), nullable=True), + sa.Column("policy_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["cipher_id"], ["ciphers.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]), ) - op.create_index('policies_ciphers_ix', 'policies_ciphers', ['cipher_id', 'policy_id'], unique=False) - op.create_table('endpoints', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('owner', sa.String(length=128), nullable=True), - sa.Column('name', sa.String(length=128), nullable=True), - sa.Column('dnsname', sa.String(length=256), nullable=True), - sa.Column('type', sa.String(length=128), nullable=True), - sa.Column('active', sa.Boolean(), nullable=True), - sa.Column('port', sa.Integer(), nullable=True), - sa.Column('date_created', sa.DateTime(), server_default=sa.text(u'now()'), nullable=False), - sa.Column('policy_id', sa.Integer(), nullable=True), - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), - sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + "policies_ciphers_ix", + "policies_ciphers", + ["cipher_id", "policy_id"], + unique=False, + ) + op.create_table( + "endpoints", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("owner", sa.String(length=128), nullable=True), + sa.Column("name", sa.String(length=128), nullable=True), + sa.Column("dnsname", sa.String(length=256), nullable=True), + sa.Column("type", sa.String(length=128), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.Column("port", sa.Integer(), nullable=True), + sa.Column( + "date_created", + sa.DateTime(), + server_default=sa.text(u"now()"), + nullable=False, + ), + sa.Column("policy_id", sa.Integer(), nullable=True), + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]), + sa.PrimaryKeyConstraint("id"), ) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_table('endpoints') - op.drop_index('policies_ciphers_ix', table_name='policies_ciphers') - op.drop_table('policies_ciphers') - op.drop_table('policy') - op.drop_table('ciphers') + op.drop_table("endpoints") + op.drop_index("policies_ciphers_ix", table_name="policies_ciphers") + op.drop_table("policies_ciphers") + op.drop_table("policy") + op.drop_table("ciphers") ### end Alembic commands ### diff --git a/lemur/migrations/versions/318b66568358_.py b/lemur/migrations/versions/318b66568358_.py index 9d4aa48d..8578cd78 100644 --- a/lemur/migrations/versions/318b66568358_.py +++ b/lemur/migrations/versions/318b66568358_.py @@ -7,8 +7,8 @@ Create Date: 2019-02-05 15:42:25.477587 """ # revision identifiers, used by Alembic. -revision = '318b66568358' -down_revision = '9f79024fe67b' +revision = "318b66568358" +down_revision = "9f79024fe67b" from alembic import op @@ -16,7 +16,7 @@ from alembic import op def upgrade(): connection = op.get_bind() # Delete duplicate entries - connection.execute('UPDATE certificates SET deleted = false WHERE deleted IS NULL') + connection.execute("UPDATE certificates SET deleted = false WHERE deleted IS NULL") def downgrade(): diff --git a/lemur/migrations/versions/3307381f3b88_.py b/lemur/migrations/versions/3307381f3b88_.py index e4da96a6..2af0448b 100644 --- a/lemur/migrations/versions/3307381f3b88_.py +++ b/lemur/migrations/versions/3307381f3b88_.py @@ -12,8 +12,8 @@ Create Date: 2016-05-20 17:33:04.360687 """ # revision identifiers, used by Alembic. -revision = '3307381f3b88' -down_revision = '412b22cb656a' +revision = "3307381f3b88" +down_revision = "412b22cb656a" from alembic import op import sqlalchemy as sa @@ -23,109 +23,165 @@ from sqlalchemy.dialects import postgresql def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.alter_column('authorities', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) - op.drop_column('authorities', 'not_after') - op.drop_column('authorities', 'bits') - op.drop_column('authorities', 'cn') - op.drop_column('authorities', 'not_before') - op.add_column('certificates', sa.Column('root_authority_id', sa.Integer(), nullable=True)) - op.alter_column('certificates', 'body', - existing_type=sa.TEXT(), - nullable=False) - op.alter_column('certificates', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) - op.drop_constraint(u'certificates_authority_id_fkey', 'certificates', type_='foreignkey') - op.create_foreign_key(None, 'certificates', 'authorities', ['authority_id'], ['id'], ondelete='CASCADE') - op.create_foreign_key(None, 'certificates', 'authorities', ['root_authority_id'], ['id'], ondelete='CASCADE') + op.alter_column( + "authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) + op.drop_column("authorities", "not_after") + op.drop_column("authorities", "bits") + op.drop_column("authorities", "cn") + op.drop_column("authorities", "not_before") + op.add_column( + "certificates", sa.Column("root_authority_id", sa.Integer(), nullable=True) + ) + op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=False) + op.alter_column( + "certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) + op.drop_constraint( + u"certificates_authority_id_fkey", "certificates", type_="foreignkey" + ) + op.create_foreign_key( + None, + "certificates", + "authorities", + ["authority_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + None, + "certificates", + "authorities", + ["root_authority_id"], + ["id"], + ondelete="CASCADE", + ) ### end Alembic commands ### # link existing certificate to their authority certificates conn = op.get_bind() - for id, body, owner in conn.execute(text('select id, body, owner from authorities')): + for id, body, owner in conn.execute( + text("select id, body, owner from authorities") + ): if not owner: owner = "lemur@nobody" # look up certificate by body, if duplications are found, pick one - stmt = text('select id from certificates where body=:body') + stmt = text("select id from certificates where body=:body") stmt = stmt.bindparams(body=body) root_certificate = conn.execute(stmt).fetchone() if root_certificate: - stmt = text('update certificates set root_authority_id=:root_authority_id where id=:id') + stmt = text( + "update certificates set root_authority_id=:root_authority_id where id=:id" + ) stmt = stmt.bindparams(root_authority_id=id, id=root_certificate[0]) op.execute(stmt) # link owner roles to their authorities - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() if not owner_role: - stmt = text('insert into roles (name, description) values (:name, :description)') - stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.') + stmt = text( + "insert into roles (name, description) values (:name, :description)" + ) + stmt = stmt.bindparams( + name=owner, description="Lemur generated role or existing owner." + ) op.execute(stmt) - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() - stmt = text('select * from roles_authorities where role_id=:role_id and authority_id=:authority_id') + stmt = text( + "select * from roles_authorities where role_id=:role_id and authority_id=:authority_id" + ) stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id) exists = conn.execute(stmt).fetchone() if not exists: - stmt = text('insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)') + stmt = text( + "insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)" + ) stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id) op.execute(stmt) # link owner roles to their certificates - for id, owner in conn.execute(text('select id, owner from certificates')): + for id, owner in conn.execute(text("select id, owner from certificates")): if not owner: owner = "lemur@nobody" - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() if not owner_role: - stmt = text('insert into roles (name, description) values (:name, :description)') - stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.') + stmt = text( + "insert into roles (name, description) values (:name, :description)" + ) + stmt = stmt.bindparams( + name=owner, description="Lemur generated role or existing owner." + ) op.execute(stmt) # link owner roles to their authorities - stmt = text('select id from roles where name=:name') + stmt = text("select id from roles where name=:name") stmt = stmt.bindparams(name=owner) owner_role = conn.execute(stmt).fetchone() - stmt = text('select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id') + stmt = text( + "select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id" + ) stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id) exists = conn.execute(stmt).fetchone() if not exists: - stmt = text('insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)') + stmt = text( + "insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)" + ) stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id) op.execute(stmt) def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'certificates', type_='foreignkey') - op.drop_constraint(None, 'certificates', type_='foreignkey') - op.create_foreign_key(u'certificates_authority_id_fkey', 'certificates', 'authorities', ['authority_id'], ['id']) - op.alter_column('certificates', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) - op.alter_column('certificates', 'body', - existing_type=sa.TEXT(), - nullable=True) - op.drop_column('certificates', 'root_authority_id') - op.add_column('authorities', sa.Column('not_before', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) - op.add_column('authorities', sa.Column('cn', sa.VARCHAR(length=128), autoincrement=False, nullable=True)) - op.add_column('authorities', sa.Column('bits', sa.INTEGER(), autoincrement=False, nullable=True)) - op.add_column('authorities', sa.Column('not_after', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) - op.alter_column('authorities', 'owner', - existing_type=sa.VARCHAR(length=128), - nullable=True) + op.drop_constraint(None, "certificates", type_="foreignkey") + op.drop_constraint(None, "certificates", type_="foreignkey") + op.create_foreign_key( + u"certificates_authority_id_fkey", + "certificates", + "authorities", + ["authority_id"], + ["id"], + ) + op.alter_column( + "certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) + op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=True) + op.drop_column("certificates", "root_authority_id") + op.add_column( + "authorities", + sa.Column( + "not_before", postgresql.TIMESTAMP(), autoincrement=False, nullable=True + ), + ) + op.add_column( + "authorities", + sa.Column("cn", sa.VARCHAR(length=128), autoincrement=False, nullable=True), + ) + op.add_column( + "authorities", + sa.Column("bits", sa.INTEGER(), autoincrement=False, nullable=True), + ) + op.add_column( + "authorities", + sa.Column( + "not_after", postgresql.TIMESTAMP(), autoincrement=False, nullable=True + ), + ) + op.alter_column( + "authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True + ) ### end Alembic commands ### diff --git a/lemur/migrations/versions/33de094da890_.py b/lemur/migrations/versions/33de094da890_.py index 76624e96..718e908f 100644 --- a/lemur/migrations/versions/33de094da890_.py +++ b/lemur/migrations/versions/33de094da890_.py @@ -7,25 +7,31 @@ Create Date: 2015-11-30 15:40:19.827272 """ # revision identifiers, used by Alembic. -revision = '33de094da890' +revision = "33de094da890" down_revision = None from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql + def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('certificate_replacement_associations', - sa.Column('replaced_certificate_id', sa.Integer(), nullable=True), - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade') + op.create_table( + "certificate_replacement_associations", + sa.Column("replaced_certificate_id", sa.Integer(), nullable=True), + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["certificate_id"], ["certificates.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["replaced_certificate_id"], ["certificates.id"], ondelete="cascade" + ), ) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_table('certificate_replacement_associations') + op.drop_table("certificate_replacement_associations") ### end Alembic commands ### diff --git a/lemur/migrations/versions/3adfdd6598df_.py b/lemur/migrations/versions/3adfdd6598df_.py index 1f290153..7f587f49 100644 --- a/lemur/migrations/versions/3adfdd6598df_.py +++ b/lemur/migrations/versions/3adfdd6598df_.py @@ -7,8 +7,8 @@ Create Date: 2018-04-10 13:25:47.007556 """ # revision identifiers, used by Alembic. -revision = '3adfdd6598df' -down_revision = '556ceb3e3c3e' +revision = "3adfdd6598df" +down_revision = "556ceb3e3c3e" import sqlalchemy as sa from alembic import op @@ -22,84 +22,90 @@ def upgrade(): # create provider table print("Creating dns_providers table") op.create_table( - 'dns_providers', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=256), nullable=True), - sa.Column('description', sa.String(length=1024), nullable=True), - sa.Column('provider_type', sa.String(length=256), nullable=True), - sa.Column('credentials', Vault(), nullable=True), - sa.Column('api_endpoint', sa.String(length=256), nullable=True), - sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False), - sa.Column('status', sa.String(length=128), nullable=True), - sa.Column('options', JSON), - sa.Column('domains', sa.JSON(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name') + "dns_providers", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=256), nullable=True), + sa.Column("description", sa.String(length=1024), nullable=True), + sa.Column("provider_type", sa.String(length=256), nullable=True), + sa.Column("credentials", Vault(), nullable=True), + sa.Column("api_endpoint", sa.String(length=256), nullable=True), + sa.Column( + "date_created", ArrowType(), server_default=sa.text("now()"), nullable=False + ), + sa.Column("status", sa.String(length=128), nullable=True), + sa.Column("options", JSON), + sa.Column("domains", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) print("Adding dns_provider_id column to certificates") - op.add_column('certificates', sa.Column('dns_provider_id', sa.Integer(), nullable=True)) + op.add_column( + "certificates", sa.Column("dns_provider_id", sa.Integer(), nullable=True) + ) print("Adding dns_provider_id column to pending_certs") - op.add_column('pending_certs', sa.Column('dns_provider_id', sa.Integer(), nullable=True)) + op.add_column( + "pending_certs", sa.Column("dns_provider_id", sa.Integer(), nullable=True) + ) print("Adding options column to pending_certs") - op.add_column('pending_certs', sa.Column('options', JSON)) + op.add_column("pending_certs", sa.Column("options", JSON)) print("Creating pending_dns_authorizations table") op.create_table( - 'pending_dns_authorizations', - sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), - sa.Column('account_number', sa.String(length=128), nullable=True), - sa.Column('domains', JSON, nullable=True), - sa.Column('dns_provider_type', sa.String(length=128), nullable=True), - sa.Column('options', JSON, nullable=True), + "pending_dns_authorizations", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("account_number", sa.String(length=128), nullable=True), + sa.Column("domains", JSON, nullable=True), + sa.Column("dns_provider_type", sa.String(length=128), nullable=True), + sa.Column("options", JSON, nullable=True), ) print("Creating certificates_dns_providers_fk foreign key") - op.create_foreign_key('certificates_dns_providers_fk', 'certificates', 'dns_providers', ['dns_provider_id'], ['id'], - ondelete='cascade') + op.create_foreign_key( + "certificates_dns_providers_fk", + "certificates", + "dns_providers", + ["dns_provider_id"], + ["id"], + ondelete="cascade", + ) print("Altering column types in the api_keys table") - op.alter_column('api_keys', 'issued_at', - existing_type=sa.BIGINT(), - nullable=True) - op.alter_column('api_keys', 'revoked', - existing_type=sa.BOOLEAN(), - nullable=True) - op.alter_column('api_keys', 'ttl', - existing_type=sa.BIGINT(), - nullable=True) - op.alter_column('api_keys', 'user_id', - existing_type=sa.INTEGER(), - nullable=True) + op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=True) + op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=True) + op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=True) + op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=True) print("Creating dns_providers_id foreign key on pending_certs table") - op.create_foreign_key(None, 'pending_certs', 'dns_providers', ['dns_provider_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key( + None, + "pending_certs", + "dns_providers", + ["dns_provider_id"], + ["id"], + ondelete="CASCADE", + ) + def downgrade(): print("Removing dns_providers_id foreign key on pending_certs table") - op.drop_constraint(None, 'pending_certs', type_='foreignkey') + op.drop_constraint(None, "pending_certs", type_="foreignkey") print("Reverting column types in the api_keys table") - op.alter_column('api_keys', 'user_id', - existing_type=sa.INTEGER(), - nullable=False) - op.alter_column('api_keys', 'ttl', - existing_type=sa.BIGINT(), - nullable=False) - op.alter_column('api_keys', 'revoked', - existing_type=sa.BOOLEAN(), - nullable=False) - op.alter_column('api_keys', 'issued_at', - existing_type=sa.BIGINT(), - nullable=False) + op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=False) + op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=False) + op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=False) + op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=False) print("Reverting certificates_dns_providers_fk foreign key") - op.drop_constraint('certificates_dns_providers_fk', 'certificates', type_='foreignkey') + op.drop_constraint( + "certificates_dns_providers_fk", "certificates", type_="foreignkey" + ) print("Dropping pending_dns_authorizations table") - op.drop_table('pending_dns_authorizations') + op.drop_table("pending_dns_authorizations") print("Undoing modifications to pending_certs table") - op.drop_column('pending_certs', 'options') - op.drop_column('pending_certs', 'dns_provider_id') + op.drop_column("pending_certs", "options") + op.drop_column("pending_certs", "dns_provider_id") print("Undoing modifications to certificates table") - op.drop_column('certificates', 'dns_provider_id') + op.drop_column("certificates", "dns_provider_id") print("Deleting dns_providers table") - op.drop_table('dns_providers') + op.drop_table("dns_providers") diff --git a/lemur/migrations/versions/412b22cb656a_.py b/lemur/migrations/versions/412b22cb656a_.py index d95ec701..c24ddfba 100644 --- a/lemur/migrations/versions/412b22cb656a_.py +++ b/lemur/migrations/versions/412b22cb656a_.py @@ -7,8 +7,8 @@ Create Date: 2016-05-17 17:37:41.210232 """ # revision identifiers, used by Alembic. -revision = '412b22cb656a' -down_revision = '4c50b903d1ae' +revision = "412b22cb656a" +down_revision = "4c50b903d1ae" from alembic import op import sqlalchemy as sa @@ -17,47 +17,102 @@ from sqlalchemy.sql import text def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('roles_authorities', - sa.Column('authority_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ), - sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) + op.create_table( + "roles_authorities", + sa.Column("authority_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["authority_id"], ["authorities.id"]), + sa.ForeignKeyConstraint(["role_id"], ["roles.id"]), ) - op.create_index('roles_authorities_ix', 'roles_authorities', ['authority_id', 'role_id'], unique=True) - op.create_table('roles_certificates', - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), - sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) + op.create_index( + "roles_authorities_ix", + "roles_authorities", + ["authority_id", "role_id"], + unique=True, + ) + op.create_table( + "roles_certificates", + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]), + sa.ForeignKeyConstraint(["role_id"], ["roles.id"]), + ) + op.create_index( + "roles_certificates_ix", + "roles_certificates", + ["certificate_id", "role_id"], + unique=True, + ) + op.create_index( + "certificate_associations_ix", + "certificate_associations", + ["domain_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_destination_associations_ix", + "certificate_destination_associations", + ["destination_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_notification_associations_ix", + "certificate_notification_associations", + ["notification_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["certificate_id", "certificate_id"], + unique=True, + ) + op.create_index( + "certificate_source_associations_ix", + "certificate_source_associations", + ["source_id", "certificate_id"], + unique=True, + ) + op.create_index( + "roles_users_ix", "roles_users", ["user_id", "role_id"], unique=True ) - op.create_index('roles_certificates_ix', 'roles_certificates', ['certificate_id', 'role_id'], unique=True) - op.create_index('certificate_associations_ix', 'certificate_associations', ['domain_id', 'certificate_id'], unique=True) - op.create_index('certificate_destination_associations_ix', 'certificate_destination_associations', ['destination_id', 'certificate_id'], unique=True) - op.create_index('certificate_notification_associations_ix', 'certificate_notification_associations', ['notification_id', 'certificate_id'], unique=True) - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True) - op.create_index('certificate_source_associations_ix', 'certificate_source_associations', ['source_id', 'certificate_id'], unique=True) - op.create_index('roles_users_ix', 'roles_users', ['user_id', 'role_id'], unique=True) ### end Alembic commands ### # migrate existing authority_id relationship to many_to_many conn = op.get_bind() - for id, authority_id in conn.execute(text('select id, authority_id from roles where authority_id is not null')): - stmt = text('insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)') + for id, authority_id in conn.execute( + text("select id, authority_id from roles where authority_id is not null") + ): + stmt = text( + "insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)" + ) stmt = stmt.bindparams(role_id=id, authority_id=authority_id) op.execute(stmt) def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_index('roles_users_ix', table_name='roles_users') - op.drop_index('certificate_source_associations_ix', table_name='certificate_source_associations') - op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations') - op.drop_index('certificate_notification_associations_ix', table_name='certificate_notification_associations') - op.drop_index('certificate_destination_associations_ix', table_name='certificate_destination_associations') - op.drop_index('certificate_associations_ix', table_name='certificate_associations') - op.drop_index('roles_certificates_ix', table_name='roles_certificates') - op.drop_table('roles_certificates') - op.drop_index('roles_authorities_ix', table_name='roles_authorities') - op.drop_table('roles_authorities') + op.drop_index("roles_users_ix", table_name="roles_users") + op.drop_index( + "certificate_source_associations_ix", + table_name="certificate_source_associations", + ) + op.drop_index( + "certificate_replacement_associations_ix", + table_name="certificate_replacement_associations", + ) + op.drop_index( + "certificate_notification_associations_ix", + table_name="certificate_notification_associations", + ) + op.drop_index( + "certificate_destination_associations_ix", + table_name="certificate_destination_associations", + ) + op.drop_index("certificate_associations_ix", table_name="certificate_associations") + op.drop_index("roles_certificates_ix", table_name="roles_certificates") + op.drop_table("roles_certificates") + op.drop_index("roles_authorities_ix", table_name="roles_authorities") + op.drop_table("roles_authorities") ### end Alembic commands ### diff --git a/lemur/migrations/versions/449c3d5c7299_.py b/lemur/migrations/versions/449c3d5c7299_.py index 0bc30db1..f33548da 100644 --- a/lemur/migrations/versions/449c3d5c7299_.py +++ b/lemur/migrations/versions/449c3d5c7299_.py @@ -7,8 +7,8 @@ Create Date: 2018-02-24 22:51:35.369229 """ # revision identifiers, used by Alembic. -revision = '449c3d5c7299' -down_revision = '5770674184de' +revision = "449c3d5c7299" +down_revision = "5770674184de" from alembic import op from flask_sqlalchemy import SQLAlchemy @@ -23,12 +23,14 @@ COLUMNS = ["notification_id", "certificate_id"] def upgrade(): connection = op.get_bind() # Delete duplicate entries - connection.execute("""\ + connection.execute( + """\ DELETE FROM certificate_notification_associations WHERE ctid NOT IN ( -- Select the first tuple ID for each (notification_id, certificate_id) combination and keep that SELECT min(ctid) FROM certificate_notification_associations GROUP BY notification_id, certificate_id ) - """) + """ + ) op.create_unique_constraint(CONSTRAINT_NAME, TABLE, COLUMNS) diff --git a/lemur/migrations/versions/4c50b903d1ae_.py b/lemur/migrations/versions/4c50b903d1ae_.py index 7b0515d4..93d4a312 100644 --- a/lemur/migrations/versions/4c50b903d1ae_.py +++ b/lemur/migrations/versions/4c50b903d1ae_.py @@ -7,20 +7,21 @@ Create Date: 2015-12-30 10:19:30.057791 """ # revision identifiers, used by Alembic. -revision = '4c50b903d1ae' -down_revision = '33de094da890' +revision = "4c50b903d1ae" +down_revision = "33de094da890" from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql + def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.add_column('domains', sa.Column('sensitive', sa.Boolean(), nullable=True)) + op.add_column("domains", sa.Column("sensitive", sa.Boolean(), nullable=True)) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_column('domains', 'sensitive') + op.drop_column("domains", "sensitive") ### end Alembic commands ### diff --git a/lemur/migrations/versions/556ceb3e3c3e_.py b/lemur/migrations/versions/556ceb3e3c3e_.py index 2916c0eb..60304138 100644 --- a/lemur/migrations/versions/556ceb3e3c3e_.py +++ b/lemur/migrations/versions/556ceb3e3c3e_.py @@ -7,8 +7,8 @@ Create Date: 2018-01-05 01:18:45.571595 """ # revision identifiers, used by Alembic. -revision = '556ceb3e3c3e' -down_revision = '449c3d5c7299' +revision = "556ceb3e3c3e" +down_revision = "449c3d5c7299" from alembic import op import sqlalchemy as sa @@ -16,84 +16,150 @@ from lemur.utils import Vault from sqlalchemy.dialects import postgresql from sqlalchemy_utils import ArrowType + def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('pending_certs', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('external_id', sa.String(length=128), nullable=True), - sa.Column('owner', sa.String(length=128), nullable=False), - sa.Column('name', sa.String(length=256), nullable=True), - sa.Column('description', sa.String(length=1024), nullable=True), - sa.Column('notify', sa.Boolean(), nullable=True), - sa.Column('number_attempts', sa.Integer(), nullable=True), - sa.Column('rename', sa.Boolean(), nullable=True), - sa.Column('cn', sa.String(length=128), nullable=True), - sa.Column('csr', sa.Text(), nullable=False), - sa.Column('chain', sa.Text(), nullable=True), - sa.Column('private_key', Vault(), nullable=True), - sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False), - sa.Column('status', sa.String(length=128), nullable=True), - sa.Column('rotation', sa.Boolean(), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=True), - sa.Column('authority_id', sa.Integer(), nullable=True), - sa.Column('root_authority_id', sa.Integer(), nullable=True), - sa.Column('rotation_policy_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['root_authority_id'], ['authorities.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['rotation_policy_id'], ['rotation_policies.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name') + op.create_table( + "pending_certs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("external_id", sa.String(length=128), nullable=True), + sa.Column("owner", sa.String(length=128), nullable=False), + sa.Column("name", sa.String(length=256), nullable=True), + sa.Column("description", sa.String(length=1024), nullable=True), + sa.Column("notify", sa.Boolean(), nullable=True), + sa.Column("number_attempts", sa.Integer(), nullable=True), + sa.Column("rename", sa.Boolean(), nullable=True), + sa.Column("cn", sa.String(length=128), nullable=True), + sa.Column("csr", sa.Text(), nullable=False), + sa.Column("chain", sa.Text(), nullable=True), + sa.Column("private_key", Vault(), nullable=True), + sa.Column( + "date_created", ArrowType(), server_default=sa.text("now()"), nullable=False + ), + sa.Column("status", sa.String(length=128), nullable=True), + sa.Column("rotation", sa.Boolean(), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("authority_id", sa.Integer(), nullable=True), + sa.Column("root_authority_id", sa.Integer(), nullable=True), + sa.Column("rotation_policy_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["authority_id"], ["authorities.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["root_authority_id"], ["authorities.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["rotation_policy_id"], ["rotation_policies.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) - op.create_table('pending_cert_destination_associations', - sa.Column('destination_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['destination_id'], ['destinations.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade') + op.create_table( + "pending_cert_destination_associations", + sa.Column("destination_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["destination_id"], ["destinations.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), ) - op.create_index('pending_cert_destination_associations_ix', 'pending_cert_destination_associations', ['destination_id', 'pending_cert_id'], unique=False) - op.create_table('pending_cert_notification_associations', - sa.Column('notification_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['notification_id'], ['notifications.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade') + op.create_index( + "pending_cert_destination_associations_ix", + "pending_cert_destination_associations", + ["destination_id", "pending_cert_id"], + unique=False, ) - op.create_index('pending_cert_notification_associations_ix', 'pending_cert_notification_associations', ['notification_id', 'pending_cert_id'], unique=False) - op.create_table('pending_cert_replacement_associations', - sa.Column('replaced_certificate_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade') + op.create_table( + "pending_cert_notification_associations", + sa.Column("notification_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["notification_id"], ["notifications.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), ) - op.create_index('pending_cert_replacement_associations_ix', 'pending_cert_replacement_associations', ['replaced_certificate_id', 'pending_cert_id'], unique=False) - op.create_table('pending_cert_role_associations', - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.Column('role_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ), - sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) + op.create_index( + "pending_cert_notification_associations_ix", + "pending_cert_notification_associations", + ["notification_id", "pending_cert_id"], + unique=False, ) - op.create_index('pending_cert_role_associations_ix', 'pending_cert_role_associations', ['pending_cert_id', 'role_id'], unique=False) - op.create_table('pending_cert_source_associations', - sa.Column('source_id', sa.Integer(), nullable=True), - sa.Column('pending_cert_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'), - sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='cascade') + op.create_table( + "pending_cert_replacement_associations", + sa.Column("replaced_certificate_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint( + ["replaced_certificate_id"], ["certificates.id"], ondelete="cascade" + ), + ) + op.create_index( + "pending_cert_replacement_associations_ix", + "pending_cert_replacement_associations", + ["replaced_certificate_id", "pending_cert_id"], + unique=False, + ) + op.create_table( + "pending_cert_role_associations", + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.Column("role_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["pending_cert_id"], ["pending_certs.id"]), + sa.ForeignKeyConstraint(["role_id"], ["roles.id"]), + ) + op.create_index( + "pending_cert_role_associations_ix", + "pending_cert_role_associations", + ["pending_cert_id", "role_id"], + unique=False, + ) + op.create_table( + "pending_cert_source_associations", + sa.Column("source_id", sa.Integer(), nullable=True), + sa.Column("pending_cert_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade" + ), + sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="cascade"), + ) + op.create_index( + "pending_cert_source_associations_ix", + "pending_cert_source_associations", + ["source_id", "pending_cert_id"], + unique=False, ) - op.create_index('pending_cert_source_associations_ix', 'pending_cert_source_associations', ['source_id', 'pending_cert_id'], unique=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_index('pending_cert_source_associations_ix', table_name='pending_cert_source_associations') - op.drop_table('pending_cert_source_associations') - op.drop_index('pending_cert_role_associations_ix', table_name='pending_cert_role_associations') - op.drop_table('pending_cert_role_associations') - op.drop_index('pending_cert_replacement_associations_ix', table_name='pending_cert_replacement_associations') - op.drop_table('pending_cert_replacement_associations') - op.drop_index('pending_cert_notification_associations_ix', table_name='pending_cert_notification_associations') - op.drop_table('pending_cert_notification_associations') - op.drop_index('pending_cert_destination_associations_ix', table_name='pending_cert_destination_associations') - op.drop_table('pending_cert_destination_associations') - op.drop_table('pending_certs') + op.drop_index( + "pending_cert_source_associations_ix", + table_name="pending_cert_source_associations", + ) + op.drop_table("pending_cert_source_associations") + op.drop_index( + "pending_cert_role_associations_ix", table_name="pending_cert_role_associations" + ) + op.drop_table("pending_cert_role_associations") + op.drop_index( + "pending_cert_replacement_associations_ix", + table_name="pending_cert_replacement_associations", + ) + op.drop_table("pending_cert_replacement_associations") + op.drop_index( + "pending_cert_notification_associations_ix", + table_name="pending_cert_notification_associations", + ) + op.drop_table("pending_cert_notification_associations") + op.drop_index( + "pending_cert_destination_associations_ix", + table_name="pending_cert_destination_associations", + ) + op.drop_table("pending_cert_destination_associations") + op.drop_table("pending_certs") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/5770674184de_.py b/lemur/migrations/versions/5770674184de_.py index 88262a84..49d89367 100644 --- a/lemur/migrations/versions/5770674184de_.py +++ b/lemur/migrations/versions/5770674184de_.py @@ -7,8 +7,8 @@ Create Date: 2018-02-23 15:27:30.335435 """ # revision identifiers, used by Alembic. -revision = '5770674184de' -down_revision = 'ce547319f7be' +revision = "5770674184de" +down_revision = "ce547319f7be" from flask_sqlalchemy import SQLAlchemy from lemur.models import certificate_notification_associations @@ -32,7 +32,9 @@ def upgrade(): # If we've seen a pair already, delete the duplicates if seen.get("{}-{}".format(x.certificate_id, x.notification_id)): print("Deleting duplicate: {}".format(x)) - d = session.query(certificate_notification_associations).filter(certificate_notification_associations.c.id==x.id) + d = session.query(certificate_notification_associations).filter( + certificate_notification_associations.c.id == x.id + ) d.delete(synchronize_session=False) seen["{}-{}".format(x.certificate_id, x.notification_id)] = True db.session.commit() diff --git a/lemur/migrations/versions/5ae0ecefb01f_.py b/lemur/migrations/versions/5ae0ecefb01f_.py index a471c4bf..7b0d5ae0 100644 --- a/lemur/migrations/versions/5ae0ecefb01f_.py +++ b/lemur/migrations/versions/5ae0ecefb01f_.py @@ -7,8 +7,8 @@ Create Date: 2018-08-14 08:16:43.329316 """ # revision identifiers, used by Alembic. -revision = '5ae0ecefb01f' -down_revision = '1db4f82bc780' +revision = "5ae0ecefb01f" +down_revision = "1db4f82bc780" from alembic import op import sqlalchemy as sa @@ -16,17 +16,14 @@ import sqlalchemy as sa def upgrade(): op.alter_column( - table_name='pending_certs', - column_name='status', - nullable=True, - type_=sa.TEXT() + table_name="pending_certs", column_name="status", nullable=True, type_=sa.TEXT() ) def downgrade(): op.alter_column( - table_name='pending_certs', - column_name='status', + table_name="pending_certs", + column_name="status", nullable=True, - type_=sa.VARCHAR(128) + type_=sa.VARCHAR(128), ) diff --git a/lemur/migrations/versions/5bc47fa7cac4_.py b/lemur/migrations/versions/5bc47fa7cac4_.py index f4a145c8..f786c527 100644 --- a/lemur/migrations/versions/5bc47fa7cac4_.py +++ b/lemur/migrations/versions/5bc47fa7cac4_.py @@ -7,16 +7,18 @@ Create Date: 2017-12-08 14:19:11.903864 """ # revision identifiers, used by Alembic. -revision = '5bc47fa7cac4' -down_revision = 'c05a8998b371' +revision = "5bc47fa7cac4" +down_revision = "c05a8998b371" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('roles', sa.Column('third_party', sa.Boolean(), nullable=True, default=False)) + op.add_column( + "roles", sa.Column("third_party", sa.Boolean(), nullable=True, default=False) + ) def downgrade(): - op.drop_column('roles', 'third_party') + op.drop_column("roles", "third_party") diff --git a/lemur/migrations/versions/5e680529b666_.py b/lemur/migrations/versions/5e680529b666_.py index d59d996f..4cca4521 100644 --- a/lemur/migrations/versions/5e680529b666_.py +++ b/lemur/migrations/versions/5e680529b666_.py @@ -7,20 +7,20 @@ Create Date: 2017-01-26 05:05:25.168125 """ # revision identifiers, used by Alembic. -revision = '5e680529b666' -down_revision = '131ec6accff5' +revision = "5e680529b666" +down_revision = "131ec6accff5" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('endpoints', sa.Column('sensitive', sa.Boolean(), nullable=True)) - op.add_column('endpoints', sa.Column('source_id', sa.Integer(), nullable=True)) - op.create_foreign_key(None, 'endpoints', 'sources', ['source_id'], ['id']) + op.add_column("endpoints", sa.Column("sensitive", sa.Boolean(), nullable=True)) + op.add_column("endpoints", sa.Column("source_id", sa.Integer(), nullable=True)) + op.create_foreign_key(None, "endpoints", "sources", ["source_id"], ["id"]) def downgrade(): - op.drop_constraint(None, 'endpoints', type_='foreignkey') - op.drop_column('endpoints', 'source_id') - op.drop_column('endpoints', 'sensitive') + op.drop_constraint(None, "endpoints", type_="foreignkey") + op.drop_column("endpoints", "source_id") + op.drop_column("endpoints", "sensitive") diff --git a/lemur/migrations/versions/6006c79b6011_.py b/lemur/migrations/versions/6006c79b6011_.py index c41b1d25..86727716 100644 --- a/lemur/migrations/versions/6006c79b6011_.py +++ b/lemur/migrations/versions/6006c79b6011_.py @@ -7,15 +7,15 @@ Create Date: 2018-10-19 15:23:06.750510 """ # revision identifiers, used by Alembic. -revision = '6006c79b6011' -down_revision = '984178255c83' +revision = "6006c79b6011" +down_revision = "984178255c83" from alembic import op def upgrade(): - op.create_unique_constraint("uq_label", 'sources', ['label']) + op.create_unique_constraint("uq_label", "sources", ["label"]) def downgrade(): - op.drop_constraint("uq_label", 'sources', type_='unique') + op.drop_constraint("uq_label", "sources", type_="unique") diff --git a/lemur/migrations/versions/7ead443ba911_.py b/lemur/migrations/versions/7ead443ba911_.py index 62be01aa..10b8e576 100644 --- a/lemur/migrations/versions/7ead443ba911_.py +++ b/lemur/migrations/versions/7ead443ba911_.py @@ -7,15 +7,16 @@ Create Date: 2018-10-21 22:06:23.056906 """ # revision identifiers, used by Alembic. -revision = '7ead443ba911' -down_revision = '6006c79b6011' +revision = "7ead443ba911" +down_revision = "6006c79b6011" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('certificates', sa.Column('csr', sa.TEXT(), nullable=True)) + op.add_column("certificates", sa.Column("csr", sa.TEXT(), nullable=True)) + def downgrade(): - op.drop_column('certificates', 'csr') + op.drop_column("certificates", "csr") diff --git a/lemur/migrations/versions/7f71c0cea31a_.py b/lemur/migrations/versions/7f71c0cea31a_.py index 04bb02ea..5e90cbb1 100644 --- a/lemur/migrations/versions/7f71c0cea31a_.py +++ b/lemur/migrations/versions/7f71c0cea31a_.py @@ -9,8 +9,8 @@ Create Date: 2016-07-28 09:39:12.736506 """ # revision identifiers, used by Alembic. -revision = '7f71c0cea31a' -down_revision = '29d8c8455c86' +revision = "7f71c0cea31a" +down_revision = "29d8c8455c86" from alembic import op import sqlalchemy as sa @@ -19,17 +19,25 @@ from sqlalchemy.sql import text def upgrade(): conn = op.get_bind() - for name in conn.execute(text('select name from certificates group by name having count(*) > 1')): - for idx, id in enumerate(conn.execute(text("select id from certificates where certificates.name like :name order by id ASC").bindparams(name=name[0]))): + for name in conn.execute( + text("select name from certificates group by name having count(*) > 1") + ): + for idx, id in enumerate( + conn.execute( + text( + "select id from certificates where certificates.name like :name order by id ASC" + ).bindparams(name=name[0]) + ) + ): if not idx: continue - new_name = name[0] + '-' + str(idx) - stmt = text('update certificates set name=:name where id=:id') + new_name = name[0] + "-" + str(idx) + stmt = text("update certificates set name=:name where id=:id") stmt = stmt.bindparams(name=new_name, id=id[0]) op.execute(stmt) - op.create_unique_constraint(None, 'certificates', ['name']) + op.create_unique_constraint(None, "certificates", ["name"]) def downgrade(): - op.drop_constraint(None, 'certificates', type_='unique') + op.drop_constraint(None, "certificates", type_="unique") diff --git a/lemur/migrations/versions/8ae67285ff14_.py b/lemur/migrations/versions/8ae67285ff14_.py index f45be70d..e8f6a217 100644 --- a/lemur/migrations/versions/8ae67285ff14_.py +++ b/lemur/migrations/versions/8ae67285ff14_.py @@ -7,18 +7,28 @@ Create Date: 2017-05-10 11:56:13.999332 """ # revision identifiers, used by Alembic. -revision = '8ae67285ff14' -down_revision = '5e680529b666' +revision = "8ae67285ff14" +down_revision = "5e680529b666" from alembic import op import sqlalchemy as sa def upgrade(): - op.drop_index('certificate_replacement_associations_ix') - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True) + op.drop_index("certificate_replacement_associations_ix") + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["replaced_certificate_id", "certificate_id"], + unique=True, + ) def downgrade(): - op.drop_index('certificate_replacement_associations_ix') - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True) + op.drop_index("certificate_replacement_associations_ix") + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["certificate_id", "certificate_id"], + unique=True, + ) diff --git a/lemur/migrations/versions/932525b82f1a_.py b/lemur/migrations/versions/932525b82f1a_.py index 2ee95d07..8ff36d1c 100644 --- a/lemur/migrations/versions/932525b82f1a_.py +++ b/lemur/migrations/versions/932525b82f1a_.py @@ -7,15 +7,15 @@ Create Date: 2016-10-13 20:14:33.928029 """ # revision identifiers, used by Alembic. -revision = '932525b82f1a' -down_revision = '7f71c0cea31a' +revision = "932525b82f1a" +down_revision = "7f71c0cea31a" from alembic import op def upgrade(): - op.alter_column('certificates', 'active', new_column_name='notify') + op.alter_column("certificates", "active", new_column_name="notify") def downgrade(): - op.alter_column('certificates', 'notify', new_column_name='active') + op.alter_column("certificates", "notify", new_column_name="active") diff --git a/lemur/migrations/versions/9392b9f9a805_.py b/lemur/migrations/versions/9392b9f9a805_.py index d6ca734b..8ff09333 100644 --- a/lemur/migrations/versions/9392b9f9a805_.py +++ b/lemur/migrations/versions/9392b9f9a805_.py @@ -6,8 +6,8 @@ Create Date: 2018-09-17 08:33:37.087488 """ # revision identifiers, used by Alembic. -revision = '9392b9f9a805' -down_revision = '5ae0ecefb01f' +revision = "9392b9f9a805" +down_revision = "5ae0ecefb01f" from alembic import op from sqlalchemy_utils import ArrowType @@ -15,10 +15,17 @@ import sqlalchemy as sa def upgrade(): - op.add_column('pending_certs', sa.Column('last_updated', ArrowType, server_default=sa.text('now()'), onupdate=sa.text('now()'), - nullable=False)) + op.add_column( + "pending_certs", + sa.Column( + "last_updated", + ArrowType, + server_default=sa.text("now()"), + onupdate=sa.text("now()"), + nullable=False, + ), + ) def downgrade(): - op.drop_column('pending_certs', 'last_updated') - + op.drop_column("pending_certs", "last_updated") diff --git a/lemur/migrations/versions/984178255c83_.py b/lemur/migrations/versions/984178255c83_.py index 40d2ce31..88cab183 100644 --- a/lemur/migrations/versions/984178255c83_.py +++ b/lemur/migrations/versions/984178255c83_.py @@ -7,18 +7,20 @@ Create Date: 2018-10-11 20:49:12.704563 """ # revision identifiers, used by Alembic. -revision = '984178255c83' -down_revision = 'f2383bf08fbc' +revision = "984178255c83" +down_revision = "f2383bf08fbc" from alembic import op import sqlalchemy as sa def upgrade(): - op.add_column('pending_certs', sa.Column('resolved', sa.Boolean(), nullable=True)) - op.add_column('pending_certs', sa.Column('resolved_cert_id', sa.Integer(), nullable=True)) + op.add_column("pending_certs", sa.Column("resolved", sa.Boolean(), nullable=True)) + op.add_column( + "pending_certs", sa.Column("resolved_cert_id", sa.Integer(), nullable=True) + ) def downgrade(): - op.drop_column('pending_certs', 'resolved_cert_id') - op.drop_column('pending_certs', 'resolved') + op.drop_column("pending_certs", "resolved_cert_id") + op.drop_column("pending_certs", "resolved") diff --git a/lemur/migrations/versions/9f79024fe67b_.py b/lemur/migrations/versions/9f79024fe67b_.py index ad22d5f3..cb7db296 100644 --- a/lemur/migrations/versions/9f79024fe67b_.py +++ b/lemur/migrations/versions/9f79024fe67b_.py @@ -7,16 +7,26 @@ Create Date: 2019-01-03 15:36:59.181911 """ # revision identifiers, used by Alembic. -revision = '9f79024fe67b' -down_revision = 'ee827d1e1974' +revision = "9f79024fe67b" +down_revision = "ee827d1e1974" from alembic import op import sqlalchemy as sa def upgrade(): - op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'delete_cert', 'key_view', 'revoke_cert', 'update_cert']) + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"], + ) def downgrade(): - op.sync_enum_values('public', 'log_type', ['create_cert', 'delete_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'key_view', 'revoke_cert', 'update_cert']) + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"], + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ) diff --git a/lemur/migrations/versions/a02a678ddc25_.py b/lemur/migrations/versions/a02a678ddc25_.py index 603bc06a..f8fa09bb 100644 --- a/lemur/migrations/versions/a02a678ddc25_.py +++ b/lemur/migrations/versions/a02a678ddc25_.py @@ -10,8 +10,8 @@ Create Date: 2017-07-12 11:45:49.257927 """ # revision identifiers, used by Alembic. -revision = 'a02a678ddc25' -down_revision = '8ae67285ff14' +revision = "a02a678ddc25" +down_revision = "8ae67285ff14" from alembic import op import sqlalchemy as sa @@ -20,25 +20,30 @@ from sqlalchemy.sql import text def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('rotation_policies', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('days', sa.Integer(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "rotation_policies", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("days", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column( + "certificates", sa.Column("rotation_policy_id", sa.Integer(), nullable=True) + ) + op.create_foreign_key( + None, "certificates", "rotation_policies", ["rotation_policy_id"], ["id"] ) - op.add_column('certificates', sa.Column('rotation_policy_id', sa.Integer(), nullable=True)) - op.create_foreign_key(None, 'certificates', 'rotation_policies', ['rotation_policy_id'], ['id']) conn = op.get_bind() - stmt = text('insert into rotation_policies (days, name) values (:days, :name)') - stmt = stmt.bindparams(days=30, name='default') + stmt = text("insert into rotation_policies (days, name) values (:days, :name)") + stmt = stmt.bindparams(days=30, name="default") conn.execute(stmt) - stmt = text('select id from rotation_policies where name=:name') - stmt = stmt.bindparams(name='default') + stmt = text("select id from rotation_policies where name=:name") + stmt = stmt.bindparams(name="default") rotation_policy_id = conn.execute(stmt).fetchone()[0] - stmt = text('update certificates set rotation_policy_id=:rotation_policy_id') + stmt = text("update certificates set rotation_policy_id=:rotation_policy_id") stmt = stmt.bindparams(rotation_policy_id=rotation_policy_id) conn.execute(stmt) # ### end Alembic commands ### @@ -46,9 +51,17 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'certificates', type_='foreignkey') - op.drop_column('certificates', 'rotation_policy_id') - op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations') - op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True) - op.drop_table('rotation_policies') + op.drop_constraint(None, "certificates", type_="foreignkey") + op.drop_column("certificates", "rotation_policy_id") + op.drop_index( + "certificate_replacement_associations_ix", + table_name="certificate_replacement_associations", + ) + op.create_index( + "certificate_replacement_associations_ix", + "certificate_replacement_associations", + ["replaced_certificate_id", "certificate_id"], + unique=True, + ) + op.drop_table("rotation_policies") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/ac483cfeb230_.py b/lemur/migrations/versions/ac483cfeb230_.py index d28a2599..d1e2361d 100644 --- a/lemur/migrations/versions/ac483cfeb230_.py +++ b/lemur/migrations/versions/ac483cfeb230_.py @@ -7,8 +7,8 @@ Create Date: 2017-10-11 10:16:39.682591 """ # revision identifiers, used by Alembic. -revision = 'ac483cfeb230' -down_revision = 'b29e2c4bf8c9' +revision = "ac483cfeb230" +down_revision = "b29e2c4bf8c9" from alembic import op import sqlalchemy as sa @@ -16,12 +16,18 @@ from sqlalchemy.dialects import postgresql def upgrade(): - op.alter_column('certificates', 'name', - existing_type=sa.VARCHAR(length=128), - type_=sa.String(length=256)) + op.alter_column( + "certificates", + "name", + existing_type=sa.VARCHAR(length=128), + type_=sa.String(length=256), + ) def downgrade(): - op.alter_column('certificates', 'name', - existing_type=sa.VARCHAR(length=256), - type_=sa.String(length=128)) + op.alter_column( + "certificates", + "name", + existing_type=sa.VARCHAR(length=256), + type_=sa.String(length=128), + ) diff --git a/lemur/migrations/versions/b29e2c4bf8c9_.py b/lemur/migrations/versions/b29e2c4bf8c9_.py index 19835e09..6f9dc526 100644 --- a/lemur/migrations/versions/b29e2c4bf8c9_.py +++ b/lemur/migrations/versions/b29e2c4bf8c9_.py @@ -7,8 +7,8 @@ Create Date: 2017-09-26 10:50:35.740367 """ # revision identifiers, used by Alembic. -revision = 'b29e2c4bf8c9' -down_revision = '1ae8e3104db8' +revision = "b29e2c4bf8c9" +down_revision = "1ae8e3104db8" from alembic import op import sqlalchemy as sa @@ -16,13 +16,25 @@ import sqlalchemy as sa def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('certificates', sa.Column('external_id', sa.String(128), nullable=True)) - op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['create_cert', 'key_view', 'revoke_cert', 'update_cert']) + op.add_column( + "certificates", sa.Column("external_id", sa.String(128), nullable=True) + ) + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "key_view", "update_cert"], + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'key_view', 'update_cert']) - op.drop_column('certificates', 'external_id') + op.sync_enum_values( + "public", + "log_type", + ["create_cert", "key_view", "revoke_cert", "update_cert"], + ["create_cert", "key_view", "update_cert"], + ) + op.drop_column("certificates", "external_id") # ### end Alembic commands ### diff --git a/lemur/migrations/versions/c05a8998b371_.py b/lemur/migrations/versions/c05a8998b371_.py index cf600043..a5c9abff 100644 --- a/lemur/migrations/versions/c05a8998b371_.py +++ b/lemur/migrations/versions/c05a8998b371_.py @@ -7,25 +7,27 @@ Create Date: 2017-11-10 14:51:28.975927 """ # revision identifiers, used by Alembic. -revision = 'c05a8998b371' -down_revision = 'ac483cfeb230' +revision = "c05a8998b371" +down_revision = "ac483cfeb230" from alembic import op import sqlalchemy as sa import sqlalchemy_utils + def upgrade(): - op.create_table('api_keys', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sa.String(length=128), nullable=True), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.Column('ttl', sa.BigInteger(), nullable=False), - sa.Column('issued_at', sa.BigInteger(), nullable=False), - sa.Column('revoked', sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "api_keys", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=128), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("ttl", sa.BigInteger(), nullable=False), + sa.Column("issued_at", sa.BigInteger(), nullable=False), + sa.Column("revoked", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.PrimaryKeyConstraint("id"), ) def downgrade(): - op.drop_table('api_keys') + op.drop_table("api_keys") diff --git a/lemur/migrations/versions/c87cb989af04_.py b/lemur/migrations/versions/c87cb989af04_.py index 4959e727..69f53bf4 100644 --- a/lemur/migrations/versions/c87cb989af04_.py +++ b/lemur/migrations/versions/c87cb989af04_.py @@ -5,15 +5,15 @@ Create Date: 2018-10-11 09:44:57.099854 """ -revision = 'c87cb989af04' -down_revision = '9392b9f9a805' +revision = "c87cb989af04" +down_revision = "9392b9f9a805" from alembic import op def upgrade(): - op.create_index(op.f('ix_domains_name'), 'domains', ['name'], unique=False) + op.create_index(op.f("ix_domains_name"), "domains", ["name"], unique=False) def downgrade(): - op.drop_index(op.f('ix_domains_name'), table_name='domains') + op.drop_index(op.f("ix_domains_name"), table_name="domains") diff --git a/lemur/migrations/versions/ce547319f7be_.py b/lemur/migrations/versions/ce547319f7be_.py index 41ef1fa8..d139c6fb 100644 --- a/lemur/migrations/versions/ce547319f7be_.py +++ b/lemur/migrations/versions/ce547319f7be_.py @@ -7,8 +7,8 @@ Create Date: 2018-02-23 11:00:02.150561 """ # revision identifiers, used by Alembic. -revision = 'ce547319f7be' -down_revision = '5bc47fa7cac4' +revision = "ce547319f7be" +down_revision = "5bc47fa7cac4" import sqlalchemy as sa @@ -24,12 +24,12 @@ TABLE = "certificate_notification_associations" def upgrade(): print("Adding id column") op.add_column( - TABLE, - sa.Column('id', sa.Integer, primary_key=True, autoincrement=True) + TABLE, sa.Column("id", sa.Integer, primary_key=True, autoincrement=True) ) db.session.commit() db.session.flush() + def downgrade(): op.drop_column(TABLE, "id") db.session.commit() diff --git a/lemur/migrations/versions/e3691fc396e9_.py b/lemur/migrations/versions/e3691fc396e9_.py index 1c5c2f15..0007b804 100644 --- a/lemur/migrations/versions/e3691fc396e9_.py +++ b/lemur/migrations/versions/e3691fc396e9_.py @@ -7,29 +7,36 @@ Create Date: 2016-11-28 13:15:46.995219 """ # revision identifiers, used by Alembic. -revision = 'e3691fc396e9' -down_revision = '932525b82f1a' +revision = "e3691fc396e9" +down_revision = "932525b82f1a" from alembic import op import sqlalchemy as sa import sqlalchemy_utils + def upgrade(): ### commands auto generated by Alembic - please adjust! ### - op.create_table('logs', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('certificate_id', sa.Integer(), nullable=True), - sa.Column('log_type', sa.Enum('key_view', name='log_type'), nullable=False), - sa.Column('logged_at', sqlalchemy_utils.types.arrow.ArrowType(), server_default=sa.text('now()'), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "logs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("certificate_id", sa.Integer(), nullable=True), + sa.Column("log_type", sa.Enum("key_view", name="log_type"), nullable=False), + sa.Column( + "logged_at", + sqlalchemy_utils.types.arrow.ArrowType(), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.PrimaryKeyConstraint("id"), ) ### end Alembic commands ### def downgrade(): ### commands auto generated by Alembic - please adjust! ### - op.drop_table('logs') + op.drop_table("logs") ### end Alembic commands ### diff --git a/lemur/migrations/versions/ee827d1e1974_.py b/lemur/migrations/versions/ee827d1e1974_.py index 62ac6222..56696fe3 100644 --- a/lemur/migrations/versions/ee827d1e1974_.py +++ b/lemur/migrations/versions/ee827d1e1974_.py @@ -7,25 +7,44 @@ Create Date: 2018-11-05 09:49:40.226368 """ # revision identifiers, used by Alembic. -revision = 'ee827d1e1974' -down_revision = '7ead443ba911' +revision = "ee827d1e1974" +down_revision = "7ead443ba911" from alembic import op from sqlalchemy.exc import ProgrammingError + def upgrade(): connection = op.get_bind() connection.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") - op.create_index('ix_certificates_cn', 'certificates', ['cn'], unique=False, postgresql_ops={'cn': 'gin_trgm_ops'}, - postgresql_using='gin') - op.create_index('ix_certificates_name', 'certificates', ['name'], unique=False, - postgresql_ops={'name': 'gin_trgm_ops'}, postgresql_using='gin') - op.create_index('ix_domains_name_gin', 'domains', ['name'], unique=False, postgresql_ops={'name': 'gin_trgm_ops'}, - postgresql_using='gin') + op.create_index( + "ix_certificates_cn", + "certificates", + ["cn"], + unique=False, + postgresql_ops={"cn": "gin_trgm_ops"}, + postgresql_using="gin", + ) + op.create_index( + "ix_certificates_name", + "certificates", + ["name"], + unique=False, + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ) + op.create_index( + "ix_domains_name_gin", + "domains", + ["name"], + unique=False, + postgresql_ops={"name": "gin_trgm_ops"}, + postgresql_using="gin", + ) def downgrade(): - op.drop_index('ix_domains_name', table_name='domains') - op.drop_index('ix_certificates_name', table_name='certificates') - op.drop_index('ix_certificates_cn', table_name='certificates') + op.drop_index("ix_domains_name", table_name="domains") + op.drop_index("ix_certificates_name", table_name="certificates") + op.drop_index("ix_certificates_cn", table_name="certificates") diff --git a/lemur/migrations/versions/f2383bf08fbc_.py b/lemur/migrations/versions/f2383bf08fbc_.py index 1fa36960..a54aa5d2 100644 --- a/lemur/migrations/versions/f2383bf08fbc_.py +++ b/lemur/migrations/versions/f2383bf08fbc_.py @@ -7,17 +7,22 @@ Create Date: 2018-10-11 11:23:31.195471 """ -revision = 'f2383bf08fbc' -down_revision = 'c87cb989af04' +revision = "f2383bf08fbc" +down_revision = "c87cb989af04" import sqlalchemy as sa from alembic import op def upgrade(): - op.create_index('ix_certificates_id_desc', 'certificates', [sa.text('id DESC')], unique=True, - postgresql_using='btree') + op.create_index( + "ix_certificates_id_desc", + "certificates", + [sa.text("id DESC")], + unique=True, + postgresql_using="btree", + ) def downgrade(): - op.drop_index('ix_certificates_id_desc', table_name='certificates') + op.drop_index("ix_certificates_id_desc", table_name="certificates") diff --git a/lemur/models.py b/lemur/models.py index 69f82360..163d156f 100644 --- a/lemur/models.py +++ b/lemur/models.py @@ -12,121 +12,201 @@ from sqlalchemy import Column, Integer, ForeignKey, Index, UniqueConstraint from lemur.database import db -certificate_associations = db.Table('certificate_associations', - Column('domain_id', Integer, ForeignKey('domains.id')), - Column('certificate_id', Integer, ForeignKey('certificates.id')) - ) +certificate_associations = db.Table( + "certificate_associations", + Column("domain_id", Integer, ForeignKey("domains.id")), + Column("certificate_id", Integer, ForeignKey("certificates.id")), +) -Index('certificate_associations_ix', certificate_associations.c.domain_id, certificate_associations.c.certificate_id) +Index( + "certificate_associations_ix", + certificate_associations.c.domain_id, + certificate_associations.c.certificate_id, +) -certificate_destination_associations = db.Table('certificate_destination_associations', - Column('destination_id', Integer, - ForeignKey('destinations.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')) - ) +certificate_destination_associations = db.Table( + "certificate_destination_associations", + Column( + "destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade") + ), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), +) -Index('certificate_destination_associations_ix', certificate_destination_associations.c.destination_id, certificate_destination_associations.c.certificate_id) +Index( + "certificate_destination_associations_ix", + certificate_destination_associations.c.destination_id, + certificate_destination_associations.c.certificate_id, +) -certificate_source_associations = db.Table('certificate_source_associations', - Column('source_id', Integer, - ForeignKey('sources.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')) - ) +certificate_source_associations = db.Table( + "certificate_source_associations", + Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), +) -Index('certificate_source_associations_ix', certificate_source_associations.c.source_id, certificate_source_associations.c.certificate_id) +Index( + "certificate_source_associations_ix", + certificate_source_associations.c.source_id, + certificate_source_associations.c.certificate_id, +) -certificate_notification_associations = db.Table('certificate_notification_associations', - Column('notification_id', Integer, - ForeignKey('notifications.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')), - Column('id', Integer, primary_key=True, autoincrement=True), - UniqueConstraint('notification_id', 'certificate_id', name='uq_dest_not_ids') - ) +certificate_notification_associations = db.Table( + "certificate_notification_associations", + Column( + "notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade") + ), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), + Column("id", Integer, primary_key=True, autoincrement=True), + UniqueConstraint("notification_id", "certificate_id", name="uq_dest_not_ids"), +) -Index('certificate_notification_associations_ix', certificate_notification_associations.c.notification_id, certificate_notification_associations.c.certificate_id) +Index( + "certificate_notification_associations_ix", + certificate_notification_associations.c.notification_id, + certificate_notification_associations.c.certificate_id, +) -certificate_replacement_associations = db.Table('certificate_replacement_associations', - Column('replaced_certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')), - Column('certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')) - ) +certificate_replacement_associations = db.Table( + "certificate_replacement_associations", + Column( + "replaced_certificate_id", + Integer, + ForeignKey("certificates.id", ondelete="cascade"), + ), + Column( + "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade") + ), +) -Index('certificate_replacement_associations_ix', certificate_replacement_associations.c.replaced_certificate_id, certificate_replacement_associations.c.certificate_id, unique=True) +Index( + "certificate_replacement_associations_ix", + certificate_replacement_associations.c.replaced_certificate_id, + certificate_replacement_associations.c.certificate_id, + unique=True, +) -roles_authorities = db.Table('roles_authorities', - Column('authority_id', Integer, ForeignKey('authorities.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +roles_authorities = db.Table( + "roles_authorities", + Column("authority_id", Integer, ForeignKey("authorities.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('roles_authorities_ix', roles_authorities.c.authority_id, roles_authorities.c.role_id) +Index( + "roles_authorities_ix", + roles_authorities.c.authority_id, + roles_authorities.c.role_id, +) -roles_certificates = db.Table('roles_certificates', - Column('certificate_id', Integer, ForeignKey('certificates.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +roles_certificates = db.Table( + "roles_certificates", + Column("certificate_id", Integer, ForeignKey("certificates.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('roles_certificates_ix', roles_certificates.c.certificate_id, roles_certificates.c.role_id) +Index( + "roles_certificates_ix", + roles_certificates.c.certificate_id, + roles_certificates.c.role_id, +) -roles_users = db.Table('roles_users', - Column('user_id', Integer, ForeignKey('users.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +roles_users = db.Table( + "roles_users", + Column("user_id", Integer, ForeignKey("users.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('roles_users_ix', roles_users.c.user_id, roles_users.c.role_id) +Index("roles_users_ix", roles_users.c.user_id, roles_users.c.role_id) -policies_ciphers = db.Table('policies_ciphers', - Column('cipher_id', Integer, ForeignKey('ciphers.id')), - Column('policy_id', Integer, ForeignKey('policy.id'))) +policies_ciphers = db.Table( + "policies_ciphers", + Column("cipher_id", Integer, ForeignKey("ciphers.id")), + Column("policy_id", Integer, ForeignKey("policy.id")), +) -Index('policies_ciphers_ix', policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id) +Index("policies_ciphers_ix", policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id) -pending_cert_destination_associations = db.Table('pending_cert_destination_associations', - Column('destination_id', Integer, - ForeignKey('destinations.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_destination_associations = db.Table( + "pending_cert_destination_associations", + Column( + "destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade") + ), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_destination_associations_ix', pending_cert_destination_associations.c.destination_id, pending_cert_destination_associations.c.pending_cert_id) +Index( + "pending_cert_destination_associations_ix", + pending_cert_destination_associations.c.destination_id, + pending_cert_destination_associations.c.pending_cert_id, +) -pending_cert_notification_associations = db.Table('pending_cert_notification_associations', - Column('notification_id', Integer, - ForeignKey('notifications.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_notification_associations = db.Table( + "pending_cert_notification_associations", + Column( + "notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade") + ), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_notification_associations_ix', pending_cert_notification_associations.c.notification_id, pending_cert_notification_associations.c.pending_cert_id) +Index( + "pending_cert_notification_associations_ix", + pending_cert_notification_associations.c.notification_id, + pending_cert_notification_associations.c.pending_cert_id, +) -pending_cert_source_associations = db.Table('pending_cert_source_associations', - Column('source_id', Integer, - ForeignKey('sources.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_source_associations = db.Table( + "pending_cert_source_associations", + Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_source_associations_ix', pending_cert_source_associations.c.source_id, pending_cert_source_associations.c.pending_cert_id) +Index( + "pending_cert_source_associations_ix", + pending_cert_source_associations.c.source_id, + pending_cert_source_associations.c.pending_cert_id, +) -pending_cert_replacement_associations = db.Table('pending_cert_replacement_associations', - Column('replaced_certificate_id', Integer, - ForeignKey('certificates.id', ondelete='cascade')), - Column('pending_cert_id', Integer, - ForeignKey('pending_certs.id', ondelete='cascade')) - ) +pending_cert_replacement_associations = db.Table( + "pending_cert_replacement_associations", + Column( + "replaced_certificate_id", + Integer, + ForeignKey("certificates.id", ondelete="cascade"), + ), + Column( + "pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade") + ), +) -Index('pending_cert_replacement_associations_ix', pending_cert_replacement_associations.c.replaced_certificate_id, pending_cert_replacement_associations.c.pending_cert_id) +Index( + "pending_cert_replacement_associations_ix", + pending_cert_replacement_associations.c.replaced_certificate_id, + pending_cert_replacement_associations.c.pending_cert_id, +) -pending_cert_role_associations = db.Table('pending_cert_role_associations', - Column('pending_cert_id', Integer, ForeignKey('pending_certs.id')), - Column('role_id', Integer, ForeignKey('roles.id')) - ) +pending_cert_role_associations = db.Table( + "pending_cert_role_associations", + Column("pending_cert_id", Integer, ForeignKey("pending_certs.id")), + Column("role_id", Integer, ForeignKey("roles.id")), +) -Index('pending_cert_role_associations_ix', pending_cert_role_associations.c.pending_cert_id, pending_cert_role_associations.c.role_id) +Index( + "pending_cert_role_associations_ix", + pending_cert_role_associations.c.pending_cert_id, + pending_cert_role_associations.c.role_id, +) diff --git a/lemur/notifications/cli.py b/lemur/notifications/cli.py index e3bf431e..a2848117 100644 --- a/lemur/notifications/cli.py +++ b/lemur/notifications/cli.py @@ -14,7 +14,14 @@ from lemur.notifications.messaging import send_expiration_notifications manager = Manager(usage="Handles notification related tasks.") -@manager.option('-e', '--exclude', dest='exclude', action='append', default=[], help='Common name matching of certificates that should be excluded from notification') +@manager.option( + "-e", + "--exclude", + dest="exclude", + action="append", + default=[], + help="Common name matching of certificates that should be excluded from notification", +) def expirations(exclude): """ Runs Lemur's notification engine, that looks for expired certificates and sends @@ -33,12 +40,13 @@ def expirations(exclude): success, failed = send_expiration_notifications(exclude) print( "Finished notifying subscribers about expiring certificates! Sent: {success} Failed: {failed}".format( - success=success, - failed=failed + success=success, failed=failed ) ) status = SUCCESS_METRIC_STATUS except Exception as e: sentry.captureException() - metrics.send('expiration_notification_job', 'counter', 1, metric_tags={'status': status}) + metrics.send( + "expiration_notification_job", "counter", 1, metric_tags={"status": status} + ) diff --git a/lemur/notifications/messaging.py b/lemur/notifications/messaging.py index cd88ebc8..919b73db 100644 --- a/lemur/notifications/messaging.py +++ b/lemur/notifications/messaging.py @@ -36,15 +36,17 @@ def get_certificates(exclude=None): now = arrow.utcnow() max = now + timedelta(days=90) - q = database.db.session.query(Certificate) \ - .filter(Certificate.not_after <= max) \ - .filter(Certificate.notify == True) \ - .filter(Certificate.expired == False) # noqa + q = ( + database.db.session.query(Certificate) + .filter(Certificate.not_after <= max) + .filter(Certificate.notify == True) + .filter(Certificate.expired == False) + ) # noqa exclude_conditions = [] if exclude: for e in exclude: - exclude_conditions.append(~Certificate.name.ilike('%{}%'.format(e))) + exclude_conditions.append(~Certificate.name.ilike("%{}%".format(e))) q = q.filter(and_(*exclude_conditions)) @@ -101,7 +103,12 @@ def send_notification(event_type, data, targets, notification): except Exception as e: sentry.captureException() - metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': event_type}) + metrics.send( + "notification", + "counter", + 1, + metric_tags={"status": status, "event_type": event_type}, + ) if status == SUCCESS_METRIC_STATUS: return True @@ -115,7 +122,7 @@ def send_expiration_notifications(exclude): success = failure = 0 # security team gets all - security_email = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') + security_email = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL") security_data = [] for owner, notification_group in get_eligible_certificates(exclude=exclude).items(): @@ -127,26 +134,43 @@ def send_expiration_notifications(exclude): for data in certificates: n, certificate = data - cert_data = certificate_notification_output_schema.dump(certificate).data + cert_data = certificate_notification_output_schema.dump( + certificate + ).data notification_data.append(cert_data) security_data.append(cert_data) - notification_recipient = get_plugin_option('recipients', notification.options) + notification_recipient = get_plugin_option( + "recipients", notification.options + ) if notification_recipient: notification_recipient = notification_recipient.split(",") - if send_notification('expiration', notification_data, [owner], notification): + if send_notification( + "expiration", notification_data, [owner], notification + ): success += 1 else: failure += 1 - if notification_recipient and owner != notification_recipient and security_email != notification_recipient: - if send_notification('expiration', notification_data, notification_recipient, notification): + if ( + notification_recipient + and owner != notification_recipient + and security_email != notification_recipient + ): + if send_notification( + "expiration", + notification_data, + notification_recipient, + notification, + ): success += 1 else: failure += 1 - if send_notification('expiration', security_data, security_email, notification): + if send_notification( + "expiration", security_data, security_email, notification + ): success += 1 else: failure += 1 @@ -165,24 +189,35 @@ def send_rotation_notification(certificate, notification_plugin=None): """ status = FAILURE_METRIC_STATUS if not notification_plugin: - notification_plugin = plugins.get(current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN')) + notification_plugin = plugins.get( + current_app.config.get("LEMUR_DEFAULT_NOTIFICATION_PLUGIN") + ) data = certificate_notification_output_schema.dump(certificate).data try: - notification_plugin.send('rotation', data, [data['owner']]) + notification_plugin.send("rotation", data, [data["owner"]]) status = SUCCESS_METRIC_STATUS except Exception as e: - current_app.logger.error('Unable to send notification to {}.'.format(data['owner']), exc_info=True) + current_app.logger.error( + "Unable to send notification to {}.".format(data["owner"]), exc_info=True + ) sentry.captureException() - metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'}) + metrics.send( + "notification", + "counter", + 1, + metric_tags={"status": status, "event_type": "rotation"}, + ) if status == SUCCESS_METRIC_STATUS: return True -def send_pending_failure_notification(pending_cert, notify_owner=True, notify_security=True, notification_plugin=None): +def send_pending_failure_notification( + pending_cert, notify_owner=True, notify_security=True, notification_plugin=None +): """ Sends a report to certificate owners when their pending certificate failed to be created. @@ -194,32 +229,47 @@ def send_pending_failure_notification(pending_cert, notify_owner=True, notify_se if not notification_plugin: notification_plugin = plugins.get( - current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN', 'email-notification') + current_app.config.get( + "LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification" + ) ) data = pending_certificate_output_schema.dump(pending_cert).data - data["security_email"] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') + data["security_email"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL") if notify_owner: try: - notification_plugin.send('failed', data, [data['owner']], pending_cert) + notification_plugin.send("failed", data, [data["owner"]], pending_cert) status = SUCCESS_METRIC_STATUS except Exception as e: - current_app.logger.error('Unable to send pending failure notification to {}.'.format(data['owner']), - exc_info=True) + current_app.logger.error( + "Unable to send pending failure notification to {}.".format( + data["owner"] + ), + exc_info=True, + ) sentry.captureException() if notify_security: try: - notification_plugin.send('failed', data, data["security_email"], pending_cert) + notification_plugin.send( + "failed", data, data["security_email"], pending_cert + ) status = SUCCESS_METRIC_STATUS except Exception as e: - current_app.logger.error('Unable to send pending failure notification to ' - '{}.'.format(data['security_email']), - exc_info=True) + current_app.logger.error( + "Unable to send pending failure notification to " + "{}.".format(data["security_email"]), + exc_info=True, + ) sentry.captureException() - metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'}) + metrics.send( + "notification", + "counter", + 1, + metric_tags={"status": status, "event_type": "rotation"}, + ) if status == SUCCESS_METRIC_STATUS: return True @@ -242,20 +292,22 @@ def needs_notification(certificate): if not notification.active or not notification.options: return - interval = get_plugin_option('interval', notification.options) - unit = get_plugin_option('unit', notification.options) + interval = get_plugin_option("interval", notification.options) + unit = get_plugin_option("unit", notification.options) - if unit == 'weeks': + if unit == "weeks": interval *= 7 - elif unit == 'months': + elif unit == "months": interval *= 30 - elif unit == 'days': # it's nice to be explicit about the base unit + elif unit == "days": # it's nice to be explicit about the base unit pass else: - raise Exception("Invalid base unit for expiration interval: {0}".format(unit)) + raise Exception( + "Invalid base unit for expiration interval: {0}".format(unit) + ) if days == interval: notifications.append(notification) diff --git a/lemur/notifications/models.py b/lemur/notifications/models.py index 87646b4c..7053b8d7 100644 --- a/lemur/notifications/models.py +++ b/lemur/notifications/models.py @@ -11,12 +11,14 @@ from sqlalchemy_utils import JSONType from lemur.database import db from lemur.plugins.base import plugins -from lemur.models import certificate_notification_associations, \ - pending_cert_notification_associations +from lemur.models import ( + certificate_notification_associations, + pending_cert_notification_associations, +) class Notification(db.Model): - __tablename__ = 'notifications' + __tablename__ = "notifications" id = Column(Integer, primary_key=True) label = Column(String(128), unique=True) description = Column(Text()) @@ -28,14 +30,14 @@ class Notification(db.Model): secondary=certificate_notification_associations, passive_deletes=True, backref="notification", - cascade='all,delete' + cascade="all,delete", ) pending_certificates = relationship( "PendingCertificate", secondary=pending_cert_notification_associations, passive_deletes=True, backref="notification", - cascade='all,delete' + cascade="all,delete", ) @property diff --git a/lemur/notifications/schemas.py b/lemur/notifications/schemas.py index b5d4e1e6..a3ff4c99 100644 --- a/lemur/notifications/schemas.py +++ b/lemur/notifications/schemas.py @@ -7,7 +7,11 @@ """ from marshmallow import fields, post_dump from lemur.common.schema import LemurInputSchema, LemurOutputSchema -from lemur.schemas import PluginInputSchema, PluginOutputSchema, AssociatedCertificateSchema +from lemur.schemas import ( + PluginInputSchema, + PluginOutputSchema, + AssociatedCertificateSchema, +) class NotificationInputSchema(LemurInputSchema): @@ -30,7 +34,7 @@ class NotificationOutputSchema(LemurOutputSchema): @post_dump def fill_object(self, data): if data: - data['plugin']['pluginOptions'] = data['options'] + data["plugin"]["pluginOptions"] = data["options"] return data diff --git a/lemur/notifications/service.py b/lemur/notifications/service.py index 957757bd..ac624d1c 100644 --- a/lemur/notifications/service.py +++ b/lemur/notifications/service.py @@ -31,26 +31,28 @@ def create_default_expiration_notifications(name, recipients, intervals=None): options = [ { - 'name': 'unit', - 'type': 'select', - 'required': True, - 'validation': '', - 'available': ['days', 'weeks', 'months'], - 'helpMessage': 'Interval unit', - 'value': 'days', + "name": "unit", + "type": "select", + "required": True, + "validation": "", + "available": ["days", "weeks", "months"], + "helpMessage": "Interval unit", + "value": "days", }, { - 'name': 'recipients', - 'type': 'str', - 'required': True, - 'validation': '^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$', - 'helpMessage': 'Comma delimited list of email addresses', - 'value': ','.join(recipients) + "name": "recipients", + "type": "str", + "required": True, + "validation": "^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$", + "helpMessage": "Comma delimited list of email addresses", + "value": ",".join(recipients), }, ] if intervals is None: - intervals = current_app.config.get("LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", [30, 15, 2]) + intervals = current_app.config.get( + "LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", [30, 15, 2] + ) notifications = [] for i in intervals: @@ -58,21 +60,25 @@ def create_default_expiration_notifications(name, recipients, intervals=None): if not n: inter = [ { - 'name': 'interval', - 'type': 'int', - 'required': True, - 'validation': '^\d+$', - 'helpMessage': 'Number of days to be alert before expiration.', - 'value': i, + "name": "interval", + "type": "int", + "required": True, + "validation": "^\d+$", + "helpMessage": "Number of days to be alert before expiration.", + "value": i, } ] inter.extend(options) n = create( label="{name}_{interval}_DAY".format(name=name, interval=i), - plugin_name=current_app.config.get("LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification"), + plugin_name=current_app.config.get( + "LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification" + ), options=list(inter), - description="Default {interval} day expiration notification".format(interval=i), - certificates=[] + description="Default {interval} day expiration notification".format( + interval=i + ), + certificates=[], ) notifications.append(n) @@ -91,7 +97,9 @@ def create(label, plugin_name, options, description, certificates): :rtype : Notification :return: """ - notification = Notification(label=label, options=options, plugin_name=plugin_name, description=description) + notification = Notification( + label=label, options=options, plugin_name=plugin_name, description=description + ) notification.certificates = certificates return database.create(notification) @@ -147,7 +155,7 @@ def get_by_label(label): :param label: :return: """ - return database.get(Notification, label, field='label') + return database.get(Notification, label, field="label") def get_all(): @@ -161,18 +169,20 @@ def get_all(): def render(args): - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if certificate_id: - query = database.session_query(Notification).join(Certificate, Notification.certificate) + query = database.session_query(Notification).join( + Certificate, Notification.certificate + ) query = query.filter(Certificate.id == certificate_id) else: query = database.session_query(Notification) if filt: - terms = filt.split(';') - if terms[0] == 'active': + terms = filt.split(";") + if terms[0] == "active": query = query.filter(Notification.active == truthiness(terms[1])) else: query = database.filter(query, Notification, terms) diff --git a/lemur/notifications/views.py b/lemur/notifications/views.py index 4a2d82a8..cdabb4d4 100644 --- a/lemur/notifications/views.py +++ b/lemur/notifications/views.py @@ -9,7 +9,11 @@ from flask import Blueprint from flask_restful import Api, reqparse, inputs from lemur.notifications import service -from lemur.notifications.schemas import notification_input_schema, notification_output_schema, notifications_output_schema +from lemur.notifications.schemas import ( + notification_input_schema, + notification_output_schema, + notifications_output_schema, +) from lemur.auth.service import AuthenticatedResource from lemur.common.utils import paginated_parser @@ -17,12 +21,13 @@ from lemur.common.utils import paginated_parser from lemur.common.schema import validate_schema -mod = Blueprint('notifications', __name__) +mod = Blueprint("notifications", __name__) api = Api(mod) class NotificationsList(AuthenticatedResource): """ Defines the 'notifications' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(NotificationsList, self).__init__() @@ -103,7 +108,7 @@ class NotificationsList(AuthenticatedResource): :statuscode 200: no error """ parser = paginated_parser.copy() - parser.add_argument('active', type=inputs.boolean, location='args') + parser.add_argument("active", type=inputs.boolean, location="args") args = parser.parse_args() return service.render(args) @@ -215,11 +220,11 @@ class NotificationsList(AuthenticatedResource): :statuscode 200: no error """ return service.create( - data['label'], - data['plugin']['slug'], - data['plugin']['plugin_options'], - data['description'], - data['certificates'] + data["label"], + data["plugin"]["slug"], + data["plugin"]["plugin_options"], + data["description"], + data["certificates"], ) @@ -334,20 +339,21 @@ class Notifications(AuthenticatedResource): """ return service.update( notification_id, - data['label'], - data['plugin']['plugin_options'], - data['description'], - data['active'], - data['certificates'] + data["label"], + data["plugin"]["plugin_options"], + data["description"], + data["active"], + data["certificates"], ) def delete(self, notification_id): service.delete(notification_id) - return {'result': True} + return {"result": True} class CertificateNotifications(AuthenticatedResource): """ Defines the 'certificate/', endpoint='notification') -api.add_resource(CertificateNotifications, '/certificates//notifications', - endpoint='certificateNotifications') +api.add_resource(NotificationsList, "/notifications", endpoint="notifications") +api.add_resource( + Notifications, "/notifications/", endpoint="notification" +) +api.add_resource( + CertificateNotifications, + "/certificates//notifications", + endpoint="certificateNotifications", +) diff --git a/lemur/pending_certificates/cli.py b/lemur/pending_certificates/cli.py index 65e2e19a..2ff29f10 100644 --- a/lemur/pending_certificates/cli.py +++ b/lemur/pending_certificates/cli.py @@ -19,7 +19,9 @@ from lemur.plugins.base import plugins manager = Manager(usage="Handles pending certificate related tasks.") -@manager.option('-i', dest='ids', action='append', help='IDs of pending certificates to fetch') +@manager.option( + "-i", dest="ids", action="append", help="IDs of pending certificates to fetch" +) def fetch(ids): """ Attempt to get full certificate for each pending certificate listed. @@ -39,25 +41,18 @@ def fetch(ids): if real_cert: # If a real certificate was returned from issuer, then create it in Lemur and mark # the pending certificate as resolved - final_cert = pending_certificate_service.create_certificate(cert, real_cert, cert.user) - pending_certificate_service.update( - cert.id, - resolved_cert_id=final_cert.id - ) - pending_certificate_service.update( - cert.id, - resolved=True + final_cert = pending_certificate_service.create_certificate( + cert, real_cert, cert.user ) + pending_certificate_service.update(cert.id, resolved_cert_id=final_cert.id) + pending_certificate_service.update(cert.id, resolved=True) # add metrics to metrics extension new += 1 else: pending_certificate_service.increment_attempt(cert) failed += 1 print( - "[+] Certificates: New: {new} Failed: {failed}".format( - new=new, - failed=failed, - ) + "[+] Certificates: New: {new} Failed: {failed}".format(new=new, failed=failed) ) @@ -69,9 +64,7 @@ def fetch_all_acme(): certificates. """ - log_data = { - "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name) - } + log_data = {"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)} pending_certs = pending_certificate_service.get_unresolved_pending_certs() new = 0 failed = 0 @@ -81,7 +74,7 @@ def fetch_all_acme(): # We only care about certs using the acme-issuer plugin for cert in pending_certs: cert_authority = get_authority(cert.authority_id) - if cert_authority.plugin_name == 'acme-issuer': + if cert_authority.plugin_name == "acme-issuer": acme_certs.append(cert) else: wrong_issuer += 1 @@ -97,15 +90,13 @@ def fetch_all_acme(): if real_cert: # If a real certificate was returned from issuer, then create it in Lemur and mark # the pending certificate as resolved - final_cert = pending_certificate_service.create_certificate(pending_cert, real_cert, pending_cert.user) - pending_certificate_service.update( - pending_cert.id, - resolved_cert_id=final_cert.id + final_cert = pending_certificate_service.create_certificate( + pending_cert, real_cert, pending_cert.user ) pending_certificate_service.update( - pending_cert.id, - resolved=True + pending_cert.id, resolved_cert_id=final_cert.id ) + pending_certificate_service.update(pending_cert.id, resolved=True) # add metrics to metrics extension new += 1 else: @@ -118,17 +109,15 @@ def fetch_all_acme(): if pending_cert.number_attempts > 4: error_log["message"] = "Marking pending certificate as resolved" - send_pending_failure_notification(pending_cert, notify_owner=pending_cert.notify) - # Mark "resolved" as True - pending_certificate_service.update( - cert.id, - resolved=True + send_pending_failure_notification( + pending_cert, notify_owner=pending_cert.notify ) + # Mark "resolved" as True + pending_certificate_service.update(cert.id, resolved=True) else: pending_certificate_service.increment_attempt(pending_cert) pending_certificate_service.update( - cert.get("pending_cert").id, - status=str(cert.get("last_error")) + cert.get("pending_cert").id, status=str(cert.get("last_error")) ) current_app.logger.error(error_log) log_data["message"] = "Complete" @@ -138,8 +127,6 @@ def fetch_all_acme(): current_app.logger.debug(log_data) print( "[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format( - new=new, - failed=failed, - wrong_issuer=wrong_issuer + new=new, failed=failed, wrong_issuer=wrong_issuer ) ) diff --git a/lemur/pending_certificates/models.py b/lemur/pending_certificates/models.py index 7dc8e602..fa6be073 100644 --- a/lemur/pending_certificates/models.py +++ b/lemur/pending_certificates/models.py @@ -5,7 +5,16 @@ """ from datetime import datetime as dt -from sqlalchemy import Integer, ForeignKey, String, PassiveDefault, func, Column, Text, Boolean +from sqlalchemy import ( + Integer, + ForeignKey, + String, + PassiveDefault, + func, + Column, + Text, + Boolean, +) from sqlalchemy.orm import relationship from sqlalchemy_utils import JSONType from sqlalchemy_utils.types.arrow import ArrowType @@ -13,20 +22,28 @@ from sqlalchemy_utils.types.arrow import ArrowType from lemur.certificates.models import get_sequence from lemur.common import defaults, utils from lemur.database import db -from lemur.models import pending_cert_source_associations, \ - pending_cert_destination_associations, pending_cert_notification_associations, \ - pending_cert_replacement_associations, pending_cert_role_associations +from lemur.models import ( + pending_cert_source_associations, + pending_cert_destination_associations, + pending_cert_notification_associations, + pending_cert_replacement_associations, + pending_cert_role_associations, +) from lemur.utils import Vault def get_or_increase_name(name, serial): - certificates = PendingCertificate.query.filter(PendingCertificate.name.ilike('{0}%'.format(name))).all() + certificates = PendingCertificate.query.filter( + PendingCertificate.name.ilike("{0}%".format(name)) + ).all() if not certificates: return name - serial_name = '{0}-{1}'.format(name, hex(int(serial))[2:].upper()) - certificates = PendingCertificate.query.filter(PendingCertificate.name.ilike('{0}%'.format(serial_name))).all() + serial_name = "{0}-{1}".format(name, hex(int(serial))[2:].upper()) + certificates = PendingCertificate.query.filter( + PendingCertificate.name.ilike("{0}%".format(serial_name)) + ).all() if not certificates: return serial_name @@ -38,11 +55,11 @@ def get_or_increase_name(name, serial): if end: ends.append(end) - return '{0}-{1}'.format(root, max(ends) + 1) + return "{0}-{1}".format(root, max(ends) + 1) class PendingCertificate(db.Model): - __tablename__ = 'pending_certs' + __tablename__ = "pending_certs" id = Column(Integer, primary_key=True) external_id = Column(String(128)) owner = Column(String(128), nullable=False) @@ -60,69 +77,101 @@ class PendingCertificate(db.Model): private_key = Column(Vault, nullable=True) date_created = Column(ArrowType, PassiveDefault(func.now()), nullable=False) - dns_provider_id = Column(Integer, ForeignKey('dns_providers.id', ondelete="CASCADE")) + dns_provider_id = Column( + Integer, ForeignKey("dns_providers.id", ondelete="CASCADE") + ) status = Column(Text(), nullable=True) - last_updated = Column(ArrowType, PassiveDefault(func.now()), onupdate=func.now(), nullable=False) + last_updated = Column( + ArrowType, PassiveDefault(func.now()), onupdate=func.now(), nullable=False + ) rotation = Column(Boolean, default=False) - user_id = Column(Integer, ForeignKey('users.id')) - authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - root_authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) - rotation_policy_id = Column(Integer, ForeignKey('rotation_policies.id')) + user_id = Column(Integer, ForeignKey("users.id")) + authority_id = Column(Integer, ForeignKey("authorities.id", ondelete="CASCADE")) + root_authority_id = Column( + Integer, ForeignKey("authorities.id", ondelete="CASCADE") + ) + rotation_policy_id = Column(Integer, ForeignKey("rotation_policies.id")) - notifications = relationship('Notification', secondary=pending_cert_notification_associations, - backref='pending_cert', passive_deletes=True) - destinations = relationship('Destination', secondary=pending_cert_destination_associations, backref='pending_cert', - passive_deletes=True) - sources = relationship('Source', secondary=pending_cert_source_associations, backref='pending_cert', - passive_deletes=True) - roles = relationship('Role', secondary=pending_cert_role_associations, backref='pending_cert', passive_deletes=True) - replaces = relationship('Certificate', - secondary=pending_cert_replacement_associations, - backref='pending_cert', - passive_deletes=True) + notifications = relationship( + "Notification", + secondary=pending_cert_notification_associations, + backref="pending_cert", + passive_deletes=True, + ) + destinations = relationship( + "Destination", + secondary=pending_cert_destination_associations, + backref="pending_cert", + passive_deletes=True, + ) + sources = relationship( + "Source", + secondary=pending_cert_source_associations, + backref="pending_cert", + passive_deletes=True, + ) + roles = relationship( + "Role", + secondary=pending_cert_role_associations, + backref="pending_cert", + passive_deletes=True, + ) + replaces = relationship( + "Certificate", + secondary=pending_cert_replacement_associations, + backref="pending_cert", + passive_deletes=True, + ) options = Column(JSONType) rotation_policy = relationship("RotationPolicy") - sensitive_fields = ('private_key',) + sensitive_fields = ("private_key",) def __init__(self, **kwargs): - self.csr = kwargs.get('csr') - self.private_key = kwargs.get('private_key', "") + self.csr = kwargs.get("csr") + self.private_key = kwargs.get("private_key", "") if self.private_key: # If the request does not send private key, the key exists but the value is None self.private_key = self.private_key.strip() - self.external_id = kwargs.get('external_id') + self.external_id = kwargs.get("external_id") # when destinations are appended they require a valid name. - if kwargs.get('name'): - self.name = get_or_increase_name(defaults.text_to_slug(kwargs['name']), 0) + if kwargs.get("name"): + self.name = get_or_increase_name(defaults.text_to_slug(kwargs["name"]), 0) self.rename = False else: # TODO: Fix auto-generated name, it should be renamed on creation self.name = get_or_increase_name( - defaults.certificate_name(kwargs['common_name'], kwargs['authority'].name, - dt.now(), dt.now(), False), self.external_id) + defaults.certificate_name( + kwargs["common_name"], + kwargs["authority"].name, + dt.now(), + dt.now(), + False, + ), + self.external_id, + ) self.rename = True self.cn = defaults.common_name(utils.parse_csr(self.csr)) - self.owner = kwargs['owner'] + self.owner = kwargs["owner"] self.number_attempts = 0 - if kwargs.get('chain'): - self.chain = kwargs['chain'].strip() + if kwargs.get("chain"): + self.chain = kwargs["chain"].strip() - self.notify = kwargs.get('notify', True) - self.destinations = kwargs.get('destinations', []) - self.notifications = kwargs.get('notifications', []) - self.description = kwargs.get('description') - self.roles = list(set(kwargs.get('roles', []))) - self.replaces = kwargs.get('replaces', []) - self.rotation = kwargs.get('rotation') - self.rotation_policy = kwargs.get('rotation_policy') + self.notify = kwargs.get("notify", True) + self.destinations = kwargs.get("destinations", []) + self.notifications = kwargs.get("notifications", []) + self.description = kwargs.get("description") + self.roles = list(set(kwargs.get("roles", []))) + self.replaces = kwargs.get("replaces", []) + self.rotation = kwargs.get("rotation") + self.rotation_policy = kwargs.get("rotation_policy") try: - self.dns_provider_id = kwargs.get('dns_provider').id + self.dns_provider_id = kwargs.get("dns_provider").id except (AttributeError, KeyError, TypeError, Exception): pass diff --git a/lemur/pending_certificates/schemas.py b/lemur/pending_certificates/schemas.py index 3dd70b16..68f22b4a 100644 --- a/lemur/pending_certificates/schemas.py +++ b/lemur/pending_certificates/schemas.py @@ -17,14 +17,14 @@ from lemur.schemas import ( AssociatedNotificationSchema, AssociatedRoleSchema, EndpointNestedOutputSchema, - ExtensionSchema + ExtensionSchema, ) from lemur.users.schemas import UserNestedOutputSchema class PendingCertificateSchema(LemurInputSchema): owner = fields.Email(required=True) - description = fields.String(missing='', allow_none=True) + description = fields.String(missing="", allow_none=True) class PendingCertificateOutputSchema(LemurOutputSchema): @@ -46,10 +46,10 @@ class PendingCertificateOutputSchema(LemurOutputSchema): # Note aliasing is the first step in deprecating these fields. notify = fields.Boolean() - active = fields.Boolean(attribute='notify') + active = fields.Boolean(attribute="notify") cn = fields.String() - common_name = fields.String(attribute='cn') + common_name = fields.String(attribute="cn") owner = fields.Email() @@ -66,7 +66,9 @@ class PendingCertificateOutputSchema(LemurOutputSchema): authority = fields.Nested(AuthorityNestedOutputSchema) roles = fields.Nested(RoleNestedOutputSchema, many=True) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) - replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') + replaced_by = fields.Nested( + CertificateNestedOutputSchema, many=True, attribute="replaced" + ) rotation_policy = fields.Nested(RotationPolicyNestedOutputSchema) @@ -89,10 +91,15 @@ class PendingCertificateEditInputSchema(PendingCertificateSchema): :param data: :return: """ - if data['owner']: - notification_name = "DEFAULT_{0}".format(data['owner'].split('@')[0].upper()) - data['notifications'] += notification_service.create_default_expiration_notifications(notification_name, - [data['owner']]) + if data["owner"]: + notification_name = "DEFAULT_{0}".format( + data["owner"].split("@")[0].upper() + ) + data[ + "notifications" + ] += notification_service.create_default_expiration_notifications( + notification_name, [data["owner"]] + ) return data @@ -108,17 +115,21 @@ class PendingCertificateUploadInputSchema(LemurInputSchema): @validates_schema def validate_cert_chain(self, data): cert = None - if data.get('body'): + if data.get("body"): try: - cert = utils.parse_certificate(data['body']) + cert = utils.parse_certificate(data["body"]) except ValueError: - raise ValidationError("Public certificate presented is not valid.", field_names=['body']) + raise ValidationError( + "Public certificate presented is not valid.", field_names=["body"] + ) - if data.get('chain'): + if data.get("chain"): try: - chain = utils.parse_cert_chain(data['chain']) + chain = utils.parse_cert_chain(data["chain"]) except ValueError: - raise ValidationError("Invalid certificate in certificate chain.", field_names=['chain']) + raise ValidationError( + "Invalid certificate in certificate chain.", field_names=["chain"] + ) # Throws ValidationError validators.verify_cert_chain([cert] + chain) diff --git a/lemur/pending_certificates/service.py b/lemur/pending_certificates/service.py index 287bd42b..935ea689 100644 --- a/lemur/pending_certificates/service.py +++ b/lemur/pending_certificates/service.py @@ -40,17 +40,18 @@ def get_by_external_id(issuer, external_id): """ if isinstance(external_id, int): external_id = str(external_id) - return PendingCertificate.query \ - .filter(PendingCertificate.authority_id == issuer.id) \ - .filter(PendingCertificate.external_id == external_id) \ + return ( + PendingCertificate.query.filter(PendingCertificate.authority_id == issuer.id) + .filter(PendingCertificate.external_id == external_id) .one_or_none() + ) def get_by_name(pending_cert_name): """ Retrieve pending certificate by name """ - return database.get(PendingCertificate, pending_cert_name, field='name') + return database.get(PendingCertificate, pending_cert_name, field="name") def delete(pending_certificate): @@ -66,7 +67,9 @@ def get_unresolved_pending_certs(): Retrieve a list of unresolved pending certs given a list of ids Filters out non-existing pending certs """ - query = database.session_query(PendingCertificate).filter(PendingCertificate.resolved.is_(False)) + query = database.session_query(PendingCertificate).filter( + PendingCertificate.resolved.is_(False) + ) return database.find_all(query, PendingCertificate, {}).all() @@ -76,7 +79,7 @@ def get_pending_certs(pending_ids): Filters out non-existing pending certs """ pending_certs = [] - if 'all' in pending_ids: + if "all" in pending_ids: query = database.session_query(PendingCertificate) return database.find_all(query, PendingCertificate, {}).all() else: @@ -96,23 +99,25 @@ def create_certificate(pending_certificate, certificate, user): user: User that called this function, used as 'creator' of the certificate if it does not have an owner """ - certificate['owner'] = pending_certificate.owner + certificate["owner"] = pending_certificate.owner data, errors = CertificateUploadInputSchema().load(certificate) if errors: - raise Exception("Unable to create certificate: {reasons}".format(reasons=errors)) + raise Exception( + "Unable to create certificate: {reasons}".format(reasons=errors) + ) data.update(vars(pending_certificate)) # Copy relationships, vars doesn't copy this without explicit fields - data['notifications'] = list(pending_certificate.notifications) - data['destinations'] = list(pending_certificate.destinations) - data['sources'] = list(pending_certificate.sources) - data['roles'] = list(pending_certificate.roles) - data['replaces'] = list(pending_certificate.replaces) - data['rotation_policy'] = pending_certificate.rotation_policy + data["notifications"] = list(pending_certificate.notifications) + data["destinations"] = list(pending_certificate.destinations) + data["sources"] = list(pending_certificate.sources) + data["roles"] = list(pending_certificate.roles) + data["replaces"] = list(pending_certificate.replaces) + data["rotation_policy"] = pending_certificate.rotation_policy # Replace external id and chain with the one fetched from source - data['external_id'] = certificate['external_id'] - data['chain'] = certificate['chain'] + data["external_id"] = certificate["external_id"] + data["chain"] = certificate["chain"] creator = user_service.get_by_email(pending_certificate.owner) if not creator: # Owner of the pending certificate is not the creator, so use the current user who called @@ -121,8 +126,8 @@ def create_certificate(pending_certificate, certificate, user): if pending_certificate.rename: # If generating name from certificate, remove the one from pending certificate - del data['name'] - data['creator'] = creator + del data["name"] + data["creator"] = creator cert = certificate_service.import_certificate(**data) database.update(cert) @@ -159,75 +164,91 @@ def cancel(pending_certificate, **kwargs): """ plugin = plugins.get(pending_certificate.authority.plugin_name) plugin.cancel_ordered_certificate(pending_certificate, **kwargs) - pending_certificate.status = 'Cancelled' + pending_certificate.status = "Cancelled" database.update(pending_certificate) return pending_certificate def render(args): query = database.session_query(PendingCertificate) - time_range = args.pop('time_range') - destination_id = args.pop('destination_id') - notification_id = args.pop('notification_id', None) - show = args.pop('show') + time_range = args.pop("time_range") + destination_id = args.pop("destination_id") + notification_id = args.pop("notification_id", None) + show = args.pop("show") # owner = args.pop('owner') # creator = args.pop('creator') # TODO we should enabling filtering by owner - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') + terms = filt.split(";") - if 'issuer' in terms: + if "issuer" in terms: # we can't rely on issuer being correct in the cert directly so we combine queries - sub_query = database.session_query(Authority.id) \ - .filter(Authority.name.ilike('%{0}%'.format(terms[1]))) \ + sub_query = ( + database.session_query(Authority.id) + .filter(Authority.name.ilike("%{0}%".format(terms[1]))) .subquery() + ) query = query.filter( or_( - PendingCertificate.issuer.ilike('%{0}%'.format(terms[1])), - PendingCertificate.authority_id.in_(sub_query) + PendingCertificate.issuer.ilike("%{0}%".format(terms[1])), + PendingCertificate.authority_id.in_(sub_query), ) ) - elif 'destination' in terms: - query = query.filter(PendingCertificate.destinations.any(Destination.id == terms[1])) - elif 'notify' in filt: + elif "destination" in terms: + query = query.filter( + PendingCertificate.destinations.any(Destination.id == terms[1]) + ) + elif "notify" in filt: query = query.filter(PendingCertificate.notify == truthiness(terms[1])) - elif 'active' in filt: + elif "active" in filt: query = query.filter(PendingCertificate.active == truthiness(terms[1])) - elif 'cn' in terms: + elif "cn" in terms: query = query.filter( or_( - PendingCertificate.cn.ilike('%{0}%'.format(terms[1])), - PendingCertificate.domains.any(Domain.name.ilike('%{0}%'.format(terms[1]))) + PendingCertificate.cn.ilike("%{0}%".format(terms[1])), + PendingCertificate.domains.any( + Domain.name.ilike("%{0}%".format(terms[1])) + ), ) ) - elif 'id' in terms: + elif "id" in terms: query = query.filter(PendingCertificate.id == cast(terms[1], Integer)) else: query = database.filter(query, PendingCertificate, terms) if show: - sub_query = database.session_query(Role.name).filter(Role.user_id == args['user'].id).subquery() + sub_query = ( + database.session_query(Role.name) + .filter(Role.user_id == args["user"].id) + .subquery() + ) query = query.filter( or_( - PendingCertificate.user_id == args['user'].id, - PendingCertificate.owner.in_(sub_query) + PendingCertificate.user_id == args["user"].id, + PendingCertificate.owner.in_(sub_query), ) ) if destination_id: - query = query.filter(PendingCertificate.destinations.any(Destination.id == destination_id)) + query = query.filter( + PendingCertificate.destinations.any(Destination.id == destination_id) + ) if notification_id: - query = query.filter(PendingCertificate.notifications.any(Notification.id == notification_id)) + query = query.filter( + PendingCertificate.notifications.any(Notification.id == notification_id) + ) if time_range: - to = arrow.now().replace(weeks=+time_range).format('YYYY-MM-DD') - now = arrow.now().format('YYYY-MM-DD') - query = query.filter(PendingCertificate.not_after <= to).filter(PendingCertificate.not_after >= now) + to = arrow.now().replace(weeks=+time_range).format("YYYY-MM-DD") + now = arrow.now().format("YYYY-MM-DD") + query = query.filter(PendingCertificate.not_after <= to).filter( + PendingCertificate.not_after >= now + ) # Only show unresolved certificates in the UI query = query.filter(PendingCertificate.resolved.is_(False)) @@ -242,30 +263,26 @@ def upload(pending_certificate_id, **kwargs): """ pending_cert = get(pending_certificate_id) partial_cert = kwargs - uploaded_chain = partial_cert['chain'] + uploaded_chain = partial_cert["chain"] authority = authorities_service.get(pending_cert.authority.id) # Construct the chain for cert validation if uploaded_chain: - chain = uploaded_chain + '\n' + authority.authority_certificate.body + chain = uploaded_chain + "\n" + authority.authority_certificate.body else: chain = authority.authority_certificate.body parsed_chain = parse_cert_chain(chain) # Check that the certificate is actually signed by the CA to avoid incorrect cert pasting - validators.verify_cert_chain([parse_certificate(partial_cert['body'])] + parsed_chain) + validators.verify_cert_chain( + [parse_certificate(partial_cert["body"])] + parsed_chain + ) final_cert = create_certificate(pending_cert, partial_cert, pending_cert.user) - pending_cert_final_result = update( - pending_cert.id, - resolved_cert_id=final_cert.id - ) - update( - pending_cert.id, - resolved=True - ) + pending_cert_final_result = update(pending_cert.id, resolved_cert_id=final_cert.id) + update(pending_cert.id, resolved=True) return pending_cert_final_result diff --git a/lemur/pending_certificates/views.py b/lemur/pending_certificates/views.py index 935f00c1..4651aed7 100644 --- a/lemur/pending_certificates/views.py +++ b/lemur/pending_certificates/views.py @@ -23,7 +23,7 @@ from lemur.pending_certificates.schemas import ( pending_certificate_upload_input_schema, ) -mod = Blueprint('pending_certificates', __name__) +mod = Blueprint("pending_certificates", __name__) api = Api(mod) @@ -110,15 +110,17 @@ class PendingCertificatesList(AuthenticatedResource): """ parser = paginated_parser.copy() - parser.add_argument('timeRange', type=int, dest='time_range', location='args') - parser.add_argument('owner', type=inputs.boolean, location='args') - parser.add_argument('id', type=str, location='args') - parser.add_argument('active', type=inputs.boolean, location='args') - parser.add_argument('destinationId', type=int, dest="destination_id", location='args') - parser.add_argument('creator', type=str, location='args') - parser.add_argument('show', type=str, location='args') + parser.add_argument("timeRange", type=int, dest="time_range", location="args") + parser.add_argument("owner", type=inputs.boolean, location="args") + parser.add_argument("id", type=str, location="args") + parser.add_argument("active", type=inputs.boolean, location="args") + parser.add_argument( + "destinationId", type=int, dest="destination_id", location="args" + ) + parser.add_argument("creator", type=str, location="args") + parser.add_argument("show", type=str, location="args") args = parser.parse_args() - args['user'] = g.user + args["user"] = g.user return service.render(args) @@ -206,7 +208,9 @@ class PendingCertificates(AuthenticatedResource): """ return service.get(pending_certificate_id) - @validate_schema(pending_certificate_edit_input_schema, pending_certificate_output_schema) + @validate_schema( + pending_certificate_edit_input_schema, pending_certificate_output_schema + ) def put(self, pending_certificate_id, data=None): """ .. http:put:: /pending_certificates/1 @@ -297,19 +301,27 @@ class PendingCertificates(AuthenticatedResource): # allow creators if g.current_user != pending_cert.user: owner_role = role_service.get_by_name(pending_cert.owner) - permission = CertificatePermission(owner_role, [x.name for x in pending_cert.roles]) + permission = CertificatePermission( + owner_role, [x.name for x in pending_cert.roles] + ) if not permission.can(): - return dict(message='You are not authorized to update this certificate'), 403 + return ( + dict(message="You are not authorized to update this certificate"), + 403, + ) - for destination in data['destinations']: + for destination in data["destinations"]: if destination.plugin.requires_key: if not pending_cert.private_key: - return dict( - message='Unable to add destination: {0}. Certificate does not have required private key.'.format( - destination.label - ) - ), 400 + return ( + dict( + message="Unable to add destination: {0}. Certificate does not have required private key.".format( + destination.label + ) + ), + 400, + ) pending_cert = service.update(pending_certificate_id, **data) return pending_cert @@ -354,18 +366,28 @@ class PendingCertificates(AuthenticatedResource): # allow creators if g.current_user != pending_cert.user: owner_role = role_service.get_by_name(pending_cert.owner) - permission = CertificatePermission(owner_role, [x.name for x in pending_cert.roles]) + permission = CertificatePermission( + owner_role, [x.name for x in pending_cert.roles] + ) if not permission.can(): - return dict(message='You are not authorized to update this certificate'), 403 + return ( + dict(message="You are not authorized to update this certificate"), + 403, + ) if service.cancel(pending_cert, **data): service.delete(pending_cert) - return('', 204) + return ("", 204) else: # service.cancel raises exception if there was an issue, but this will ensure something # is relayed to user in case of something unexpected (unsuccessful update somehow). - return dict(message="Unexpected error occurred while trying to cancel this certificate"), 500 + return ( + dict( + message="Unexpected error occurred while trying to cancel this certificate" + ), + 500, + ) class PendingCertificatePrivateKey(AuthenticatedResource): @@ -412,11 +434,11 @@ class PendingCertificatePrivateKey(AuthenticatedResource): permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) if not permission.can(): - return dict(message='You are not authorized to view this key'), 403 + return dict(message="You are not authorized to view this key"), 403 response = make_response(jsonify(key=cert.private_key), 200) - response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store' - response.headers['pragma'] = 'no-cache' + response.headers["cache-control"] = "private, max-age=0, no-cache, no-store" + response.headers["pragma"] = "no-cache" return response @@ -427,7 +449,9 @@ class PendingCertificatesUpload(AuthenticatedResource): self.reqparse = reqparse.RequestParser() super(PendingCertificatesUpload, self).__init__() - @validate_schema(pending_certificate_upload_input_schema, pending_certificate_output_schema) + @validate_schema( + pending_certificate_upload_input_schema, pending_certificate_output_schema + ) def post(self, pending_certificate_id, data=None): """ .. http:post:: /pending_certificates/1/upload @@ -514,7 +538,21 @@ class PendingCertificatesUpload(AuthenticatedResource): return service.upload(pending_certificate_id, **data) -api.add_resource(PendingCertificatesList, '/pending_certificates', endpoint='pending_certificates') -api.add_resource(PendingCertificates, '/pending_certificates/', endpoint='pending_certificate') -api.add_resource(PendingCertificatesUpload, '/pending_certificates//upload', endpoint='pendingCertificateUpload') -api.add_resource(PendingCertificatePrivateKey, '/pending_certificates//key', endpoint='privateKeyPendingCertificates') +api.add_resource( + PendingCertificatesList, "/pending_certificates", endpoint="pending_certificates" +) +api.add_resource( + PendingCertificates, + "/pending_certificates/", + endpoint="pending_certificate", +) +api.add_resource( + PendingCertificatesUpload, + "/pending_certificates//upload", + endpoint="pendingCertificateUpload", +) +api.add_resource( + PendingCertificatePrivateKey, + "/pending_certificates//key", + endpoint="privateKeyPendingCertificates", +) diff --git a/lemur/plugins/base/manager.py b/lemur/plugins/base/manager.py index a2306445..117700a6 100644 --- a/lemur/plugins/base/manager.py +++ b/lemur/plugins/base/manager.py @@ -18,7 +18,9 @@ class PluginManager(InstanceManager): return sum(1 for i in self.all()) def all(self, version=1, plugin_type=None): - for plugin in sorted(super(PluginManager, self).all(), key=lambda x: x.get_title()): + for plugin in sorted( + super(PluginManager, self).all(), key=lambda x: x.get_title() + ): if not plugin.type == plugin_type and plugin_type: continue if not plugin.is_enabled(): @@ -36,29 +38,34 @@ class PluginManager(InstanceManager): return plugin current_app.logger.error( "Unable to find slug: {} in self.all version 1: {} or version 2: {}".format( - slug, self.all(version=1), self.all(version=2)) + slug, self.all(version=1), self.all(version=2) + ) ) raise KeyError(slug) def first(self, func_name, *args, **kwargs): - version = kwargs.pop('version', 1) + version = kwargs.pop("version", 1) for plugin in self.all(version=version): try: result = getattr(plugin, func_name)(*args, **kwargs) except Exception as e: - current_app.logger.error('Error processing %s() on %r: %s', func_name, plugin.__class__, e, extra={ - 'func_arg': args, - 'func_kwargs': kwargs, - }, exc_info=True) + current_app.logger.error( + "Error processing %s() on %r: %s", + func_name, + plugin.__class__, + e, + extra={"func_arg": args, "func_kwargs": kwargs}, + exc_info=True, + ) continue if result is not None: return result def register(self, cls): - self.add('%s.%s' % (cls.__module__, cls.__name__)) + self.add("%s.%s" % (cls.__module__, cls.__name__)) return cls def unregister(self, cls): - self.remove('%s.%s' % (cls.__module__, cls.__name__)) + self.remove("%s.%s" % (cls.__module__, cls.__name__)) return cls diff --git a/lemur/plugins/base/v1.py b/lemur/plugins/base/v1.py index fb688c73..664385b3 100644 --- a/lemur/plugins/base/v1.py +++ b/lemur/plugins/base/v1.py @@ -18,7 +18,7 @@ class PluginMount(type): if new_cls.title is None: new_cls.title = new_cls.__name__ if not new_cls.slug: - new_cls.slug = new_cls.title.replace(' ', '-').lower() + new_cls.slug = new_cls.title.replace(" ", "-").lower() return new_cls @@ -36,6 +36,7 @@ class IPlugin(local): As a general rule all inherited methods should allow ``**kwargs`` to ensure ease of future compatibility. """ + # Generic plugin information title = None slug = None @@ -72,7 +73,7 @@ class IPlugin(local): Returns a string representing the configuration keyspace prefix for this plugin. """ if not self.conf_key: - self.conf_key = self.get_conf_title().lower().replace(' ', '_') + self.conf_key = self.get_conf_title().lower().replace(" ", "_") return self.conf_key def get_conf_title(self): @@ -111,8 +112,8 @@ class IPlugin(local): @staticmethod def get_option(name, options): for o in options: - if o.get('name') == name: - return o.get('value', o.get('default')) + if o.get("name") == name: + return o.get("value", o.get("default")) class Plugin(IPlugin): @@ -121,5 +122,6 @@ class Plugin(IPlugin): control when or how the plugin gets instantiated, nor is it guaranteed that it will happen, or happen more than once. """ + __version__ = 1 __metaclass__ = PluginMount diff --git a/lemur/plugins/bases/destination.py b/lemur/plugins/bases/destination.py index fc73ebcb..e00c5090 100644 --- a/lemur/plugins/bases/destination.py +++ b/lemur/plugins/bases/destination.py @@ -10,10 +10,10 @@ from lemur.plugins.base import Plugin, plugins class DestinationPlugin(Plugin): - type = 'destination' + type = "destination" requires_key = True sync_as_source = False - sync_as_source_name = '' + sync_as_source_name = "" def upload(self, name, body, private_key, cert_chain, options, **kwargs): raise NotImplementedError @@ -22,10 +22,10 @@ class DestinationPlugin(Plugin): class ExportDestinationPlugin(DestinationPlugin): default_options = [ { - 'name': 'exportPlugin', - 'type': 'export-plugin', - 'required': True, - 'helpMessage': 'Export plugin to use before sending data to destination.' + "name": "exportPlugin", + "type": "export-plugin", + "required": True, + "helpMessage": "Export plugin to use before sending data to destination.", } ] @@ -34,15 +34,17 @@ class ExportDestinationPlugin(DestinationPlugin): return self.default_options + self.additional_options def export(self, body, private_key, cert_chain, options): - export_plugin = self.get_option('exportPlugin', options) + export_plugin = self.get_option("exportPlugin", options) if export_plugin: - plugin = plugins.get(export_plugin['slug']) - extension, passphrase, data = plugin.export(body, cert_chain, private_key, export_plugin['plugin_options']) + plugin = plugins.get(export_plugin["slug"]) + extension, passphrase, data = plugin.export( + body, cert_chain, private_key, export_plugin["plugin_options"] + ) return [(extension, passphrase, data)] - data = body + '\n' + cert_chain + '\n' + private_key - return [('.pem', '', data)] + data = body + "\n" + cert_chain + "\n" + private_key + return [(".pem", "", data)] def upload(self, name, body, private_key, cert_chain, options, **kwargs): raise NotImplementedError diff --git a/lemur/plugins/bases/export.py b/lemur/plugins/bases/export.py index 1466c1ab..6d078906 100644 --- a/lemur/plugins/bases/export.py +++ b/lemur/plugins/bases/export.py @@ -14,7 +14,8 @@ class ExportPlugin(Plugin): This is the base class from which all supported exporters will inherit from. """ - type = 'export' + + type = "export" requires_key = True def export(self, body, chain, key, options, **kwargs): diff --git a/lemur/plugins/bases/issuer.py b/lemur/plugins/bases/issuer.py index 5eb0964c..f1e6aa0e 100644 --- a/lemur/plugins/bases/issuer.py +++ b/lemur/plugins/bases/issuer.py @@ -14,7 +14,8 @@ class IssuerPlugin(Plugin): This is the base class from which all of the supported issuers will inherit from. """ - type = 'issuer' + + type = "issuer" def create_certificate(self, csr, issuer_options): raise NotImplementedError diff --git a/lemur/plugins/bases/metric.py b/lemur/plugins/bases/metric.py index 259af235..2e4ce69b 100644 --- a/lemur/plugins/bases/metric.py +++ b/lemur/plugins/bases/metric.py @@ -10,7 +10,9 @@ from lemur.plugins.base import Plugin class MetricPlugin(Plugin): - type = 'metric' + type = "metric" - def submit(self, metric_name, metric_type, metric_value, metric_tags=None, options=None): + def submit( + self, metric_name, metric_type, metric_value, metric_tags=None, options=None + ): raise NotImplementedError diff --git a/lemur/plugins/bases/notification.py b/lemur/plugins/bases/notification.py index a7ba4e0d..730f68be 100644 --- a/lemur/plugins/bases/notification.py +++ b/lemur/plugins/bases/notification.py @@ -14,7 +14,8 @@ class NotificationPlugin(Plugin): This is the base class from which all of the supported issuers will inherit from. """ - type = 'notification' + + type = "notification" def send(self, notification_type, message, targets, options, **kwargs): raise NotImplementedError @@ -26,22 +27,23 @@ class ExpirationNotificationPlugin(NotificationPlugin): It contains some default options that are needed for all expiration notification plugins. """ + default_options = [ { - 'name': 'interval', - 'type': 'int', - 'required': True, - 'validation': '^\d+$', - 'helpMessage': 'Number of days to be alert before expiration.', + "name": "interval", + "type": "int", + "required": True, + "validation": "^\d+$", + "helpMessage": "Number of days to be alert before expiration.", }, { - 'name': 'unit', - 'type': 'select', - 'required': True, - 'validation': '', - 'available': ['days', 'weeks', 'months'], - 'helpMessage': 'Interval unit', - } + "name": "unit", + "type": "select", + "required": True, + "validation": "", + "available": ["days", "weeks", "months"], + "helpMessage": "Interval unit", + }, ] @property diff --git a/lemur/plugins/bases/source.py b/lemur/plugins/bases/source.py index ff3492fe..6f521e40 100644 --- a/lemur/plugins/bases/source.py +++ b/lemur/plugins/bases/source.py @@ -10,15 +10,15 @@ from lemur.plugins.base import Plugin class SourcePlugin(Plugin): - type = 'source' + type = "source" default_options = [ { - 'name': 'pollRate', - 'type': 'int', - 'required': False, - 'helpMessage': 'Rate in seconds to poll source for new information.', - 'default': '60', + "name": "pollRate", + "type": "int", + "required": False, + "helpMessage": "Rate in seconds to poll source for new information.", + "default": "60", } ] diff --git a/lemur/plugins/lemur_acme/__init__.py b/lemur/plugins/lemur_acme/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_acme/__init__.py +++ b/lemur/plugins/lemur_acme/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_acme/cloudflare.py b/lemur/plugins/lemur_acme/cloudflare.py index a6308025..a19495f8 100644 --- a/lemur/plugins/lemur_acme/cloudflare.py +++ b/lemur/plugins/lemur_acme/cloudflare.py @@ -5,24 +5,24 @@ from flask import current_app def cf_api_call(): - cf_key = current_app.config.get('ACME_CLOUDFLARE_KEY', '') - cf_email = current_app.config.get('ACME_CLOUDFLARE_EMAIL', '') + cf_key = current_app.config.get("ACME_CLOUDFLARE_KEY", "") + cf_email = current_app.config.get("ACME_CLOUDFLARE_EMAIL", "") return CloudFlare.CloudFlare(email=cf_email, token=cf_key) def find_zone_id(host): - elements = host.split('.') + elements = host.split(".") cf = cf_api_call() n = 1 while n < 5: n = n + 1 - domain = '.'.join(elements[-n:]) + domain = ".".join(elements[-n:]) current_app.logger.debug("Trying to get ID for zone {0}".format(domain)) try: - zone = cf.zones.get(params={'name': domain, 'per_page': 1}) + zone = cf.zones.get(params={"name": domain, "per_page": 1}) except Exception as e: current_app.logger.error("Cloudflare API error: %s" % e) pass @@ -31,10 +31,10 @@ def find_zone_id(host): break if len(zone) == 0: - current_app.logger.error('No zone found') + current_app.logger.error("No zone found") return else: - return zone[0]['id'] + return zone[0]["id"] def wait_for_dns_change(change_id, account_number=None): @@ -42,8 +42,8 @@ def wait_for_dns_change(change_id, account_number=None): zone_id, record_id = change_id while True: r = cf.zones.get(zone_id, record_id) - current_app.logger.debug("Record status: %s" % r['status']) - if r['status'] == 'active': + current_app.logger.debug("Record status: %s" % r["status"]) + if r["status"] == "active": break time.sleep(1) return @@ -55,15 +55,19 @@ def create_txt_record(host, value, account_number): if not zone_id: return - txt_record = {'name': host, 'type': 'TXT', 'content': value} + txt_record = {"name": host, "type": "TXT", "content": value} - current_app.logger.debug("Creating TXT record {0} with value {1}".format(host, value)) + current_app.logger.debug( + "Creating TXT record {0} with value {1}".format(host, value) + ) try: r = cf.zones.dns_records.post(zone_id, data=txt_record) except Exception as e: - current_app.logger.error('/zones.dns_records.post %s: %s' % (txt_record['name'], e)) - return zone_id, r['id'] + current_app.logger.error( + "/zones.dns_records.post %s: %s" % (txt_record["name"], e) + ) + return zone_id, r["id"] def delete_txt_record(change_ids, account_number, host, value): @@ -74,4 +78,4 @@ def delete_txt_record(change_ids, account_number, host, value): try: cf.zones.dns_records.delete(zone_id, record_id) except Exception as e: - current_app.logger.error('/zones.dns_records.post: %s' % e) + current_app.logger.error("/zones.dns_records.post: %s" % e) diff --git a/lemur/plugins/lemur_acme/dyn.py b/lemur/plugins/lemur_acme/dyn.py index db33caf0..00a48eb6 100644 --- a/lemur/plugins/lemur_acme/dyn.py +++ b/lemur/plugins/lemur_acme/dyn.py @@ -5,7 +5,12 @@ import dns.exception import dns.name import dns.query import dns.resolver -from dyn.tm.errors import DynectCreateError, DynectDeleteError, DynectGetError, DynectUpdateError +from dyn.tm.errors import ( + DynectCreateError, + DynectDeleteError, + DynectGetError, + DynectUpdateError, +) from dyn.tm.session import DynectSession from dyn.tm.zones import Node, Zone, get_all_zones from flask import current_app @@ -16,13 +21,13 @@ from lemur.extensions import metrics, sentry def get_dynect_session(): try: dynect_session = DynectSession( - current_app.config.get('ACME_DYN_CUSTOMER_NAME', ''), - current_app.config.get('ACME_DYN_USERNAME', ''), - current_app.config.get('ACME_DYN_PASSWORD', ''), + current_app.config.get("ACME_DYN_CUSTOMER_NAME", ""), + current_app.config.get("ACME_DYN_USERNAME", ""), + current_app.config.get("ACME_DYN_PASSWORD", ""), ) except Exception as e: sentry.captureException() - metrics.send('get_dynect_session_fail', 'counter', 1) + metrics.send("get_dynect_session_fail", "counter", 1) current_app.logger.debug("Unable to establish connection to Dyn", exc_info=True) raise return dynect_session @@ -33,17 +38,17 @@ def _has_dns_propagated(name, token): try: dns_resolver = dns.resolver.Resolver() dns_resolver.nameservers = [get_authoritative_nameserver(name)] - dns_response = dns_resolver.query(name, 'TXT') + dns_response = dns_resolver.query(name, "TXT") for rdata in dns_response: for txt_record in rdata.strings: txt_records.append(txt_record.decode("utf-8")) except dns.exception.DNSException: - metrics.send('has_dns_propagated_fail', 'counter', 1) + metrics.send("has_dns_propagated_fail", "counter", 1) return False for txt_record in txt_records: if txt_record == token: - metrics.send('has_dns_propagated_success', 'counter', 1) + metrics.send("has_dns_propagated_success", "counter", 1) return True return False @@ -56,18 +61,19 @@ def wait_for_dns_change(change_id, account_number=None): status = _has_dns_propagated(fqdn, token) current_app.logger.debug("Record status for fqdn: {}: {}".format(fqdn, status)) if status: - metrics.send('wait_for_dns_change_success', 'counter', 1) + metrics.send("wait_for_dns_change_success", "counter", 1) break time.sleep(10) if not status: # TODO: Delete associated DNS text record here - metrics.send('wait_for_dns_change_fail', 'counter', 1) - sentry.captureException( - extra={ - "fqdn": str(fqdn), "txt_record": str(token)} + metrics.send("wait_for_dns_change_fail", "counter", 1) + sentry.captureException(extra={"fqdn": str(fqdn), "txt_record": str(token)}) + metrics.send( + "wait_for_dns_change_error", + "counter", + 1, + metric_tags={"fqdn": fqdn, "txt_record": token}, ) - metrics.send('wait_for_dns_change_error', 'counter', 1, - metric_tags={'fqdn': fqdn, 'txt_record': token}) return @@ -84,7 +90,7 @@ def get_zone_name(domain): if z.name.count(".") > zone_name.count("."): zone_name = z.name if not zone_name: - metrics.send('dyn_no_zone_name', 'counter', 1) + metrics.send("dyn_no_zone_name", "counter", 1) raise Exception("No Dyn zone found for domain: {}".format(domain)) return zone_name @@ -101,23 +107,28 @@ def get_zones(account_number): def create_txt_record(domain, token, account_number): get_dynect_session() zone_name = get_zone_name(domain) - zone_parts = len(zone_name.split('.')) - node_name = '.'.join(domain.split('.')[:-zone_parts]) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) fqdn = "{0}.{1}".format(node_name, zone_name) zone = Zone(zone_name) try: - zone.add_record(node_name, record_type='TXT', txtdata="\"{}\"".format(token), ttl=5) + zone.add_record( + node_name, record_type="TXT", txtdata='"{}"'.format(token), ttl=5 + ) zone.publish() - current_app.logger.debug("TXT record created: {0}, token: {1}".format(fqdn, token)) + current_app.logger.debug( + "TXT record created: {0}, token: {1}".format(fqdn, token) + ) except (DynectCreateError, DynectUpdateError) as e: if "Cannot duplicate existing record data" in e.message: current_app.logger.debug( "Unable to add record. Domain: {}. Token: {}. " - "Record already exists: {}".format(domain, token, e), exc_info=True + "Record already exists: {}".format(domain, token, e), + exc_info=True, ) else: - metrics.send('create_txt_record_error', 'counter', 1) + metrics.send("create_txt_record_error", "counter", 1) sentry.captureException() raise @@ -132,17 +143,17 @@ def delete_txt_record(change_id, account_number, domain, token): return zone_name = get_zone_name(domain) - zone_parts = len(zone_name.split('.')) - node_name = '.'.join(domain.split('.')[:-zone_parts]) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) fqdn = "{0}.{1}".format(node_name, zone_name) zone = Zone(zone_name) node = Node(zone_name, fqdn) try: - all_txt_records = node.get_all_records_by_type('TXT') + all_txt_records = node.get_all_records_by_type("TXT") except DynectGetError: - metrics.send('delete_txt_record_geterror', 'counter', 1) + metrics.send("delete_txt_record_geterror", "counter", 1) # No Text Records remain or host is not in the zone anymore because all records have been deleted. return for txt_record in all_txt_records: @@ -153,22 +164,36 @@ def delete_txt_record(change_id, account_number, domain, token): except DynectDeleteError: sentry.captureException( extra={ - "fqdn": str(fqdn), "zone_name": str(zone_name), "node_name": str(node_name), - "txt_record": str(txt_record.txtdata)} + "fqdn": str(fqdn), + "zone_name": str(zone_name), + "node_name": str(node_name), + "txt_record": str(txt_record.txtdata), + } + ) + metrics.send( + "delete_txt_record_deleteerror", + "counter", + 1, + metric_tags={"fqdn": fqdn, "txt_record": txt_record.txtdata}, ) - metrics.send('delete_txt_record_deleteerror', 'counter', 1, - metric_tags={'fqdn': fqdn, 'txt_record': txt_record.txtdata}) try: zone.publish() except DynectUpdateError: sentry.captureException( extra={ - "fqdn": str(fqdn), "zone_name": str(zone_name), "node_name": str(node_name), - "txt_record": str(txt_record.txtdata)} + "fqdn": str(fqdn), + "zone_name": str(zone_name), + "node_name": str(node_name), + "txt_record": str(txt_record.txtdata), + } + ) + metrics.send( + "delete_txt_record_publish_error", + "counter", + 1, + metric_tags={"fqdn": str(fqdn), "txt_record": str(txt_record.txtdata)}, ) - metrics.send('delete_txt_record_publish_error', 'counter', 1, - metric_tags={'fqdn': str(fqdn), 'txt_record': str(txt_record.txtdata)}) def delete_acme_txt_records(domain): @@ -180,18 +205,21 @@ def delete_acme_txt_records(domain): if not domain.startswith(acme_challenge_string): current_app.logger.debug( "delete_acme_txt_records: Domain {} doesn't start with string {}. " - "Cowardly refusing to delete TXT records".format(domain, acme_challenge_string)) + "Cowardly refusing to delete TXT records".format( + domain, acme_challenge_string + ) + ) return zone_name = get_zone_name(domain) - zone_parts = len(zone_name.split('.')) - node_name = '.'.join(domain.split('.')[:-zone_parts]) + zone_parts = len(zone_name.split(".")) + node_name = ".".join(domain.split(".")[:-zone_parts]) fqdn = "{0}.{1}".format(node_name, zone_name) zone = Zone(zone_name) node = Node(zone_name, fqdn) - all_txt_records = node.get_all_records_by_type('TXT') + all_txt_records = node.get_all_records_by_type("TXT") for txt_record in all_txt_records: current_app.logger.debug("Deleting TXT record name: {0}".format(fqdn)) try: @@ -199,16 +227,23 @@ def delete_acme_txt_records(domain): except DynectDeleteError: sentry.captureException( extra={ - "fqdn": str(fqdn), "zone_name": str(zone_name), "node_name": str(node_name), - "txt_record": str(txt_record.txtdata)} + "fqdn": str(fqdn), + "zone_name": str(zone_name), + "node_name": str(node_name), + "txt_record": str(txt_record.txtdata), + } + ) + metrics.send( + "delete_txt_record_deleteerror", + "counter", + 1, + metric_tags={"fqdn": fqdn, "txt_record": txt_record.txtdata}, ) - metrics.send('delete_txt_record_deleteerror', 'counter', 1, - metric_tags={'fqdn': fqdn, 'txt_record': txt_record.txtdata}) zone.publish() def get_authoritative_nameserver(domain): - if current_app.config.get('ACME_DYN_GET_AUTHORATATIVE_NAMESERVER'): + if current_app.config.get("ACME_DYN_GET_AUTHORATATIVE_NAMESERVER"): n = dns.name.from_text(domain) depth = 2 @@ -219,7 +254,7 @@ def get_authoritative_nameserver(domain): while not last: s = n.split(depth) - last = s[0].to_unicode() == u'@' + last = s[0].to_unicode() == u"@" sub = s[1] query = dns.message.make_query(sub, dns.rdatatype.NS) @@ -227,11 +262,11 @@ def get_authoritative_nameserver(domain): rcode = response.rcode() if rcode != dns.rcode.NOERROR: - metrics.send('get_authoritative_nameserver_error', 'counter', 1) + metrics.send("get_authoritative_nameserver_error", "counter", 1) if rcode == dns.rcode.NXDOMAIN: - raise Exception('%s does not exist.' % sub) + raise Exception("%s does not exist." % sub) else: - raise Exception('Error %s' % dns.rcode.to_text(rcode)) + raise Exception("Error %s" % dns.rcode.to_text(rcode)) if len(response.authority) > 0: rrset = response.authority[0] diff --git a/lemur/plugins/lemur_acme/plugin.py b/lemur/plugins/lemur_acme/plugin.py index d9c41968..c734923a 100644 --- a/lemur/plugins/lemur_acme/plugin.py +++ b/lemur/plugins/lemur_acme/plugin.py @@ -48,7 +48,7 @@ class AcmeHandler(object): try: self.all_dns_providers = dns_provider_service.get_all_dns_providers() except Exception as e: - metrics.send('AcmeHandler_init_error', 'counter', 1) + metrics.send("AcmeHandler_init_error", "counter", 1) sentry.captureException() current_app.logger.error(f"Unable to fetch DNS Providers: {e}") self.all_dns_providers = [] @@ -67,45 +67,60 @@ class AcmeHandler(object): return host.replace("*.", "") def maybe_add_extension(self, host, dns_provider_options): - if dns_provider_options and dns_provider_options.get("acme_challenge_extension"): + if dns_provider_options and dns_provider_options.get( + "acme_challenge_extension" + ): host = host + dns_provider_options.get("acme_challenge_extension") return host - def start_dns_challenge(self, acme_client, account_number, host, dns_provider, order, dns_provider_options): + def start_dns_challenge( + self, + acme_client, + account_number, + host, + dns_provider, + order, + dns_provider_options, + ): current_app.logger.debug("Starting DNS challenge for {0}".format(host)) change_ids = [] host_to_validate = self.maybe_remove_wildcard(host) dns_challenges = self.find_dns_challenge(host_to_validate, order.authorizations) - host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) + host_to_validate = self.maybe_add_extension( + host_to_validate, dns_provider_options + ) if not dns_challenges: sentry.captureException() - metrics.send('start_dns_challenge_error_no_dns_challenges', 'counter', 1) + metrics.send("start_dns_challenge_error_no_dns_challenges", "counter", 1) raise Exception("Unable to determine DNS challenges from authorizations") for dns_challenge in dns_challenges: change_id = dns_provider.create_txt_record( dns_challenge.validation_domain_name(host_to_validate), dns_challenge.validation(acme_client.client.net.key), - account_number + account_number, ) change_ids.append(change_id) return AuthorizationRecord( - host, - order.authorizations, - dns_challenges, - change_ids + host, order.authorizations, dns_challenges, change_ids ) def complete_dns_challenge(self, acme_client, authz_record): - current_app.logger.debug("Finalizing DNS challenge for {0}".format(authz_record.authz[0].body.identifier.value)) + current_app.logger.debug( + "Finalizing DNS challenge for {0}".format( + authz_record.authz[0].body.identifier.value + ) + ) dns_providers = self.dns_providers_for_domain.get(authz_record.host) if not dns_providers: - metrics.send('complete_dns_challenge_error_no_dnsproviders', 'counter', 1) - raise Exception("No DNS providers found for domain: {}".format(authz_record.host)) + metrics.send("complete_dns_challenge_error_no_dnsproviders", "counter", 1) + raise Exception( + "No DNS providers found for domain: {}".format(authz_record.host) + ) for dns_provider in dns_providers: # Grab account number (For Route53) @@ -114,13 +129,17 @@ class AcmeHandler(object): dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) for change_id in authz_record.change_id: try: - dns_provider_plugin.wait_for_dns_change(change_id, account_number=account_number) + dns_provider_plugin.wait_for_dns_change( + change_id, account_number=account_number + ) except Exception: - metrics.send('complete_dns_challenge_error', 'counter', 1) + metrics.send("complete_dns_challenge_error", "counter", 1) sentry.captureException() current_app.logger.debug( f"Unable to resolve DNS challenge for change_id: {change_id}, account_id: " - f"{account_number}", exc_info=True) + f"{account_number}", + exc_info=True, + ) raise for dns_challenge in authz_record.dns_challenge: @@ -129,11 +148,11 @@ class AcmeHandler(object): verified = response.simple_verify( dns_challenge.chall, authz_record.host, - acme_client.client.net.key.public_key() + acme_client.client.net.key.public_key(), ) if not verified: - metrics.send('complete_dns_challenge_verification_error', 'counter', 1) + metrics.send("complete_dns_challenge_verification_error", "counter", 1) raise ValueError("Failed verification") time.sleep(5) @@ -152,8 +171,10 @@ class AcmeHandler(object): except (AcmeError, TimeoutError): sentry.captureException(extra={"order_url": str(order.uri)}) - metrics.send('request_certificate_error', 'counter', 1) - current_app.logger.error(f"Unable to resolve Acme order: {order.uri}", exc_info=True) + metrics.send("request_certificate_error", "counter", 1) + current_app.logger.error( + f"Unable to resolve Acme order: {order.uri}", exc_info=True + ) raise except errors.ValidationError: if order.fullchain_pem: @@ -161,12 +182,19 @@ class AcmeHandler(object): else: raise - pem_certificate = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, - OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, - orderr.fullchain_pem)).decode() - pem_certificate_chain = orderr.fullchain_pem[len(pem_certificate):].lstrip() + pem_certificate = OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, orderr.fullchain_pem + ), + ).decode() + pem_certificate_chain = orderr.fullchain_pem[ + len(pem_certificate) : # noqa + ].lstrip() - current_app.logger.debug("{0} {1}".format(type(pem_certificate), type(pem_certificate_chain))) + current_app.logger.debug( + "{0} {1}".format(type(pem_certificate), type(pem_certificate_chain)) + ) return pem_certificate, pem_certificate_chain def setup_acme_client(self, authority): @@ -176,30 +204,40 @@ class AcmeHandler(object): for option in json.loads(authority.options): options[option["name"]] = option.get("value") - email = options.get('email', current_app.config.get('ACME_EMAIL')) - tel = options.get('telephone', current_app.config.get('ACME_TEL')) - directory_url = options.get('acme_url', current_app.config.get('ACME_DIRECTORY_URL')) + email = options.get("email", current_app.config.get("ACME_EMAIL")) + tel = options.get("telephone", current_app.config.get("ACME_TEL")) + directory_url = options.get( + "acme_url", current_app.config.get("ACME_DIRECTORY_URL") + ) - existing_key = options.get('acme_private_key', current_app.config.get('ACME_PRIVATE_KEY')) - existing_regr = options.get('acme_regr', current_app.config.get('ACME_REGR')) + existing_key = options.get( + "acme_private_key", current_app.config.get("ACME_PRIVATE_KEY") + ) + existing_regr = options.get("acme_regr", current_app.config.get("ACME_REGR")) if existing_key and existing_regr: # Reuse the same account for each certificate issuance key = jose.JWK.json_loads(existing_key) regr = messages.RegistrationResource.json_loads(existing_regr) - current_app.logger.debug("Connecting with directory at {0}".format(directory_url)) + current_app.logger.debug( + "Connecting with directory at {0}".format(directory_url) + ) net = ClientNetwork(key, account=regr) client = BackwardsCompatibleClientV2(net, key, directory_url) return client, {} else: # Create an account for each certificate issuance - key = jose.JWKRSA(key=generate_private_key('RSA2048')) + key = jose.JWKRSA(key=generate_private_key("RSA2048")) - current_app.logger.debug("Connecting with directory at {0}".format(directory_url)) + current_app.logger.debug( + "Connecting with directory at {0}".format(directory_url) + ) net = ClientNetwork(key, account=None, timeout=3600) client = BackwardsCompatibleClientV2(net, key, directory_url) - registration = client.new_account_and_tos(messages.NewRegistration.from_data(email=email)) + registration = client.new_account_and_tos( + messages.NewRegistration.from_data(email=email) + ) current_app.logger.debug("Connected: {0}".format(registration.uri)) return client, registration @@ -212,9 +250,9 @@ class AcmeHandler(object): """ current_app.logger.debug("Fetching domains") - domains = [options['common_name']] - if options.get('extensions'): - for name in options['extensions']['sub_alt_names']['names']: + domains = [options["common_name"]] + if options.get("extensions"): + for name in options["extensions"]["sub_alt_names"]["names"]: domains.append(name) current_app.logger.debug("Got these domains: {0}".format(domains)) @@ -225,16 +263,22 @@ class AcmeHandler(object): for domain in order_info.domains: if not self.dns_providers_for_domain.get(domain): - metrics.send('get_authorizations_no_dns_provider_for_domain', 'counter', 1) + metrics.send( + "get_authorizations_no_dns_provider_for_domain", "counter", 1 + ) raise Exception("No DNS providers found for domain: {}".format(domain)) for dns_provider in self.dns_providers_for_domain[domain]: dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) dns_provider_options = json.loads(dns_provider.credentials) account_number = dns_provider_options.get("account_id") - authz_record = self.start_dns_challenge(acme_client, account_number, domain, - dns_provider_plugin, - order, - dns_provider.options) + authz_record = self.start_dns_challenge( + acme_client, + account_number, + domain, + dns_provider_plugin, + order, + dns_provider.options, + ) authorizations.append(authz_record) return authorizations @@ -268,16 +312,20 @@ class AcmeHandler(object): dns_providers = self.dns_providers_for_domain.get(authz_record.host) for dns_provider in dns_providers: # Grab account number (For Route53) - dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) + dns_provider_plugin = self.get_dns_provider( + dns_provider.provider_type + ) dns_provider_options = json.loads(dns_provider.credentials) account_number = dns_provider_options.get("account_id") host_to_validate = self.maybe_remove_wildcard(authz_record.host) - host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) + host_to_validate = self.maybe_add_extension( + host_to_validate, dns_provider_options + ) dns_provider_plugin.delete_txt_record( authz_record.change_id, account_number, dns_challenge.validation_domain_name(host_to_validate), - dns_challenge.validation(acme_client.client.net.key) + dns_challenge.validation(acme_client.client.net.key), ) return authorizations @@ -302,7 +350,9 @@ class AcmeHandler(object): account_number = dns_provider_options.get("account_id") dns_challenges = authz_record.dns_challenge host_to_validate = self.maybe_remove_wildcard(authz_record.host) - host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) + host_to_validate = self.maybe_add_extension( + host_to_validate, dns_provider_options + ) dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) for dns_challenge in dns_challenges: try: @@ -310,21 +360,17 @@ class AcmeHandler(object): authz_record.change_id, account_number, dns_challenge.validation_domain_name(host_to_validate), - dns_challenge.validation(acme_client.client.net.key) + dns_challenge.validation(acme_client.client.net.key), ) except Exception as e: # If this fails, it's most likely because the record doesn't exist (It was already cleaned up) # or we're not authorized to modify it. - metrics.send('cleanup_dns_challenges_error', 'counter', 1) + metrics.send("cleanup_dns_challenges_error", "counter", 1) sentry.captureException() pass def get_dns_provider(self, type): - provider_types = { - 'cloudflare': cloudflare, - 'dyn': dyn, - 'route53': route53, - } + provider_types = {"cloudflare": cloudflare, "dyn": dyn, "route53": route53} provider = provider_types.get(type) if not provider: raise UnknownProvider("No such DNS provider: {}".format(type)) @@ -332,41 +378,43 @@ class AcmeHandler(object): class ACMEIssuerPlugin(IssuerPlugin): - title = 'Acme' - slug = 'acme-issuer' - description = 'Enables the creation of certificates via ACME CAs (including Let\'s Encrypt)' + title = "Acme" + slug = "acme-issuer" + description = ( + "Enables the creation of certificates via ACME CAs (including Let's Encrypt)" + ) version = acme.VERSION - author = 'Netflix' - author_url = 'https://github.com/netflix/lemur.git' + author = "Netflix" + author_url = "https://github.com/netflix/lemur.git" options = [ { - 'name': 'acme_url', - 'type': 'str', - 'required': True, - 'validation': '/^http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+$/', - 'helpMessage': 'Must be a valid web url starting with http[s]://', + "name": "acme_url", + "type": "str", + "required": True, + "validation": "/^http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+$/", + "helpMessage": "Must be a valid web url starting with http[s]://", }, { - 'name': 'telephone', - 'type': 'str', - 'default': '', - 'helpMessage': 'Telephone to use' + "name": "telephone", + "type": "str", + "default": "", + "helpMessage": "Telephone to use", }, { - 'name': 'email', - 'type': 'str', - 'default': '', - 'validation': '/^?([-a-zA-Z0-9.`?{}]+@\w+\.\w+)$/', - 'helpMessage': 'Email to use' + "name": "email", + "type": "str", + "default": "", + "validation": "/^?([-a-zA-Z0-9.`?{}]+@\w+\.\w+)$/", + "helpMessage": "Email to use", }, { - 'name': 'certificate', - 'type': 'textarea', - 'default': '', - 'validation': '/^-----BEGIN CERTIFICATE-----/', - 'helpMessage': 'Certificate to use' + "name": "certificate", + "type": "textarea", + "default": "", + "validation": "/^-----BEGIN CERTIFICATE-----/", + "helpMessage": "Certificate to use", }, ] @@ -376,11 +424,7 @@ class ACMEIssuerPlugin(IssuerPlugin): def get_dns_provider(self, type): self.acme = AcmeHandler() - provider_types = { - 'cloudflare': cloudflare, - 'dyn': dyn, - 'route53': route53, - } + provider_types = {"cloudflare": cloudflare, "dyn": dyn, "route53": route53} provider = provider_types.get(type) if not provider: raise UnknownProvider("No such DNS provider: {}".format(type)) @@ -411,24 +455,31 @@ class ACMEIssuerPlugin(IssuerPlugin): try: order = acme_client.new_order(pending_cert.csr) except WildcardUnsupportedError: - metrics.send('get_ordered_certificate_wildcard_unsupported', 'counter', 1) - raise Exception("The currently selected ACME CA endpoint does" - " not support issuing wildcard certificates.") + metrics.send("get_ordered_certificate_wildcard_unsupported", "counter", 1) + raise Exception( + "The currently selected ACME CA endpoint does" + " not support issuing wildcard certificates." + ) try: - authorizations = self.acme.get_authorizations(acme_client, order, order_info) + authorizations = self.acme.get_authorizations( + acme_client, order, order_info + ) except ClientError: sentry.captureException() - metrics.send('get_ordered_certificate_error', 'counter', 1) - current_app.logger.error(f"Unable to resolve pending cert: {pending_cert.name}", exc_info=True) + metrics.send("get_ordered_certificate_error", "counter", 1) + current_app.logger.error( + f"Unable to resolve pending cert: {pending_cert.name}", exc_info=True + ) return False authorizations = self.acme.finalize_authorizations(acme_client, authorizations) pem_certificate, pem_certificate_chain = self.acme.request_certificate( - acme_client, authorizations, order) + acme_client, authorizations, order + ) cert = { - 'body': "\n".join(str(pem_certificate).splitlines()), - 'chain': "\n".join(str(pem_certificate_chain).splitlines()), - 'external_id': str(pending_cert.external_id) + "body": "\n".join(str(pem_certificate).splitlines()), + "chain": "\n".join(str(pem_certificate_chain).splitlines()), + "external_id": str(pending_cert.external_id), } return cert @@ -438,10 +489,14 @@ class ACMEIssuerPlugin(IssuerPlugin): certs = [] for pending_cert in pending_certs: try: - acme_client, registration = self.acme.setup_acme_client(pending_cert.authority) + acme_client, registration = self.acme.setup_acme_client( + pending_cert.authority + ) order_info = authorization_service.get(pending_cert.external_id) if pending_cert.dns_provider_id: - dns_provider = dns_provider_service.get(pending_cert.dns_provider_id) + dns_provider = dns_provider_service.get( + pending_cert.dns_provider_id + ) for domain in order_info.domains: # Currently, we only support specifying one DNS provider per certificate, even if that @@ -455,70 +510,79 @@ class ACMEIssuerPlugin(IssuerPlugin): order = acme_client.new_order(pending_cert.csr) except WildcardUnsupportedError: sentry.captureException() - metrics.send('get_ordered_certificates_wildcard_unsupported_error', 'counter', 1) - raise Exception("The currently selected ACME CA endpoint does" - " not support issuing wildcard certificates.") + metrics.send( + "get_ordered_certificates_wildcard_unsupported_error", + "counter", + 1, + ) + raise Exception( + "The currently selected ACME CA endpoint does" + " not support issuing wildcard certificates." + ) - authorizations = self.acme.get_authorizations(acme_client, order, order_info) + authorizations = self.acme.get_authorizations( + acme_client, order, order_info + ) - pending.append({ - "acme_client": acme_client, - "authorizations": authorizations, - "pending_cert": pending_cert, - "order": order, - }) + pending.append( + { + "acme_client": acme_client, + "authorizations": authorizations, + "pending_cert": pending_cert, + "order": order, + } + ) except (ClientError, ValueError, Exception) as e: sentry.captureException() - metrics.send('get_ordered_certificates_pending_creation_error', 'counter', 1) - current_app.logger.error(f"Unable to resolve pending cert: {pending_cert}", exc_info=True) + metrics.send( + "get_ordered_certificates_pending_creation_error", "counter", 1 + ) + current_app.logger.error( + f"Unable to resolve pending cert: {pending_cert}", exc_info=True + ) error = e if globals().get("order") and order: error += f" Order uri: {order.uri}" - certs.append({ - "cert": False, - "pending_cert": pending_cert, - "last_error": e, - }) + certs.append( + {"cert": False, "pending_cert": pending_cert, "last_error": e} + ) for entry in pending: try: entry["authorizations"] = self.acme.finalize_authorizations( - entry["acme_client"], - entry["authorizations"], + entry["acme_client"], entry["authorizations"] ) pem_certificate, pem_certificate_chain = self.acme.request_certificate( - entry["acme_client"], - entry["authorizations"], - entry["order"] + entry["acme_client"], entry["authorizations"], entry["order"] ) cert = { - 'body': "\n".join(str(pem_certificate).splitlines()), - 'chain': "\n".join(str(pem_certificate_chain).splitlines()), - 'external_id': str(entry["pending_cert"].external_id) + "body": "\n".join(str(pem_certificate).splitlines()), + "chain": "\n".join(str(pem_certificate_chain).splitlines()), + "external_id": str(entry["pending_cert"].external_id), } - certs.append({ - "cert": cert, - "pending_cert": entry["pending_cert"], - }) + certs.append({"cert": cert, "pending_cert": entry["pending_cert"]}) except (PollError, AcmeError, Exception) as e: sentry.captureException() - metrics.send('get_ordered_certificates_resolution_error', 'counter', 1) + metrics.send("get_ordered_certificates_resolution_error", "counter", 1) order_url = order.uri error = f"{e}. Order URI: {order_url}" current_app.logger.error( f"Unable to resolve pending cert: {pending_cert}. " - f"Check out {order_url} for more information.", exc_info=True) - certs.append({ - "cert": False, - "pending_cert": entry["pending_cert"], - "last_error": error, - }) + f"Check out {order_url} for more information.", + exc_info=True, + ) + certs.append( + { + "cert": False, + "pending_cert": entry["pending_cert"], + "last_error": error, + } + ) # Ensure DNS records get deleted self.acme.cleanup_dns_challenges( - entry["acme_client"], - entry["authorizations"], + entry["acme_client"], entry["authorizations"] ) return certs @@ -531,20 +595,26 @@ class ACMEIssuerPlugin(IssuerPlugin): :return: :raise Exception: """ self.acme = AcmeHandler() - authority = issuer_options.get('authority') - create_immediately = issuer_options.get('create_immediately', False) + authority = issuer_options.get("authority") + create_immediately = issuer_options.get("create_immediately", False) acme_client, registration = self.acme.setup_acme_client(authority) - dns_provider = issuer_options.get('dns_provider', {}) + dns_provider = issuer_options.get("dns_provider", {}) if dns_provider: dns_provider_options = dns_provider.options credentials = json.loads(dns_provider.credentials) - current_app.logger.debug("Using DNS provider: {0}".format(dns_provider.provider_type)) - dns_provider_plugin = __import__(dns_provider.provider_type, globals(), locals(), [], 1) + current_app.logger.debug( + "Using DNS provider: {0}".format(dns_provider.provider_type) + ) + dns_provider_plugin = __import__( + dns_provider.provider_type, globals(), locals(), [], 1 + ) account_number = credentials.get("account_id") provider_type = dns_provider.provider_type if provider_type == "route53" and not account_number: - error = "Route53 DNS Provider {} does not have an account number configured.".format(dns_provider.name) + error = "Route53 DNS Provider {} does not have an account number configured.".format( + dns_provider.name + ) current_app.logger.error(error) raise InvalidConfiguration(error) else: @@ -563,16 +633,29 @@ class ACMEIssuerPlugin(IssuerPlugin): else: authz_domains.append(d.value) - dns_authorization = authorization_service.create(account_number, authz_domains, - provider_type) + dns_authorization = authorization_service.create( + account_number, authz_domains, provider_type + ) # Return id of the DNS Authorization return None, None, dns_authorization.id - authorizations = self.acme.get_authorizations(acme_client, account_number, domains, dns_provider_plugin, - dns_provider_options) - self.acme.finalize_authorizations(acme_client, account_number, dns_provider_plugin, authorizations, - dns_provider_options) - pem_certificate, pem_certificate_chain = self.acme.request_certificate(acme_client, authorizations, csr) + authorizations = self.acme.get_authorizations( + acme_client, + account_number, + domains, + dns_provider_plugin, + dns_provider_options, + ) + self.acme.finalize_authorizations( + acme_client, + account_number, + dns_provider_plugin, + authorizations, + dns_provider_options, + ) + pem_certificate, pem_certificate_chain = self.acme.request_certificate( + acme_client, authorizations, csr + ) # TODO add external ID (if possible) return pem_certificate, pem_certificate_chain, None @@ -585,18 +668,18 @@ class ACMEIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'acme'} - plugin_options = options.get('plugin', {}).get('plugin_options') + role = {"username": "", "password": "", "name": "acme"} + plugin_options = options.get("plugin", {}).get("plugin_options") if not plugin_options: error = "Invalid options for lemur_acme plugin: {}".format(options) current_app.logger.error(error) raise InvalidConfiguration(error) # Define static acme_root based off configuration variable by default. However, if user has passed a # certificate, use this certificate as the root. - acme_root = current_app.config.get('ACME_ROOT') + acme_root = current_app.config.get("ACME_ROOT") for option in plugin_options: - if option.get('name') == 'certificate': - acme_root = option.get('value') + if option.get("name") == "certificate": + acme_root = option.get("value") return acme_root, "", [role] def cancel_ordered_certificate(self, pending_cert, **kwargs): diff --git a/lemur/plugins/lemur_acme/route53.py b/lemur/plugins/lemur_acme/route53.py index 3b6c5b32..55da5161 100644 --- a/lemur/plugins/lemur_acme/route53.py +++ b/lemur/plugins/lemur_acme/route53.py @@ -3,7 +3,7 @@ import time from lemur.plugins.lemur_aws.sts import sts_client -@sts_client('route53') +@sts_client("route53") def wait_for_dns_change(change_id, client=None): _, change_id = change_id @@ -14,7 +14,7 @@ def wait_for_dns_change(change_id, client=None): time.sleep(5) -@sts_client('route53') +@sts_client("route53") def find_zone_id(domain, client=None): paginator = client.get_paginator("list_hosted_zones") zones = [] @@ -25,34 +25,35 @@ def find_zone_id(domain, client=None): zones.append((zone["Name"], zone["Id"])) if not zones: - raise ValueError( - "Unable to find a Route53 hosted zone for {}".format(domain) - ) + raise ValueError("Unable to find a Route53 hosted zone for {}".format(domain)) return zones[0][1] -@sts_client('route53') +@sts_client("route53") def get_zones(client=None): paginator = client.get_paginator("list_hosted_zones") zones = [] for page in paginator.paginate(): for zone in page["HostedZones"]: - zones.append(zone["Name"][:-1]) # We need [:-1] to strip out the trailing dot. + zones.append( + zone["Name"][:-1] + ) # We need [:-1] to strip out the trailing dot. return zones -@sts_client('route53') +@sts_client("route53") def change_txt_record(action, zone_id, domain, value, client=None): current_txt_records = [] try: current_records = client.list_resource_record_sets( HostedZoneId=zone_id, StartRecordName=domain, - StartRecordType='TXT', - MaxItems="1")["ResourceRecordSets"] + StartRecordType="TXT", + MaxItems="1", + )["ResourceRecordSets"] for record in current_records: - if record.get('Type') == 'TXT': + if record.get("Type") == "TXT": current_txt_records.extend(record.get("ResourceRecords", [])) except Exception as e: # Current Resource Record does not exist @@ -72,7 +73,9 @@ def change_txt_record(action, zone_id, domain, value, client=None): # If we want to delete one record out of many, we'll update the record to not include the deleted value instead. # This allows us to support concurrent issuance. current_txt_records = [ - record for record in current_txt_records if not (record.get('Value') == '"{}"'.format(value)) + record + for record in current_txt_records + if not (record.get("Value") == '"{}"'.format(value)) ] action = "UPSERT" @@ -87,10 +90,10 @@ def change_txt_record(action, zone_id, domain, value, client=None): "Type": "TXT", "TTL": 300, "ResourceRecords": current_txt_records, - } + }, } ] - } + }, ) return response["ChangeInfo"]["Id"] @@ -98,11 +101,7 @@ def change_txt_record(action, zone_id, domain, value, client=None): def create_txt_record(host, value, account_number): zone_id = find_zone_id(host, account_number=account_number) change_id = change_txt_record( - "UPSERT", - zone_id, - host, - value, - account_number=account_number + "UPSERT", zone_id, host, value, account_number=account_number ) return zone_id, change_id @@ -113,11 +112,7 @@ def delete_txt_record(change_ids, account_number, host, value): zone_id, _ = change_id try: change_txt_record( - "DELETE", - zone_id, - host, - value, - account_number=account_number + "DELETE", zone_id, host, value, account_number=account_number ) except Exception as e: if "but it was not found" in e.response.get("Error", {}).get("Message"): diff --git a/lemur/plugins/lemur_acme/tests/test_acme.py b/lemur/plugins/lemur_acme/tests/test_acme.py index 0c406627..3bf1d05c 100644 --- a/lemur/plugins/lemur_acme/tests/test_acme.py +++ b/lemur/plugins/lemur_acme/tests/test_acme.py @@ -6,8 +6,7 @@ from lemur.plugins.lemur_acme import plugin class TestAcme(unittest.TestCase): - - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") def setUp(self, mock_dns_provider_service): self.ACMEIssuerPlugin = plugin.ACMEIssuerPlugin() self.acme = plugin.AcmeHandler() @@ -15,14 +14,17 @@ class TestAcme(unittest.TestCase): mock_dns_provider.name = "cloudflare" mock_dns_provider.credentials = "{}" mock_dns_provider.provider_type = "cloudflare" - self.acme.dns_providers_for_domain = {"www.test.com": [mock_dns_provider], - "test.fakedomain.net": [mock_dns_provider]} + self.acme.dns_providers_for_domain = { + "www.test.com": [mock_dns_provider], + "test.fakedomain.net": [mock_dns_provider], + } - @patch('lemur.plugins.lemur_acme.plugin.len', return_value=1) + @patch("lemur.plugins.lemur_acme.plugin.len", return_value=1) def test_find_dns_challenge(self, mock_len): assert mock_len from acme import challenges + c = challenges.DNS01() mock_authz = Mock() @@ -37,11 +39,13 @@ class TestAcme(unittest.TestCase): a = plugin.AuthorizationRecord("host", "authz", "challenge", "id") self.assertEqual(type(a), plugin.AuthorizationRecord) - @patch('acme.client.Client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.len', return_value=1) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge') - def test_start_dns_challenge(self, mock_find_dns_challenge, mock_len, mock_app, mock_acme): + @patch("acme.client.Client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.len", return_value=1) + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge") + def test_start_dns_challenge( + self, mock_find_dns_challenge, mock_len, mock_app, mock_acme + ): assert mock_len mock_order = Mock() mock_app.logger.debug = Mock() @@ -49,6 +53,7 @@ class TestAcme(unittest.TestCase): mock_authz.body.resolved_combinations = [] mock_entry = MagicMock() from acme import challenges + c = challenges.DNS01() mock_entry.chall = TestAcme.test_complete_dns_challenge_fail mock_authz.body.resolved_combinations.append(mock_entry) @@ -60,13 +65,17 @@ class TestAcme(unittest.TestCase): iterable = mock_find_dns_challenge.return_value iterator = iter(values) iterable.__iter__.return_value = iterator - result = self.acme.start_dns_challenge(mock_acme, "accountid", "host", mock_dns_provider, mock_order, {}) + result = self.acme.start_dns_challenge( + mock_acme, "accountid", "host", mock_dns_provider, mock_order, {} + ) self.assertEqual(type(result), plugin.AuthorizationRecord) - @patch('acme.client.Client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change') - def test_complete_dns_challenge_success(self, mock_wait_for_dns_change, mock_current_app, mock_acme): + @patch("acme.client.Client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change") + def test_complete_dns_challenge_success( + self, mock_wait_for_dns_change, mock_current_app, mock_acme + ): mock_dns_provider = Mock() mock_dns_provider.wait_for_dns_change = Mock(return_value=True) mock_authz = Mock() @@ -84,10 +93,12 @@ class TestAcme(unittest.TestCase): mock_authz.dns_challenge.append(dns_challenge) self.acme.complete_dns_challenge(mock_acme, mock_authz) - @patch('acme.client.Client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change') - def test_complete_dns_challenge_fail(self, mock_wait_for_dns_change, mock_current_app, mock_acme): + @patch("acme.client.Client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.cloudflare.wait_for_dns_change") + def test_complete_dns_challenge_fail( + self, mock_wait_for_dns_change, mock_current_app, mock_acme + ): mock_dns_provider = Mock() mock_dns_provider.wait_for_dns_change = Mock(return_value=True) @@ -105,16 +116,22 @@ class TestAcme(unittest.TestCase): dns_challenge = Mock() mock_authz.dns_challenge.append(dns_challenge) self.assertRaises( - ValueError, - self.acme.complete_dns_challenge(mock_acme, mock_authz) + ValueError, self.acme.complete_dns_challenge(mock_acme, mock_authz) ) - @patch('acme.client.Client') - @patch('OpenSSL.crypto', return_value="mock_cert") - @patch('josepy.util.ComparableX509') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - def test_request_certificate(self, mock_current_app, mock_find_dns_challenge, mock_jose, mock_crypto, mock_acme): + @patch("acme.client.Client") + @patch("OpenSSL.crypto", return_value="mock_cert") + @patch("josepy.util.ComparableX509") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + def test_request_certificate( + self, + mock_current_app, + mock_find_dns_challenge, + mock_jose, + mock_crypto, + mock_acme, + ): mock_cert_response = Mock() mock_cert_response.body = "123" mock_cert_response_full = [mock_cert_response, True] @@ -124,7 +141,7 @@ class TestAcme(unittest.TestCase): mock_authz_record.authz = Mock() mock_authz.append(mock_authz_record) mock_acme.fetch_chain = Mock(return_value="mock_chain") - mock_crypto.dump_certificate = Mock(return_value=b'chain') + mock_crypto.dump_certificate = Mock(return_value=b"chain") mock_order = Mock() self.acme.request_certificate(mock_acme, [], mock_order) @@ -134,8 +151,8 @@ class TestAcme(unittest.TestCase): with self.assertRaises(Exception): self.acme.setup_acme_client(mock_authority) - @patch('lemur.plugins.lemur_acme.plugin.BackwardsCompatibleClientV2') - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.BackwardsCompatibleClientV2") + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_setup_acme_client_success(self, mock_current_app, mock_acme): mock_authority = Mock() mock_authority.options = '[{"name": "mock_name", "value": "mock_value"}]' @@ -150,31 +167,29 @@ class TestAcme(unittest.TestCase): assert result_client assert result_registration - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_get_domains_single(self, mock_current_app): - options = { - "common_name": "test.netflix.net" - } + options = {"common_name": "test.netflix.net"} result = self.acme.get_domains(options) self.assertEqual(result, [options["common_name"]]) - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_get_domains_multiple(self, mock_current_app): options = { "common_name": "test.netflix.net", "extensions": { - "sub_alt_names": { - "names": [ - "test2.netflix.net", - "test3.netflix.net" - ] - } - } + "sub_alt_names": {"names": ["test2.netflix.net", "test3.netflix.net"]} + }, } result = self.acme.get_domains(options) - self.assertEqual(result, [options["common_name"], "test2.netflix.net", "test3.netflix.net"]) + self.assertEqual( + result, [options["common_name"], "test2.netflix.net", "test3.netflix.net"] + ) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.start_dns_challenge', return_value="test") + @patch( + "lemur.plugins.lemur_acme.plugin.AcmeHandler.start_dns_challenge", + return_value="test", + ) def test_get_authorizations(self, mock_start_dns_challenge): mock_order = Mock() mock_order.body.identifiers = [] @@ -183,10 +198,15 @@ class TestAcme(unittest.TestCase): mock_order_info = Mock() mock_order_info.account_number = 1 mock_order_info.domains = ["test.fakedomain.net"] - result = self.acme.get_authorizations("acme_client", mock_order, mock_order_info) + result = self.acme.get_authorizations( + "acme_client", mock_order, mock_order_info + ) self.assertEqual(result, ["test"]) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.complete_dns_challenge', return_value="test") + @patch( + "lemur.plugins.lemur_acme.plugin.AcmeHandler.complete_dns_challenge", + return_value="test", + ) def test_finalize_authorizations(self, mock_complete_dns_challenge): mock_authz = [] mock_authz_record = MagicMock() @@ -202,28 +222,28 @@ class TestAcme(unittest.TestCase): result = self.acme.finalize_authorizations(mock_acme_client, mock_authz) self.assertEqual(result, mock_authz) - @patch('lemur.plugins.lemur_acme.plugin.current_app') + @patch("lemur.plugins.lemur_acme.plugin.current_app") def test_create_authority(self, mock_current_app): mock_current_app.config = Mock() options = { - "plugin": { - "plugin_options": [{ - "name": "certificate", - "value": "123" - }] - } + "plugin": {"plugin_options": [{"name": "certificate", "value": "123"}]} } acme_root, b, role = self.ACMEIssuerPlugin.create_authority(options) self.assertEqual(acme_root, "123") self.assertEqual(b, "") - self.assertEqual(role, [{'username': '', 'password': '', 'name': 'acme'}]) + self.assertEqual(role, [{"username": "", "password": "", "name": "acme"}]) - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.dyn.current_app') - @patch('lemur.plugins.lemur_acme.cloudflare.current_app') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - def test_get_dns_provider(self, mock_dns_provider_service, mock_current_app_cloudflare, mock_current_app_dyn, - mock_current_app): + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.dyn.current_app") + @patch("lemur.plugins.lemur_acme.cloudflare.current_app") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + def test_get_dns_provider( + self, + mock_dns_provider_service, + mock_current_app_cloudflare, + mock_current_app_dyn, + mock_current_app, + ): provider = plugin.ACMEIssuerPlugin() route53 = provider.get_dns_provider("route53") assert route53 @@ -232,16 +252,23 @@ class TestAcme(unittest.TestCase): dyn = provider.get_dns_provider("dyn") assert dyn - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.authorization_service') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate') + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.authorization_service") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") def test_get_ordered_certificate( - self, mock_request_certificate, mock_finalize_authorizations, mock_get_authorizations, - mock_dns_provider_service, mock_authorization_service, mock_current_app, mock_acme): + self, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_dns_provider_service, + mock_authorization_service, + mock_current_app, + mock_acme, + ): mock_client = Mock() mock_acme.return_value = (mock_client, "") mock_request_certificate.return_value = ("pem_certificate", "chain") @@ -253,24 +280,26 @@ class TestAcme(unittest.TestCase): provider.get_dns_provider = Mock() result = provider.get_ordered_certificate(mock_cert) self.assertEqual( - result, - { - 'body': "pem_certificate", - 'chain': "chain", - 'external_id': "1" - } + result, {"body": "pem_certificate", "chain": "chain", "external_id": "1"} ) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.authorization_service') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate') + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.authorization_service") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") def test_get_ordered_certificates( - self, mock_request_certificate, mock_finalize_authorizations, mock_get_authorizations, - mock_dns_provider_service, mock_authorization_service, mock_current_app, mock_acme): + self, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_dns_provider_service, + mock_authorization_service, + mock_current_app, + mock_acme, + ): mock_client = Mock() mock_acme.return_value = (mock_client, "") mock_request_certificate.return_value = ("pem_certificate", "chain") @@ -285,19 +314,32 @@ class TestAcme(unittest.TestCase): provider.get_dns_provider = Mock() result = provider.get_ordered_certificates([mock_cert, mock_cert2]) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['cert'], {'body': 'pem_certificate', 'chain': 'chain', 'external_id': '1'}) - self.assertEqual(result[1]['cert'], {'body': 'pem_certificate', 'chain': 'chain', 'external_id': '2'}) + self.assertEqual( + result[0]["cert"], + {"body": "pem_certificate", "chain": "chain", "external_id": "1"}, + ) + self.assertEqual( + result[1]["cert"], + {"body": "pem_certificate", "chain": "chain", "external_id": "2"}, + ) - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client') - @patch('lemur.plugins.lemur_acme.plugin.dns_provider_service') - @patch('lemur.plugins.lemur_acme.plugin.current_app') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations') - @patch('lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate') - @patch('lemur.plugins.lemur_acme.plugin.authorization_service') - def test_create_certificate(self, mock_authorization_service, mock_request_certificate, - mock_finalize_authorizations, mock_get_authorizations, - mock_current_app, mock_dns_provider_service, mock_acme): + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.setup_acme_client") + @patch("lemur.plugins.lemur_acme.plugin.dns_provider_service") + @patch("lemur.plugins.lemur_acme.plugin.current_app") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.finalize_authorizations") + @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.request_certificate") + @patch("lemur.plugins.lemur_acme.plugin.authorization_service") + def test_create_certificate( + self, + mock_authorization_service, + mock_request_certificate, + mock_finalize_authorizations, + mock_get_authorizations, + mock_current_app, + mock_dns_provider_service, + mock_acme, + ): provider = plugin.ACMEIssuerPlugin() mock_authority = Mock() @@ -310,9 +352,9 @@ class TestAcme(unittest.TestCase): mock_dns_provider_service.get.return_value = mock_dns_provider issuer_options = { - 'authority': mock_authority, - 'dns_provider': mock_dns_provider, - "common_name": "test.netflix.net" + "authority": mock_authority, + "dns_provider": mock_dns_provider, + "common_name": "test.netflix.net", } csr = "123" mock_request_certificate.return_value = ("pem_certificate", "chain") diff --git a/lemur/plugins/lemur_adcs/__init__.py b/lemur/plugins/lemur_adcs/__init__.py index 6b61e936..b902ed7a 100644 --- a/lemur/plugins/lemur_adcs/__init__.py +++ b/lemur/plugins/lemur_adcs/__init__.py @@ -1,6 +1,5 @@ """Set the version information.""" try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_adcs/plugin.py b/lemur/plugins/lemur_adcs/plugin.py index b7698474..bc07ede3 100644 --- a/lemur/plugins/lemur_adcs/plugin.py +++ b/lemur/plugins/lemur_adcs/plugin.py @@ -7,13 +7,13 @@ from flask import current_app class ADCSIssuerPlugin(IssuerPlugin): - title = 'ADCS' - slug = 'adcs-issuer' - description = 'Enables the creation of certificates by ADCS (Active Directory Certificate Services)' + title = "ADCS" + slug = "adcs-issuer" + description = "Enables the creation of certificates by ADCS (Active Directory Certificate Services)" version = ADCS.VERSION - author = 'sirferl' - author_url = 'https://github.com/sirferl/lemur' + author = "sirferl" + author_url = "https://github.com/sirferl/lemur" def __init__(self, *args, **kwargs): """Initialize the issuer with the appropriate details.""" @@ -30,66 +30,80 @@ class ADCSIssuerPlugin(IssuerPlugin): :param options: :return: """ - adcs_root = current_app.config.get('ADCS_ROOT') - adcs_issuing = current_app.config.get('ADCS_ISSUING') - role = {'username': '', 'password': '', 'name': 'adcs'} + adcs_root = current_app.config.get("ADCS_ROOT") + adcs_issuing = current_app.config.get("ADCS_ISSUING") + role = {"username": "", "password": "", "name": "adcs"} return adcs_root, adcs_issuing, [role] def create_certificate(self, csr, issuer_options): - adcs_server = current_app.config.get('ADCS_SERVER') - adcs_user = current_app.config.get('ADCS_USER') - adcs_pwd = current_app.config.get('ADCS_PWD') - adcs_auth_method = current_app.config.get('ADCS_AUTH_METHOD') - adcs_template = current_app.config.get('ADCS_TEMPLATE') - ca_server = Certsrv(adcs_server, adcs_user, adcs_pwd, auth_method=adcs_auth_method) + adcs_server = current_app.config.get("ADCS_SERVER") + adcs_user = current_app.config.get("ADCS_USER") + adcs_pwd = current_app.config.get("ADCS_PWD") + adcs_auth_method = current_app.config.get("ADCS_AUTH_METHOD") + adcs_template = current_app.config.get("ADCS_TEMPLATE") + ca_server = Certsrv( + adcs_server, adcs_user, adcs_pwd, auth_method=adcs_auth_method + ) current_app.logger.info("Requesting CSR: {0}".format(csr)) current_app.logger.info("Issuer options: {0}".format(issuer_options)) - cert, req_id = ca_server.get_cert(csr, adcs_template, encoding='b64').decode('utf-8').replace('\r\n', '\n') - chain = ca_server.get_ca_cert(encoding='b64').decode('utf-8').replace('\r\n', '\n') + cert, req_id = ( + ca_server.get_cert(csr, adcs_template, encoding="b64") + .decode("utf-8") + .replace("\r\n", "\n") + ) + chain = ( + ca_server.get_ca_cert(encoding="b64").decode("utf-8").replace("\r\n", "\n") + ) return cert, chain, req_id def revoke_certificate(self, certificate, comments): - raise NotImplementedError('Not implemented\n', self, certificate, comments) + raise NotImplementedError("Not implemented\n", self, certificate, comments) def get_ordered_certificate(self, order_id): - raise NotImplementedError('Not implemented\n', self, order_id) + raise NotImplementedError("Not implemented\n", self, order_id) def canceled_ordered_certificate(self, pending_cert, **kwargs): - raise NotImplementedError('Not implemented\n', self, pending_cert, **kwargs) + raise NotImplementedError("Not implemented\n", self, pending_cert, **kwargs) class ADCSSourcePlugin(SourcePlugin): - title = 'ADCS' - slug = 'adcs-source' - description = 'Enables the collecion of certificates' + title = "ADCS" + slug = "adcs-source" + description = "Enables the collecion of certificates" version = ADCS.VERSION - author = 'sirferl' - author_url = 'https://github.com/sirferl/lemur' + author = "sirferl" + author_url = "https://github.com/sirferl/lemur" options = [ { - 'name': 'dummy', - 'type': 'str', - 'required': False, - 'validation': '/^[0-9]{12,12}$/', - 'helpMessage': 'Just to prevent error' + "name": "dummy", + "type": "str", + "required": False, + "validation": "/^[0-9]{12,12}$/", + "helpMessage": "Just to prevent error", } ] def get_certificates(self, options, **kwargs): - adcs_server = current_app.config.get('ADCS_SERVER') - adcs_user = current_app.config.get('ADCS_USER') - adcs_pwd = current_app.config.get('ADCS_PWD') - adcs_auth_method = current_app.config.get('ADCS_AUTH_METHOD') - adcs_start = current_app.config.get('ADCS_START') - adcs_stop = current_app.config.get('ADCS_STOP') - ca_server = Certsrv(adcs_server, adcs_user, adcs_pwd, auth_method=adcs_auth_method) + adcs_server = current_app.config.get("ADCS_SERVER") + adcs_user = current_app.config.get("ADCS_USER") + adcs_pwd = current_app.config.get("ADCS_PWD") + adcs_auth_method = current_app.config.get("ADCS_AUTH_METHOD") + adcs_start = current_app.config.get("ADCS_START") + adcs_stop = current_app.config.get("ADCS_STOP") + ca_server = Certsrv( + adcs_server, adcs_user, adcs_pwd, auth_method=adcs_auth_method + ) out_certlist = [] for id in range(adcs_start, adcs_stop): try: - cert = ca_server.get_existing_cert(id, encoding='b64').decode('utf-8').replace('\r\n', '\n') + cert = ( + ca_server.get_existing_cert(id, encoding="b64") + .decode("utf-8") + .replace("\r\n", "\n") + ) except Exception as err: - if '{0}'.format(err).find("CERTSRV_E_PROPERTY_EMPTY"): + if "{0}".format(err).find("CERTSRV_E_PROPERTY_EMPTY"): # this error indicates end of certificate list(?), so we stop break else: @@ -101,16 +115,16 @@ class ADCSSourcePlugin(SourcePlugin): # loop through extensions to see if we find "TLS Web Server Authentication" for e_id in range(0, pubkey.get_extension_count() - 1): try: - extension = '{0}'.format(pubkey.get_extension(e_id)) + extension = "{0}".format(pubkey.get_extension(e_id)) except Exception: - extensionn = '' + extensionn = "" if extension.find("TLS Web Server Authentication") != -1: - out_certlist.append({ - 'name': format(pubkey.get_subject().CN), - 'body': cert}) + out_certlist.append( + {"name": format(pubkey.get_subject().CN), "body": cert} + ) break return out_certlist def get_endpoints(self, options, **kwargs): # There are no endpoints in the ADCS - raise NotImplementedError('Not implemented\n', self, options, **kwargs) + raise NotImplementedError("Not implemented\n", self, options, **kwargs) diff --git a/lemur/plugins/lemur_atlas/__init__.py b/lemur/plugins/lemur_atlas/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_atlas/__init__.py +++ b/lemur/plugins/lemur_atlas/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_atlas/plugin.py b/lemur/plugins/lemur_atlas/plugin.py index 09d4c9f9..7cf78ed2 100644 --- a/lemur/plugins/lemur_atlas/plugin.py +++ b/lemur/plugins/lemur_atlas/plugin.py @@ -26,44 +26,41 @@ def millis_since_epoch(): class AtlasMetricPlugin(MetricPlugin): - title = 'Atlas' - slug = 'atlas-metric' - description = 'Adds support for sending key metrics to Atlas' + title = "Atlas" + slug = "atlas-metric" + description = "Adds support for sending key metrics to Atlas" version = atlas.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" options = [ { - 'name': 'sidecar_host', - 'type': 'str', - 'required': False, - 'help_message': 'If no host is provided localhost is assumed', - 'default': 'localhost' + "name": "sidecar_host", + "type": "str", + "required": False, + "help_message": "If no host is provided localhost is assumed", + "default": "localhost", }, - { - 'name': 'sidecar_port', - 'type': 'int', - 'required': False, - 'default': 8078 - } + {"name": "sidecar_port", "type": "int", "required": False, "default": 8078}, ] metric_data = {} sidecar_host = None sidecar_port = None - def submit(self, metric_name, metric_type, metric_value, metric_tags=None, options=None): + def submit( + self, metric_name, metric_type, metric_value, metric_tags=None, options=None + ): if not options: options = self.options # TODO marshmallow schema? - valid_types = ['COUNTER', 'GAUGE', 'TIMER'] + valid_types = ["COUNTER", "GAUGE", "TIMER"] if metric_type.upper() not in valid_types: raise Exception( "Invalid Metric Type for Atlas: '{metric}' choose from: {options}".format( - metric=metric_type, options=','.join(valid_types) + metric=metric_type, options=",".join(valid_types) ) ) @@ -73,31 +70,35 @@ class AtlasMetricPlugin(MetricPlugin): "Invalid Metric Tags for Atlas: Tags must be in dict format" ) - if metric_value == "NaN" or isinstance(metric_value, int) or isinstance(metric_value, float): - self.metric_data['value'] = metric_value + if ( + metric_value == "NaN" + or isinstance(metric_value, int) + or isinstance(metric_value, float) + ): + self.metric_data["value"] = metric_value else: - raise Exception( - "Invalid Metric Value for Atlas: Metric must be a number" - ) + raise Exception("Invalid Metric Value for Atlas: Metric must be a number") - self.metric_data['type'] = metric_type.upper() - self.metric_data['name'] = str(metric_name) - self.metric_data['tags'] = metric_tags - self.metric_data['timestamp'] = millis_since_epoch() + self.metric_data["type"] = metric_type.upper() + self.metric_data["name"] = str(metric_name) + self.metric_data["tags"] = metric_tags + self.metric_data["timestamp"] = millis_since_epoch() - self.sidecar_host = self.get_option('sidecar_host', options) - self.sidecar_port = self.get_option('sidecar_port', options) + self.sidecar_host = self.get_option("sidecar_host", options) + self.sidecar_port = self.get_option("sidecar_port", options) try: res = requests.post( - 'http://{host}:{port}/metrics'.format( - host=self.sidecar_host, - port=self.sidecar_port), - data=json.dumps([self.metric_data]) + "http://{host}:{port}/metrics".format( + host=self.sidecar_host, port=self.sidecar_port + ), + data=json.dumps([self.metric_data]), ) if res.status_code != 200: - current_app.logger.warning("Failed to publish altas metric. {0}".format(res.content)) + current_app.logger.warning( + "Failed to publish altas metric. {0}".format(res.content) + ) except ConnectionError: current_app.logger.warning( diff --git a/lemur/plugins/lemur_aws/__init__.py b/lemur/plugins/lemur_aws/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_aws/__init__.py +++ b/lemur/plugins/lemur_aws/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_aws/ec2.py b/lemur/plugins/lemur_aws/ec2.py index 3bd20e60..04b42140 100644 --- a/lemur/plugins/lemur_aws/ec2.py +++ b/lemur/plugins/lemur_aws/ec2.py @@ -8,16 +8,16 @@ from lemur.plugins.lemur_aws.sts import sts_client -@sts_client('ec2') +@sts_client("ec2") def get_regions(**kwargs): - regions = kwargs['client'].describe_regions() - return [x['RegionName'] for x in regions['Regions']] + regions = kwargs["client"].describe_regions() + return [x["RegionName"] for x in regions["Regions"]] -@sts_client('ec2') +@sts_client("ec2") def get_all_instances(**kwargs): """ Fetches all instance objects for a given account and region. """ - paginator = kwargs['client'].get_paginator('describe_instances') + paginator = kwargs["client"].get_paginator("describe_instances") return paginator.paginate() diff --git a/lemur/plugins/lemur_aws/elb.py b/lemur/plugins/lemur_aws/elb.py index 618f75e8..1ab71b65 100644 --- a/lemur/plugins/lemur_aws/elb.py +++ b/lemur/plugins/lemur_aws/elb.py @@ -27,15 +27,14 @@ def retry_throttled(exception): raise exception except Exception as e: current_app.logger.error("ELB retry_throttled triggered", exc_info=True) - metrics.send('elb_retry', 'counter', 1, - metric_tags={"exception": e}) + metrics.send("elb_retry", "counter", 1, metric_tags={"exception": e}) sentry.captureException() if isinstance(exception, botocore.exceptions.ClientError): - if exception.response['Error']['Code'] == 'LoadBalancerNotFound': + if exception.response["Error"]["Code"] == "LoadBalancerNotFound": return False - if exception.response['Error']['Code'] == 'CertificateNotFound': + if exception.response["Error"]["Code"] == "CertificateNotFound": return False return True @@ -56,7 +55,7 @@ def is_valid(listener_tuple): :param listener_tuple: """ lb_port, i_port, lb_protocol, arn = listener_tuple - if lb_protocol.lower() in ['ssl', 'https']: + if lb_protocol.lower() in ["ssl", "https"]: if not arn: raise InvalidListener @@ -75,14 +74,14 @@ def get_all_elbs(**kwargs): while True: response = get_elbs(**kwargs) - elbs += response['LoadBalancerDescriptions'] + elbs += response["LoadBalancerDescriptions"] - if not response.get('NextMarker'): + if not response.get("NextMarker"): return elbs else: - kwargs.update(dict(Marker=response['NextMarker'])) + kwargs.update(dict(Marker=response["NextMarker"])) except Exception as e: # noqa - metrics.send('get_all_elbs_error', 'counter', 1) + metrics.send("get_all_elbs_error", "counter", 1) sentry.captureException() raise @@ -99,19 +98,19 @@ def get_all_elbs_v2(**kwargs): try: while True: response = get_elbs_v2(**kwargs) - elbs += response['LoadBalancers'] + elbs += response["LoadBalancers"] - if not response.get('NextMarker'): + if not response.get("NextMarker"): return elbs else: - kwargs.update(dict(Marker=response['NextMarker'])) + kwargs.update(dict(Marker=response["NextMarker"])) except Exception as e: # noqa - metrics.send('get_all_elbs_v2_error', 'counter', 1) + metrics.send("get_all_elbs_v2_error", "counter", 1) sentry.captureException() raise -@sts_client('elbv2') +@sts_client("elbv2") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def get_listener_arn_from_endpoint(endpoint_name, endpoint_port, **kwargs): """ @@ -121,38 +120,51 @@ def get_listener_arn_from_endpoint(endpoint_name, endpoint_port, **kwargs): :return: """ try: - client = kwargs.pop('client') + client = kwargs.pop("client") elbs = client.describe_load_balancers(Names=[endpoint_name]) - for elb in elbs['LoadBalancers']: - listeners = client.describe_listeners(LoadBalancerArn=elb['LoadBalancerArn']) - for listener in listeners['Listeners']: - if listener['Port'] == endpoint_port: - return listener['ListenerArn'] + for elb in elbs["LoadBalancers"]: + listeners = client.describe_listeners( + LoadBalancerArn=elb["LoadBalancerArn"] + ) + for listener in listeners["Listeners"]: + if listener["Port"] == endpoint_port: + return listener["ListenerArn"] except Exception as e: # noqa - metrics.send('get_listener_arn_from_endpoint_error', 'counter', 1, - metric_tags={"error": e, "endpoint_name": endpoint_name, "endpoint_port": endpoint_port}) - sentry.captureException(extra={"endpoint_name": str(endpoint_name), - "endpoint_port": str(endpoint_port)}) + metrics.send( + "get_listener_arn_from_endpoint_error", + "counter", + 1, + metric_tags={ + "error": e, + "endpoint_name": endpoint_name, + "endpoint_port": endpoint_port, + }, + ) + sentry.captureException( + extra={ + "endpoint_name": str(endpoint_name), + "endpoint_port": str(endpoint_port), + } + ) raise -@sts_client('elb') +@sts_client("elb") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def get_elbs(**kwargs): """ Fetches one page elb objects for a given account and region. """ try: - client = kwargs.pop('client') + client = kwargs.pop("client") return client.describe_load_balancers(**kwargs) except Exception as e: # noqa - metrics.send('get_elbs_error', 'counter', 1, - metric_tags={"error": e}) + metrics.send("get_elbs_error", "counter", 1, metric_tags={"error": e}) sentry.captureException() raise -@sts_client('elbv2') +@sts_client("elbv2") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def get_elbs_v2(**kwargs): """ @@ -162,16 +174,15 @@ def get_elbs_v2(**kwargs): :return: """ try: - client = kwargs.pop('client') + client = kwargs.pop("client") return client.describe_load_balancers(**kwargs) except Exception as e: # noqa - metrics.send('get_elbs_v2_error', 'counter', 1, - metric_tags={"error": e}) + metrics.send("get_elbs_v2_error", "counter", 1, metric_tags={"error": e}) sentry.captureException() raise -@sts_client('elbv2') +@sts_client("elbv2") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_listeners_v2(**kwargs): """ @@ -181,16 +192,17 @@ def describe_listeners_v2(**kwargs): :return: """ try: - client = kwargs.pop('client') + client = kwargs.pop("client") return client.describe_listeners(**kwargs) except Exception as e: # noqa - metrics.send('describe_listeners_v2_error', 'counter', 1, - metric_tags={"error": e}) + metrics.send( + "describe_listeners_v2_error", "counter", 1, metric_tags={"error": e} + ) sentry.captureException() raise -@sts_client('elb') +@sts_client("elb") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_load_balancer_policies(load_balancer_name, policy_names, **kwargs): """ @@ -201,17 +213,30 @@ def describe_load_balancer_policies(load_balancer_name, policy_names, **kwargs): """ try: - return kwargs['client'].describe_load_balancer_policies(LoadBalancerName=load_balancer_name, - PolicyNames=policy_names) + return kwargs["client"].describe_load_balancer_policies( + LoadBalancerName=load_balancer_name, PolicyNames=policy_names + ) except Exception as e: # noqa - metrics.send('describe_load_balancer_policies_error', 'counter', 1, - metric_tags={"load_balancer_name": load_balancer_name, "policy_names": policy_names, "error": e}) - sentry.captureException(extra={"load_balancer_name": str(load_balancer_name), - "policy_names": str(policy_names)}) + metrics.send( + "describe_load_balancer_policies_error", + "counter", + 1, + metric_tags={ + "load_balancer_name": load_balancer_name, + "policy_names": policy_names, + "error": e, + }, + ) + sentry.captureException( + extra={ + "load_balancer_name": str(load_balancer_name), + "policy_names": str(policy_names), + } + ) raise -@sts_client('elbv2') +@sts_client("elbv2") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_ssl_policies_v2(policy_names, **kwargs): """ @@ -221,15 +246,19 @@ def describe_ssl_policies_v2(policy_names, **kwargs): :return: """ try: - return kwargs['client'].describe_ssl_policies(Names=policy_names) + return kwargs["client"].describe_ssl_policies(Names=policy_names) except Exception as e: # noqa - metrics.send('describe_ssl_policies_v2_error', 'counter', 1, - metric_tags={"policy_names": policy_names, "error": e}) + metrics.send( + "describe_ssl_policies_v2_error", + "counter", + 1, + metric_tags={"policy_names": policy_names, "error": e}, + ) sentry.captureException(extra={"policy_names": str(policy_names)}) raise -@sts_client('elb') +@sts_client("elb") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def describe_load_balancer_types(policies, **kwargs): """ @@ -238,10 +267,12 @@ def describe_load_balancer_types(policies, **kwargs): :param policies: :return: """ - return kwargs['client'].describe_load_balancer_policy_types(PolicyTypeNames=policies) + return kwargs["client"].describe_load_balancer_policy_types( + PolicyTypeNames=policies + ) -@sts_client('elb') +@sts_client("elb") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def attach_certificate(name, port, certificate_id, **kwargs): """ @@ -253,15 +284,19 @@ def attach_certificate(name, port, certificate_id, **kwargs): :param certificate_id: """ try: - return kwargs['client'].set_load_balancer_listener_ssl_certificate(LoadBalancerName=name, LoadBalancerPort=port, SSLCertificateId=certificate_id) + return kwargs["client"].set_load_balancer_listener_ssl_certificate( + LoadBalancerName=name, + LoadBalancerPort=port, + SSLCertificateId=certificate_id, + ) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] == 'LoadBalancerNotFound': + if e.response["Error"]["Code"] == "LoadBalancerNotFound": current_app.logger.warning("Loadbalancer does not exist.") else: raise e -@sts_client('elbv2') +@sts_client("elbv2") @retry(retry_on_exception=retry_throttled, wait_fixed=2000, stop_max_attempt_number=20) def attach_certificate_v2(listener_arn, port, certificates, **kwargs): """ @@ -273,9 +308,11 @@ def attach_certificate_v2(listener_arn, port, certificates, **kwargs): :param certificates: """ try: - return kwargs['client'].modify_listener(ListenerArn=listener_arn, Port=port, Certificates=certificates) + return kwargs["client"].modify_listener( + ListenerArn=listener_arn, Port=port, Certificates=certificates + ) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] == 'LoadBalancerNotFound': + if e.response["Error"]["Code"] == "LoadBalancerNotFound": current_app.logger.warning("Loadbalancer does not exist.") else: raise e diff --git a/lemur/plugins/lemur_aws/iam.py b/lemur/plugins/lemur_aws/iam.py index 49816c2b..5a6b753d 100644 --- a/lemur/plugins/lemur_aws/iam.py +++ b/lemur/plugins/lemur_aws/iam.py @@ -21,10 +21,10 @@ def retry_throttled(exception): :return: """ if isinstance(exception, botocore.exceptions.ClientError): - if exception.response['Error']['Code'] == 'NoSuchEntity': + if exception.response["Error"]["Code"] == "NoSuchEntity": return False - metrics.send('iam_retry', 'counter', 1) + metrics.send("iam_retry", "counter", 1) return True @@ -47,11 +47,11 @@ def create_arn_from_cert(account_number, region, certificate_name): :return: """ return "arn:aws:iam::{account_number}:server-certificate/{certificate_name}".format( - account_number=account_number, - certificate_name=certificate_name) + account_number=account_number, certificate_name=certificate_name + ) -@sts_client('iam') +@sts_client("iam") @retry(retry_on_exception=retry_throttled, wait_fixed=2000) def upload_cert(name, body, private_key, path, cert_chain=None, **kwargs): """ @@ -65,12 +65,12 @@ def upload_cert(name, body, private_key, path, cert_chain=None, **kwargs): :return: """ assert isinstance(private_key, str) - client = kwargs.pop('client') + client = kwargs.pop("client") - if not path or path == '/': - path = '/' + if not path or path == "/": + path = "/" else: - name = name + '-' + path.strip('/') + name = name + "-" + path.strip("/") try: if cert_chain: @@ -79,21 +79,21 @@ def upload_cert(name, body, private_key, path, cert_chain=None, **kwargs): ServerCertificateName=name, CertificateBody=str(body), PrivateKey=str(private_key), - CertificateChain=str(cert_chain) + CertificateChain=str(cert_chain), ) else: return client.upload_server_certificate( Path=path, ServerCertificateName=name, CertificateBody=str(body), - PrivateKey=str(private_key) + PrivateKey=str(private_key), ) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] != 'EntityAlreadyExists': + if e.response["Error"]["Code"] != "EntityAlreadyExists": raise e -@sts_client('iam') +@sts_client("iam") @retry(retry_on_exception=retry_throttled, wait_fixed=2000) def delete_cert(cert_name, **kwargs): """ @@ -102,15 +102,15 @@ def delete_cert(cert_name, **kwargs): :param cert_name: :return: """ - client = kwargs.pop('client') + client = kwargs.pop("client") try: client.delete_server_certificate(ServerCertificateName=cert_name) except botocore.exceptions.ClientError as e: - if e.response['Error']['Code'] != 'NoSuchEntity': + if e.response["Error"]["Code"] != "NoSuchEntity": raise e -@sts_client('iam') +@sts_client("iam") @retry(retry_on_exception=retry_throttled, wait_fixed=2000) def get_certificate(name, **kwargs): """ @@ -118,13 +118,13 @@ def get_certificate(name, **kwargs): :return: """ - client = kwargs.pop('client') - return client.get_server_certificate( - ServerCertificateName=name - )['ServerCertificate'] + client = kwargs.pop("client") + return client.get_server_certificate(ServerCertificateName=name)[ + "ServerCertificate" + ] -@sts_client('iam') +@sts_client("iam") @retry(retry_on_exception=retry_throttled, wait_fixed=2000) def get_certificates(**kwargs): """ @@ -132,7 +132,7 @@ def get_certificates(**kwargs): :param kwargs: :return: """ - client = kwargs.pop('client') + client = kwargs.pop("client") return client.list_server_certificates(**kwargs) @@ -141,16 +141,20 @@ def get_all_certificates(**kwargs): Use STS to fetch all of the SSL certificates from a given account """ certificates = [] - account_number = kwargs.get('account_number') + account_number = kwargs.get("account_number") while True: response = get_certificates(**kwargs) - metadata = response['ServerCertificateMetadataList'] + metadata = response["ServerCertificateMetadataList"] for m in metadata: - certificates.append(get_certificate(m['ServerCertificateName'], account_number=account_number)) + certificates.append( + get_certificate( + m["ServerCertificateName"], account_number=account_number + ) + ) - if not response.get('Marker'): + if not response.get("Marker"): return certificates else: - kwargs.update(dict(Marker=response['Marker'])) + kwargs.update(dict(Marker=response["Marker"])) diff --git a/lemur/plugins/lemur_aws/plugin.py b/lemur/plugins/lemur_aws/plugin.py index 41bec31c..4414a62c 100644 --- a/lemur/plugins/lemur_aws/plugin.py +++ b/lemur/plugins/lemur_aws/plugin.py @@ -40,7 +40,7 @@ from lemur.plugins.lemur_aws import iam, s3, elb, ec2 def get_region_from_dns(dns): - return dns.split('.')[-4] + return dns.split(".")[-4] def format_elb_cipher_policy_v2(policy): @@ -52,10 +52,10 @@ def format_elb_cipher_policy_v2(policy): ciphers = [] name = None - for descr in policy['SslPolicies']: - name = descr['Name'] - for cipher in descr['Ciphers']: - ciphers.append(cipher['Name']) + for descr in policy["SslPolicies"]: + name = descr["Name"] + for cipher in descr["Ciphers"]: + ciphers.append(cipher["Name"]) return dict(name=name, ciphers=ciphers) @@ -68,14 +68,14 @@ def format_elb_cipher_policy(policy): """ ciphers = [] name = None - for descr in policy['PolicyDescriptions']: - for attr in descr['PolicyAttributeDescriptions']: - if attr['AttributeName'] == 'Reference-Security-Policy': - name = attr['AttributeValue'] + for descr in policy["PolicyDescriptions"]: + for attr in descr["PolicyAttributeDescriptions"]: + if attr["AttributeName"] == "Reference-Security-Policy": + name = attr["AttributeValue"] continue - if attr['AttributeValue'] == 'true': - ciphers.append(attr['AttributeName']) + if attr["AttributeValue"] == "true": + ciphers.append(attr["AttributeName"]) return dict(name=name, ciphers=ciphers) @@ -89,25 +89,31 @@ def get_elb_endpoints(account_number, region, elb_dict): :return: """ endpoints = [] - for listener in elb_dict['ListenerDescriptions']: - if not listener['Listener'].get('SSLCertificateId'): + for listener in elb_dict["ListenerDescriptions"]: + if not listener["Listener"].get("SSLCertificateId"): continue - if listener['Listener']['SSLCertificateId'] == 'Invalid-Certificate': + if listener["Listener"]["SSLCertificateId"] == "Invalid-Certificate": continue endpoint = dict( - name=elb_dict['LoadBalancerName'], - dnsname=elb_dict['DNSName'], - type='elb', - port=listener['Listener']['LoadBalancerPort'], - certificate_name=iam.get_name_from_arn(listener['Listener']['SSLCertificateId']) + name=elb_dict["LoadBalancerName"], + dnsname=elb_dict["DNSName"], + type="elb", + port=listener["Listener"]["LoadBalancerPort"], + certificate_name=iam.get_name_from_arn( + listener["Listener"]["SSLCertificateId"] + ), ) - if listener['PolicyNames']: - policy = elb.describe_load_balancer_policies(elb_dict['LoadBalancerName'], listener['PolicyNames'], - account_number=account_number, region=region) - endpoint['policy'] = format_elb_cipher_policy(policy) + if listener["PolicyNames"]: + policy = elb.describe_load_balancer_policies( + elb_dict["LoadBalancerName"], + listener["PolicyNames"], + account_number=account_number, + region=region, + ) + endpoint["policy"] = format_elb_cipher_policy(policy) current_app.logger.debug("Found new endpoint. Endpoint: {}".format(endpoint)) @@ -125,24 +131,29 @@ def get_elb_endpoints_v2(account_number, region, elb_dict): :return: """ endpoints = [] - listeners = elb.describe_listeners_v2(account_number=account_number, region=region, - LoadBalancerArn=elb_dict['LoadBalancerArn']) - for listener in listeners['Listeners']: - if not listener.get('Certificates'): + listeners = elb.describe_listeners_v2( + account_number=account_number, + region=region, + LoadBalancerArn=elb_dict["LoadBalancerArn"], + ) + for listener in listeners["Listeners"]: + if not listener.get("Certificates"): continue - for certificate in listener['Certificates']: + for certificate in listener["Certificates"]: endpoint = dict( - name=elb_dict['LoadBalancerName'], - dnsname=elb_dict['DNSName'], - type='elbv2', - port=listener['Port'], - certificate_name=iam.get_name_from_arn(certificate['CertificateArn']) + name=elb_dict["LoadBalancerName"], + dnsname=elb_dict["DNSName"], + type="elbv2", + port=listener["Port"], + certificate_name=iam.get_name_from_arn(certificate["CertificateArn"]), ) - if listener['SslPolicy']: - policy = elb.describe_ssl_policies_v2([listener['SslPolicy']], account_number=account_number, region=region) - endpoint['policy'] = format_elb_cipher_policy_v2(policy) + if listener["SslPolicy"]: + policy = elb.describe_ssl_policies_v2( + [listener["SslPolicy"]], account_number=account_number, region=region + ) + endpoint["policy"] = format_elb_cipher_policy_v2(policy) endpoints.append(endpoint) @@ -150,54 +161,70 @@ def get_elb_endpoints_v2(account_number, region, elb_dict): class AWSSourcePlugin(SourcePlugin): - title = 'AWS' - slug = 'aws-source' - description = 'Discovers all SSL certificates and ELB endpoints in an AWS account' + title = "AWS" + slug = "aws-source" + description = "Discovers all SSL certificates and ELB endpoints in an AWS account" version = aws.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" options = [ { - 'name': 'accountNumber', - 'type': 'str', - 'required': True, - 'validation': '/^[0-9]{12,12}$/', - 'helpMessage': 'Must be a valid AWS account number!', + "name": "accountNumber", + "type": "str", + "required": True, + "validation": "/^[0-9]{12,12}$/", + "helpMessage": "Must be a valid AWS account number!", }, { - 'name': 'regions', - 'type': 'str', - 'helpMessage': 'Comma separated list of regions to search in, if no region is specified we look in all regions.' + "name": "regions", + "type": "str", + "helpMessage": "Comma separated list of regions to search in, if no region is specified we look in all regions.", }, ] def get_certificates(self, options, **kwargs): - cert_data = iam.get_all_certificates(account_number=self.get_option('accountNumber', options)) - return [dict(body=c['CertificateBody'], chain=c.get('CertificateChain'), - name=c['ServerCertificateMetadata']['ServerCertificateName']) for c in cert_data] + cert_data = iam.get_all_certificates( + account_number=self.get_option("accountNumber", options) + ) + return [ + dict( + body=c["CertificateBody"], + chain=c.get("CertificateChain"), + name=c["ServerCertificateMetadata"]["ServerCertificateName"], + ) + for c in cert_data + ] def get_endpoints(self, options, **kwargs): endpoints = [] - account_number = self.get_option('accountNumber', options) - regions = self.get_option('regions', options) + account_number = self.get_option("accountNumber", options) + regions = self.get_option("regions", options) if not regions: regions = ec2.get_regions(account_number=account_number) else: - regions = regions.split(',') + regions = regions.split(",") for region in regions: elbs = elb.get_all_elbs(account_number=account_number, region=region) - current_app.logger.info("Describing classic load balancers in {0}-{1}".format(account_number, region)) + current_app.logger.info( + "Describing classic load balancers in {0}-{1}".format( + account_number, region + ) + ) for e in elbs: endpoints.extend(get_elb_endpoints(account_number, region, e)) # fetch advanced ELBs elbs_v2 = elb.get_all_elbs_v2(account_number=account_number, region=region) - current_app.logger.info("Describing advanced load balancers in {0}-{1}".format(account_number, region)) + current_app.logger.info( + "Describing advanced load balancers in {0}-{1}".format( + account_number, region + ) + ) for e in elbs_v2: endpoints.extend(get_elb_endpoints_v2(account_number, region, e)) @@ -206,106 +233,125 @@ class AWSSourcePlugin(SourcePlugin): def update_endpoint(self, endpoint, certificate): options = endpoint.source.options - account_number = self.get_option('accountNumber', options) + account_number = self.get_option("accountNumber", options) # relies on the fact that region is included in DNS name region = get_region_from_dns(endpoint.dnsname) arn = iam.create_arn_from_cert(account_number, region, certificate.name) - if endpoint.type == 'elbv2': - listener_arn = elb.get_listener_arn_from_endpoint(endpoint.name, endpoint.port, - account_number=account_number, region=region) - elb.attach_certificate_v2(listener_arn, endpoint.port, [{'CertificateArn': arn}], - account_number=account_number, region=region) + if endpoint.type == "elbv2": + listener_arn = elb.get_listener_arn_from_endpoint( + endpoint.name, + endpoint.port, + account_number=account_number, + region=region, + ) + elb.attach_certificate_v2( + listener_arn, + endpoint.port, + [{"CertificateArn": arn}], + account_number=account_number, + region=region, + ) else: - elb.attach_certificate(endpoint.name, endpoint.port, arn, account_number=account_number, region=region) + elb.attach_certificate( + endpoint.name, + endpoint.port, + arn, + account_number=account_number, + region=region, + ) def clean(self, certificate, options, **kwargs): - account_number = self.get_option('accountNumber', options) + account_number = self.get_option("accountNumber", options) iam.delete_cert(certificate.name, account_number=account_number) class AWSDestinationPlugin(DestinationPlugin): - title = 'AWS' - slug = 'aws-destination' - description = 'Allow the uploading of certificates to AWS IAM' + title = "AWS" + slug = "aws-destination" + description = "Allow the uploading of certificates to AWS IAM" version = aws.VERSION sync_as_source = True sync_as_source_name = AWSSourcePlugin.slug - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" options = [ { - 'name': 'accountNumber', - 'type': 'str', - 'required': True, - 'validation': '[0-9]{12}', - 'helpMessage': 'Must be a valid AWS account number!', + "name": "accountNumber", + "type": "str", + "required": True, + "validation": "[0-9]{12}", + "helpMessage": "Must be a valid AWS account number!", }, { - 'name': 'path', - 'type': 'str', - 'default': '/', - 'helpMessage': 'Path to upload certificate.' - } + "name": "path", + "type": "str", + "default": "/", + "helpMessage": "Path to upload certificate.", + }, ] def upload(self, name, body, private_key, cert_chain, options, **kwargs): - iam.upload_cert(name, body, private_key, - self.get_option('path', options), - cert_chain=cert_chain, - account_number=self.get_option('accountNumber', options)) + iam.upload_cert( + name, + body, + private_key, + self.get_option("path", options), + cert_chain=cert_chain, + account_number=self.get_option("accountNumber", options), + ) def deploy(self, elb_name, account, region, certificate): pass class S3DestinationPlugin(ExportDestinationPlugin): - title = 'AWS-S3' - slug = 'aws-s3' - description = 'Allow the uploading of certificates to Amazon S3' + title = "AWS-S3" + slug = "aws-s3" + description = "Allow the uploading of certificates to Amazon S3" - author = 'Mikhail Khodorovskiy, Harm Weites ' - author_url = 'https://github.com/Netflix/lemur' + author = "Mikhail Khodorovskiy, Harm Weites " + author_url = "https://github.com/Netflix/lemur" additional_options = [ { - 'name': 'bucket', - 'type': 'str', - 'required': True, - 'validation': '[0-9a-z.-]{3,63}', - 'helpMessage': 'Must be a valid S3 bucket name!', + "name": "bucket", + "type": "str", + "required": True, + "validation": "[0-9a-z.-]{3,63}", + "helpMessage": "Must be a valid S3 bucket name!", }, { - 'name': 'accountNumber', - 'type': 'str', - 'required': True, - 'validation': '[0-9]{12}', - 'helpMessage': 'A valid AWS account number with permission to access S3', + "name": "accountNumber", + "type": "str", + "required": True, + "validation": "[0-9]{12}", + "helpMessage": "A valid AWS account number with permission to access S3", }, { - 'name': 'region', - 'type': 'str', - 'default': 'us-east-1', - 'required': False, - 'helpMessage': 'Region bucket exists', - 'available': ['us-east-1', 'us-west-2', 'eu-west-1'] + "name": "region", + "type": "str", + "default": "us-east-1", + "required": False, + "helpMessage": "Region bucket exists", + "available": ["us-east-1", "us-west-2", "eu-west-1"], }, { - 'name': 'encrypt', - 'type': 'bool', - 'required': False, - 'helpMessage': 'Enable server side encryption', - 'default': True + "name": "encrypt", + "type": "bool", + "required": False, + "helpMessage": "Enable server side encryption", + "default": True, }, { - 'name': 'prefix', - 'type': 'str', - 'required': False, - 'helpMessage': 'Must be a valid S3 object prefix!', - } + "name": "prefix", + "type": "str", + "required": False, + "helpMessage": "Must be a valid S3 object prefix!", + }, ] def __init__(self, *args, **kwargs): @@ -316,13 +362,12 @@ class S3DestinationPlugin(ExportDestinationPlugin): for ext, passphrase, data in files: s3.put( - self.get_option('bucket', options), - self.get_option('region', options), - '{prefix}/{name}.{extension}'.format( - prefix=self.get_option('prefix', options), - name=name, - extension=ext), + self.get_option("bucket", options), + self.get_option("region", options), + "{prefix}/{name}.{extension}".format( + prefix=self.get_option("prefix", options), name=name, extension=ext + ), data, - self.get_option('encrypt', options), - account_number=self.get_option('accountNumber', options) + self.get_option("encrypt", options), + account_number=self.get_option("accountNumber", options), ) diff --git a/lemur/plugins/lemur_aws/s3.py b/lemur/plugins/lemur_aws/s3.py index 2f8983e5..43faa28f 100644 --- a/lemur/plugins/lemur_aws/s3.py +++ b/lemur/plugins/lemur_aws/s3.py @@ -10,28 +10,26 @@ from flask import current_app from .sts import sts_client -@sts_client('s3', service_type='resource') +@sts_client("s3", service_type="resource") def put(bucket_name, region, prefix, data, encrypt, **kwargs): """ Use STS to write to an S3 bucket """ - bucket = kwargs['resource'].Bucket(bucket_name) - current_app.logger.debug('Persisting data to S3. Bucket: {0} Prefix: {1}'.format(bucket_name, prefix)) + bucket = kwargs["resource"].Bucket(bucket_name) + current_app.logger.debug( + "Persisting data to S3. Bucket: {0} Prefix: {1}".format(bucket_name, prefix) + ) # get data ready for writing if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") if encrypt: bucket.put_object( Key=prefix, Body=data, - ACL='bucket-owner-full-control', - ServerSideEncryption='AES256' + ACL="bucket-owner-full-control", + ServerSideEncryption="AES256", ) else: - bucket.put_object( - Key=prefix, - Body=data, - ACL='bucket-owner-full-control' - ) + bucket.put_object(Key=prefix, Body=data, ACL="bucket-owner-full-control") diff --git a/lemur/plugins/lemur_aws/sts.py b/lemur/plugins/lemur_aws/sts.py index 6253ad7a..c1bd562c 100644 --- a/lemur/plugins/lemur_aws/sts.py +++ b/lemur/plugins/lemur_aws/sts.py @@ -13,46 +13,42 @@ from botocore.config import Config from flask import current_app -config = Config( - retries=dict( - max_attempts=20 - ) -) +config = Config(retries=dict(max_attempts=20)) -def sts_client(service, service_type='client'): +def sts_client(service, service_type="client"): def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): - sts = boto3.client('sts', config=config) - arn = 'arn:aws:iam::{0}:role/{1}'.format( - kwargs.pop('account_number'), - current_app.config.get('LEMUR_INSTANCE_PROFILE', 'Lemur') + sts = boto3.client("sts", config=config) + arn = "arn:aws:iam::{0}:role/{1}".format( + kwargs.pop("account_number"), + current_app.config.get("LEMUR_INSTANCE_PROFILE", "Lemur"), ) # TODO add user specific information to RoleSessionName - role = sts.assume_role(RoleArn=arn, RoleSessionName='lemur') + role = sts.assume_role(RoleArn=arn, RoleSessionName="lemur") - if service_type == 'client': + if service_type == "client": client = boto3.client( service, - region_name=kwargs.pop('region', 'us-east-1'), - aws_access_key_id=role['Credentials']['AccessKeyId'], - aws_secret_access_key=role['Credentials']['SecretAccessKey'], - aws_session_token=role['Credentials']['SessionToken'], - config=config + region_name=kwargs.pop("region", "us-east-1"), + aws_access_key_id=role["Credentials"]["AccessKeyId"], + aws_secret_access_key=role["Credentials"]["SecretAccessKey"], + aws_session_token=role["Credentials"]["SessionToken"], + config=config, ) - kwargs['client'] = client - elif service_type == 'resource': + kwargs["client"] = client + elif service_type == "resource": resource = boto3.resource( service, - region_name=kwargs.pop('region', 'us-east-1'), - aws_access_key_id=role['Credentials']['AccessKeyId'], - aws_secret_access_key=role['Credentials']['SecretAccessKey'], - aws_session_token=role['Credentials']['SessionToken'], - config=config + region_name=kwargs.pop("region", "us-east-1"), + aws_access_key_id=role["Credentials"]["AccessKeyId"], + aws_secret_access_key=role["Credentials"]["SecretAccessKey"], + aws_session_token=role["Credentials"]["SessionToken"], + config=config, ) - kwargs['resource'] = resource + kwargs["resource"] = resource return f(*args, **kwargs) return decorated_function diff --git a/lemur/plugins/lemur_aws/tests/test_elb.py b/lemur/plugins/lemur_aws/tests/test_elb.py index 7facc4dd..4571b87a 100644 --- a/lemur/plugins/lemur_aws/tests/test_elb.py +++ b/lemur/plugins/lemur_aws/tests/test_elb.py @@ -6,23 +6,24 @@ from moto import mock_sts, mock_elb @mock_elb() def test_get_all_elbs(app, aws_credentials): from lemur.plugins.lemur_aws.elb import get_all_elbs - client = boto3.client('elb', region_name='us-east-1') - elbs = get_all_elbs(account_number='123456789012', region='us-east-1') + client = boto3.client("elb", region_name="us-east-1") + + elbs = get_all_elbs(account_number="123456789012", region="us-east-1") assert not elbs client.create_load_balancer( - LoadBalancerName='example-lb', + LoadBalancerName="example-lb", Listeners=[ { - 'Protocol': 'string', - 'LoadBalancerPort': 443, - 'InstanceProtocol': 'tcp', - 'InstancePort': 5443, - 'SSLCertificateId': 'tcp' + "Protocol": "string", + "LoadBalancerPort": 443, + "InstanceProtocol": "tcp", + "InstancePort": 5443, + "SSLCertificateId": "tcp", } - ] + ], ) - elbs = get_all_elbs(account_number='123456789012', region='us-east-1') + elbs = get_all_elbs(account_number="123456789012", region="us-east-1") assert elbs diff --git a/lemur/plugins/lemur_aws/tests/test_iam.py b/lemur/plugins/lemur_aws/tests/test_iam.py index deec221e..5932d52d 100644 --- a/lemur/plugins/lemur_aws/tests/test_iam.py +++ b/lemur/plugins/lemur_aws/tests/test_iam.py @@ -6,15 +6,21 @@ from lemur.tests.vectors import EXTERNAL_VALID_STR, SAN_CERT_KEY def test_get_name_from_arn(): from lemur.plugins.lemur_aws.iam import get_name_from_arn - arn = 'arn:aws:iam::123456789012:server-certificate/tttt2.netflixtest.net-NetflixInc-20150624-20150625' - assert get_name_from_arn(arn) == 'tttt2.netflixtest.net-NetflixInc-20150624-20150625' + + arn = "arn:aws:iam::123456789012:server-certificate/tttt2.netflixtest.net-NetflixInc-20150624-20150625" + assert ( + get_name_from_arn(arn) == "tttt2.netflixtest.net-NetflixInc-20150624-20150625" + ) -@pytest.mark.skipif(True, reason="this fails because moto is not currently returning what boto does") +@pytest.mark.skipif( + True, reason="this fails because moto is not currently returning what boto does" +) @mock_sts() @mock_iam() def test_get_all_server_certs(app): from lemur.plugins.lemur_aws.iam import upload_cert, get_all_certificates - upload_cert('123456789012', 'testCert', EXTERNAL_VALID_STR, SAN_CERT_KEY) - certs = get_all_certificates('123456789012') + + upload_cert("123456789012", "testCert", EXTERNAL_VALID_STR, SAN_CERT_KEY) + certs = get_all_certificates("123456789012") assert len(certs) == 1 diff --git a/lemur/plugins/lemur_aws/tests/test_plugin.py b/lemur/plugins/lemur_aws/tests/test_plugin.py index 95e4c9a4..dbad7b02 100644 --- a/lemur/plugins/lemur_aws/tests/test_plugin.py +++ b/lemur/plugins/lemur_aws/tests/test_plugin.py @@ -1,6 +1,5 @@ - def test_get_certificates(app): from lemur.plugins.base import plugins - p = plugins.get('aws-s3') + p = plugins.get("aws-s3") assert p diff --git a/lemur/plugins/lemur_cfssl/__init__.py b/lemur/plugins/lemur_cfssl/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_cfssl/__init__.py +++ b/lemur/plugins/lemur_cfssl/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_cfssl/plugin.py b/lemur/plugins/lemur_cfssl/plugin.py index 4bfefc85..ae16d168 100644 --- a/lemur/plugins/lemur_cfssl/plugin.py +++ b/lemur/plugins/lemur_cfssl/plugin.py @@ -24,13 +24,13 @@ from lemur.extensions import metrics class CfsslIssuerPlugin(IssuerPlugin): - title = 'CFSSL' - slug = 'cfssl-issuer' - description = 'Enables the creation of certificates by CFSSL private CA' + title = "CFSSL" + slug = "cfssl-issuer" + description = "Enables the creation of certificates by CFSSL private CA" version = cfssl.VERSION - author = 'Charles Hendrie' - author_url = 'https://github.com/netflix/lemur.git' + author = "Charles Hendrie" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): self.session = requests.Session() @@ -44,15 +44,17 @@ class CfsslIssuerPlugin(IssuerPlugin): :param issuer_options: :return: """ - current_app.logger.info("Requesting a new cfssl certificate with csr: {0}".format(csr)) + current_app.logger.info( + "Requesting a new cfssl certificate with csr: {0}".format(csr) + ) - url = "{0}{1}".format(current_app.config.get('CFSSL_URL'), '/api/v1/cfssl/sign') + url = "{0}{1}".format(current_app.config.get("CFSSL_URL"), "/api/v1/cfssl/sign") - data = {'certificate_request': csr} + data = {"certificate_request": csr} data = json.dumps(data) try: - hex_key = current_app.config.get('CFSSL_KEY') + hex_key = current_app.config.get("CFSSL_KEY") key = bytes.fromhex(hex_key) except (ValueError, NameError): # unable to find CFSSL_KEY in config, continue using normal sign method @@ -60,22 +62,33 @@ class CfsslIssuerPlugin(IssuerPlugin): else: data = data.encode() - token = base64.b64encode(hmac.new(key, data, digestmod=hashlib.sha256).digest()) + token = base64.b64encode( + hmac.new(key, data, digestmod=hashlib.sha256).digest() + ) data = base64.b64encode(data) - data = json.dumps({'token': token.decode('utf-8'), 'request': data.decode('utf-8')}) + data = json.dumps( + {"token": token.decode("utf-8"), "request": data.decode("utf-8")} + ) - url = "{0}{1}".format(current_app.config.get('CFSSL_URL'), '/api/v1/cfssl/authsign') - response = self.session.post(url, data=data.encode(encoding='utf_8', errors='strict')) + url = "{0}{1}".format( + current_app.config.get("CFSSL_URL"), "/api/v1/cfssl/authsign" + ) + response = self.session.post( + url, data=data.encode(encoding="utf_8", errors="strict") + ) if response.status_code > 399: - metrics.send('cfssl_create_certificate_failure', 'counter', 1) - raise Exception( - "Error creating cert. Please check your CFSSL API server") - response_json = json.loads(response.content.decode('utf_8')) - cert = response_json['result']['certificate'] + metrics.send("cfssl_create_certificate_failure", "counter", 1) + raise Exception("Error creating cert. Please check your CFSSL API server") + response_json = json.loads(response.content.decode("utf_8")) + cert = response_json["result"]["certificate"] parsed_cert = parse_certificate(cert) - metrics.send('cfssl_create_certificate_success', 'counter', 1) - return cert, current_app.config.get('CFSSL_INTERMEDIATE'), parsed_cert.serial_number + metrics.send("cfssl_create_certificate_success", "counter", 1) + return ( + cert, + current_app.config.get("CFSSL_INTERMEDIATE"), + parsed_cert.serial_number, + ) @staticmethod def create_authority(options): @@ -86,22 +99,26 @@ class CfsslIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'cfssl'} - return current_app.config.get('CFSSL_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "cfssl"} + return current_app.config.get("CFSSL_ROOT"), "", [role] def revoke_certificate(self, certificate, comments): """Revoke a CFSSL certificate.""" - base_url = current_app.config.get('CFSSL_URL') - create_url = '{0}/api/v1/cfssl/revoke'.format(base_url) - data = '{"serial": "' + certificate.external_id + '","authority_key_id": "' + \ - get_authority_key(certificate.body) + \ - '", "reason": "superseded"}' + base_url = current_app.config.get("CFSSL_URL") + create_url = "{0}/api/v1/cfssl/revoke".format(base_url) + data = ( + '{"serial": "' + + certificate.external_id + + '","authority_key_id": "' + + get_authority_key(certificate.body) + + '", "reason": "superseded"}' + ) current_app.logger.debug("Revoking cert: {0}".format(data)) response = self.session.post( - create_url, data=data.encode(encoding='utf_8', errors='strict')) + create_url, data=data.encode(encoding="utf_8", errors="strict") + ) if response.status_code > 399: - metrics.send('cfssl_revoke_certificate_failure', 'counter', 1) - raise Exception( - "Error revoking cert. Please check your CFSSL API server") - metrics.send('cfssl_revoke_certificate_success', 'counter', 1) + metrics.send("cfssl_revoke_certificate_failure", "counter", 1) + raise Exception("Error revoking cert. Please check your CFSSL API server") + metrics.send("cfssl_revoke_certificate_success", "counter", 1) return response.json() diff --git a/lemur/plugins/lemur_cfssl/tests/test_cfssl.py b/lemur/plugins/lemur_cfssl/tests/test_cfssl.py index ea8f0856..10fb9963 100644 --- a/lemur/plugins/lemur_cfssl/tests/test_cfssl.py +++ b/lemur/plugins/lemur_cfssl/tests/test_cfssl.py @@ -1,6 +1,5 @@ - def test_get_certificates(app): from lemur.plugins.base import plugins - p = plugins.get('cfssl-issuer') + p = plugins.get("cfssl-issuer") assert p diff --git a/lemur/plugins/lemur_cryptography/__init__.py b/lemur/plugins/lemur_cryptography/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_cryptography/__init__.py +++ b/lemur/plugins/lemur_cryptography/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_cryptography/plugin.py b/lemur/plugins/lemur_cryptography/plugin.py index 97060391..005f36f9 100644 --- a/lemur/plugins/lemur_cryptography/plugin.py +++ b/lemur/plugins/lemur_cryptography/plugin.py @@ -22,7 +22,7 @@ from lemur.certificates.service import create_csr def build_certificate_authority(options): - options['certificate_authority'] = True + options["certificate_authority"] = True csr, private_key = create_csr(**options) cert_pem, chain_cert_pem = issue_certificate(csr, options, private_key) @@ -30,24 +30,32 @@ def build_certificate_authority(options): def issue_certificate(csr, options, private_key=None): - csr = x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend()) + csr = x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend()) if options.get("parent"): # creating intermediate authorities will have options['parent'] to specify the issuer # creating certificates will have options['authority'] to specify the issuer # This works around that by making sure options['authority'] can be referenced for either - options['authority'] = options['parent'] + options["authority"] = options["parent"] if options.get("authority"): # Issue certificate signed by an existing lemur_certificates authority - issuer_subject = options['authority'].authority_certificate.subject - assert private_key is None, "Private would be ignored, authority key used instead" - private_key = options['authority'].authority_certificate.private_key - chain_cert_pem = options['authority'].authority_certificate.body - authority_key_identifier_public = options['authority'].authority_certificate.public_key - authority_key_identifier_subject = x509.SubjectKeyIdentifier.from_public_key(authority_key_identifier_public) + issuer_subject = options["authority"].authority_certificate.subject + assert ( + private_key is None + ), "Private would be ignored, authority key used instead" + private_key = options["authority"].authority_certificate.private_key + chain_cert_pem = options["authority"].authority_certificate.body + authority_key_identifier_public = options[ + "authority" + ].authority_certificate.public_key + authority_key_identifier_subject = x509.SubjectKeyIdentifier.from_public_key( + authority_key_identifier_public + ) authority_key_identifier_issuer = issuer_subject - authority_key_identifier_serial = int(options['authority'].authority_certificate.serial) + authority_key_identifier_serial = int( + options["authority"].authority_certificate.serial + ) # TODO figure out a better way to increment serial # New authorities have a value at options['serial_number'] that is being ignored here. serial = int(uuid.uuid4()) @@ -58,7 +66,7 @@ def issue_certificate(csr, options, private_key=None): authority_key_identifier_public = csr.public_key() authority_key_identifier_subject = None authority_key_identifier_issuer = csr.subject - authority_key_identifier_serial = options['serial_number'] + authority_key_identifier_serial = options["serial_number"] # TODO figure out a better way to increment serial serial = int(uuid.uuid4()) @@ -68,19 +76,20 @@ def issue_certificate(csr, options, private_key=None): issuer_name=issuer_subject, subject_name=csr.subject, public_key=csr.public_key(), - not_valid_before=options['validity_start'], - not_valid_after=options['validity_end'], + not_valid_before=options["validity_start"], + not_valid_after=options["validity_end"], serial_number=serial, - extensions=extensions) + extensions=extensions, + ) - for k, v in options.get('extensions', {}).items(): - if k == 'authority_key_identifier': + for k, v in options.get("extensions", {}).items(): + if k == "authority_key_identifier": # One or both of these options may be present inside the aki extension (authority_key_identifier, authority_identifier) = (False, False) for k2, v2 in v.items(): - if k2 == 'use_key_identifier' and v2: + if k2 == "use_key_identifier" and v2: authority_key_identifier = True - if k2 == 'use_authority_cert' and v2: + if k2 == "use_authority_cert" and v2: authority_identifier = True if authority_key_identifier: if authority_key_identifier_subject: @@ -89,13 +98,21 @@ def issue_certificate(csr, options, private_key=None): # but the digest of the ski is at just ski.digest. Until that library is fixed, # this function won't work. The second line has the same result. # aki = x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(authority_key_identifier_subject) - aki = x509.AuthorityKeyIdentifier(authority_key_identifier_subject.digest, None, None) + aki = x509.AuthorityKeyIdentifier( + authority_key_identifier_subject.digest, None, None + ) else: - aki = x509.AuthorityKeyIdentifier.from_issuer_public_key(authority_key_identifier_public) + aki = x509.AuthorityKeyIdentifier.from_issuer_public_key( + authority_key_identifier_public + ) elif authority_identifier: - aki = x509.AuthorityKeyIdentifier(None, [x509.DirectoryName(authority_key_identifier_issuer)], authority_key_identifier_serial) + aki = x509.AuthorityKeyIdentifier( + None, + [x509.DirectoryName(authority_key_identifier_issuer)], + authority_key_identifier_serial, + ) builder = builder.add_extension(aki, critical=False) - if k == 'certificate_info_access': + if k == "certificate_info_access": # FIXME: Implement the AuthorityInformationAccess extension # descriptions = [ # x509.AccessDescription(x509.oid.AuthorityInformationAccessOID.OCSP, x509.UniformResourceIdentifier(u"http://FIXME")), @@ -108,7 +125,7 @@ def issue_certificate(csr, options, private_key=None): # critical=False # ) pass - if k == 'crl_distribution_points': + if k == "crl_distribution_points": # FIXME: Implement the CRLDistributionPoints extension # FIXME: Not implemented in lemur/schemas.py yet https://github.com/Netflix/lemur/issues/662 pass @@ -116,20 +133,24 @@ def issue_certificate(csr, options, private_key=None): private_key = parse_private_key(private_key) cert = builder.sign(private_key, hashes.SHA256(), default_backend()) - cert_pem = cert.public_bytes( - encoding=serialization.Encoding.PEM - ).decode('utf-8') + cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8") return cert_pem, chain_cert_pem def normalize_extensions(csr): try: - san_extension = csr.extensions.get_extension_for_oid(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + san_extension = csr.extensions.get_extension_for_oid( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ) san_dnsnames = san_extension.value.get_values_for_type(x509.DNSName) except x509.extensions.ExtensionNotFound: san_dnsnames = [] - san_extension = x509.Extension(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, True, x509.SubjectAlternativeName(san_dnsnames)) + san_extension = x509.Extension( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + True, + x509.SubjectAlternativeName(san_dnsnames), + ) common_name = csr.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME) common_name = common_name[0].value @@ -149,7 +170,11 @@ def normalize_extensions(csr): for san in san_extension.value: general_names.append(san) - san_extension = x509.Extension(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, True, x509.SubjectAlternativeName(general_names)) + san_extension = x509.Extension( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + True, + x509.SubjectAlternativeName(general_names), + ) # Remove original san extension from CSR and add new SAN extension extensions = list(filter(filter_san_extensions, csr.extensions._extensions)) @@ -166,13 +191,13 @@ def filter_san_extensions(ext): class CryptographyIssuerPlugin(IssuerPlugin): - title = 'Cryptography' - slug = 'cryptography-issuer' - description = 'Enables the creation and signing of self-signed certificates' + title = "Cryptography" + slug = "cryptography-issuer" + description = "Enables the creation and signing of self-signed certificates" version = cryptography_issuer.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def create_certificate(self, csr, options): """ @@ -182,7 +207,9 @@ class CryptographyIssuerPlugin(IssuerPlugin): :param options: :return: :raise Exception: """ - current_app.logger.debug("Issuing new cryptography certificate with options: {0}".format(options)) + current_app.logger.debug( + "Issuing new cryptography certificate with options: {0}".format(options) + ) cert_pem, chain_cert_pem = issue_certificate(csr, options) return cert_pem, chain_cert_pem, None @@ -195,10 +222,12 @@ class CryptographyIssuerPlugin(IssuerPlugin): :param options: :return: """ - current_app.logger.debug("Issuing new cryptography authority with options: {0}".format(options)) + current_app.logger.debug( + "Issuing new cryptography authority with options: {0}".format(options) + ) cert_pem, private_key, chain_cert_pem = build_certificate_authority(options) roles = [ - {'username': '', 'password': '', 'name': options['name'] + '_admin'}, - {'username': '', 'password': '', 'name': options['name'] + '_operator'} + {"username": "", "password": "", "name": options["name"] + "_admin"}, + {"username": "", "password": "", "name": options["name"] + "_operator"}, ] return cert_pem, private_key, chain_cert_pem, roles diff --git a/lemur/plugins/lemur_cryptography/tests/test_cryptography.py b/lemur/plugins/lemur_cryptography/tests/test_cryptography.py index 8a81bf6c..7f1777fc 100644 --- a/lemur/plugins/lemur_cryptography/tests/test_cryptography.py +++ b/lemur/plugins/lemur_cryptography/tests/test_cryptography.py @@ -5,24 +5,24 @@ def test_build_certificate_authority(): from lemur.plugins.lemur_cryptography.plugin import build_certificate_authority options = { - 'key_type': 'RSA2048', - 'country': 'US', - 'state': 'CA', - 'location': 'Example place', - 'organization': 'Example, Inc.', - 'organizational_unit': 'Example Unit', - 'common_name': 'Example ROOT', - 'validity_start': arrow.get('2016-12-01').datetime, - 'validity_end': arrow.get('2016-12-02').datetime, - 'first_serial': 1, - 'serial_number': 1, - 'owner': 'owner@example.com' + "key_type": "RSA2048", + "country": "US", + "state": "CA", + "location": "Example place", + "organization": "Example, Inc.", + "organizational_unit": "Example Unit", + "common_name": "Example ROOT", + "validity_start": arrow.get("2016-12-01").datetime, + "validity_end": arrow.get("2016-12-02").datetime, + "first_serial": 1, + "serial_number": 1, + "owner": "owner@example.com", } cert_pem, private_key_pem, chain_cert_pem = build_certificate_authority(options) assert cert_pem assert private_key_pem - assert chain_cert_pem == '' + assert chain_cert_pem == "" def test_issue_certificate(authority): @@ -30,10 +30,10 @@ def test_issue_certificate(authority): from lemur.plugins.lemur_cryptography.plugin import issue_certificate options = { - 'common_name': 'Example.com', - 'authority': authority, - 'validity_start': arrow.get('2016-12-01').datetime, - 'validity_end': arrow.get('2016-12-02').datetime + "common_name": "Example.com", + "authority": authority, + "validity_start": arrow.get("2016-12-01").datetime, + "validity_end": arrow.get("2016-12-02").datetime, } cert_pem, chain_cert_pem = issue_certificate(CSR_STR, options) assert cert_pem diff --git a/lemur/plugins/lemur_csr/__init__.py b/lemur/plugins/lemur_csr/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_csr/__init__.py +++ b/lemur/plugins/lemur_csr/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_csr/plugin.py b/lemur/plugins/lemur_csr/plugin.py index 13f42084..776dfce5 100644 --- a/lemur/plugins/lemur_csr/plugin.py +++ b/lemur/plugins/lemur_csr/plugin.py @@ -43,38 +43,30 @@ def create_csr(cert, chain, csr_tmp, key): assert isinstance(key, str) with mktempfile() as key_tmp: - with open(key_tmp, 'w') as f: + with open(key_tmp, "w") as f: f.write(key) with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: + with open(cert_tmp, "w") as f: if chain: f.writelines([cert.strip() + "\n", chain.strip() + "\n"]) else: f.writelines([cert.strip() + "\n"]) - output = subprocess.check_output([ - "openssl", - "x509", - "-x509toreq", - "-in", cert_tmp, - "-signkey", key_tmp, - ]) - subprocess.run([ - "openssl", - "req", - "-out", csr_tmp - ], input=output) + output = subprocess.check_output( + ["openssl", "x509", "-x509toreq", "-in", cert_tmp, "-signkey", key_tmp] + ) + subprocess.run(["openssl", "req", "-out", csr_tmp], input=output) class CSRExportPlugin(ExportPlugin): - title = 'CSR' - slug = 'openssl-csr' - description = 'Exports a CSR' + title = "CSR" + slug = "openssl-csr" + description = "Exports a CSR" version = csr.VERSION - author = 'jchuong' - author_url = 'https://github.com/jchuong' + author = "jchuong" + author_url = "https://github.com/jchuong" def export(self, body, chain, key, options, **kwargs): """ @@ -93,7 +85,7 @@ class CSRExportPlugin(ExportPlugin): create_csr(body, chain, output_tmp, key) extension = "csr" - with open(output_tmp, 'rb') as f: + with open(output_tmp, "rb") as f: raw = f.read() # passphrase is None return extension, None, raw diff --git a/lemur/plugins/lemur_csr/tests/test_csr_export.py b/lemur/plugins/lemur_csr/tests/test_csr_export.py index 9b233a4e..0b55aefe 100644 --- a/lemur/plugins/lemur_csr/tests/test_csr_export.py +++ b/lemur/plugins/lemur_csr/tests/test_csr_export.py @@ -4,7 +4,8 @@ from lemur.tests.vectors import INTERNAL_PRIVATE_KEY_A_STR, INTERNAL_CERTIFICATE def test_export_certificate_to_csr(app): from lemur.plugins.base import plugins - p = plugins.get('openssl-csr') + + p = plugins.get("openssl-csr") options = [] with pytest.raises(Exception): p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) diff --git a/lemur/plugins/lemur_digicert/__init__.py b/lemur/plugins/lemur_digicert/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_digicert/__init__.py +++ b/lemur/plugins/lemur_digicert/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_digicert/plugin.py b/lemur/plugins/lemur_digicert/plugin.py index a65c02ff..c5b01cc4 100644 --- a/lemur/plugins/lemur_digicert/plugin.py +++ b/lemur/plugins/lemur_digicert/plugin.py @@ -40,7 +40,7 @@ def log_status_code(r, *args, **kwargs): :param kwargs: :return: """ - metrics.send('digicert_status_code_{}'.format(r.status_code), 'counter', 1) + metrics.send("digicert_status_code_{}".format(r.status_code), "counter", 1) def signature_hash(signing_algorithm): @@ -50,18 +50,18 @@ def signature_hash(signing_algorithm): :return: str digicert specific algorithm string """ if not signing_algorithm: - return current_app.config.get('DIGICERT_DEFAULT_SIGNING_ALGORITHM', 'sha256') + return current_app.config.get("DIGICERT_DEFAULT_SIGNING_ALGORITHM", "sha256") - if signing_algorithm == 'sha256WithRSA': - return 'sha256' + if signing_algorithm == "sha256WithRSA": + return "sha256" - elif signing_algorithm == 'sha384WithRSA': - return 'sha384' + elif signing_algorithm == "sha384WithRSA": + return "sha384" - elif signing_algorithm == 'sha512WithRSA': - return 'sha512' + elif signing_algorithm == "sha512WithRSA": + return "sha512" - raise Exception('Unsupported signing algorithm.') + raise Exception("Unsupported signing algorithm.") def determine_validity_years(end_date): @@ -79,8 +79,9 @@ def determine_validity_years(end_date): elif end_date < now.replace(years=+3): return 3 - raise Exception("DigiCert issued certificates cannot exceed three" - " years in validity") + raise Exception( + "DigiCert issued certificates cannot exceed three" " years in validity" + ) def get_additional_names(options): @@ -92,8 +93,8 @@ def get_additional_names(options): """ names = [] # add SANs if present - if options.get('extensions'): - for san in options['extensions']['sub_alt_names']['names']: + if options.get("extensions"): + for san in options["extensions"]["sub_alt_names"]["names"]: if isinstance(san, x509.DNSName): names.append(san.value) return names @@ -106,31 +107,33 @@ def map_fields(options, csr): :param csr: :return: dict or valid DigiCert options """ - if not options.get('validity_years'): - if not options.get('validity_end'): - options['validity_years'] = current_app.config.get('DIGICERT_DEFAULT_VALIDITY', 1) + if not options.get("validity_years"): + if not options.get("validity_end"): + options["validity_years"] = current_app.config.get( + "DIGICERT_DEFAULT_VALIDITY", 1 + ) - data = dict(certificate={ - "common_name": options['common_name'], - "csr": csr, - "signature_hash": - signature_hash(options.get('signing_algorithm')), - }, organization={ - "id": current_app.config.get("DIGICERT_ORG_ID") - }) + data = dict( + certificate={ + "common_name": options["common_name"], + "csr": csr, + "signature_hash": signature_hash(options.get("signing_algorithm")), + }, + organization={"id": current_app.config.get("DIGICERT_ORG_ID")}, + ) - data['certificate']['dns_names'] = get_additional_names(options) + data["certificate"]["dns_names"] = get_additional_names(options) - if options.get('validity_years'): - data['validity_years'] = options['validity_years'] + if options.get("validity_years"): + data["validity_years"] = options["validity_years"] else: - data['custom_expiration_date'] = options['validity_end'].format('YYYY-MM-DD') + data["custom_expiration_date"] = options["validity_end"].format("YYYY-MM-DD") - if current_app.config.get('DIGICERT_PRIVATE', False): - if 'product' in data: - data['product']['type_hint'] = 'private' + if current_app.config.get("DIGICERT_PRIVATE", False): + if "product" in data: + data["product"]["type_hint"] = "private" else: - data['product'] = dict(type_hint='private') + data["product"] = dict(type_hint="private") return data @@ -143,26 +146,30 @@ def map_cis_fields(options, csr): :param csr: :return: """ - if not options.get('validity_years'): - if not options.get('validity_end'): - options['validity_end'] = arrow.utcnow().replace(years=current_app.config.get('DIGICERT_DEFAULT_VALIDITY', 1)) - options['validity_years'] = determine_validity_years(options['validity_end']) + if not options.get("validity_years"): + if not options.get("validity_end"): + options["validity_end"] = arrow.utcnow().replace( + years=current_app.config.get("DIGICERT_DEFAULT_VALIDITY", 1) + ) + options["validity_years"] = determine_validity_years(options["validity_end"]) else: - options['validity_end'] = arrow.utcnow().replace(years=options['validity_years']) + options["validity_end"] = arrow.utcnow().replace( + years=options["validity_years"] + ) data = { - "profile_name": current_app.config.get('DIGICERT_CIS_PROFILE_NAME'), - "common_name": options['common_name'], + "profile_name": current_app.config.get("DIGICERT_CIS_PROFILE_NAME"), + "common_name": options["common_name"], "additional_dns_names": get_additional_names(options), "csr": csr, - "signature_hash": signature_hash(options.get('signing_algorithm')), + "signature_hash": signature_hash(options.get("signing_algorithm")), "validity": { - "valid_to": options['validity_end'].format('YYYY-MM-DDTHH:MM') + 'Z' + "valid_to": options["validity_end"].format("YYYY-MM-DDTHH:MM") + "Z" }, "organization": { - "name": options['organization'], - "units": [options['organizational_unit']] - } + "name": options["organization"], + "units": [options["organizational_unit"]], + }, } return data @@ -175,7 +182,7 @@ def handle_response(response): :return: """ if response.status_code > 399: - raise Exception(response.json()['errors'][0]['message']) + raise Exception(response.json()["errors"][0]["message"]) return response.json() @@ -197,19 +204,17 @@ def get_certificate_id(session, base_url, order_id): """Retrieve certificate order id from Digicert API.""" order_url = "{0}/services/v2/order/certificate/{1}".format(base_url, order_id) response_data = handle_response(session.get(order_url)) - if response_data['status'] != 'issued': + if response_data["status"] != "issued": raise Exception("Order not in issued state.") - return response_data['certificate']['id'] + return response_data["certificate"]["id"] @retry(stop_max_attempt_number=10, wait_fixed=10000) def get_cis_certificate(session, base_url, order_id): """Retrieve certificate order id from Digicert API.""" - certificate_url = '{0}/platform/cis/certificate/{1}'.format(base_url, order_id) - session.headers.update( - {'Accept': 'application/x-pem-file'} - ) + certificate_url = "{0}/platform/cis/certificate/{1}".format(base_url, order_id) + session.headers.update({"Accept": "application/x-pem-file"}) response = session.get(certificate_url) if response.status_code == 404: @@ -220,29 +225,30 @@ def get_cis_certificate(session, base_url, order_id): class DigiCertSourcePlugin(SourcePlugin): """Wrap the Digicert Certifcate API.""" - title = 'DigiCert' - slug = 'digicert-source' + + title = "DigiCert" + slug = "digicert-source" description = "Enables the use of Digicert as a source of existing certificates." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): """Initialize source with appropriate details.""" required_vars = [ - 'DIGICERT_API_KEY', - 'DIGICERT_URL', - 'DIGICERT_ORG_ID', - 'DIGICERT_ROOT', + "DIGICERT_API_KEY", + "DIGICERT_URL", + "DIGICERT_ORG_ID", + "DIGICERT_ROOT", ] validate_conf(current_app, required_vars) self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_API_KEY"], + "Content-Type": "application/json", } ) @@ -256,22 +262,23 @@ class DigiCertSourcePlugin(SourcePlugin): class DigiCertIssuerPlugin(IssuerPlugin): """Wrap the Digicert Issuer API.""" - title = 'DigiCert' - slug = 'digicert-issuer' + + title = "DigiCert" + slug = "digicert-issuer" description = "Enables the creation of certificates by the DigiCert REST API." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): """Initialize the issuer with the appropriate details.""" required_vars = [ - 'DIGICERT_API_KEY', - 'DIGICERT_URL', - 'DIGICERT_ORG_ID', - 'DIGICERT_ORDER_TYPE', - 'DIGICERT_ROOT', + "DIGICERT_API_KEY", + "DIGICERT_URL", + "DIGICERT_ORG_ID", + "DIGICERT_ORDER_TYPE", + "DIGICERT_ROOT", ] validate_conf(current_app, required_vars) @@ -279,8 +286,8 @@ class DigiCertIssuerPlugin(IssuerPlugin): self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_API_KEY"], + "Content-Type": "application/json", } ) @@ -295,69 +302,93 @@ class DigiCertIssuerPlugin(IssuerPlugin): :param issuer_options: :return: :raise Exception: """ - base_url = current_app.config.get('DIGICERT_URL') - cert_type = current_app.config.get('DIGICERT_ORDER_TYPE') + base_url = current_app.config.get("DIGICERT_URL") + cert_type = current_app.config.get("DIGICERT_ORDER_TYPE") # make certificate request - determinator_url = "{0}/services/v2/order/certificate/{1}".format(base_url, cert_type) + determinator_url = "{0}/services/v2/order/certificate/{1}".format( + base_url, cert_type + ) data = map_fields(issuer_options, csr) response = self.session.post(determinator_url, data=json.dumps(data)) if response.status_code > 399: - raise Exception(response.json()['errors'][0]['message']) + raise Exception(response.json()["errors"][0]["message"]) - order_id = response.json()['id'] + order_id = response.json()["id"] certificate_id = get_certificate_id(self.session, base_url, order_id) # retrieve certificate - certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format(base_url, certificate_id) - end_entity, intermediate, root = pem.parse(self.session.get(certificate_url).content) - return "\n".join(str(end_entity).splitlines()), "\n".join(str(intermediate).splitlines()), certificate_id + certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format( + base_url, certificate_id + ) + end_entity, intermediate, root = pem.parse( + self.session.get(certificate_url).content + ) + return ( + "\n".join(str(end_entity).splitlines()), + "\n".join(str(intermediate).splitlines()), + certificate_id, + ) def revoke_certificate(self, certificate, comments): """Revoke a Digicert certificate.""" - base_url = current_app.config.get('DIGICERT_URL') + base_url = current_app.config.get("DIGICERT_URL") # make certificate revoke request - create_url = '{0}/services/v2/certificate/{1}/revoke'.format(base_url, certificate.external_id) - metrics.send('digicert_revoke_certificate', 'counter', 1) - response = self.session.put(create_url, data=json.dumps({'comments': comments})) + create_url = "{0}/services/v2/certificate/{1}/revoke".format( + base_url, certificate.external_id + ) + metrics.send("digicert_revoke_certificate", "counter", 1) + response = self.session.put(create_url, data=json.dumps({"comments": comments})) return handle_response(response) def get_ordered_certificate(self, pending_cert): """ Retrieve a certificate via order id """ order_id = pending_cert.external_id - base_url = current_app.config.get('DIGICERT_URL') + base_url = current_app.config.get("DIGICERT_URL") try: certificate_id = get_certificate_id(self.session, base_url, order_id) except Exception as ex: return None - certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format(base_url, certificate_id) - end_entity, intermediate, root = pem.parse(self.session.get(certificate_url).content) - cert = {'body': "\n".join(str(end_entity).splitlines()), - 'chain': "\n".join(str(intermediate).splitlines()), - 'external_id': str(certificate_id)} + certificate_url = "{0}/services/v2/certificate/{1}/download/format/pem_all".format( + base_url, certificate_id + ) + end_entity, intermediate, root = pem.parse( + self.session.get(certificate_url).content + ) + cert = { + "body": "\n".join(str(end_entity).splitlines()), + "chain": "\n".join(str(intermediate).splitlines()), + "external_id": str(certificate_id), + } return cert def cancel_ordered_certificate(self, pending_cert, **kwargs): """ Set the certificate order to canceled """ - base_url = current_app.config.get('DIGICERT_URL') - api_url = "{0}/services/v2/order/certificate/{1}/status".format(base_url, pending_cert.external_id) - payload = { - 'status': 'CANCELED', - 'note': kwargs.get('note') - } + base_url = current_app.config.get("DIGICERT_URL") + api_url = "{0}/services/v2/order/certificate/{1}/status".format( + base_url, pending_cert.external_id + ) + payload = {"status": "CANCELED", "note": kwargs.get("note")} response = self.session.put(api_url, data=json.dumps(payload)) if response.status_code == 404: # not well documented by Digicert, but either the certificate does not exist or we # don't own that order (someone else's order id!). Either way, we can just ignore it # and have it removed from Lemur current_app.logger.warning( - "Digicert Plugin tried to cancel pending certificate {0} but it does not exist!".format(pending_cert.name)) + "Digicert Plugin tried to cancel pending certificate {0} but it does not exist!".format( + pending_cert.name + ) + ) elif response.status_code != 204: - current_app.logger.debug("{0} code {1}".format(response.status_code, response.content)) - raise Exception("Failed to cancel pending certificate {0}".format(pending_cert.name)) + current_app.logger.debug( + "{0} code {1}".format(response.status_code, response.content) + ) + raise Exception( + "Failed to cancel pending certificate {0}".format(pending_cert.name) + ) @staticmethod def create_authority(options): @@ -370,72 +401,81 @@ class DigiCertIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'digicert'} - return current_app.config.get('DIGICERT_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "digicert"} + return current_app.config.get("DIGICERT_ROOT"), "", [role] class DigiCertCISSourcePlugin(SourcePlugin): """Wrap the Digicert CIS Certifcate API.""" - title = 'DigiCert' - slug = 'digicert-cis-source' + + title = "DigiCert" + slug = "digicert-cis-source" description = "Enables the use of Digicert as a source of existing certificates." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" additional_options = [] def __init__(self, *args, **kwargs): """Initialize source with appropriate details.""" required_vars = [ - 'DIGICERT_CIS_API_KEY', - 'DIGICERT_CIS_URL', - 'DIGICERT_CIS_ROOT', - 'DIGICERT_CIS_INTERMEDIATE', - 'DIGICERT_CIS_PROFILE_NAME' + "DIGICERT_CIS_API_KEY", + "DIGICERT_CIS_URL", + "DIGICERT_CIS_ROOT", + "DIGICERT_CIS_INTERMEDIATE", + "DIGICERT_CIS_PROFILE_NAME", ] validate_conf(current_app, required_vars) self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_CIS_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_CIS_API_KEY"], + "Content-Type": "application/json", } ) self.session.hooks = dict(response=log_status_code) a = requests.adapters.HTTPAdapter(max_retries=3) - self.session.mount('https://', a) + self.session.mount("https://", a) super(DigiCertCISSourcePlugin, self).__init__(*args, **kwargs) def get_certificates(self, options, **kwargs): """Fetch all Digicert certificates.""" - base_url = current_app.config.get('DIGICERT_CIS_URL') + base_url = current_app.config.get("DIGICERT_CIS_URL") # make request - search_url = '{0}/platform/cis/certificate/search'.format(base_url) + search_url = "{0}/platform/cis/certificate/search".format(base_url) certs = [] page = 1 while True: - response = self.session.get(search_url, params={'status': ['issued'], 'page': page}) + response = self.session.get( + search_url, params={"status": ["issued"], "page": page} + ) data = handle_cis_response(response) - for c in data['certificates']: - download_url = '{0}/platform/cis/certificate/{1}'.format(base_url, c['id']) + for c in data["certificates"]: + download_url = "{0}/platform/cis/certificate/{1}".format( + base_url, c["id"] + ) certificate = self.session.get(download_url) # normalize serial - serial = str(int(c['serial_number'], 16)) - cert = {'body': certificate.content, 'serial': serial, 'external_id': c['id']} + serial = str(int(c["serial_number"], 16)) + cert = { + "body": certificate.content, + "serial": serial, + "external_id": c["id"], + } certs.append(cert) - if page == data['total_pages']: + if page == data["total_pages"]: break page += 1 @@ -444,22 +484,23 @@ class DigiCertCISSourcePlugin(SourcePlugin): class DigiCertCISIssuerPlugin(IssuerPlugin): """Wrap the Digicert Certificate Issuing API.""" - title = 'DigiCert CIS' - slug = 'digicert-cis-issuer' + + title = "DigiCert CIS" + slug = "digicert-cis-issuer" description = "Enables the creation of certificates by the DigiCert CIS REST API." version = digicert.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): """Initialize the issuer with the appropriate details.""" required_vars = [ - 'DIGICERT_CIS_API_KEY', - 'DIGICERT_CIS_URL', - 'DIGICERT_CIS_ROOT', - 'DIGICERT_CIS_INTERMEDIATE', - 'DIGICERT_CIS_PROFILE_NAME' + "DIGICERT_CIS_API_KEY", + "DIGICERT_CIS_URL", + "DIGICERT_CIS_ROOT", + "DIGICERT_CIS_INTERMEDIATE", + "DIGICERT_CIS_PROFILE_NAME", ] validate_conf(current_app, required_vars) @@ -467,8 +508,8 @@ class DigiCertCISIssuerPlugin(IssuerPlugin): self.session = requests.Session() self.session.headers.update( { - 'X-DC-DEVKEY': current_app.config['DIGICERT_CIS_API_KEY'], - 'Content-Type': 'application/json' + "X-DC-DEVKEY": current_app.config["DIGICERT_CIS_API_KEY"], + "Content-Type": "application/json", } ) @@ -478,41 +519,51 @@ class DigiCertCISIssuerPlugin(IssuerPlugin): def create_certificate(self, csr, issuer_options): """Create a DigiCert certificate.""" - base_url = current_app.config.get('DIGICERT_CIS_URL') + base_url = current_app.config.get("DIGICERT_CIS_URL") # make certificate request - create_url = '{0}/platform/cis/certificate'.format(base_url) + create_url = "{0}/platform/cis/certificate".format(base_url) data = map_cis_fields(issuer_options, csr) response = self.session.post(create_url, data=json.dumps(data)) data = handle_cis_response(response) # retrieve certificate - certificate_pem = get_cis_certificate(self.session, base_url, data['id']) + certificate_pem = get_cis_certificate(self.session, base_url, data["id"]) - self.session.headers.pop('Accept') + self.session.headers.pop("Accept") end_entity = pem.parse(certificate_pem)[0] - if 'ECC' in issuer_options['key_type']: - return "\n".join(str(end_entity).splitlines()), current_app.config.get('DIGICERT_ECC_CIS_INTERMEDIATE'), data['id'] + if "ECC" in issuer_options["key_type"]: + return ( + "\n".join(str(end_entity).splitlines()), + current_app.config.get("DIGICERT_ECC_CIS_INTERMEDIATE"), + data["id"], + ) # By default return RSA - return "\n".join(str(end_entity).splitlines()), current_app.config.get('DIGICERT_CIS_INTERMEDIATE'), data['id'] + return ( + "\n".join(str(end_entity).splitlines()), + current_app.config.get("DIGICERT_CIS_INTERMEDIATE"), + data["id"], + ) def revoke_certificate(self, certificate, comments): """Revoke a Digicert certificate.""" - base_url = current_app.config.get('DIGICERT_CIS_URL') + base_url = current_app.config.get("DIGICERT_CIS_URL") # make certificate revoke request - revoke_url = '{0}/platform/cis/certificate/{1}/revoke'.format(base_url, certificate.external_id) - metrics.send('digicert_revoke_certificate_success', 'counter', 1) - response = self.session.put(revoke_url, data=json.dumps({'comments': comments})) + revoke_url = "{0}/platform/cis/certificate/{1}/revoke".format( + base_url, certificate.external_id + ) + metrics.send("digicert_revoke_certificate_success", "counter", 1) + response = self.session.put(revoke_url, data=json.dumps({"comments": comments})) if response.status_code != 204: - metrics.send('digicert_revoke_certificate_failure', 'counter', 1) - raise Exception('Failed to revoke certificate.') + metrics.send("digicert_revoke_certificate_failure", "counter", 1) + raise Exception("Failed to revoke certificate.") - metrics.send('digicert_revoke_certificate_success', 'counter', 1) + metrics.send("digicert_revoke_certificate_success", "counter", 1) @staticmethod def create_authority(options): @@ -525,5 +576,5 @@ class DigiCertCISIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'digicert'} - return current_app.config.get('DIGICERT_CIS_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "digicert"} + return current_app.config.get("DIGICERT_CIS_ROOT"), "", [role] diff --git a/lemur/plugins/lemur_digicert/tests/test_digicert.py b/lemur/plugins/lemur_digicert/tests/test_digicert.py index d8d1519d..71efbad4 100644 --- a/lemur/plugins/lemur_digicert/tests/test_digicert.py +++ b/lemur/plugins/lemur_digicert/tests/test_digicert.py @@ -13,144 +13,129 @@ from cryptography import x509 def test_map_fields_with_validity_end_and_start(app): from lemur.plugins.lemur_digicert.plugin import map_fields - names = [u'one.example.com', u'two.example.com', u'three.example.com'] + names = [u"one.example.com", u"two.example.com", u"three.example.com"] options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'validity_end': arrow.get(2017, 5, 7), - 'validity_start': arrow.get(2016, 10, 30) + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "validity_end": arrow.get(2017, 5, 7), + "validity_start": arrow.get(2016, 10, 30), } data = map_fields(options, CSR_STR) assert data == { - 'certificate': { - 'csr': CSR_STR, - 'common_name': 'example.com', - 'dns_names': names, - 'signature_hash': 'sha256' + "certificate": { + "csr": CSR_STR, + "common_name": "example.com", + "dns_names": names, + "signature_hash": "sha256", }, - 'organization': {'id': 111111}, - 'custom_expiration_date': arrow.get(2017, 5, 7).format('YYYY-MM-DD') + "organization": {"id": 111111}, + "custom_expiration_date": arrow.get(2017, 5, 7).format("YYYY-MM-DD"), } def test_map_fields_with_validity_years(app): from lemur.plugins.lemur_digicert.plugin import map_fields - names = [u'one.example.com', u'two.example.com', u'three.example.com'] + names = [u"one.example.com", u"two.example.com", u"three.example.com"] options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'validity_years': 2, - 'validity_end': arrow.get(2017, 10, 30) + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "validity_years": 2, + "validity_end": arrow.get(2017, 10, 30), } data = map_fields(options, CSR_STR) assert data == { - 'certificate': { - 'csr': CSR_STR, - 'common_name': 'example.com', - 'dns_names': names, - 'signature_hash': 'sha256' + "certificate": { + "csr": CSR_STR, + "common_name": "example.com", + "dns_names": names, + "signature_hash": "sha256", }, - 'organization': {'id': 111111}, - 'validity_years': 2 + "organization": {"id": 111111}, + "validity_years": 2, } def test_map_cis_fields(app): from lemur.plugins.lemur_digicert.plugin import map_cis_fields - names = [u'one.example.com', u'two.example.com', u'three.example.com'] + names = [u"one.example.com", u"two.example.com", u"three.example.com"] options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'organization': 'Example, Inc.', - 'organizational_unit': 'Example Org', - 'validity_end': arrow.get(2017, 5, 7), - 'validity_start': arrow.get(2016, 10, 30) + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "organization": "Example, Inc.", + "organizational_unit": "Example Org", + "validity_end": arrow.get(2017, 5, 7), + "validity_start": arrow.get(2016, 10, 30), } data = map_cis_fields(options, CSR_STR) assert data == { - 'common_name': 'example.com', - 'csr': CSR_STR, - 'additional_dns_names': names, - 'signature_hash': 'sha256', - 'organization': {'name': 'Example, Inc.', 'units': ['Example Org']}, - 'validity': { - 'valid_to': arrow.get(2017, 5, 7).format('YYYY-MM-DDTHH:MM') + 'Z' + "common_name": "example.com", + "csr": CSR_STR, + "additional_dns_names": names, + "signature_hash": "sha256", + "organization": {"name": "Example, Inc.", "units": ["Example Org"]}, + "validity": { + "valid_to": arrow.get(2017, 5, 7).format("YYYY-MM-DDTHH:MM") + "Z" }, - 'profile_name': None + "profile_name": None, } options = { - 'common_name': 'example.com', - 'owner': 'bob@example.com', - 'description': 'test certificate', - 'extensions': { - 'sub_alt_names': { - 'names': [x509.DNSName(x) for x in names] - } - }, - 'organization': 'Example, Inc.', - 'organizational_unit': 'Example Org', - 'validity_years': 2 + "common_name": "example.com", + "owner": "bob@example.com", + "description": "test certificate", + "extensions": {"sub_alt_names": {"names": [x509.DNSName(x) for x in names]}}, + "organization": "Example, Inc.", + "organizational_unit": "Example Org", + "validity_years": 2, } with freeze_time(time_to_freeze=arrow.get(2016, 11, 3).datetime): data = map_cis_fields(options, CSR_STR) assert data == { - 'common_name': 'example.com', - 'csr': CSR_STR, - 'additional_dns_names': names, - 'signature_hash': 'sha256', - 'organization': {'name': 'Example, Inc.', 'units': ['Example Org']}, - 'validity': { - 'valid_to': arrow.get(2018, 11, 3).format('YYYY-MM-DDTHH:MM') + 'Z' + "common_name": "example.com", + "csr": CSR_STR, + "additional_dns_names": names, + "signature_hash": "sha256", + "organization": {"name": "Example, Inc.", "units": ["Example Org"]}, + "validity": { + "valid_to": arrow.get(2018, 11, 3).format("YYYY-MM-DDTHH:MM") + "Z" }, - 'profile_name': None + "profile_name": None, } def test_signature_hash(app): from lemur.plugins.lemur_digicert.plugin import signature_hash - assert signature_hash(None) == 'sha256' - assert signature_hash('sha256WithRSA') == 'sha256' - assert signature_hash('sha384WithRSA') == 'sha384' - assert signature_hash('sha512WithRSA') == 'sha512' + assert signature_hash(None) == "sha256" + assert signature_hash("sha256WithRSA") == "sha256" + assert signature_hash("sha384WithRSA") == "sha384" + assert signature_hash("sha512WithRSA") == "sha512" with pytest.raises(Exception): - signature_hash('sdfdsf') + signature_hash("sdfdsf") -def test_issuer_plugin_create_certificate(certificate_="""\ +def test_issuer_plugin_create_certificate( + certificate_="""\ -----BEGIN CERTIFICATE----- abc -----END CERTIFICATE----- @@ -160,7 +145,8 @@ def -----BEGIN CERTIFICATE----- ghi -----END CERTIFICATE----- -"""): +""" +): import requests_mock from lemur.plugins.lemur_digicert.plugin import DigiCertIssuerPlugin @@ -168,12 +154,26 @@ ghi subject = DigiCertIssuerPlugin() adapter = requests_mock.Adapter() - adapter.register_uri('POST', 'mock://www.digicert.com/services/v2/order/certificate/ssl_plus', text=json.dumps({'id': 'id123'})) - adapter.register_uri('GET', 'mock://www.digicert.com/services/v2/order/certificate/id123', text=json.dumps({'status': 'issued', 'certificate': {'id': 'cert123'}})) - adapter.register_uri('GET', 'mock://www.digicert.com/services/v2/certificate/cert123/download/format/pem_all', text=pem_fixture) - subject.session.mount('mock', adapter) + adapter.register_uri( + "POST", + "mock://www.digicert.com/services/v2/order/certificate/ssl_plus", + text=json.dumps({"id": "id123"}), + ) + adapter.register_uri( + "GET", + "mock://www.digicert.com/services/v2/order/certificate/id123", + text=json.dumps({"status": "issued", "certificate": {"id": "cert123"}}), + ) + adapter.register_uri( + "GET", + "mock://www.digicert.com/services/v2/certificate/cert123/download/format/pem_all", + text=pem_fixture, + ) + subject.session.mount("mock", adapter) - cert, intermediate, external_id = subject.create_certificate("", {'common_name': 'test.com'}) + cert, intermediate, external_id = subject.create_certificate( + "", {"common_name": "test.com"} + ) assert cert == "-----BEGIN CERTIFICATE-----\nabc\n-----END CERTIFICATE-----" assert intermediate == "-----BEGIN CERTIFICATE-----\ndef\n-----END CERTIFICATE-----" @@ -187,10 +187,18 @@ def test_cancel_ordered_certificate(mock_pending_cert): mock_pending_cert.external_id = 1234 subject = DigiCertIssuerPlugin() adapter = requests_mock.Adapter() - adapter.register_uri('PUT', 'mock://www.digicert.com/services/v2/order/certificate/1234/status', status_code=204) - adapter.register_uri('PUT', 'mock://www.digicert.com/services/v2/order/certificate/111/status', status_code=404) - subject.session.mount('mock', adapter) - data = {'note': 'Test'} + adapter.register_uri( + "PUT", + "mock://www.digicert.com/services/v2/order/certificate/1234/status", + status_code=204, + ) + adapter.register_uri( + "PUT", + "mock://www.digicert.com/services/v2/order/certificate/111/status", + status_code=404, + ) + subject.session.mount("mock", adapter) + data = {"note": "Test"} subject.cancel_ordered_certificate(mock_pending_cert, **data) # A non-existing order id, does not raise exception because if it doesn't exist, then it doesn't matter diff --git a/lemur/plugins/lemur_email/__init__.py b/lemur/plugins/lemur_email/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_email/__init__.py +++ b/lemur/plugins/lemur_email/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_email/plugin.py b/lemur/plugins/lemur_email/plugin.py index 18007b99..241aa1b0 100644 --- a/lemur/plugins/lemur_email/plugin.py +++ b/lemur/plugins/lemur_email/plugin.py @@ -27,8 +27,10 @@ def render_html(template_name, message): :param message: :return: """ - template = env.get_template('{}.html'.format(template_name)) - return template.render(dict(message=message, hostname=current_app.config.get('LEMUR_HOSTNAME'))) + template = env.get_template("{}.html".format(template_name)) + return template.render( + dict(message=message, hostname=current_app.config.get("LEMUR_HOSTNAME")) + ) def send_via_smtp(subject, body, targets): @@ -40,7 +42,9 @@ def send_via_smtp(subject, body, targets): :param targets: :return: """ - msg = Message(subject, recipients=targets, sender=current_app.config.get("LEMUR_EMAIL")) + msg = Message( + subject, recipients=targets, sender=current_app.config.get("LEMUR_EMAIL") + ) msg.body = "" # kinda a weird api for sending html emails msg.html = body smtp_mail.send(msg) @@ -54,65 +58,55 @@ def send_via_ses(subject, body, targets): :param targets: :return: """ - client = boto3.client('ses', region_name='us-east-1') + client = boto3.client("ses", region_name="us-east-1") client.send_email( - Source=current_app.config.get('LEMUR_EMAIL'), - Destination={ - 'ToAddresses': targets - }, + Source=current_app.config.get("LEMUR_EMAIL"), + Destination={"ToAddresses": targets}, Message={ - 'Subject': { - 'Data': subject, - 'Charset': 'UTF-8' - }, - 'Body': { - 'Html': { - 'Data': body, - 'Charset': 'UTF-8' - } - } - } + "Subject": {"Data": subject, "Charset": "UTF-8"}, + "Body": {"Html": {"Data": body, "Charset": "UTF-8"}}, + }, ) class EmailNotificationPlugin(ExpirationNotificationPlugin): - title = 'Email' - slug = 'email-notification' - description = 'Sends expiration email notifications' + title = "Email" + slug = "email-notification" + description = "Sends expiration email notifications" version = email.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" additional_options = [ { - 'name': 'recipients', - 'type': 'str', - 'required': True, - 'validation': '^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$', - 'helpMessage': 'Comma delimited list of email addresses', - }, + "name": "recipients", + "type": "str", + "required": True, + "validation": "^([\w+-.%]+@[\w-.]+\.[A-Za-z]{2,4},?)+$", + "helpMessage": "Comma delimited list of email addresses", + } ] def __init__(self, *args, **kwargs): """Initialize the plugin with the appropriate details.""" - sender = current_app.config.get('LEMUR_EMAIL_SENDER', 'ses').lower() + sender = current_app.config.get("LEMUR_EMAIL_SENDER", "ses").lower() - if sender not in ['ses', 'smtp']: - raise InvalidConfiguration('Email sender type {0} is not recognized.') + if sender not in ["ses", "smtp"]: + raise InvalidConfiguration("Email sender type {0} is not recognized.") @staticmethod def send(notification_type, message, targets, options, **kwargs): - subject = 'Lemur: {0} Notification'.format(notification_type.capitalize()) + subject = "Lemur: {0} Notification".format(notification_type.capitalize()) - data = {'options': options, 'certificates': message} + data = {"options": options, "certificates": message} body = render_html(notification_type, data) - s_type = current_app.config.get("LEMUR_EMAIL_SENDER", 'ses').lower() + s_type = current_app.config.get("LEMUR_EMAIL_SENDER", "ses").lower() - if s_type == 'ses': + if s_type == "ses": send_via_ses(subject, body, targets) - elif s_type == 'smtp': + elif s_type == "smtp": send_via_smtp(subject, body, targets) diff --git a/lemur/plugins/lemur_email/templates/config.py b/lemur/plugins/lemur_email/templates/config.py index 2ec8a6c2..3d877fe0 100644 --- a/lemur/plugins/lemur_email/templates/config.py +++ b/lemur/plugins/lemur_email/templates/config.py @@ -5,22 +5,24 @@ from jinja2 import Environment, FileSystemLoader, select_autoescape from lemur.plugins.utils import get_plugin_option loader = FileSystemLoader(searchpath=os.path.dirname(os.path.realpath(__file__))) -env = Environment(loader=loader, # nosec: potentially dangerous types esc. - autoescape=select_autoescape(['html', 'xml'])) +env = Environment( + loader=loader, # nosec: potentially dangerous types esc. + autoescape=select_autoescape(["html", "xml"]), +) def human_time(time): - return arrow.get(time).format('dddd, MMMM D, YYYY') + return arrow.get(time).format("dddd, MMMM D, YYYY") def interval(options): - return get_plugin_option('interval', options) + return get_plugin_option("interval", options) def unit(options): - return get_plugin_option('unit', options) + return get_plugin_option("unit", options) -env.filters['time'] = human_time -env.filters['interval'] = interval -env.filters['unit'] = unit +env.filters["time"] = human_time +env.filters["interval"] = interval +env.filters["unit"] = unit diff --git a/lemur/plugins/lemur_email/tests/test_email.py b/lemur/plugins/lemur_email/tests/test_email.py index 9d58402f..43168cab 100644 --- a/lemur/plugins/lemur_email/tests/test_email.py +++ b/lemur/plugins/lemur_email/tests/test_email.py @@ -13,21 +13,24 @@ def test_render(certificate, endpoint): new_cert.replaces.append(certificate) data = { - 'certificates': [certificate_notification_output_schema.dump(certificate).data], - 'options': [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + "certificates": [certificate_notification_output_schema.dump(certificate).data], + "options": [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ], } - template = env.get_template('{}.html'.format('expiration')) + template = env.get_template("{}.html".format("expiration")) - body = template.render(dict(message=data, hostname='lemur.test.example.com')) + body = template.render(dict(message=data, hostname="lemur.test.example.com")) - template = env.get_template('{}.html'.format('rotation')) + template = env.get_template("{}.html".format("rotation")) certificate.endpoints.append(endpoint) body = template.render( dict( certificate=certificate_notification_output_schema.dump(certificate).data, - hostname='lemur.test.example.com' + hostname="lemur.test.example.com", ) ) diff --git a/lemur/plugins/lemur_jks/__init__.py b/lemur/plugins/lemur_jks/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_jks/__init__.py +++ b/lemur/plugins/lemur_jks/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_jks/plugin.py b/lemur/plugins/lemur_jks/plugin.py index 3d456f1c..7134faeb 100644 --- a/lemur/plugins/lemur_jks/plugin.py +++ b/lemur/plugins/lemur_jks/plugin.py @@ -31,10 +31,10 @@ def create_truststore(cert, chain, alias, passphrase): entries = [] for idx, cert_bytes in enumerate(cert_chain_as_der(cert, chain)): # The original cert gets name _cert, first chain element is _cert_1, etc. - cert_alias = alias + '_cert' + ('_{}'.format(idx) if idx else '') + cert_alias = alias + "_cert" + ("_{}".format(idx) if idx else "") entries.append(TrustedCertEntry.new(cert_alias, cert_bytes)) - return KeyStore.new('jks', entries).saves(passphrase) + return KeyStore.new("jks", entries).saves(passphrase) def create_keystore(cert, chain, key, alias, passphrase): @@ -42,36 +42,36 @@ def create_keystore(cert, chain, key, alias, passphrase): key_bytes = parse_private_key(key).private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) entry = PrivateKeyEntry.new(alias, certs_bytes, key_bytes) - return KeyStore.new('jks', [entry]).saves(passphrase) + return KeyStore.new("jks", [entry]).saves(passphrase) class JavaTruststoreExportPlugin(ExportPlugin): - title = 'Java Truststore (JKS)' - slug = 'java-truststore-jks' - description = 'Generates a JKS truststore' + title = "Java Truststore (JKS)" + slug = "java-truststore-jks" + description = "Generates a JKS truststore" requires_key = False version = jks.VERSION - author = 'Marti Raudsepp' - author_url = 'https://github.com/intgr' + author = "Marti Raudsepp" + author_url = "https://github.com/intgr" options = [ { - 'name': 'alias', - 'type': 'str', - 'required': False, - 'helpMessage': 'Enter the alias you wish to use for the truststore.', + "name": "alias", + "type": "str", + "required": False, + "helpMessage": "Enter the alias you wish to use for the truststore.", }, { - 'name': 'passphrase', - 'type': 'str', - 'required': False, - 'helpMessage': 'If no passphrase is given one will be generated for you, we highly recommend this.', - 'validation': '' + "name": "passphrase", + "type": "str", + "required": False, + "helpMessage": "If no passphrase is given one will be generated for you, we highly recommend this.", + "validation": "", }, ] @@ -80,44 +80,44 @@ class JavaTruststoreExportPlugin(ExportPlugin): Generates a Java Truststore """ - if self.get_option('alias', options): - alias = self.get_option('alias', options) + if self.get_option("alias", options): + alias = self.get_option("alias", options) else: alias = common_name(parse_certificate(body)) - if self.get_option('passphrase', options): - passphrase = self.get_option('passphrase', options) + if self.get_option("passphrase", options): + passphrase = self.get_option("passphrase", options) else: - passphrase = Fernet.generate_key().decode('utf-8') + passphrase = Fernet.generate_key().decode("utf-8") raw = create_truststore(body, chain, alias, passphrase) - return 'jks', passphrase, raw + return "jks", passphrase, raw class JavaKeystoreExportPlugin(ExportPlugin): - title = 'Java Keystore (JKS)' - slug = 'java-keystore-jks' - description = 'Generates a JKS keystore' + title = "Java Keystore (JKS)" + slug = "java-keystore-jks" + description = "Generates a JKS keystore" version = jks.VERSION - author = 'Marti Raudsepp' - author_url = 'https://github.com/intgr' + author = "Marti Raudsepp" + author_url = "https://github.com/intgr" options = [ { - 'name': 'passphrase', - 'type': 'str', - 'required': False, - 'helpMessage': 'If no passphrase is given one will be generated for you, we highly recommend this.', - 'validation': '' + "name": "passphrase", + "type": "str", + "required": False, + "helpMessage": "If no passphrase is given one will be generated for you, we highly recommend this.", + "validation": "", }, { - 'name': 'alias', - 'type': 'str', - 'required': False, - 'helpMessage': 'Enter the alias you wish to use for the keystore.', - } + "name": "alias", + "type": "str", + "required": False, + "helpMessage": "Enter the alias you wish to use for the keystore.", + }, ] def export(self, body, chain, key, options, **kwargs): @@ -125,16 +125,16 @@ class JavaKeystoreExportPlugin(ExportPlugin): Generates a Java Keystore """ - if self.get_option('passphrase', options): - passphrase = self.get_option('passphrase', options) + if self.get_option("passphrase", options): + passphrase = self.get_option("passphrase", options) else: - passphrase = Fernet.generate_key().decode('utf-8') + passphrase = Fernet.generate_key().decode("utf-8") - if self.get_option('alias', options): - alias = self.get_option('alias', options) + if self.get_option("alias", options): + alias = self.get_option("alias", options) else: alias = common_name(parse_certificate(body)) raw = create_keystore(body, chain, key, alias, passphrase) - return 'jks', passphrase, raw + return "jks", passphrase, raw diff --git a/lemur/plugins/lemur_jks/tests/test_jks.py b/lemur/plugins/lemur_jks/tests/test_jks.py index e4a5e64a..b9fe9b33 100644 --- a/lemur/plugins/lemur_jks/tests/test_jks.py +++ b/lemur/plugins/lemur_jks/tests/test_jks.py @@ -1,96 +1,105 @@ import pytest from jks import KeyStore, TrustedCertEntry, PrivateKeyEntry -from lemur.tests.vectors import INTERNAL_CERTIFICATE_A_STR, SAN_CERT_STR, INTERMEDIATE_CERT_STR, ROOTCA_CERT_STR, \ - SAN_CERT_KEY +from lemur.tests.vectors import ( + INTERNAL_CERTIFICATE_A_STR, + SAN_CERT_STR, + INTERMEDIATE_CERT_STR, + ROOTCA_CERT_STR, + SAN_CERT_KEY, +) def test_export_truststore(app): from lemur.plugins.base import plugins - p = plugins.get('java-truststore-jks') + p = plugins.get("java-truststore-jks") options = [ - {'name': 'passphrase', 'value': 'hunter2'}, - {'name': 'alias', 'value': 'AzureDiamond'}, + {"name": "passphrase", "value": "hunter2"}, + {"name": "alias", "value": "AzureDiamond"}, ] - chain = INTERMEDIATE_CERT_STR + '\n' + ROOTCA_CERT_STR + chain = INTERMEDIATE_CERT_STR + "\n" + ROOTCA_CERT_STR ext, password, raw = p.export(SAN_CERT_STR, chain, SAN_CERT_KEY, options) - assert ext == 'jks' - assert password == 'hunter2' + assert ext == "jks" + assert password == "hunter2" assert isinstance(raw, bytes) - ks = KeyStore.loads(raw, 'hunter2') - assert ks.store_type == 'jks' + ks = KeyStore.loads(raw, "hunter2") + assert ks.store_type == "jks" # JKS lower-cases alias strings - assert ks.entries.keys() == {'azurediamond_cert', 'azurediamond_cert_1', 'azurediamond_cert_2'} - assert isinstance(ks.entries['azurediamond_cert'], TrustedCertEntry) + assert ks.entries.keys() == { + "azurediamond_cert", + "azurediamond_cert_1", + "azurediamond_cert_2", + } + assert isinstance(ks.entries["azurediamond_cert"], TrustedCertEntry) def test_export_truststore_defaults(app): from lemur.plugins.base import plugins - p = plugins.get('java-truststore-jks') + p = plugins.get("java-truststore-jks") options = [] - ext, password, raw = p.export(INTERNAL_CERTIFICATE_A_STR, '', '', options) + ext, password, raw = p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) - assert ext == 'jks' + assert ext == "jks" assert isinstance(password, str) assert isinstance(raw, bytes) ks = KeyStore.loads(raw, password) - assert ks.store_type == 'jks' + assert ks.store_type == "jks" # JKS lower-cases alias strings - assert ks.entries.keys() == {'acommonname_cert'} - assert isinstance(ks.entries['acommonname_cert'], TrustedCertEntry) + assert ks.entries.keys() == {"acommonname_cert"} + assert isinstance(ks.entries["acommonname_cert"], TrustedCertEntry) def test_export_keystore(app): from lemur.plugins.base import plugins - p = plugins.get('java-keystore-jks') + p = plugins.get("java-keystore-jks") options = [ - {'name': 'passphrase', 'value': 'hunter2'}, - {'name': 'alias', 'value': 'AzureDiamond'}, + {"name": "passphrase", "value": "hunter2"}, + {"name": "alias", "value": "AzureDiamond"}, ] - chain = INTERMEDIATE_CERT_STR + '\n' + ROOTCA_CERT_STR + chain = INTERMEDIATE_CERT_STR + "\n" + ROOTCA_CERT_STR with pytest.raises(Exception): - p.export(INTERNAL_CERTIFICATE_A_STR, chain, '', options) + p.export(INTERNAL_CERTIFICATE_A_STR, chain, "", options) ext, password, raw = p.export(SAN_CERT_STR, chain, SAN_CERT_KEY, options) - assert ext == 'jks' - assert password == 'hunter2' + assert ext == "jks" + assert password == "hunter2" assert isinstance(raw, bytes) ks = KeyStore.loads(raw, password) - assert ks.store_type == 'jks' + assert ks.store_type == "jks" # JKS lower-cases alias strings - assert ks.entries.keys() == {'azurediamond'} - entry = ks.entries['azurediamond'] + assert ks.entries.keys() == {"azurediamond"} + entry = ks.entries["azurediamond"] assert isinstance(entry, PrivateKeyEntry) - assert len(entry.cert_chain) == 3 # Cert and chain were provided + assert len(entry.cert_chain) == 3 # Cert and chain were provided def test_export_keystore_defaults(app): from lemur.plugins.base import plugins - p = plugins.get('java-keystore-jks') + p = plugins.get("java-keystore-jks") options = [] with pytest.raises(Exception): - p.export(INTERNAL_CERTIFICATE_A_STR, '', '', options) + p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) - ext, password, raw = p.export(SAN_CERT_STR, '', SAN_CERT_KEY, options) + ext, password, raw = p.export(SAN_CERT_STR, "", SAN_CERT_KEY, options) - assert ext == 'jks' + assert ext == "jks" assert isinstance(password, str) assert isinstance(raw, bytes) ks = KeyStore.loads(raw, password) - assert ks.store_type == 'jks' - assert ks.entries.keys() == {'san.example.org'} - entry = ks.entries['san.example.org'] + assert ks.store_type == "jks" + assert ks.entries.keys() == {"san.example.org"} + entry = ks.entries["san.example.org"] assert isinstance(entry, PrivateKeyEntry) - assert len(entry.cert_chain) == 1 # Only cert itself, no chain was provided + assert len(entry.cert_chain) == 1 # Only cert itself, no chain was provided diff --git a/lemur/plugins/lemur_kubernetes/__init__.py b/lemur/plugins/lemur_kubernetes/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_kubernetes/__init__.py +++ b/lemur/plugins/lemur_kubernetes/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_kubernetes/plugin.py b/lemur/plugins/lemur_kubernetes/plugin.py index 30b864eb..62ffffda 100644 --- a/lemur/plugins/lemur_kubernetes/plugin.py +++ b/lemur/plugins/lemur_kubernetes/plugin.py @@ -21,7 +21,7 @@ from lemur.common.defaults import common_name from lemur.common.utils import parse_certificate from lemur.plugins.bases import DestinationPlugin -DEFAULT_API_VERSION = 'v1' +DEFAULT_API_VERSION = "v1" def ensure_resource(k8s_api, k8s_base_uri, namespace, kind, name, data): @@ -34,7 +34,7 @@ def ensure_resource(k8s_api, k8s_base_uri, namespace, kind, name, data): if 200 <= create_resp.status_code <= 299: return None - elif create_resp.json().get('reason', '') != 'AlreadyExists': + elif create_resp.json().get("reason", "") != "AlreadyExists": return create_resp.content url = _resolve_uri(k8s_base_uri, namespace, kind, name) @@ -50,22 +50,27 @@ def ensure_resource(k8s_api, k8s_base_uri, namespace, kind, name, data): def _resolve_ns(k8s_base_uri, namespace, api_ver=DEFAULT_API_VERSION): - api_group = 'api' - if '/' in api_ver: - api_group = 'apis' - return '{base}/{api_group}/{api_ver}/namespaces'.format(base=k8s_base_uri, api_group=api_group, api_ver=api_ver) + ( - '/' + namespace if namespace else '') + api_group = "api" + if "/" in api_ver: + api_group = "apis" + return "{base}/{api_group}/{api_ver}/namespaces".format( + base=k8s_base_uri, api_group=api_group, api_ver=api_ver + ) + ("/" + namespace if namespace else "") def _resolve_uri(k8s_base_uri, namespace, kind, name=None, api_ver=DEFAULT_API_VERSION): if not namespace: - namespace = 'default' + namespace = "default" - return "/".join(itertools.chain.from_iterable([ - (_resolve_ns(k8s_base_uri, namespace, api_ver=api_ver),), - ((kind + 's').lower(),), - (name,) if name else (), - ])) + return "/".join( + itertools.chain.from_iterable( + [ + (_resolve_ns(k8s_base_uri, namespace, api_ver=api_ver),), + ((kind + "s").lower(),), + (name,) if name else (), + ] + ) + ) # Performs Base64 encoding of string to string using the base64.b64encode() function @@ -76,117 +81,113 @@ def base64encode(string): def build_secret(secret_format, secret_name, body, private_key, cert_chain): secret = { - 'apiVersion': 'v1', - 'kind': 'Secret', - 'type': 'Opaque', - 'metadata': { - 'name': secret_name, - } + "apiVersion": "v1", + "kind": "Secret", + "type": "Opaque", + "metadata": {"name": secret_name}, } - if secret_format == 'Full': - secret['data'] = { - 'combined.pem': base64encode('%s\n%s' % (body, private_key)), - 'ca.crt': base64encode(cert_chain), - 'service.key': base64encode(private_key), - 'service.crt': base64encode(body), + if secret_format == "Full": + secret["data"] = { + "combined.pem": base64encode("%s\n%s" % (body, private_key)), + "ca.crt": base64encode(cert_chain), + "service.key": base64encode(private_key), + "service.crt": base64encode(body), } - if secret_format == 'TLS': - secret['type'] = 'kubernetes.io/tls' - secret['data'] = { - 'tls.crt': base64encode(cert_chain), - 'tls.key': base64encode(private_key) - } - if secret_format == 'Certificate': - secret['data'] = { - 'tls.crt': base64encode(cert_chain), + if secret_format == "TLS": + secret["type"] = "kubernetes.io/tls" + secret["data"] = { + "tls.crt": base64encode(cert_chain), + "tls.key": base64encode(private_key), } + if secret_format == "Certificate": + secret["data"] = {"tls.crt": base64encode(cert_chain)} return secret class KubernetesDestinationPlugin(DestinationPlugin): - title = 'Kubernetes' - slug = 'kubernetes-destination' - description = 'Allow the uploading of certificates to Kubernetes as secret' + title = "Kubernetes" + slug = "kubernetes-destination" + description = "Allow the uploading of certificates to Kubernetes as secret" - author = 'Mikhail Khodorovskiy' - author_url = 'https://github.com/mik373/lemur' + author = "Mikhail Khodorovskiy" + author_url = "https://github.com/mik373/lemur" options = [ { - 'name': 'secretNameFormat', - 'type': 'str', - 'required': False, + "name": "secretNameFormat", + "type": "str", + "required": False, # Validation is difficult. This regex is used by kubectl to validate secret names: # [a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)* # Allowing the insertion of "{common_name}" (or any other such placeholder} # at any point in the string proved very challenging and had a tendency to # cause my browser to hang. The specified expression will allow any valid string # but will also accept many invalid strings. - 'validation': '(?:[a-z0-9.-]|\\{common_name\\})+', - 'helpMessage': 'Must be a valid secret name, possibly including "{common_name}"', - 'default': '{common_name}' + "validation": "(?:[a-z0-9.-]|\\{common_name\\})+", + "helpMessage": 'Must be a valid secret name, possibly including "{common_name}"', + "default": "{common_name}", }, { - 'name': 'kubernetesURL', - 'type': 'str', - 'required': False, - 'validation': 'https?://[a-zA-Z0-9.-]+(?::[0-9]+)?', - 'helpMessage': 'Must be a valid Kubernetes server URL!', - 'default': 'https://kubernetes.default' + "name": "kubernetesURL", + "type": "str", + "required": False, + "validation": "https?://[a-zA-Z0-9.-]+(?::[0-9]+)?", + "helpMessage": "Must be a valid Kubernetes server URL!", + "default": "https://kubernetes.default", }, { - 'name': 'kubernetesAuthToken', - 'type': 'str', - 'required': False, - 'validation': '[0-9a-zA-Z-_.]+', - 'helpMessage': 'Must be a valid Kubernetes server Token!', + "name": "kubernetesAuthToken", + "type": "str", + "required": False, + "validation": "[0-9a-zA-Z-_.]+", + "helpMessage": "Must be a valid Kubernetes server Token!", }, { - 'name': 'kubernetesAuthTokenFile', - 'type': 'str', - 'required': False, - 'validation': '(/[^/]+)+', - 'helpMessage': 'Must be a valid file path!', - 'default': '/var/run/secrets/kubernetes.io/serviceaccount/token' + "name": "kubernetesAuthTokenFile", + "type": "str", + "required": False, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + "default": "/var/run/secrets/kubernetes.io/serviceaccount/token", }, { - 'name': 'kubernetesServerCertificate', - 'type': 'textarea', - 'required': False, - 'validation': '-----BEGIN CERTIFICATE-----[a-zA-Z0-9/+\\s\\r\\n]+-----END CERTIFICATE-----', - 'helpMessage': 'Must be a valid Kubernetes server Certificate!', + "name": "kubernetesServerCertificate", + "type": "textarea", + "required": False, + "validation": "-----BEGIN CERTIFICATE-----[a-zA-Z0-9/+\\s\\r\\n]+-----END CERTIFICATE-----", + "helpMessage": "Must be a valid Kubernetes server Certificate!", }, { - 'name': 'kubernetesServerCertificateFile', - 'type': 'str', - 'required': False, - 'validation': '(/[^/]+)+', - 'helpMessage': 'Must be a valid file path!', - 'default': '/var/run/secrets/kubernetes.io/serviceaccount/ca.crt' + "name": "kubernetesServerCertificateFile", + "type": "str", + "required": False, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + "default": "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt", }, { - 'name': 'kubernetesNamespace', - 'type': 'str', - 'required': False, - 'validation': '[a-z0-9]([-a-z0-9]*[a-z0-9])?', - 'helpMessage': 'Must be a valid Kubernetes Namespace!', + "name": "kubernetesNamespace", + "type": "str", + "required": False, + "validation": "[a-z0-9]([-a-z0-9]*[a-z0-9])?", + "helpMessage": "Must be a valid Kubernetes Namespace!", }, { - 'name': 'kubernetesNamespaceFile', - 'type': 'str', - 'required': False, - 'validation': '(/[^/]+)+', - 'helpMessage': 'Must be a valid file path!', - 'default': '/var/run/secrets/kubernetes.io/serviceaccount/namespace' + "name": "kubernetesNamespaceFile", + "type": "str", + "required": False, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", + "default": "/var/run/secrets/kubernetes.io/serviceaccount/namespace", }, { - 'name': 'secretFormat', - 'type': 'select', - 'required': True, - 'available': ['Full', 'TLS', 'Certificate'], - 'helpMessage': 'The type of Secret to create.', - 'default': 'Full' - } + "name": "secretFormat", + "type": "select", + "required": True, + "available": ["Full", "TLS", "Certificate"], + "helpMessage": "The type of Secret to create.", + "default": "Full", + }, ] def __init__(self, *args, **kwargs): @@ -195,27 +196,28 @@ class KubernetesDestinationPlugin(DestinationPlugin): def upload(self, name, body, private_key, cert_chain, options, **kwargs): try: - k8_base_uri = self.get_option('kubernetesURL', options) - secret_format = self.get_option('secretFormat', options) - k8s_api = K8sSession( - self.k8s_bearer(options), - self.k8s_cert(options) - ) + k8_base_uri = self.get_option("kubernetesURL", options) + secret_format = self.get_option("secretFormat", options) + k8s_api = K8sSession(self.k8s_bearer(options), self.k8s_cert(options)) cn = common_name(parse_certificate(body)) - secret_name_format = self.get_option('secretNameFormat', options) + secret_name_format = self.get_option("secretNameFormat", options) secret_name = secret_name_format.format(common_name=cn) - secret = build_secret(secret_format, secret_name, body, private_key, cert_chain) + secret = build_secret( + secret_format, secret_name, body, private_key, cert_chain + ) err = ensure_resource( k8s_api, k8s_base_uri=k8_base_uri, namespace=self.k8s_namespace(options), kind="secret", name=secret_name, - data=secret + data=secret, ) except Exception as e: - current_app.logger.exception("Exception in upload: {}".format(e), exc_info=True) + current_app.logger.exception( + "Exception in upload: {}".format(e), exc_info=True + ) raise if err is not None: @@ -223,24 +225,28 @@ class KubernetesDestinationPlugin(DestinationPlugin): raise Exception("Error uploading secret: " + err) def k8s_bearer(self, options): - bearer = self.get_option('kubernetesAuthToken', options) + bearer = self.get_option("kubernetesAuthToken", options) if not bearer: - bearer_file = self.get_option('kubernetesAuthTokenFile', options) + bearer_file = self.get_option("kubernetesAuthTokenFile", options) with open(bearer_file, "r") as file: bearer = file.readline() if bearer: current_app.logger.debug("Using token read from %s", bearer_file) else: - raise Exception("Unable to locate token in options or from %s", bearer_file) + raise Exception( + "Unable to locate token in options or from %s", bearer_file + ) else: current_app.logger.debug("Using token from options") return bearer def k8s_cert(self, options): - cert_file = self.get_option('kubernetesServerCertificateFile', options) - cert = self.get_option('kubernetesServerCertificate', options) + cert_file = self.get_option("kubernetesServerCertificateFile", options) + cert = self.get_option("kubernetesServerCertificate", options) if cert: - cert_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'k8.cert') + cert_file = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "k8.cert" + ) with open(cert_file, "w") as text_file: text_file.write(cert) current_app.logger.debug("Using certificate from options") @@ -249,36 +255,69 @@ class KubernetesDestinationPlugin(DestinationPlugin): return cert_file def k8s_namespace(self, options): - namespace = self.get_option('kubernetesNamespace', options) + namespace = self.get_option("kubernetesNamespace", options) if not namespace: - namespace_file = self.get_option('kubernetesNamespaceFile', options) + namespace_file = self.get_option("kubernetesNamespaceFile", options) with open(namespace_file, "r") as file: namespace = file.readline() if namespace: - current_app.logger.debug("Using namespace %s from %s", namespace, namespace_file) + current_app.logger.debug( + "Using namespace %s from %s", namespace, namespace_file + ) else: - raise Exception("Unable to locate namespace in options or from %s", namespace_file) + raise Exception( + "Unable to locate namespace in options or from %s", namespace_file + ) else: current_app.logger.debug("Using namespace %s from options", namespace) return namespace class K8sSession(requests.Session): - def __init__(self, bearer, cert_file): super(K8sSession, self).__init__() - self.headers.update({ - 'Authorization': 'Bearer %s' % bearer - }) + self.headers.update({"Authorization": "Bearer %s" % bearer}) self.verify = cert_file - def request(self, method, url, params=None, data=None, headers=None, cookies=None, files=None, auth=None, - timeout=30, allow_redirects=True, proxies=None, hooks=None, stream=None, verify=None, cert=None, - json=None): + def request( + self, + method, + url, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + timeout=30, + allow_redirects=True, + proxies=None, + hooks=None, + stream=None, + verify=None, + cert=None, + json=None, + ): """ This method overrides the default timeout to be 10s. """ - return super(K8sSession, self).request(method, url, params, data, headers, cookies, files, auth, timeout, - allow_redirects, proxies, hooks, stream, verify, cert, json) + return super(K8sSession, self).request( + method, + url, + params, + data, + headers, + cookies, + files, + auth, + timeout, + allow_redirects, + proxies, + hooks, + stream, + verify, + cert, + json, + ) diff --git a/lemur/plugins/lemur_openssl/__init__.py b/lemur/plugins/lemur_openssl/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_openssl/__init__.py +++ b/lemur/plugins/lemur_openssl/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_openssl/plugin.py b/lemur/plugins/lemur_openssl/plugin.py index 6d6f89aa..02da311b 100644 --- a/lemur/plugins/lemur_openssl/plugin.py +++ b/lemur/plugins/lemur_openssl/plugin.py @@ -50,59 +50,66 @@ def create_pkcs12(cert, chain, p12_tmp, key, alias, passphrase): assert isinstance(key, str) with mktempfile() as key_tmp: - with open(key_tmp, 'w') as f: + with open(key_tmp, "w") as f: f.write(key) # Create PKCS12 keystore from private key and public certificate with mktempfile() as cert_tmp: - with open(cert_tmp, 'w') as f: + with open(cert_tmp, "w") as f: if chain: f.writelines([cert.strip() + "\n", chain.strip() + "\n"]) else: f.writelines([cert.strip() + "\n"]) - run_process([ - "openssl", - "pkcs12", - "-export", - "-name", alias, - "-in", cert_tmp, - "-inkey", key_tmp, - "-out", p12_tmp, - "-password", "pass:{}".format(passphrase) - ]) + run_process( + [ + "openssl", + "pkcs12", + "-export", + "-name", + alias, + "-in", + cert_tmp, + "-inkey", + key_tmp, + "-out", + p12_tmp, + "-password", + "pass:{}".format(passphrase), + ] + ) class OpenSSLExportPlugin(ExportPlugin): - title = 'OpenSSL' - slug = 'openssl-export' - description = 'Is a loose interface to openssl and support various formats' + title = "OpenSSL" + slug = "openssl-export" + description = "Is a loose interface to openssl and support various formats" version = openssl.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur" options = [ { - 'name': 'type', - 'type': 'select', - 'required': True, - 'available': ['PKCS12 (.p12)'], - 'helpMessage': 'Choose the format you wish to export', + "name": "type", + "type": "select", + "required": True, + "available": ["PKCS12 (.p12)"], + "helpMessage": "Choose the format you wish to export", }, { - 'name': 'passphrase', - 'type': 'str', - 'required': False, - 'helpMessage': 'If no passphrase is given one will be generated for you, we highly recommend this.', - 'validation': '' + "name": "passphrase", + "type": "str", + "required": False, + "helpMessage": "If no passphrase is given one will be generated for you, we highly recommend this.", + "validation": "", }, { - 'name': 'alias', - 'type': 'str', - 'required': False, - 'helpMessage': 'Enter the alias you wish to use for the keystore.', - } + "name": "alias", + "type": "str", + "required": False, + "helpMessage": "Enter the alias you wish to use for the keystore.", + }, ] def export(self, body, chain, key, options, **kwargs): @@ -115,20 +122,20 @@ class OpenSSLExportPlugin(ExportPlugin): :param options: :param kwargs: """ - if self.get_option('passphrase', options): - passphrase = self.get_option('passphrase', options) + if self.get_option("passphrase", options): + passphrase = self.get_option("passphrase", options) else: passphrase = get_psuedo_random_string() - if self.get_option('alias', options): - alias = self.get_option('alias', options) + if self.get_option("alias", options): + alias = self.get_option("alias", options) else: alias = common_name(parse_certificate(body)) - type = self.get_option('type', options) + type = self.get_option("type", options) with mktemppath() as output_tmp: - if type == 'PKCS12 (.p12)': + if type == "PKCS12 (.p12)": if not key: raise Exception("Private Key required by {0}".format(type)) @@ -137,7 +144,7 @@ class OpenSSLExportPlugin(ExportPlugin): else: raise Exception("Unable to export, unsupported type: {0}".format(type)) - with open(output_tmp, 'rb') as f: + with open(output_tmp, "rb") as f: raw = f.read() return extension, passphrase, raw diff --git a/lemur/plugins/lemur_openssl/tests/test_openssl.py b/lemur/plugins/lemur_openssl/tests/test_openssl.py index e24033e8..c332f941 100644 --- a/lemur/plugins/lemur_openssl/tests/test_openssl.py +++ b/lemur/plugins/lemur_openssl/tests/test_openssl.py @@ -4,8 +4,12 @@ from lemur.tests.vectors import INTERNAL_PRIVATE_KEY_A_STR, INTERNAL_CERTIFICATE def test_export_certificate_to_pkcs12(app): from lemur.plugins.base import plugins - p = plugins.get('openssl-export') - options = [{'name': 'passphrase', 'value': 'test1234'}, {'name': 'type', 'value': 'PKCS12 (.p12)'}] + + p = plugins.get("openssl-export") + options = [ + {"name": "passphrase", "value": "test1234"}, + {"name": "type", "value": "PKCS12 (.p12)"}, + ] with pytest.raises(Exception): p.export(INTERNAL_CERTIFICATE_A_STR, "", "", options) diff --git a/lemur/plugins/lemur_sftp/__init__.py b/lemur/plugins/lemur_sftp/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_sftp/__init__.py +++ b/lemur/plugins/lemur_sftp/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_sftp/plugin.py b/lemur/plugins/lemur_sftp/plugin.py index d74effc5..de8df427 100644 --- a/lemur/plugins/lemur_sftp/plugin.py +++ b/lemur/plugins/lemur_sftp/plugin.py @@ -27,107 +27,105 @@ from lemur.plugins.bases import DestinationPlugin class SFTPDestinationPlugin(DestinationPlugin): - title = 'SFTP' - slug = 'sftp-destination' - description = 'Allow the uploading of certificates to SFTP' + title = "SFTP" + slug = "sftp-destination" + description = "Allow the uploading of certificates to SFTP" version = lemur_sftp.VERSION - author = 'Dmitry Zykov' - author_url = 'https://github.com/DmitryZykov' + author = "Dmitry Zykov" + author_url = "https://github.com/DmitryZykov" options = [ { - 'name': 'host', - 'type': 'str', - 'required': True, - 'helpMessage': 'The SFTP host.' + "name": "host", + "type": "str", + "required": True, + "helpMessage": "The SFTP host.", }, { - 'name': 'port', - 'type': 'int', - 'required': True, - 'helpMessage': 'The SFTP port, default is 22.', - 'validation': '^(6553[0-5]|655[0-2][0-9]\d|65[0-4](\d){2}|6[0-4](\d){3}|[1-5](\d){4}|[1-9](\d){0,3})', - 'default': '22' + "name": "port", + "type": "int", + "required": True, + "helpMessage": "The SFTP port, default is 22.", + "validation": "^(6553[0-5]|655[0-2][0-9]\d|65[0-4](\d){2}|6[0-4](\d){3}|[1-5](\d){4}|[1-9](\d){0,3})", + "default": "22", }, { - 'name': 'user', - 'type': 'str', - 'required': True, - 'helpMessage': 'The SFTP user. Default is root.', - 'default': 'root' + "name": "user", + "type": "str", + "required": True, + "helpMessage": "The SFTP user. Default is root.", + "default": "root", }, { - 'name': 'password', - 'type': 'str', - 'required': False, - 'helpMessage': 'The SFTP password (optional when the private key is used).', - 'default': None + "name": "password", + "type": "str", + "required": False, + "helpMessage": "The SFTP password (optional when the private key is used).", + "default": None, }, { - 'name': 'privateKeyPath', - 'type': 'str', - 'required': False, - 'helpMessage': 'The path to the RSA private key on the Lemur server (optional).', - 'default': None + "name": "privateKeyPath", + "type": "str", + "required": False, + "helpMessage": "The path to the RSA private key on the Lemur server (optional).", + "default": None, }, { - 'name': 'privateKeyPass', - 'type': 'str', - 'required': False, - 'helpMessage': 'The password for the encrypted RSA private key (optional).', - 'default': None + "name": "privateKeyPass", + "type": "str", + "required": False, + "helpMessage": "The password for the encrypted RSA private key (optional).", + "default": None, }, { - 'name': 'destinationPath', - 'type': 'str', - 'required': True, - 'helpMessage': 'The SFTP path where certificates will be uploaded.', - 'default': '/etc/nginx/certs' + "name": "destinationPath", + "type": "str", + "required": True, + "helpMessage": "The SFTP path where certificates will be uploaded.", + "default": "/etc/nginx/certs", }, { - 'name': 'exportFormat', - 'required': True, - 'value': 'NGINX', - 'helpMessage': 'The export format for certificates.', - 'type': 'select', - 'available': [ - 'NGINX', - 'Apache' - ] - } + "name": "exportFormat", + "required": True, + "value": "NGINX", + "helpMessage": "The export format for certificates.", + "type": "select", + "available": ["NGINX", "Apache"], + }, ] def upload(self, name, body, private_key, cert_chain, options, **kwargs): - current_app.logger.debug('SFTP destination plugin is started') + current_app.logger.debug("SFTP destination plugin is started") cn = common_name(parse_certificate(body)) - host = self.get_option('host', options) - port = self.get_option('port', options) - user = self.get_option('user', options) - password = self.get_option('password', options) - ssh_priv_key = self.get_option('privateKeyPath', options) - ssh_priv_key_pass = self.get_option('privateKeyPass', options) - dst_path = self.get_option('destinationPath', options) - export_format = self.get_option('exportFormat', options) + host = self.get_option("host", options) + port = self.get_option("port", options) + user = self.get_option("user", options) + password = self.get_option("password", options) + ssh_priv_key = self.get_option("privateKeyPath", options) + ssh_priv_key_pass = self.get_option("privateKeyPass", options) + dst_path = self.get_option("destinationPath", options) + export_format = self.get_option("exportFormat", options) # prepare files for upload - files = {cn + '.key': private_key, - cn + '.pem': body} + files = {cn + ".key": private_key, cn + ".pem": body} if cert_chain: - if export_format == 'NGINX': + if export_format == "NGINX": # assemble body + chain in the single file - files[cn + '.pem'] += '\n' + cert_chain + files[cn + ".pem"] += "\n" + cert_chain - elif export_format == 'Apache': + elif export_format == "Apache": # store chain in the separate file - files[cn + '.ca.bundle.pem'] = cert_chain + files[cn + ".ca.bundle.pem"] = cert_chain # upload files try: - current_app.logger.debug('Connecting to {0}@{1}:{2}'.format(user, host, port)) + current_app.logger.debug( + "Connecting to {0}@{1}:{2}".format(user, host, port) + ) ssh = paramiko.SSHClient() # allow connection to the new unknown host @@ -135,14 +133,18 @@ class SFTPDestinationPlugin(DestinationPlugin): # open the ssh connection if password: - current_app.logger.debug('Using password') + current_app.logger.debug("Using password") ssh.connect(host, username=user, port=port, password=password) elif ssh_priv_key: - current_app.logger.debug('Using RSA private key') - pkey = paramiko.RSAKey.from_private_key_file(ssh_priv_key, ssh_priv_key_pass) + current_app.logger.debug("Using RSA private key") + pkey = paramiko.RSAKey.from_private_key_file( + ssh_priv_key, ssh_priv_key_pass + ) ssh.connect(host, username=user, port=port, pkey=pkey) else: - current_app.logger.error("No password or private key provided. Can't proceed") + current_app.logger.error( + "No password or private key provided. Can't proceed" + ) raise paramiko.ssh_exception.AuthenticationException # open the sftp session inside the ssh connection @@ -150,29 +152,33 @@ class SFTPDestinationPlugin(DestinationPlugin): # make sure that the destination path exist try: - current_app.logger.debug('Creating {0}'.format(dst_path)) + current_app.logger.debug("Creating {0}".format(dst_path)) sftp.mkdir(dst_path) except IOError: - current_app.logger.debug('{0} already exist, resuming'.format(dst_path)) + current_app.logger.debug("{0} already exist, resuming".format(dst_path)) try: - dst_path_cn = dst_path + '/' + cn - current_app.logger.debug('Creating {0}'.format(dst_path_cn)) + dst_path_cn = dst_path + "/" + cn + current_app.logger.debug("Creating {0}".format(dst_path_cn)) sftp.mkdir(dst_path_cn) except IOError: - current_app.logger.debug('{0} already exist, resuming'.format(dst_path_cn)) + current_app.logger.debug( + "{0} already exist, resuming".format(dst_path_cn) + ) # upload certificate files to the sftp destination for filename, data in files.items(): - current_app.logger.debug('Uploading {0} to {1}'.format(filename, dst_path_cn)) - with sftp.open(dst_path_cn + '/' + filename, 'w') as f: + current_app.logger.debug( + "Uploading {0} to {1}".format(filename, dst_path_cn) + ) + with sftp.open(dst_path_cn + "/" + filename, "w") as f: f.write(data) # read only for owner, -r-------- - sftp.chmod(dst_path_cn + '/' + filename, 0o400) + sftp.chmod(dst_path_cn + "/" + filename, 0o400) ssh.close() except Exception as e: - current_app.logger.error('ERROR in {0}: {1}'.format(e.__class__, e)) + current_app.logger.error("ERROR in {0}: {1}".format(e.__class__, e)) try: ssh.close() except BaseException: diff --git a/lemur/plugins/lemur_slack/__init__.py b/lemur/plugins/lemur_slack/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_slack/__init__.py +++ b/lemur/plugins/lemur_slack/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_slack/plugin.py b/lemur/plugins/lemur_slack/plugin.py index a986aa9a..7569d295 100644 --- a/lemur/plugins/lemur_slack/plugin.py +++ b/lemur/plugins/lemur_slack/plugin.py @@ -17,102 +17,101 @@ import requests def create_certificate_url(name): - return 'https://{hostname}/#/certificates/{name}'.format( - hostname=current_app.config.get('LEMUR_HOSTNAME'), - name=name + return "https://{hostname}/#/certificates/{name}".format( + hostname=current_app.config.get("LEMUR_HOSTNAME"), name=name ) def create_expiration_attachments(certificates): attachments = [] for certificate in certificates: - attachments.append({ - 'title': certificate['name'], - 'title_link': create_certificate_url(certificate['name']), - 'color': 'danger', - 'fallback': '', - 'fields': [ - { - 'title': 'Owner', - 'value': certificate['owner'], - 'short': True - }, - { - 'title': 'Expires', - 'value': arrow.get(certificate['validityEnd']).format('dddd, MMMM D, YYYY'), - 'short': True - }, - { - 'title': 'Endpoints Detected', - 'value': len(certificate['endpoints']), - 'short': True - } - ], - 'text': '', - 'mrkdwn_in': ['text'] - }) + attachments.append( + { + "title": certificate["name"], + "title_link": create_certificate_url(certificate["name"]), + "color": "danger", + "fallback": "", + "fields": [ + {"title": "Owner", "value": certificate["owner"], "short": True}, + { + "title": "Expires", + "value": arrow.get(certificate["validityEnd"]).format( + "dddd, MMMM D, YYYY" + ), + "short": True, + }, + { + "title": "Endpoints Detected", + "value": len(certificate["endpoints"]), + "short": True, + }, + ], + "text": "", + "mrkdwn_in": ["text"], + } + ) return attachments def create_rotation_attachments(certificate): return { - 'title': certificate['name'], - 'title_link': create_certificate_url(certificate['name']), - 'fields': [ + "title": certificate["name"], + "title_link": create_certificate_url(certificate["name"]), + "fields": [ { + {"title": "Owner", "value": certificate["owner"], "short": True}, { - 'title': 'Owner', - 'value': certificate['owner'], - 'short': True + "title": "Expires", + "value": arrow.get(certificate["validityEnd"]).format( + "dddd, MMMM D, YYYY" + ), + "short": True, }, { - 'title': 'Expires', - 'value': arrow.get(certificate['validityEnd']).format('dddd, MMMM D, YYYY'), - 'short': True + "title": "Replaced By", + "value": len(certificate["replaced"][0]["name"]), + "short": True, }, { - 'title': 'Replaced By', - 'value': len(certificate['replaced'][0]['name']), - 'short': True + "title": "Endpoints Rotated", + "value": len(certificate["endpoints"]), + "short": True, }, - { - 'title': 'Endpoints Rotated', - 'value': len(certificate['endpoints']), - 'short': True - } } - ] + ], } class SlackNotificationPlugin(ExpirationNotificationPlugin): - title = 'Slack' - slug = 'slack-notification' - description = 'Sends notifications to Slack' + title = "Slack" + slug = "slack-notification" + description = "Sends notifications to Slack" version = slack.VERSION - author = 'Harm Weites' - author_url = 'https://github.com/netflix/lemur' + author = "Harm Weites" + author_url = "https://github.com/netflix/lemur" additional_options = [ { - 'name': 'webhook', - 'type': 'str', - 'required': True, - 'validation': '^https:\/\/hooks\.slack\.com\/services\/.+$', - 'helpMessage': 'The url Slack told you to use for this integration', - }, { - 'name': 'username', - 'type': 'str', - 'validation': '^.+$', - 'helpMessage': 'The great storyteller', - 'default': 'Lemur' - }, { - 'name': 'recipients', - 'type': 'str', - 'required': True, - 'validation': '^(@|#).+$', - 'helpMessage': 'Where to send to, either @username or #channel', + "name": "webhook", + "type": "str", + "required": True, + "validation": "^https:\/\/hooks\.slack\.com\/services\/.+$", + "helpMessage": "The url Slack told you to use for this integration", + }, + { + "name": "username", + "type": "str", + "validation": "^.+$", + "helpMessage": "The great storyteller", + "default": "Lemur", + }, + { + "name": "recipients", + "type": "str", + "required": True, + "validation": "^(@|#).+$", + "helpMessage": "Where to send to, either @username or #channel", }, ] @@ -122,25 +121,27 @@ class SlackNotificationPlugin(ExpirationNotificationPlugin): `lemur notify` """ attachments = None - if notification_type == 'expiration': + if notification_type == "expiration": attachments = create_expiration_attachments(message) - elif notification_type == 'rotation': + elif notification_type == "rotation": attachments = create_rotation_attachments(message) if not attachments: - raise Exception('Unable to create message attachments') + raise Exception("Unable to create message attachments") body = { - 'text': 'Lemur {0} Notification'.format(notification_type.capitalize()), - 'attachments': attachments, - 'channel': self.get_option('recipients', options), - 'username': self.get_option('username', options) + "text": "Lemur {0} Notification".format(notification_type.capitalize()), + "attachments": attachments, + "channel": self.get_option("recipients", options), + "username": self.get_option("username", options), } - r = requests.post(self.get_option('webhook', options), json.dumps(body)) + r = requests.post(self.get_option("webhook", options), json.dumps(body)) if r.status_code not in [200]: - raise Exception('Failed to send message') + raise Exception("Failed to send message") - current_app.logger.error("Slack response: {0} Message Body: {1}".format(r.status_code, body)) + current_app.logger.error( + "Slack response: {0} Message Body: {1}".format(r.status_code, body) + ) diff --git a/lemur/plugins/lemur_slack/tests/test_slack.py b/lemur/plugins/lemur_slack/tests/test_slack.py index 701f69d9..86add25f 100644 --- a/lemur/plugins/lemur_slack/tests/test_slack.py +++ b/lemur/plugins/lemur_slack/tests/test_slack.py @@ -1,33 +1,23 @@ - - def test_formatting(certificate): from lemur.plugins.lemur_slack.plugin import create_expiration_attachments from lemur.certificates.schemas import certificate_notification_output_schema + data = [certificate_notification_output_schema.dump(certificate).data] attachment = { - 'title': certificate.name, - 'color': 'danger', - 'fields': [ - { - 'short': True, - 'value': 'joe@example.com', - 'title': 'Owner' - }, - { - 'short': True, - 'value': u'Tuesday, December 31, 2047', - 'title': 'Expires' - }, { - 'short': True, - 'value': 0, - 'title': 'Endpoints Detected' - } + "title": certificate.name, + "color": "danger", + "fields": [ + {"short": True, "value": "joe@example.com", "title": "Owner"}, + {"short": True, "value": u"Tuesday, December 31, 2047", "title": "Expires"}, + {"short": True, "value": 0, "title": "Endpoints Detected"}, ], - 'title_link': 'https://lemur.example.com/#/certificates/{name}'.format(name=certificate.name), - 'mrkdwn_in': ['text'], - 'text': '', - 'fallback': '' + "title_link": "https://lemur.example.com/#/certificates/{name}".format( + name=certificate.name + ), + "mrkdwn_in": ["text"], + "text": "", + "fallback": "", } assert attachment == create_expiration_attachments(data)[0] diff --git a/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py b/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py index 3a751848..b4d708ce 100644 --- a/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py +++ b/lemur/plugins/lemur_statsd/lemur_statsd/__init__.py @@ -1,4 +1,4 @@ try: - VERSION = __import__('pkg_resources').get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'Unknown' + VERSION = "Unknown" diff --git a/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py b/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py index a6a87c66..293b4634 100644 --- a/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py +++ b/lemur/plugins/lemur_statsd/lemur_statsd/plugin.py @@ -6,40 +6,44 @@ from datadog import DogStatsd class StatsdMetricPlugin(MetricPlugin): - title = 'Statsd' - slug = 'statsd-metrics' - description = 'Adds support for sending metrics to Statsd' + title = "Statsd" + slug = "statsd-metrics" + description = "Adds support for sending metrics to Statsd" version = plug.VERSION def __init__(self): - host = current_app.config.get('STATSD_HOST') - port = current_app.config.get('STATSD_PORT') - prefix = current_app.config.get('STATSD_PREFIX') + host = current_app.config.get("STATSD_HOST") + port = current_app.config.get("STATSD_PORT") + prefix = current_app.config.get("STATSD_PREFIX") self.statsd = DogStatsd(host=host, port=port, namespace=prefix) - def submit(self, metric_name, metric_type, metric_value, metric_tags=None, options=None): - valid_types = ['COUNTER', 'GAUGE', 'TIMER'] + def submit( + self, metric_name, metric_type, metric_value, metric_tags=None, options=None + ): + valid_types = ["COUNTER", "GAUGE", "TIMER"] tags = [] if metric_type.upper() not in valid_types: raise Exception( "Invalid Metric Type for Statsd, '{metric}' choose from: {options}".format( - metric=metric_type, options=','.join(valid_types) + metric=metric_type, options=",".join(valid_types) ) ) if metric_tags: if not isinstance(metric_tags, dict): - raise Exception("Invalid Metric Tags for Statsd: Tags must be in dict format") + raise Exception( + "Invalid Metric Tags for Statsd: Tags must be in dict format" + ) else: tags = map(lambda e: "{0}:{1}".format(*e), metric_tags.items()) - if metric_type.upper() == 'COUNTER': + if metric_type.upper() == "COUNTER": self.statsd.increment(metric_name, metric_value, tags) - elif metric_type.upper() == 'GAUGE': + elif metric_type.upper() == "GAUGE": self.statsd.gauge(metric_name, metric_value, tags) - elif metric_type.upper() == 'TIMER': + elif metric_type.upper() == "TIMER": self.statsd.timing(metric_name, metric_value, tags) return diff --git a/lemur/plugins/lemur_statsd/setup.py b/lemur/plugins/lemur_statsd/setup.py index 6c4c2dd6..9b3c5f52 100644 --- a/lemur/plugins/lemur_statsd/setup.py +++ b/lemur/plugins/lemur_statsd/setup.py @@ -2,23 +2,16 @@ from __future__ import absolute_import from setuptools import setup, find_packages -install_requires = [ - 'lemur', - 'datadog' -] +install_requires = ["lemur", "datadog"] setup( - name='lemur_statsd', - version='1.0.0', - author='Cloudflare Security Engineering', - author_email='', + name="lemur_statsd", + version="1.0.0", + author="Cloudflare Security Engineering", + author_email="", include_package_data=True, packages=find_packages(), zip_safe=False, install_requires=install_requires, - entry_points={ - 'lemur.plugins': [ - 'statsd = lemur_statsd.plugin:StatsdMetricPlugin', - ] - } + entry_points={"lemur.plugins": ["statsd = lemur_statsd.plugin:StatsdMetricPlugin"]}, ) diff --git a/lemur/plugins/lemur_vault_dest/__init__.py b/lemur/plugins/lemur_vault_dest/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_vault_dest/__init__.py +++ b/lemur/plugins/lemur_vault_dest/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_vault_dest/plugin.py b/lemur/plugins/lemur_vault_dest/plugin.py index 803b0a0c..c8843cf5 100644 --- a/lemur/plugins/lemur_vault_dest/plugin.py +++ b/lemur/plugins/lemur_vault_dest/plugin.py @@ -25,59 +25,57 @@ from cryptography.hazmat.backends import default_backend class VaultSourcePlugin(SourcePlugin): """ Class for importing certificates from Hashicorp Vault""" - title = 'Vault' - slug = 'vault-source' - description = 'Discovers all certificates in a given path' - author = 'Christopher Jolley' - author_url = 'https://github.com/alwaysjolley/lemur' + title = "Vault" + slug = "vault-source" + description = "Discovers all certificates in a given path" + + author = "Christopher Jolley" + author_url = "https://github.com/alwaysjolley/lemur" options = [ { - 'name': 'vaultUrl', - 'type': 'str', - 'required': True, - 'validation': '^https?://[a-zA-Z0-9.:-]+$', - 'helpMessage': 'Valid URL to Hashi Vault instance' + "name": "vaultUrl", + "type": "str", + "required": True, + "validation": "^https?://[a-zA-Z0-9.:-]+$", + "helpMessage": "Valid URL to Hashi Vault instance", }, { - 'name': 'vaultKvApiVersion', - 'type': 'select', - 'value': '2', - 'available': [ - '1', - '2' - ], - 'required': True, - 'helpMessage': 'Version of the Vault KV API to use' + "name": "vaultKvApiVersion", + "type": "select", + "value": "2", + "available": ["1", "2"], + "required": True, + "helpMessage": "Version of the Vault KV API to use", }, { - 'name': 'vaultAuthTokenFile', - 'type': 'str', - 'required': True, - 'validation': '(/[^/]+)+', - 'helpMessage': 'Must be a valid file path!' + "name": "vaultAuthTokenFile", + "type": "str", + "required": True, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", }, { - 'name': 'vaultMount', - 'type': 'str', - 'required': True, - 'validation': r'^\S+$', - 'helpMessage': 'Must be a valid Vault secrets mount name!' + "name": "vaultMount", + "type": "str", + "required": True, + "validation": r"^\S+$", + "helpMessage": "Must be a valid Vault secrets mount name!", }, { - 'name': 'vaultPath', - 'type': 'str', - 'required': True, - 'validation': '^([a-zA-Z0-9_-]+/?)+$', - 'helpMessage': 'Must be a valid Vault secrets path' + "name": "vaultPath", + "type": "str", + "required": True, + "validation": "^([a-zA-Z0-9_-]+/?)+$", + "helpMessage": "Must be a valid Vault secrets path", }, { - 'name': 'objectName', - 'type': 'str', - 'required': True, - 'validation': '[0-9a-zA-Z:_-]+', - 'helpMessage': 'Object Name to search' + "name": "objectName", + "type": "str", + "required": True, + "validation": "[0-9a-zA-Z:_-]+", + "helpMessage": "Object Name to search", }, ] @@ -85,38 +83,38 @@ class VaultSourcePlugin(SourcePlugin): """Pull certificates from objects in Hashicorp Vault""" data = [] cert = [] - body = '' - url = self.get_option('vaultUrl', options) - token_file = self.get_option('vaultAuthTokenFile', options) - mount = self.get_option('vaultMount', options) - path = self.get_option('vaultPath', options) - obj_name = self.get_option('objectName', options) - api_version = self.get_option('vaultKvApiVersion', options) - cert_filter = '-----BEGIN CERTIFICATE-----' - cert_delimiter = '-----END CERTIFICATE-----' + body = "" + url = self.get_option("vaultUrl", options) + token_file = self.get_option("vaultAuthTokenFile", options) + mount = self.get_option("vaultMount", options) + path = self.get_option("vaultPath", options) + obj_name = self.get_option("objectName", options) + api_version = self.get_option("vaultKvApiVersion", options) + cert_filter = "-----BEGIN CERTIFICATE-----" + cert_delimiter = "-----END CERTIFICATE-----" - with open(token_file, 'r') as tfile: - token = tfile.readline().rstrip('\n') + with open(token_file, "r") as tfile: + token = tfile.readline().rstrip("\n") client = hvac.Client(url=url, token=token) client.secrets.kv.default_kv_version = api_version - path = '{0}/{1}'.format(path, obj_name) + path = "{0}/{1}".format(path, obj_name) secret = get_secret(client, mount, path) - for cname in secret['data']: - if 'crt' in secret['data'][cname]: - cert = secret['data'][cname]['crt'].split(cert_delimiter + '\n') - elif 'pem' in secret['data'][cname]: - cert = secret['data'][cname]['pem'].split(cert_delimiter + '\n') + for cname in secret["data"]: + if "crt" in secret["data"][cname]: + cert = secret["data"][cname]["crt"].split(cert_delimiter + "\n") + elif "pem" in secret["data"][cname]: + cert = secret["data"][cname]["pem"].split(cert_delimiter + "\n") else: - for key in secret['data'][cname]: - if secret['data'][cname][key].startswith(cert_filter): - cert = secret['data'][cname][key].split(cert_delimiter + '\n') + for key in secret["data"][cname]: + if secret["data"][cname][key].startswith(cert_filter): + cert = secret["data"][cname][key].split(cert_delimiter + "\n") break body = cert[0] + cert_delimiter - if 'chain' in secret['data'][cname]: - chain = secret['data'][cname]['chain'] + if "chain" in secret["data"][cname]: + chain = secret["data"][cname]["chain"] elif len(cert) > 1: if cert[1].startswith(cert_filter): chain = cert[1] + cert_delimiter @@ -124,8 +122,10 @@ class VaultSourcePlugin(SourcePlugin): chain = None else: chain = None - data.append({'body': body, 'chain': chain, 'name': cname}) - return [dict(body=c['body'], chain=c.get('chain'), name=c['name']) for c in data] + data.append({"body": body, "chain": chain, "name": cname}) + return [ + dict(body=c["body"], chain=c.get("chain"), name=c["name"]) for c in data + ] def get_endpoints(self, options, **kwargs): """ Not implemented yet """ @@ -135,81 +135,74 @@ class VaultSourcePlugin(SourcePlugin): class VaultDestinationPlugin(DestinationPlugin): """Hashicorp Vault Destination plugin for Lemur""" - title = 'Vault' - slug = 'hashi-vault-destination' - description = 'Allow the uploading of certificates to Hashi Vault as secret' - author = 'Christopher Jolley' - author_url = 'https://github.com/alwaysjolley/lemur' + title = "Vault" + slug = "hashi-vault-destination" + description = "Allow the uploading of certificates to Hashi Vault as secret" + + author = "Christopher Jolley" + author_url = "https://github.com/alwaysjolley/lemur" options = [ { - 'name': 'vaultUrl', - 'type': 'str', - 'required': True, - 'validation': '^https?://[a-zA-Z0-9.:-]+$', - 'helpMessage': 'Valid URL to Hashi Vault instance' + "name": "vaultUrl", + "type": "str", + "required": True, + "validation": "^https?://[a-zA-Z0-9.:-]+$", + "helpMessage": "Valid URL to Hashi Vault instance", }, { - 'name': 'vaultKvApiVersion', - 'type': 'select', - 'value': '2', - 'available': [ - '1', - '2' - ], - 'required': True, - 'helpMessage': 'Version of the Vault KV API to use' + "name": "vaultKvApiVersion", + "type": "select", + "value": "2", + "available": ["1", "2"], + "required": True, + "helpMessage": "Version of the Vault KV API to use", }, { - 'name': 'vaultAuthTokenFile', - 'type': 'str', - 'required': True, - 'validation': '(/[^/]+)+', - 'helpMessage': 'Must be a valid file path!' + "name": "vaultAuthTokenFile", + "type": "str", + "required": True, + "validation": "(/[^/]+)+", + "helpMessage": "Must be a valid file path!", }, { - 'name': 'vaultMount', - 'type': 'str', - 'required': True, - 'validation': r'^\S+$', - 'helpMessage': 'Must be a valid Vault secrets mount name!' + "name": "vaultMount", + "type": "str", + "required": True, + "validation": r"^\S+$", + "helpMessage": "Must be a valid Vault secrets mount name!", }, { - 'name': 'vaultPath', - 'type': 'str', - 'required': True, - 'validation': '^([a-zA-Z0-9_-]+/?)+$', - 'helpMessage': 'Must be a valid Vault secrets path' + "name": "vaultPath", + "type": "str", + "required": True, + "validation": "^([a-zA-Z0-9_-]+/?)+$", + "helpMessage": "Must be a valid Vault secrets path", }, { - 'name': 'objectName', - 'type': 'str', - 'required': False, - 'validation': '[0-9a-zA-Z:_-]+', - 'helpMessage': 'Name to bundle certs under, if blank use cn' + "name": "objectName", + "type": "str", + "required": False, + "validation": "[0-9a-zA-Z:_-]+", + "helpMessage": "Name to bundle certs under, if blank use cn", }, { - 'name': 'bundleChain', - 'type': 'select', - 'value': 'cert only', - 'available': [ - 'Nginx', - 'Apache', - 'PEM', - 'no chain' - ], - 'required': True, - 'helpMessage': 'Bundle the chain into the certificate' + "name": "bundleChain", + "type": "select", + "value": "cert only", + "available": ["Nginx", "Apache", "PEM", "no chain"], + "required": True, + "helpMessage": "Bundle the chain into the certificate", }, { - 'name': 'sanFilter', - 'type': 'str', - 'value': '.*', - 'required': False, - 'validation': '.*', - 'helpMessage': 'Valid regex filter' - } + "name": "sanFilter", + "type": "str", + "value": ".*", + "required": False, + "validation": ".*", + "helpMessage": "Valid regex filter", + }, ] def __init__(self, *args, **kwargs): @@ -225,14 +218,14 @@ class VaultDestinationPlugin(DestinationPlugin): """ cname = common_name(parse_certificate(body)) - url = self.get_option('vaultUrl', options) - token_file = self.get_option('vaultAuthTokenFile', options) - mount = self.get_option('vaultMount', options) - path = self.get_option('vaultPath', options) - bundle = self.get_option('bundleChain', options) - obj_name = self.get_option('objectName', options) - api_version = self.get_option('vaultKvApiVersion', options) - san_filter = self.get_option('sanFilter', options) + url = self.get_option("vaultUrl", options) + token_file = self.get_option("vaultAuthTokenFile", options) + mount = self.get_option("vaultMount", options) + path = self.get_option("vaultPath", options) + bundle = self.get_option("bundleChain", options) + obj_name = self.get_option("objectName", options) + api_version = self.get_option("vaultKvApiVersion", options) + san_filter = self.get_option("sanFilter", options) san_list = get_san_list(body) if san_filter: @@ -240,58 +233,67 @@ class VaultDestinationPlugin(DestinationPlugin): try: if not re.match(san_filter, san, flags=re.IGNORECASE): current_app.logger.exception( - "Exception uploading secret to vault: invalid SAN: {}".format(san), - exc_info=True) + "Exception uploading secret to vault: invalid SAN: {}".format( + san + ), + exc_info=True, + ) os._exit(1) except re.error: current_app.logger.exception( "Exception compiling regex filter: invalid filter", - exc_info=True) + exc_info=True, + ) - with open(token_file, 'r') as tfile: - token = tfile.readline().rstrip('\n') + with open(token_file, "r") as tfile: + token = tfile.readline().rstrip("\n") client = hvac.Client(url=url, token=token) client.secrets.kv.default_kv_version = api_version if obj_name: - path = '{0}/{1}'.format(path, obj_name) + path = "{0}/{1}".format(path, obj_name) else: - path = '{0}/{1}'.format(path, cname) + path = "{0}/{1}".format(path, cname) secret = get_secret(client, mount, path) - secret['data'][cname] = {} + secret["data"][cname] = {} - if bundle == 'Nginx': - secret['data'][cname]['crt'] = '{0}\n{1}'.format(body, cert_chain) - secret['data'][cname]['key'] = private_key - elif bundle == 'Apache': - secret['data'][cname]['crt'] = body - secret['data'][cname]['chain'] = cert_chain - secret['data'][cname]['key'] = private_key - elif bundle == 'PEM': - secret['data'][cname]['pem'] = '{0}\n{1}\n{2}'.format(body, cert_chain, private_key) + if bundle == "Nginx": + secret["data"][cname]["crt"] = "{0}\n{1}".format(body, cert_chain) + secret["data"][cname]["key"] = private_key + elif bundle == "Apache": + secret["data"][cname]["crt"] = body + secret["data"][cname]["chain"] = cert_chain + secret["data"][cname]["key"] = private_key + elif bundle == "PEM": + secret["data"][cname]["pem"] = "{0}\n{1}\n{2}".format( + body, cert_chain, private_key + ) else: - secret['data'][cname]['crt'] = body - secret['data'][cname]['key'] = private_key + secret["data"][cname]["crt"] = body + secret["data"][cname]["key"] = private_key if isinstance(san_list, list): - secret['data'][cname]['san'] = san_list + secret["data"][cname]["san"] = san_list try: client.secrets.kv.create_or_update_secret( - path=path, mount_point=mount, secret=secret['data'] + path=path, mount_point=mount, secret=secret["data"] ) except ConnectionError as err: current_app.logger.exception( - "Exception uploading secret to vault: {0}".format(err), exc_info=True) + "Exception uploading secret to vault: {0}".format(err), exc_info=True + ) def get_san_list(body): """ parse certificate for SAN names and return list, return empty list on error """ san_list = [] try: - byte_body = body.encode('utf-8') + byte_body = body.encode("utf-8") cert = x509.load_pem_x509_certificate(byte_body, default_backend()) - ext = cert.extensions.get_extension_for_oid(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + ext = cert.extensions.get_extension_for_oid( + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ) san_list = ext.value.get_values_for_type(x509.DNSName) except x509.extensions.ExtensionNotFound: pass @@ -301,12 +303,14 @@ def get_san_list(body): def get_secret(client, mount, path): """ retreive existing data from mount path and return dictionary """ - result = {'data': {}} + result = {"data": {}} try: - if client.secrets.kv.default_kv_version == '1': + if client.secrets.kv.default_kv_version == "1": result = client.secrets.kv.v1.read_secret(path=path, mount_point=mount) else: - result = client.secrets.kv.v2.read_secret_version(path=path, mount_point=mount) + result = client.secrets.kv.v2.read_secret_version( + path=path, mount_point=mount + ) except ConnectionError: pass finally: diff --git a/lemur/plugins/lemur_verisign/__init__.py b/lemur/plugins/lemur_verisign/__init__.py index 8ce5a7f3..f8afd7e3 100644 --- a/lemur/plugins/lemur_verisign/__init__.py +++ b/lemur/plugins/lemur_verisign/__init__.py @@ -1,5 +1,4 @@ try: - VERSION = __import__('pkg_resources') \ - .get_distribution(__name__).version + VERSION = __import__("pkg_resources").get_distribution(__name__).version except Exception as e: - VERSION = 'unknown' + VERSION = "unknown" diff --git a/lemur/plugins/lemur_verisign/plugin.py b/lemur/plugins/lemur_verisign/plugin.py index e5207def..65bd1cac 100644 --- a/lemur/plugins/lemur_verisign/plugin.py +++ b/lemur/plugins/lemur_verisign/plugin.py @@ -58,7 +58,7 @@ VERISIGN_ERRORS = { "0x300a": "Domain/SubjectAltName Mismatched -- make sure that the SANs have the proper domain suffix", "0x950e": "Invalid Common Name -- make sure the CN has a proper domain suffix", "0xa00e": "Pending. (Insufficient number of tokens.)", - "0x8134": "Pending. (Domain failed CAA validation.)" + "0x8134": "Pending. (Domain failed CAA validation.)", } @@ -71,7 +71,7 @@ def log_status_code(r, *args, **kwargs): :param kwargs: :return: """ - metrics.send('symantec_status_code_{}'.format(r.status_code), 'counter', 1) + metrics.send("symantec_status_code_{}".format(r.status_code), "counter", 1) def get_additional_names(options): @@ -83,8 +83,8 @@ def get_additional_names(options): """ names = [] # add SANs if present - if options.get('extensions'): - for san in options['extensions']['sub_alt_names']: + if options.get("extensions"): + for san in options["extensions"]["sub_alt_names"]: if isinstance(san, x509.DNSName): names.append(san.value) return names @@ -99,37 +99,43 @@ def process_options(options): :return: dict or valid verisign options """ data = { - 'challenge': get_psuedo_random_string(), - 'serverType': 'Apache', - 'certProductType': 'Server', - 'firstName': current_app.config.get("VERISIGN_FIRST_NAME"), - 'lastName': current_app.config.get("VERISIGN_LAST_NAME"), - 'signatureAlgorithm': 'sha256WithRSAEncryption', - 'email': current_app.config.get("VERISIGN_EMAIL"), - 'ctLogOption': current_app.config.get("VERISIGN_CS_LOG_OPTION", "public"), + "challenge": get_psuedo_random_string(), + "serverType": "Apache", + "certProductType": "Server", + "firstName": current_app.config.get("VERISIGN_FIRST_NAME"), + "lastName": current_app.config.get("VERISIGN_LAST_NAME"), + "signatureAlgorithm": "sha256WithRSAEncryption", + "email": current_app.config.get("VERISIGN_EMAIL"), + "ctLogOption": current_app.config.get("VERISIGN_CS_LOG_OPTION", "public"), } - data['subject_alt_names'] = ",".join(get_additional_names(options)) + data["subject_alt_names"] = ",".join(get_additional_names(options)) - if options.get('validity_end') > arrow.utcnow().replace(years=2): - raise Exception("Verisign issued certificates cannot exceed two years in validity") + if options.get("validity_end") > arrow.utcnow().replace(years=2): + raise Exception( + "Verisign issued certificates cannot exceed two years in validity" + ) - if options.get('validity_end'): + if options.get("validity_end"): # VeriSign (Symantec) only accepts strictly smaller than 2 year end date - if options.get('validity_end') < arrow.utcnow().replace(years=2).replace(days=-1): + if options.get("validity_end") < arrow.utcnow().replace(years=2).replace( + days=-1 + ): period = get_default_issuance(options) - data['specificEndDate'] = options['validity_end'].format("MM/DD/YYYY") - data['validityPeriod'] = period + data["specificEndDate"] = options["validity_end"].format("MM/DD/YYYY") + data["validityPeriod"] = period else: # allowing Symantec website setting the end date, given the validity period - data['validityPeriod'] = str(get_default_issuance(options)) - options.pop('validity_end', None) + data["validityPeriod"] = str(get_default_issuance(options)) + options.pop("validity_end", None) - elif options.get('validity_years'): - if options['validity_years'] in [1, 2]: - data['validityPeriod'] = str(options['validity_years']) + 'Y' + elif options.get("validity_years"): + if options["validity_years"] in [1, 2]: + data["validityPeriod"] = str(options["validity_years"]) + "Y" else: - raise Exception("Verisign issued certificates cannot exceed two years in validity") + raise Exception( + "Verisign issued certificates cannot exceed two years in validity" + ) return data @@ -143,12 +149,14 @@ def get_default_issuance(options): """ now = arrow.utcnow() - if options['validity_end'] < now.replace(years=+1): - validity_period = '1Y' - elif options['validity_end'] < now.replace(years=+2): - validity_period = '2Y' + if options["validity_end"] < now.replace(years=+1): + validity_period = "1Y" + elif options["validity_end"] < now.replace(years=+2): + validity_period = "2Y" else: - raise Exception("Verisign issued certificates cannot exceed two years in validity") + raise Exception( + "Verisign issued certificates cannot exceed two years in validity" + ) return validity_period @@ -161,27 +169,27 @@ def handle_response(content): """ d = xmltodict.parse(content) global VERISIGN_ERRORS - if d.get('Error'): - status_code = d['Error']['StatusCode'] - elif d.get('Response'): - status_code = d['Response']['StatusCode'] + if d.get("Error"): + status_code = d["Error"]["StatusCode"] + elif d.get("Response"): + status_code = d["Response"]["StatusCode"] if status_code in VERISIGN_ERRORS.keys(): raise Exception(VERISIGN_ERRORS[status_code]) return d class VerisignIssuerPlugin(IssuerPlugin): - title = 'Verisign' - slug = 'verisign-issuer' - description = 'Enables the creation of certificates by the VICE2.0 verisign API.' + title = "Verisign" + slug = "verisign-issuer" + description = "Enables the creation of certificates by the VICE2.0 verisign API." version = verisign.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): self.session = requests.Session() - self.session.cert = current_app.config.get('VERISIGN_PEM_PATH') + self.session.cert = current_app.config.get("VERISIGN_PEM_PATH") self.session.hooks = dict(response=log_status_code) super(VerisignIssuerPlugin, self).__init__(*args, **kwargs) @@ -193,23 +201,31 @@ class VerisignIssuerPlugin(IssuerPlugin): :param issuer_options: :return: :raise Exception: """ - url = current_app.config.get("VERISIGN_URL") + '/rest/services/enroll' + url = current_app.config.get("VERISIGN_URL") + "/rest/services/enroll" data = process_options(issuer_options) - data['csr'] = csr + data["csr"] = csr - current_app.logger.info("Requesting a new verisign certificate: {0}".format(data)) + current_app.logger.info( + "Requesting a new verisign certificate: {0}".format(data) + ) response = self.session.post(url, data=data) try: - cert = handle_response(response.content)['Response']['Certificate'] + cert = handle_response(response.content)["Response"]["Certificate"] except KeyError: - metrics.send('verisign_create_certificate_error', 'counter', 1, - metric_tags={"common_name": issuer_options.get("common_name", "")}) - sentry.captureException(extra={"common_name": issuer_options.get("common_name", "")}) + metrics.send( + "verisign_create_certificate_error", + "counter", + 1, + metric_tags={"common_name": issuer_options.get("common_name", "")}, + ) + sentry.captureException( + extra={"common_name": issuer_options.get("common_name", "")} + ) raise Exception(f"Error with Verisign: {response.content}") # TODO add external id - return cert, current_app.config.get('VERISIGN_INTERMEDIATE'), None + return cert, current_app.config.get("VERISIGN_INTERMEDIATE"), None @staticmethod def create_authority(options): @@ -220,8 +236,8 @@ class VerisignIssuerPlugin(IssuerPlugin): :param options: :return: """ - role = {'username': '', 'password': '', 'name': 'verisign'} - return current_app.config.get('VERISIGN_ROOT'), "", [role] + role = {"username": "", "password": "", "name": "verisign"} + return current_app.config.get("VERISIGN_ROOT"), "", [role] def get_available_units(self): """ @@ -230,9 +246,11 @@ class VerisignIssuerPlugin(IssuerPlugin): :return: """ - url = current_app.config.get("VERISIGN_URL") + '/rest/services/getTokens' - response = self.session.post(url, headers={'content-type': 'application/x-www-form-urlencoded'}) - return handle_response(response.content)['Response']['Order'] + url = current_app.config.get("VERISIGN_URL") + "/rest/services/getTokens" + response = self.session.post( + url, headers={"content-type": "application/x-www-form-urlencoded"} + ) + return handle_response(response.content)["Response"]["Order"] def clear_pending_certificates(self): """ @@ -240,52 +258,54 @@ class VerisignIssuerPlugin(IssuerPlugin): :return: """ - url = current_app.config.get('VERISIGN_URL') + '/reportingws' + url = current_app.config.get("VERISIGN_URL") + "/reportingws" end = arrow.now() start = end.replace(days=-7) data = { - 'reportType': 'detail', - 'certProductType': 'Server', - 'certStatus': 'Pending', - 'startDate': start.format("MM/DD/YYYY"), - 'endDate': end.format("MM/DD/YYYY") + "reportType": "detail", + "certProductType": "Server", + "certStatus": "Pending", + "startDate": start.format("MM/DD/YYYY"), + "endDate": end.format("MM/DD/YYYY"), } response = self.session.post(url, data=data) - url = current_app.config.get('VERISIGN_URL') + '/rest/services/reject' - for order_id in response.json()['orderNumber']: - response = self.session.get(url, params={'transaction_id': order_id}) + url = current_app.config.get("VERISIGN_URL") + "/rest/services/reject" + for order_id in response.json()["orderNumber"]: + response = self.session.get(url, params={"transaction_id": order_id}) if response.status_code == 200: print("Rejecting certificate. TransactionId: {}".format(order_id)) class VerisignSourcePlugin(SourcePlugin): - title = 'Verisign' - slug = 'verisign-source' - description = 'Allows for the polling of issued certificates from the VICE2.0 verisign API.' + title = "Verisign" + slug = "verisign-source" + description = ( + "Allows for the polling of issued certificates from the VICE2.0 verisign API." + ) version = verisign.VERSION - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): self.session = requests.Session() - self.session.cert = current_app.config.get('VERISIGN_PEM_PATH') + self.session.cert = current_app.config.get("VERISIGN_PEM_PATH") super(VerisignSourcePlugin, self).__init__(*args, **kwargs) def get_certificates(self): - url = current_app.config.get('VERISIGN_URL') + '/reportingws' + url = current_app.config.get("VERISIGN_URL") + "/reportingws" end = arrow.now() start = end.replace(years=-5) data = { - 'reportType': 'detail', - 'startDate': start.format("MM/DD/YYYY"), - 'endDate': end.format("MM/DD/YYYY"), - 'structuredRecord': 'Y', - 'certStatus': 'Valid', + "reportType": "detail", + "startDate": start.format("MM/DD/YYYY"), + "endDate": end.format("MM/DD/YYYY"), + "structuredRecord": "Y", + "certStatus": "Valid", } current_app.logger.debug(data) response = self.session.post(url, data=data) diff --git a/lemur/plugins/lemur_verisign/tests/test_verisign.py b/lemur/plugins/lemur_verisign/tests/test_verisign.py index 8c4f1d81..42c528e8 100644 --- a/lemur/plugins/lemur_verisign/tests/test_verisign.py +++ b/lemur/plugins/lemur_verisign/tests/test_verisign.py @@ -1,4 +1,4 @@ - def test_get_certificates(app): from lemur.plugins.base import plugins - p = plugins.get('verisign-issuer') + + p = plugins.get("verisign-issuer") diff --git a/lemur/plugins/utils.py b/lemur/plugins/utils.py index e057d071..19655519 100644 --- a/lemur/plugins/utils.py +++ b/lemur/plugins/utils.py @@ -17,8 +17,8 @@ def get_plugin_option(name, options): :return: """ for o in options: - if o.get('name') == name: - return o.get('value', o.get('default')) + if o.get("name") == name: + return o.get("value", o.get("default")) def set_plugin_option(name, value, options): @@ -27,5 +27,5 @@ def set_plugin_option(name, value, options): :param options: """ for o in options: - if o.get('name') == name: - o.update({'value': value}) + if o.get("name") == name: + o.update({"value": value}) diff --git a/lemur/plugins/views.py b/lemur/plugins/views.py index dbdfccab..605b234a 100644 --- a/lemur/plugins/views.py +++ b/lemur/plugins/views.py @@ -15,12 +15,13 @@ from lemur.schemas import plugins_output_schema, plugin_output_schema from lemur.common.schema import validate_schema from lemur.plugins.base import plugins -mod = Blueprint('plugins', __name__) +mod = Blueprint("plugins", __name__) api = Api(mod) class PluginsList(AuthenticatedResource): """ Defines the 'plugins' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(PluginsList, self).__init__() @@ -69,17 +70,18 @@ class PluginsList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - self.reqparse.add_argument('type', type=str, location='args') + self.reqparse.add_argument("type", type=str, location="args") args = self.reqparse.parse_args() - if args['type']: - return list(plugins.all(plugin_type=args['type'])) + if args["type"]: + return list(plugins.all(plugin_type=args["type"])) return list(plugins.all()) class Plugins(AuthenticatedResource): """ Defines the 'plugins' endpoint """ + def __init__(self): super(Plugins, self).__init__() @@ -118,5 +120,5 @@ class Plugins(AuthenticatedResource): return plugins.get(name) -api.add_resource(PluginsList, '/plugins', endpoint='plugins') -api.add_resource(Plugins, '/plugins/', endpoint='pluginName') +api.add_resource(PluginsList, "/plugins", endpoint="plugins") +api.add_resource(Plugins, "/plugins/", endpoint="pluginName") diff --git a/lemur/policies/cli.py b/lemur/policies/cli.py index 725c1583..317f3414 100644 --- a/lemur/policies/cli.py +++ b/lemur/policies/cli.py @@ -12,8 +12,8 @@ from lemur.policies import service as policy_service manager = Manager(usage="Handles all policy related tasks.") -@manager.option('-d', '--days', dest='days', help='Number of days before expiration.') -@manager.option('-n', '--name', dest='name', help='Policy name.') +@manager.option("-d", "--days", dest="days", help="Number of days before expiration.") +@manager.option("-n", "--name", dest="name", help="Policy name.") def create(days, name): """ Create a new certificate rotation policy diff --git a/lemur/policies/models.py b/lemur/policies/models.py index 2329a347..a17d3ca1 100644 --- a/lemur/policies/models.py +++ b/lemur/policies/models.py @@ -12,10 +12,12 @@ from lemur.database import db class RotationPolicy(db.Model): - __tablename__ = 'rotation_policies' + __tablename__ = "rotation_policies" id = Column(Integer, primary_key=True) name = Column(String) days = Column(Integer) def __repr__(self): - return "RotationPolicy(days={days}, name={name})".format(days=self.days, name=self.name) + return "RotationPolicy(days={days}, name={name})".format( + days=self.days, name=self.name + ) diff --git a/lemur/policies/service.py b/lemur/policies/service.py index 10e9053b..cb43d52e 100644 --- a/lemur/policies/service.py +++ b/lemur/policies/service.py @@ -24,7 +24,7 @@ def get_by_name(policy_name): :param policy_name: :return: """ - return database.get_all(RotationPolicy, policy_name, field='name').all() + return database.get_all(RotationPolicy, policy_name, field="name").all() def delete(policy_id): diff --git a/lemur/reporting/cli.py b/lemur/reporting/cli.py index 8f797c33..c92b79cd 100644 --- a/lemur/reporting/cli.py +++ b/lemur/reporting/cli.py @@ -13,49 +13,73 @@ from lemur.reporting.service import fqdns, expiring_certificates manager = Manager(usage="Reporting related tasks.") -@manager.option('-v', '--validity', dest='validity', choices=['all', 'expired', 'valid'], default='all', help='Filter certificates by validity.') -@manager.option('-d', '--deployment', dest='deployment', choices=['all', 'deployed', 'ready'], default='all', help='Filter by deployment status.') +@manager.option( + "-v", + "--validity", + dest="validity", + choices=["all", "expired", "valid"], + default="all", + help="Filter certificates by validity.", +) +@manager.option( + "-d", + "--deployment", + dest="deployment", + choices=["all", "deployed", "ready"], + default="all", + help="Filter by deployment status.", +) def fqdn(deployment, validity): """ Generates a report in order to determine the number of FQDNs covered by Lemur issued certificates. """ - headers = ['FQDN', 'Root Domain', 'Issuer', 'Owner', 'Validity End', 'Total Length (days), Time Until Expiration (days)'] + headers = [ + "FQDN", + "Root Domain", + "Issuer", + "Owner", + "Validity End", + "Total Length (days), Time Until Expiration (days)", + ] rows = [] for cert in fqdns(validity=validity, deployment=deployment).all(): for domain in cert.domains: - rows.append([ - domain.name, - '.'.join(domain.name.split('.')[1:]), - cert.issuer, - cert.owner, - cert.not_after, - cert.validity_range.days, - cert.validity_remaining.days - ]) + rows.append( + [ + domain.name, + ".".join(domain.name.split(".")[1:]), + cert.issuer, + cert.owner, + cert.not_after, + cert.validity_range.days, + cert.validity_remaining.days, + ] + ) print(tabulate(rows, headers=headers)) -@manager.option('-ttl', '--ttl', dest='ttl', default=30, help='Days til expiration.') -@manager.option('-d', '--deployment', dest='deployment', choices=['all', 'deployed', 'ready'], default='all', help='Filter by deployment status.') +@manager.option("-ttl", "--ttl", dest="ttl", default=30, help="Days til expiration.") +@manager.option( + "-d", + "--deployment", + dest="deployment", + choices=["all", "deployed", "ready"], + default="all", + help="Filter by deployment status.", +) def expiring(ttl, deployment): """ Returns certificates expiring in the next n days. """ - headers = ['Common Name', 'Owner', 'Issuer', 'Validity End', 'Endpoint'] + headers = ["Common Name", "Owner", "Issuer", "Validity End", "Endpoint"] rows = [] for cert in expiring_certificates(ttl=ttl, deployment=deployment).all(): for endpoint in cert.endpoints: rows.append( - [ - cert.cn, - cert.owner, - cert.issuer, - cert.not_after, - endpoint.dnsname - ] + [cert.cn, cert.owner, cert.issuer, cert.not_after, endpoint.dnsname] ) print(tabulate(rows, headers=headers)) diff --git a/lemur/reporting/service.py b/lemur/reporting/service.py index 348cf2f4..77eb7b3e 100644 --- a/lemur/reporting/service.py +++ b/lemur/reporting/service.py @@ -9,10 +9,10 @@ from lemur.certificates.models import Certificate def filter_by_validity(query, validity=None): - if validity == 'expired': + if validity == "expired": query = query.filter(Certificate.expired == True) # noqa - elif validity == 'valid': + elif validity == "valid": query = query.filter(Certificate.expired == False) # noqa return query @@ -33,10 +33,10 @@ def filter_by_issuer(query, issuer=None): def filter_by_deployment(query, deployment=None): - if deployment == 'deployed': + if deployment == "deployed": query = query.filter(Certificate.endpoints.any()) - elif deployment == 'ready': + elif deployment == "ready": query = query.filter(not_(Certificate.endpoints.any())) return query @@ -55,8 +55,8 @@ def fqdns(**kwargs): :return: """ query = database.session_query(Certificate) - query = filter_by_deployment(query, deployment=kwargs.get('deployed')) - query = filter_by_validity(query, validity=kwargs.get('validity')) + query = filter_by_deployment(query, deployment=kwargs.get("deployed")) + query = filter_by_validity(query, validity=kwargs.get("validity")) return query @@ -65,13 +65,13 @@ def expiring_certificates(**kwargs): Returns an Expiring report. :return: """ - ttl = kwargs.get('ttl', 30) + ttl = kwargs.get("ttl", 30) now = arrow.utcnow() validity_end = now + timedelta(days=ttl) query = database.session_query(Certificate) - query = filter_by_deployment(query, deployment=kwargs.get('deployed')) - query = filter_by_validity(query, validity='valid') + query = filter_by_deployment(query, deployment=kwargs.get("deployed")) + query = filter_by_validity(query, validity="valid") query = filter_by_validity_end(query, validity_end=validity_end) return query diff --git a/lemur/roles/models.py b/lemur/roles/models.py index 85bf1bf1..91b5d58c 100644 --- a/lemur/roles/models.py +++ b/lemur/roles/models.py @@ -14,26 +14,42 @@ from sqlalchemy import Boolean, Column, Integer, String, Text, ForeignKey from lemur.database import db from lemur.utils import Vault -from lemur.models import roles_users, roles_authorities, roles_certificates, \ - pending_cert_role_associations +from lemur.models import ( + roles_users, + roles_authorities, + roles_certificates, + pending_cert_role_associations, +) class Role(db.Model): - __tablename__ = 'roles' + __tablename__ = "roles" id = Column(Integer, primary_key=True) name = Column(String(128), unique=True) username = Column(String(128)) password = Column(Vault) description = Column(Text) - authority_id = Column(Integer, ForeignKey('authorities.id')) - authorities = relationship("Authority", secondary=roles_authorities, passive_deletes=True, backref="role", cascade='all,delete') - user_id = Column(Integer, ForeignKey('users.id')) + authority_id = Column(Integer, ForeignKey("authorities.id")) + authorities = relationship( + "Authority", + secondary=roles_authorities, + passive_deletes=True, + backref="role", + cascade="all,delete", + ) + user_id = Column(Integer, ForeignKey("users.id")) third_party = Column(Boolean) - users = relationship("User", secondary=roles_users, passive_deletes=True, backref="role") - certificates = relationship("Certificate", secondary=roles_certificates, backref="role") - pending_certificates = relationship("PendingCertificate", secondary=pending_cert_role_associations, backref="role") + users = relationship( + "User", secondary=roles_users, passive_deletes=True, backref="role" + ) + certificates = relationship( + "Certificate", secondary=roles_certificates, backref="role" + ) + pending_certificates = relationship( + "PendingCertificate", secondary=pending_cert_role_associations, backref="role" + ) - sensitive_fields = ('password',) + sensitive_fields = ("password",) def __repr__(self): return "Role(name={name})".format(name=self.name) diff --git a/lemur/roles/service.py b/lemur/roles/service.py index bbeef1ce..51597d6e 100644 --- a/lemur/roles/service.py +++ b/lemur/roles/service.py @@ -47,7 +47,9 @@ def set_third_party(role_id, third_party_status=False): return role -def create(name, password=None, description=None, username=None, users=None, third_party=False): +def create( + name, password=None, description=None, username=None, users=None, third_party=False +): """ Create a new role @@ -58,7 +60,13 @@ def create(name, password=None, description=None, username=None, users=None, thi :param password: :return: """ - role = Role(name=name, description=description, username=username, password=password, third_party=third_party) + role = Role( + name=name, + description=description, + username=username, + password=password, + third_party=third_party, + ) if users: role.users = users @@ -83,7 +91,7 @@ def get_by_name(role_name): :param role_name: :return: """ - return database.get(Role, role_name, field='name') + return database.get(Role, role_name, field="name") def delete(role_id): @@ -105,9 +113,9 @@ def render(args): :return: """ query = database.session_query(Role) - filt = args.pop('filter') - user_id = args.pop('user_id', None) - authority_id = args.pop('authority_id', None) + filt = args.pop("filter") + user_id = args.pop("user_id", None) + authority_id = args.pop("authority_id", None) if user_id: query = query.filter(Role.users.any(User.id == user_id)) @@ -116,7 +124,7 @@ def render(args): query = query.filter(Role.authority_id == authority_id) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Role, terms) return database.sort_and_page(query, Role, args) diff --git a/lemur/roles/views.py b/lemur/roles/views.py index a635fdba..1e12f24b 100644 --- a/lemur/roles/views.py +++ b/lemur/roles/views.py @@ -17,15 +17,20 @@ from lemur.auth.permissions import RoleMemberPermission, admin_permission from lemur.common.utils import paginated_parser from lemur.common.schema import validate_schema -from lemur.roles.schemas import role_input_schema, role_output_schema, roles_output_schema +from lemur.roles.schemas import ( + role_input_schema, + role_output_schema, + roles_output_schema, +) -mod = Blueprint('roles', __name__) +mod = Blueprint("roles", __name__) api = Api(mod) class RolesList(AuthenticatedResource): """ Defines the 'roles' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(RolesList, self).__init__() @@ -79,11 +84,11 @@ class RolesList(AuthenticatedResource): :statuscode 403: unauthenticated """ parser = paginated_parser.copy() - parser.add_argument('owner', type=str, location='args') - parser.add_argument('id', type=str, location='args') + parser.add_argument("owner", type=str, location="args") + parser.add_argument("id", type=str, location="args") args = parser.parse_args() - args['user'] = g.current_user + args["user"] = g.current_user return service.render(args) @admin_permission.require(http_exception=403) @@ -135,8 +140,13 @@ class RolesList(AuthenticatedResource): :statuscode 200: no error :statuscode 403: unauthenticated """ - return service.create(data['name'], data.get('password'), data.get('description'), data.get('username'), - data.get('users')) + return service.create( + data["name"], + data.get("password"), + data.get("description"), + data.get("username"), + data.get("users"), + ) class RoleViewCredentials(AuthenticatedResource): @@ -177,11 +187,18 @@ class RoleViewCredentials(AuthenticatedResource): permission = RoleMemberPermission(role_id) if permission.can(): role = service.get(role_id) - response = make_response(jsonify(username=role.username, password=role.password), 200) - response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store' - response.headers['pragma'] = 'no-cache' + response = make_response( + jsonify(username=role.username, password=role.password), 200 + ) + response.headers["cache-control"] = "private, max-age=0, no-cache, no-store" + response.headers["pragma"] = "no-cache" return response - return dict(message='You are not authorized to view the credentials for this role.'), 403 + return ( + dict( + message="You are not authorized to view the credentials for this role." + ), + 403, + ) class Roles(AuthenticatedResource): @@ -227,7 +244,12 @@ class Roles(AuthenticatedResource): if permission.can(): return service.get(role_id) - return dict(message="You are not allowed to view a role which you are not a member of."), 403 + return ( + dict( + message="You are not allowed to view a role which you are not a member of." + ), + 403, + ) @validate_schema(role_input_schema, role_output_schema) def put(self, role_id, data=None): @@ -269,8 +291,10 @@ class Roles(AuthenticatedResource): """ permission = RoleMemberPermission(role_id) if permission.can(): - return service.update(role_id, data['name'], data.get('description'), data.get('users')) - return dict(message='You are not authorized to modify this role.'), 403 + return service.update( + role_id, data["name"], data.get("description"), data.get("users") + ) + return dict(message="You are not authorized to modify this role."), 403 @admin_permission.require(http_exception=403) def delete(self, role_id): @@ -304,11 +328,12 @@ class Roles(AuthenticatedResource): :statuscode 403: unauthenticated """ service.delete(role_id) - return {'message': 'ok'} + return {"message": "ok"} class UserRolesList(AuthenticatedResource): """ Defines the 'roles' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(UserRolesList, self).__init__() @@ -362,12 +387,13 @@ class UserRolesList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['user_id'] = user_id + args["user_id"] = user_id return service.render(args) class AuthorityRolesList(AuthenticatedResource): """ Defines the 'roles' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(AuthorityRolesList, self).__init__() @@ -421,12 +447,18 @@ class AuthorityRolesList(AuthenticatedResource): """ parser = paginated_parser.copy() args = parser.parse_args() - args['authority_id'] = authority_id + args["authority_id"] = authority_id return service.render(args) -api.add_resource(RolesList, '/roles', endpoint='roles') -api.add_resource(Roles, '/roles/', endpoint='role') -api.add_resource(RoleViewCredentials, '/roles//credentials', endpoint='roleCredentials`') -api.add_resource(AuthorityRolesList, '/authorities//roles', endpoint='authorityRoles') -api.add_resource(UserRolesList, '/users//roles', endpoint='userRoles') +api.add_resource(RolesList, "/roles", endpoint="roles") +api.add_resource(Roles, "/roles/", endpoint="role") +api.add_resource( + RoleViewCredentials, "/roles//credentials", endpoint="roleCredentials`" +) +api.add_resource( + AuthorityRolesList, + "/authorities//roles", + endpoint="authorityRoles", +) +api.add_resource(UserRolesList, "/users//roles", endpoint="userRoles") diff --git a/lemur/schemas.py b/lemur/schemas.py index ffdfe66f..e7b0fd64 100644 --- a/lemur/schemas.py +++ b/lemur/schemas.py @@ -14,7 +14,12 @@ from marshmallow.exceptions import ValidationError from lemur.common import validators from lemur.common.schema import LemurSchema, LemurInputSchema, LemurOutputSchema -from lemur.common.fields import KeyUsageExtension, ExtendedKeyUsageExtension, BasicConstraintsExtension, SubjectAlternativeNameExtension +from lemur.common.fields import ( + KeyUsageExtension, + ExtendedKeyUsageExtension, + BasicConstraintsExtension, + SubjectAlternativeNameExtension, +) from lemur.plugins import plugins from lemur.plugins.utils import get_plugin_option @@ -34,40 +39,42 @@ def validate_options(options): :param options: :return: """ - interval = get_plugin_option('interval', options) - unit = get_plugin_option('unit', options) + interval = get_plugin_option("interval", options) + unit = get_plugin_option("unit", options) if not interval and not unit: return - if unit == 'month': + if unit == "month": interval *= 30 - elif unit == 'week': + elif unit == "week": interval *= 7 if interval > 90: - raise ValidationError('Notification cannot be more than 90 days into the future.') + raise ValidationError( + "Notification cannot be more than 90 days into the future." + ) def get_object_attribute(data, many=False): if many: - ids = [d.get('id') for d in data] - names = [d.get('name') for d in data] + ids = [d.get("id") for d in data] + names = [d.get("name") for d in data] if None in ids: if None in names: - raise ValidationError('Associated object require a name or id.') + raise ValidationError("Associated object require a name or id.") else: - return 'name' - return 'id' + return "name" + return "id" else: - if data.get('id'): - return 'id' - elif data.get('name'): - return 'name' + if data.get("id"): + return "id" + elif data.get("name"): + return "name" else: - raise ValidationError('Associated object require a name or id.') + raise ValidationError("Associated object require a name or id.") def fetch_objects(model, data, many=False): @@ -80,10 +87,11 @@ def fetch_objects(model, data, many=False): diff = set(values).symmetric_difference(set(found)) if diff: - raise ValidationError('Unable to locate {model} with {attr} {diff}'.format( - model=model, - attr=attr, - diff=",".join(list(diff)))) + raise ValidationError( + "Unable to locate {model} with {attr} {diff}".format( + model=model, attr=attr, diff=",".join(list(diff)) + ) + ) return items @@ -91,10 +99,11 @@ def fetch_objects(model, data, many=False): try: return model.query.filter(getattr(model, attr) == data[attr]).one() except NoResultFound: - raise ValidationError('Unable to find {model} with {attr}: {data}'.format( - model=model, - attr=attr, - data=data[attr])) + raise ValidationError( + "Unable to find {model} with {attr}: {data}".format( + model=model, attr=attr, data=data[attr] + ) + ) class AssociatedAuthoritySchema(LemurInputSchema): @@ -178,17 +187,19 @@ class PluginInputSchema(LemurInputSchema): @post_load def get_object(self, data, many=False): try: - data['plugin_object'] = plugins.get(data['slug']) + data["plugin_object"] = plugins.get(data["slug"]) # parse any sub-plugins - for option in data.get('plugin_options', []): - if 'plugin' in option.get('type', []): - sub_data, errors = PluginInputSchema().load(option['value']) - option['value'] = sub_data + for option in data.get("plugin_options", []): + if "plugin" in option.get("type", []): + sub_data, errors = PluginInputSchema().load(option["value"]) + option["value"] = sub_data return data except Exception as e: - raise ValidationError('Unable to find plugin. Slug: {0} Reason: {1}'.format(data['slug'], e)) + raise ValidationError( + "Unable to find plugin. Slug: {0} Reason: {1}".format(data["slug"], e) + ) class PluginOutputSchema(LemurOutputSchema): @@ -196,7 +207,7 @@ class PluginOutputSchema(LemurOutputSchema): label = fields.String() description = fields.String() active = fields.Boolean() - options = fields.List(fields.Dict(), dump_to='pluginOptions') + options = fields.List(fields.Dict(), dump_to="pluginOptions") slug = fields.String() title = fields.String() @@ -227,7 +238,7 @@ class CertificateInfoAccessSchema(BaseExtensionSchema): @post_dump def handle_keys(self, data): - return {'includeAIA': data['include_aia']} + return {"includeAIA": data["include_aia"]} class CRLDistributionPointsSchema(BaseExtensionSchema): @@ -235,7 +246,7 @@ class CRLDistributionPointsSchema(BaseExtensionSchema): @post_dump def handle_keys(self, data): - return {'includeCRLDP': data['include_crl_dp']} + return {"includeCRLDP": data["include_crl_dp"]} class SubjectKeyIdentifierSchema(BaseExtensionSchema): @@ -243,7 +254,7 @@ class SubjectKeyIdentifierSchema(BaseExtensionSchema): @post_dump def handle_keys(self, data): - return {'includeSKI': data['include_ski']} + return {"includeSKI": data["include_ski"]} class CustomOIDSchema(BaseExtensionSchema): @@ -258,14 +269,18 @@ class NamesSchema(BaseExtensionSchema): class ExtensionSchema(BaseExtensionSchema): - basic_constraints = BasicConstraintsExtension() # some devices balk on default basic constraints + basic_constraints = ( + BasicConstraintsExtension() + ) # some devices balk on default basic constraints key_usage = KeyUsageExtension() extended_key_usage = ExtendedKeyUsageExtension() subject_key_identifier = fields.Nested(SubjectKeyIdentifierSchema) sub_alt_names = fields.Nested(NamesSchema) authority_key_identifier = fields.Nested(AuthorityKeyIdentifierSchema) certificate_info_access = fields.Nested(CertificateInfoAccessSchema) - crl_distribution_points = fields.Nested(CRLDistributionPointsSchema, dump_to='cRL_distribution_points') + crl_distribution_points = fields.Nested( + CRLDistributionPointsSchema, dump_to="cRL_distribution_points" + ) # FIXME: Convert custom OIDs to a custom field in fields.py like other Extensions # FIXME: Remove support in UI for Critical custom extensions https://github.com/Netflix/lemur/issues/665 custom = fields.List(fields.Nested(CustomOIDSchema)) diff --git a/lemur/sources/cli.py b/lemur/sources/cli.py index 0ab8c9f8..c41a1cf7 100644 --- a/lemur/sources/cli.py +++ b/lemur/sources/cli.py @@ -35,24 +35,32 @@ def validate_sources(source_strings): table.append([source.label, source.active, source.description]) print("No source specified choose from below:") - print(tabulate(table, headers=['Label', 'Active', 'Description'])) + print(tabulate(table, headers=["Label", "Active", "Description"])) sys.exit(1) - if 'all' in source_strings: + if "all" in source_strings: sources = source_service.get_all() else: for source_str in source_strings: source = source_service.get_by_label(source_str) if not source: - print("Unable to find specified source with label: {0}".format(source_str)) + print( + "Unable to find specified source with label: {0}".format(source_str) + ) sys.exit(1) sources.append(source) return sources -@manager.option('-s', '--sources', dest='source_strings', action='append', help='Sources to operate on.') +@manager.option( + "-s", + "--sources", + dest="source_strings", + action="append", + help="Sources to operate on.", +) def sync(source_strings): sources = validate_sources(source_strings) for source in sources: @@ -61,26 +69,23 @@ def sync(source_strings): start_time = time.time() print("[+] Staring to sync source: {label}!\n".format(label=source.label)) - user = user_service.get_by_username('lemur') + user = user_service.get_by_username("lemur") try: data = source_service.sync(source, user) print( "[+] Certificates: New: {new} Updated: {updated}".format( - new=data['certificates'][0], - updated=data['certificates'][1] + new=data["certificates"][0], updated=data["certificates"][1] ) ) print( "[+] Endpoints: New: {new} Updated: {updated}".format( - new=data['endpoints'][0], - updated=data['endpoints'][1] + new=data["endpoints"][0], updated=data["endpoints"][1] ) ) print( "[+] Finished syncing source: {label}. Run Time: {time}".format( - label=source.label, - time=(time.time() - start_time) + label=source.label, time=(time.time() - start_time) ) ) status = SUCCESS_METRIC_STATUS @@ -88,27 +93,50 @@ def sync(source_strings): except Exception as e: current_app.logger.exception(e) - print( - "[X] Failed syncing source {label}!\n".format(label=source.label) - ) + print("[X] Failed syncing source {label}!\n".format(label=source.label)) sentry.captureException() - metrics.send('source_sync_fail', 'counter', 1, metric_tags={'source': source.label, 'status': status}) + metrics.send( + "source_sync_fail", + "counter", + 1, + metric_tags={"source": source.label, "status": status}, + ) - metrics.send('source_sync', 'counter', 1, metric_tags={'source': source.label, 'status': status}) + metrics.send( + "source_sync", + "counter", + 1, + metric_tags={"source": source.label, "status": status}, + ) -@manager.option('-s', '--sources', dest='source_strings', action='append', help='Sources to operate on.') -@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') +@manager.option( + "-s", + "--sources", + dest="source_strings", + action="append", + help="Sources to operate on.", +) +@manager.option( + "-c", + "--commit", + dest="commit", + action="store_true", + default=False, + help="Persist changes.", +) def clean(source_strings, commit): sources = validate_sources(source_strings) for source in sources: s = plugins.get(source.plugin_name) - if not hasattr(s, 'clean'): - print("Cannot clean source: {0}, source plugin does not implement 'clean()'".format( - source.label - )) + if not hasattr(s, "clean"): + print( + "Cannot clean source: {0}, source plugin does not implement 'clean()'".format( + source.label + ) + ) continue start_time = time.time() @@ -128,19 +156,23 @@ def clean(source_strings, commit): current_app.logger.exception(e) sentry.captureException() - metrics.send('clean', 'counter', 1, metric_tags={'source': source.label, 'status': status}) + metrics.send( + "clean", + "counter", + 1, + metric_tags={"source": source.label, "status": status}, + ) - current_app.logger.warning("Removed {0} from source {1} during cleaning".format( - certificate.name, - source.label - )) + current_app.logger.warning( + "Removed {0} from source {1} during cleaning".format( + certificate.name, source.label + ) + ) cleaned += 1 print( "[+] Finished cleaning source: {label}. Removed {cleaned} certificates from source. Run Time: {time}\n".format( - label=source.label, - time=(time.time() - start_time), - cleaned=cleaned + label=source.label, time=(time.time() - start_time), cleaned=cleaned ) ) diff --git a/lemur/sources/models.py b/lemur/sources/models.py index 071688d1..78dbb213 100644 --- a/lemur/sources/models.py +++ b/lemur/sources/models.py @@ -15,7 +15,7 @@ from sqlalchemy_utils import ArrowType class Source(db.Model): - __tablename__ = 'sources' + __tablename__ = "sources" id = Column(Integer, primary_key=True) label = Column(String(32), unique=True) options = Column(JSONType) diff --git a/lemur/sources/schemas.py b/lemur/sources/schemas.py index 028fdb32..5531293f 100644 --- a/lemur/sources/schemas.py +++ b/lemur/sources/schemas.py @@ -30,7 +30,7 @@ class SourceOutputSchema(LemurOutputSchema): @post_dump def fill_object(self, data): if data: - data['plugin']['pluginOptions'] = data['options'] + data["plugin"]["pluginOptions"] = data["options"] return data diff --git a/lemur/sources/service.py b/lemur/sources/service.py index a4d373ab..ec988623 100644 --- a/lemur/sources/service.py +++ b/lemur/sources/service.py @@ -29,9 +29,11 @@ def certificate_create(certificate, source): data, errors = CertificateUploadInputSchema().load(certificate) if errors: - raise Exception("Unable to import certificate: {reasons}".format(reasons=errors)) + raise Exception( + "Unable to import certificate: {reasons}".format(reasons=errors) + ) - data['creator'] = certificate['creator'] + data["creator"] = certificate["creator"] cert = certificate_service.import_certificate(**data) cert.description = "This certificate was automatically discovered by Lemur" @@ -70,33 +72,44 @@ def sync_endpoints(source): try: endpoints = s.get_endpoints(source.options) except NotImplementedError: - current_app.logger.warning("Unable to sync endpoints for source {0} plugin has not implemented 'get_endpoints'".format(source.label)) + current_app.logger.warning( + "Unable to sync endpoints for source {0} plugin has not implemented 'get_endpoints'".format( + source.label + ) + ) return new, updated for endpoint in endpoints: - exists = endpoint_service.get_by_dnsname_and_port(endpoint['dnsname'], endpoint['port']) + exists = endpoint_service.get_by_dnsname_and_port( + endpoint["dnsname"], endpoint["port"] + ) - certificate_name = endpoint.pop('certificate_name') + certificate_name = endpoint.pop("certificate_name") - endpoint['certificate'] = certificate_service.get_by_name(certificate_name) + endpoint["certificate"] = certificate_service.get_by_name(certificate_name) - if not endpoint['certificate']: + if not endpoint["certificate"]: current_app.logger.error( - "Certificate Not Found. Name: {0} Endpoint: {1}".format(certificate_name, endpoint['name'])) + "Certificate Not Found. Name: {0} Endpoint: {1}".format( + certificate_name, endpoint["name"] + ) + ) continue - policy = endpoint.pop('policy') + policy = endpoint.pop("policy") policy_ciphers = [] - for nc in policy['ciphers']: + for nc in policy["ciphers"]: policy_ciphers.append(endpoint_service.get_or_create_cipher(name=nc)) - policy['ciphers'] = policy_ciphers - endpoint['policy'] = endpoint_service.get_or_create_policy(**policy) - endpoint['source'] = source + policy["ciphers"] = policy_ciphers + endpoint["policy"] = endpoint_service.get_or_create_policy(**policy) + endpoint["source"] = source if not exists: - current_app.logger.debug("Endpoint Created: Name: {name}".format(name=endpoint['name'])) + current_app.logger.debug( + "Endpoint Created: Name: {name}".format(name=endpoint["name"]) + ) endpoint_service.create(**endpoint) new += 1 @@ -119,27 +132,27 @@ def sync_certificates(source, user): for certificate in certificates: exists = False - if certificate.get('search', None): - conditions = certificate.pop('search') + if certificate.get("search", None): + conditions = certificate.pop("search") exists = certificate_service.get_by_attributes(conditions) - if not exists and certificate.get('name'): - result = certificate_service.get_by_name(certificate['name']) + if not exists and certificate.get("name"): + result = certificate_service.get_by_name(certificate["name"]) if result: exists = [result] - if not exists and certificate.get('serial'): - exists = certificate_service.get_by_serial(certificate['serial']) + if not exists and certificate.get("serial"): + exists = certificate_service.get_by_serial(certificate["serial"]) if not exists: - cert = parse_certificate(certificate['body']) + cert = parse_certificate(certificate["body"]) matching_serials = certificate_service.get_by_serial(serial(cert)) exists = find_matching_certificates_by_hash(cert, matching_serials) - if not certificate.get('owner'): - certificate['owner'] = user.email + if not certificate.get("owner"): + certificate["owner"] = user.email - certificate['creator'] = user + certificate["creator"] = user exists = [x for x in exists if x] if not exists: @@ -148,10 +161,10 @@ def sync_certificates(source, user): else: for e in exists: - if certificate.get('external_id'): - e.external_id = certificate['external_id'] - if certificate.get('authority_id'): - e.authority_id = certificate['authority_id'] + if certificate.get("external_id"): + e.external_id = certificate["external_id"] + if certificate.get("authority_id"): + e.authority_id = certificate["authority_id"] certificate_update(e, source) updated += 1 @@ -165,7 +178,10 @@ def sync(source, user): source.last_run = arrow.utcnow() database.update(source) - return {'endpoints': (new_endpoints, updated_endpoints), 'certificates': (new_certs, updated_certs)} + return { + "endpoints": (new_endpoints, updated_endpoints), + "certificates": (new_certs, updated_certs), + } def create(label, plugin_name, options, description=None): @@ -179,7 +195,9 @@ def create(label, plugin_name, options, description=None): :rtype : Source :return: New source """ - source = Source(label=label, options=options, plugin_name=plugin_name, description=description) + source = Source( + label=label, options=options, plugin_name=plugin_name, description=description + ) return database.create(source) @@ -230,7 +248,7 @@ def get_by_label(label): :param label: :return: """ - return database.get(Source, label, field='label') + return database.get(Source, label, field="label") def get_all(): @@ -244,8 +262,8 @@ def get_all(): def render(args): - filt = args.pop('filter') - certificate_id = args.pop('certificate_id', None) + filt = args.pop("filter") + certificate_id = args.pop("certificate_id", None) if certificate_id: query = database.session_query(Source).join(Certificate, Source.certificate) @@ -254,7 +272,7 @@ def render(args): query = database.session_query(Source) if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, Source, terms) return database.sort_and_page(query, Source, args) @@ -272,21 +290,27 @@ def add_aws_destination_to_sources(dst): src_accounts = set() sources = get_all() for src in sources: - src_accounts.add(get_plugin_option('accountNumber', src.options)) + src_accounts.add(get_plugin_option("accountNumber", src.options)) # check destination_plugin = plugins.get(dst.plugin_name) - account_number = get_plugin_option('accountNumber', dst.options) - if account_number is not None and \ - destination_plugin.sync_as_source is not None and \ - destination_plugin.sync_as_source and \ - (account_number not in src_accounts): - src_options = copy.deepcopy(plugins.get(destination_plugin.sync_as_source_name).options) - set_plugin_option('accountNumber', account_number, src_options) - create(label=dst.label, - plugin_name=destination_plugin.sync_as_source_name, - options=src_options, - description=dst.description) + account_number = get_plugin_option("accountNumber", dst.options) + if ( + account_number is not None + and destination_plugin.sync_as_source is not None + and destination_plugin.sync_as_source + and (account_number not in src_accounts) + ): + src_options = copy.deepcopy( + plugins.get(destination_plugin.sync_as_source_name).options + ) + set_plugin_option("accountNumber", account_number, src_options) + create( + label=dst.label, + plugin_name=destination_plugin.sync_as_source_name, + options=src_options, + description=dst.description, + ) return True return False diff --git a/lemur/sources/views.py b/lemur/sources/views.py index abf68109..b74c4d80 100644 --- a/lemur/sources/views.py +++ b/lemur/sources/views.py @@ -11,19 +11,24 @@ from flask_restful import Api, reqparse from lemur.sources import service from lemur.common.schema import validate_schema -from lemur.sources.schemas import source_input_schema, source_output_schema, sources_output_schema +from lemur.sources.schemas import ( + source_input_schema, + source_output_schema, + sources_output_schema, +) from lemur.auth.service import AuthenticatedResource from lemur.auth.permissions import admin_permission from lemur.common.utils import paginated_parser -mod = Blueprint('sources', __name__) +mod = Blueprint("sources", __name__) api = Api(mod) class SourcesList(AuthenticatedResource): """ Defines the 'sources' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(SourcesList, self).__init__() @@ -151,7 +156,12 @@ class SourcesList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.create(data['label'], data['plugin']['slug'], data['plugin']['plugin_options'], data['description']) + return service.create( + data["label"], + data["plugin"]["slug"], + data["plugin"]["plugin_options"], + data["description"], + ) class Sources(AuthenticatedResource): @@ -271,16 +281,22 @@ class Sources(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.update(source_id, data['label'], data['plugin']['plugin_options'], data['description']) + return service.update( + source_id, + data["label"], + data["plugin"]["plugin_options"], + data["description"], + ) @admin_permission.require(http_exception=403) def delete(self, source_id): service.delete(source_id) - return {'result': True} + return {"result": True} class CertificateSources(AuthenticatedResource): """ Defines the 'certificate/', endpoint='account') -api.add_resource(CertificateSources, '/certificates//sources', - endpoint='certificateSources') +api.add_resource(SourcesList, "/sources", endpoint="sources") +api.add_resource(Sources, "/sources/", endpoint="account") +api.add_resource( + CertificateSources, + "/certificates//sources", + endpoint="certificateSources", +) diff --git a/lemur/tests/conf.py b/lemur/tests/conf.py index 525200cf..6d0d6967 100644 --- a/lemur/tests/conf.py +++ b/lemur/tests/conf.py @@ -15,49 +15,51 @@ debug = False TESTING = True # this is the secret key used by flask session management -SECRET_KEY = 'I/dVhOZNSMZMqrFJa5tWli6VQccOGudKerq3eWPMSzQNmHHVhMAQfQ==' +SECRET_KEY = "I/dVhOZNSMZMqrFJa5tWli6VQccOGudKerq3eWPMSzQNmHHVhMAQfQ==" # You should consider storing these separately from your config -LEMUR_TOKEN_SECRET = 'test' -LEMUR_ENCRYPTION_KEYS = 'o61sBLNBSGtAckngtNrfVNd8xy8Hp9LBGDstTbMbqCY=' +LEMUR_TOKEN_SECRET = "test" +LEMUR_ENCRYPTION_KEYS = "o61sBLNBSGtAckngtNrfVNd8xy8Hp9LBGDstTbMbqCY=" # List of domain regular expressions that non-admin users can issue LEMUR_WHITELISTED_DOMAINS = [ - '^[a-zA-Z0-9-]+\.example\.com$', - '^[a-zA-Z0-9-]+\.example\.org$', - '^example\d+\.long\.com$', + "^[a-zA-Z0-9-]+\.example\.com$", + "^[a-zA-Z0-9-]+\.example\.org$", + "^example\d+\.long\.com$", ] # Mail Server # Lemur currently only supports SES for sending email, this address # needs to be verified -LEMUR_EMAIL = '' -LEMUR_SECURITY_TEAM_EMAIL = ['security@example.com'] +LEMUR_EMAIL = "" +LEMUR_SECURITY_TEAM_EMAIL = ["security@example.com"] -LEMUR_HOSTNAME = 'lemur.example.com' +LEMUR_HOSTNAME = "lemur.example.com" # Logging LOG_LEVEL = "DEBUG" LOG_FILE = "lemur.log" -LEMUR_DEFAULT_COUNTRY = 'US' -LEMUR_DEFAULT_STATE = 'California' -LEMUR_DEFAULT_LOCATION = 'Los Gatos' -LEMUR_DEFAULT_ORGANIZATION = 'Example, Inc.' -LEMUR_DEFAULT_ORGANIZATIONAL_UNIT = 'Example' +LEMUR_DEFAULT_COUNTRY = "US" +LEMUR_DEFAULT_STATE = "California" +LEMUR_DEFAULT_LOCATION = "Los Gatos" +LEMUR_DEFAULT_ORGANIZATION = "Example, Inc." +LEMUR_DEFAULT_ORGANIZATIONAL_UNIT = "Example" LEMUR_ALLOW_WEEKEND_EXPIRATION = False # Database # modify this if you are not using a local database -SQLALCHEMY_DATABASE_URI = os.getenv('SQLALCHEMY_DATABASE_URI', 'postgresql://lemur:lemur@localhost:5432/lemur') +SQLALCHEMY_DATABASE_URI = os.getenv( + "SQLALCHEMY_DATABASE_URI", "postgresql://lemur:lemur@localhost:5432/lemur" +) SQLALCHEMY_TRACK_MODIFICATIONS = False # AWS -LEMUR_INSTANCE_PROFILE = 'Lemur' +LEMUR_INSTANCE_PROFILE = "Lemur" # Issuers @@ -72,21 +74,21 @@ LEMUR_INSTANCE_PROFILE = 'Lemur' # CLOUDCA_DEFAULT_VALIDITY = 2 -DIGICERT_URL = 'mock://www.digicert.com' -DIGICERT_ORDER_TYPE = 'ssl_plus' -DIGICERT_API_KEY = 'api-key' +DIGICERT_URL = "mock://www.digicert.com" +DIGICERT_ORDER_TYPE = "ssl_plus" +DIGICERT_API_KEY = "api-key" DIGICERT_ORG_ID = 111111 DIGICERT_ROOT = "ROOT" -VERISIGN_URL = 'http://example.com' -VERISIGN_PEM_PATH = '~/' -VERISIGN_FIRST_NAME = 'Jim' -VERISIGN_LAST_NAME = 'Bob' -VERSIGN_EMAIL = 'jim@example.com' +VERISIGN_URL = "http://example.com" +VERISIGN_PEM_PATH = "~/" +VERISIGN_FIRST_NAME = "Jim" +VERISIGN_LAST_NAME = "Bob" +VERSIGN_EMAIL = "jim@example.com" -ACME_AWS_ACCOUNT_NUMBER = '11111111111' +ACME_AWS_ACCOUNT_NUMBER = "11111111111" -ACME_PRIVATE_KEY = ''' +ACME_PRIVATE_KEY = """ -----BEGIN RSA PRIVATE KEY----- MIIJJwIBAAKCAgEA0+jySNCc1i73LwDZEuIdSkZgRYQ4ZQVIioVf38RUhDElxy51 4gdWZwp8/TDpQ8cVXMj6QhdRpTVLluOz71hdvBAjxXTISRCRlItzizTgBD9CLXRh @@ -138,7 +140,7 @@ cRe4df5/EbRiUOyx/ZBepttB1meTnsH6cGPN0JnmTMQHQvanL3jjtjrC13408ONK omsEEjDt4qVqGvSyy+V/1EhqGPzm9ri3zapnorf69rscuXYYsMBZ8M6AtSio4ldB LjCRNS1lR6/mV8AqUNR9Kn2NLQyJ76yDoEVLulKZqGUsC9STN4oGJLUeFw== -----END RSA PRIVATE KEY----- -''' +""" ACME_ROOT = """ -----BEGIN CERTIFICATE----- @@ -174,17 +176,17 @@ PB0t6JzUA81mSqM3kxl5e+IZwhYAyO0OTg3/fs8HqGTNKd9BqoUwSRBzp06JMg5b rUCGwbCUDI0mxadJ3Bz4WxR6fyNpBK2yAinWEsikxqEt -----END CERTIFICATE----- """ -ACME_URL = 'https://acme-v01.api.letsencrypt.org' -ACME_EMAIL = 'jim@example.com' -ACME_TEL = '4088675309' -ACME_DIRECTORY_URL = 'https://acme-v01.api.letsencrypt.org' +ACME_URL = "https://acme-v01.api.letsencrypt.org" +ACME_EMAIL = "jim@example.com" +ACME_TEL = "4088675309" +ACME_DIRECTORY_URL = "https://acme-v01.api.letsencrypt.org" ACME_DISABLE_AUTORESOLVE = True LDAP_AUTH = True -LDAP_BIND_URI = 'ldap://localhost' -LDAP_BASE_DN = 'dc=example,dc=com' -LDAP_EMAIL_DOMAIN = 'example.com' -LDAP_REQUIRED_GROUP = 'Lemur Access' -LDAP_DEFAULT_ROLE = 'role1' +LDAP_BIND_URI = "ldap://localhost" +LDAP_BASE_DN = "dc=example,dc=com" +LDAP_EMAIL_DOMAIN = "example.com" +LDAP_REQUIRED_GROUP = "Lemur Access" +LDAP_DEFAULT_ROLE = "role1" ALLOW_CERT_DELETION = True diff --git a/lemur/tests/conftest.py b/lemur/tests/conftest.py index 809b9a6a..2efd65d9 100644 --- a/lemur/tests/conftest.py +++ b/lemur/tests/conftest.py @@ -13,16 +13,34 @@ from lemur import create_app from lemur.common.utils import parse_private_key from lemur.database import db as _db from lemur.auth.service import create_token -from lemur.tests.vectors import SAN_CERT_KEY, INTERMEDIATE_KEY, ROOTCA_CERT_STR, ROOTCA_KEY +from lemur.tests.vectors import ( + SAN_CERT_KEY, + INTERMEDIATE_KEY, + ROOTCA_CERT_STR, + ROOTCA_KEY, +) -from .factories import ApiKeyFactory, AuthorityFactory, NotificationFactory, DestinationFactory, \ - CertificateFactory, UserFactory, RoleFactory, SourceFactory, EndpointFactory, \ - RotationPolicyFactory, PendingCertificateFactory, AsyncAuthorityFactory, InvalidCertificateFactory, \ - CryptoAuthorityFactory, CACertificateFactory +from .factories import ( + ApiKeyFactory, + AuthorityFactory, + NotificationFactory, + DestinationFactory, + CertificateFactory, + UserFactory, + RoleFactory, + SourceFactory, + EndpointFactory, + RotationPolicyFactory, + PendingCertificateFactory, + AsyncAuthorityFactory, + InvalidCertificateFactory, + CryptoAuthorityFactory, + CACertificateFactory, +) def pytest_runtest_setup(item): - if 'slow' in item.keywords and not item.config.getoption("--runslow"): + if "slow" in item.keywords and not item.config.getoption("--runslow"): pytest.skip("need --runslow option to run") if "incremental" in item.keywords: @@ -44,7 +62,9 @@ def app(request): Creates a new Flask application for a test duration. Uses application factory `create_app`. """ - _app = create_app(config_path=os.path.dirname(os.path.realpath(__file__)) + '/conf.py') + _app = create_app( + config_path=os.path.dirname(os.path.realpath(__file__)) + "/conf.py" + ) ctx = _app.app_context() ctx.push() @@ -56,15 +76,15 @@ def app(request): @pytest.yield_fixture(scope="session") def db(app, request): _db.drop_all() - _db.engine.execute(text('CREATE EXTENSION IF NOT EXISTS pg_trgm')) + _db.engine.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) _db.create_all() _db.app = app UserFactory() - r = RoleFactory(name='admin') + r = RoleFactory(name="admin") u = UserFactory(roles=[r]) - rp = RotationPolicyFactory(name='default') + rp = RotationPolicyFactory(name="default") ApiKeyFactory(user=u) _db.session.commit() @@ -159,8 +179,8 @@ def user(session): u = UserFactory() session.commit() user_token = create_token(u) - token = {'Authorization': 'Basic ' + user_token} - return {'user': u, 'token': token} + token = {"Authorization": "Basic " + user_token} + return {"user": u, "token": token} @pytest.fixture @@ -203,18 +223,19 @@ def invalid_certificate(session): @pytest.fixture def admin_user(session): u = UserFactory() - admin_role = RoleFactory(name='admin') + admin_role = RoleFactory(name="admin") u.roles.append(admin_role) session.commit() user_token = create_token(u) - token = {'Authorization': 'Basic ' + user_token} - return {'user': u, 'token': token} + token = {"Authorization": "Basic " + user_token} + return {"user": u, "token": token} @pytest.fixture def async_issuer_plugin(): from lemur.plugins.base import register from .plugins.issuer_plugin import TestAsyncIssuerPlugin + register(TestAsyncIssuerPlugin) return TestAsyncIssuerPlugin @@ -223,6 +244,7 @@ def async_issuer_plugin(): def issuer_plugin(): from lemur.plugins.base import register from .plugins.issuer_plugin import TestIssuerPlugin + register(TestIssuerPlugin) return TestIssuerPlugin @@ -231,6 +253,7 @@ def issuer_plugin(): def notification_plugin(): from lemur.plugins.base import register from .plugins.notification_plugin import TestNotificationPlugin + register(TestNotificationPlugin) return TestNotificationPlugin @@ -239,6 +262,7 @@ def notification_plugin(): def destination_plugin(): from lemur.plugins.base import register from .plugins.destination_plugin import TestDestinationPlugin + register(TestDestinationPlugin) return TestDestinationPlugin @@ -247,6 +271,7 @@ def destination_plugin(): def source_plugin(): from lemur.plugins.base import register from .plugins.source_plugin import TestSourcePlugin + register(TestSourcePlugin) return TestSourcePlugin @@ -277,13 +302,19 @@ def issuer_private_key(): @pytest.fixture def cert_builder(private_key): - return (x509.CertificateBuilder() - .subject_name(x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, 'foo.com')])) - .issuer_name(x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, 'foo.com')])) - .serial_number(1) - .public_key(private_key.public_key()) - .not_valid_before(datetime.datetime(2017, 12, 22)) - .not_valid_after(datetime.datetime(2040, 1, 1))) + return ( + x509.CertificateBuilder() + .subject_name( + x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, "foo.com")]) + ) + .issuer_name( + x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, "foo.com")]) + ) + .serial_number(1) + .public_key(private_key.public_key()) + .not_valid_before(datetime.datetime(2017, 12, 22)) + .not_valid_after(datetime.datetime(2040, 1, 1)) + ) @pytest.fixture @@ -292,9 +323,9 @@ def selfsigned_cert(cert_builder, private_key): return cert_builder.sign(private_key, hashes.SHA256(), default_backend()) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def aws_credentials(): - os.environ['AWS_ACCESS_KEY_ID'] = 'testing' - os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' - os.environ['AWS_SECURITY_TOKEN'] = 'testing' - os.environ['AWS_SESSION_TOKEN'] = 'testing' + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" diff --git a/lemur/tests/factories.py b/lemur/tests/factories.py index de78f8a3..fea4c59a 100644 --- a/lemur/tests/factories.py +++ b/lemur/tests/factories.py @@ -1,4 +1,3 @@ - from datetime import date from factory import Sequence, post_generation, SubFactory @@ -19,8 +18,16 @@ from lemur.endpoints.models import Policy, Endpoint from lemur.policies.models import RotationPolicy from lemur.api_keys.models import ApiKey -from .vectors import SAN_CERT_STR, SAN_CERT_KEY, CSR_STR, INTERMEDIATE_CERT_STR, ROOTCA_CERT_STR, INTERMEDIATE_KEY, \ - WILDCARD_CERT_KEY, INVALID_CERT_STR +from .vectors import ( + SAN_CERT_STR, + SAN_CERT_KEY, + CSR_STR, + INTERMEDIATE_CERT_STR, + ROOTCA_CERT_STR, + INTERMEDIATE_KEY, + WILDCARD_CERT_KEY, + INVALID_CERT_STR, +) class BaseFactory(SQLAlchemyModelFactory): @@ -28,28 +35,32 @@ class BaseFactory(SQLAlchemyModelFactory): class Meta: """Factory configuration.""" + abstract = True sqlalchemy_session = db.session class RotationPolicyFactory(BaseFactory): """Rotation Factory.""" - name = Sequence(lambda n: 'policy{0}'.format(n)) + + name = Sequence(lambda n: "policy{0}".format(n)) days = 30 class Meta: """Factory configuration.""" + model = RotationPolicy class CertificateFactory(BaseFactory): """Certificate factory.""" - name = Sequence(lambda n: 'certificate{0}'.format(n)) + + name = Sequence(lambda n: "certificate{0}".format(n)) chain = INTERMEDIATE_CERT_STR body = SAN_CERT_STR private_key = SAN_CERT_KEY - owner = 'joe@example.com' - status = FuzzyChoice(['valid', 'revoked', 'unknown']) + owner = "joe@example.com" + status = FuzzyChoice(["valid", "revoked", "unknown"]) deleted = False description = FuzzyText(length=128) active = True @@ -58,6 +69,7 @@ class CertificateFactory(BaseFactory): class Meta: """Factory Configuration.""" + model = Certificate @post_generation @@ -139,20 +151,22 @@ class CACertificateFactory(CertificateFactory): class InvalidCertificateFactory(CertificateFactory): body = INVALID_CERT_STR - private_key = '' - chain = '' + private_key = "" + chain = "" class AuthorityFactory(BaseFactory): """Authority factory.""" - name = Sequence(lambda n: 'authority{0}'.format(n)) - owner = 'joe@example.com' - plugin = {'slug': 'test-issuer'} + + name = Sequence(lambda n: "authority{0}".format(n)) + owner = "joe@example.com" + plugin = {"slug": "test-issuer"} description = FuzzyText(length=128) authority_certificate = SubFactory(CACertificateFactory) class Meta: """Factory configuration.""" + model = Authority @post_generation @@ -167,54 +181,64 @@ class AuthorityFactory(BaseFactory): class AsyncAuthorityFactory(AuthorityFactory): """Async Authority factory.""" - name = Sequence(lambda n: 'authority{0}'.format(n)) - owner = 'joe@example.com' - plugin = {'slug': 'test-issuer-async'} + + name = Sequence(lambda n: "authority{0}".format(n)) + owner = "joe@example.com" + plugin = {"slug": "test-issuer-async"} description = FuzzyText(length=128) authority_certificate = SubFactory(CertificateFactory) class CryptoAuthorityFactory(AuthorityFactory): """Authority factory based on 'cryptography' plugin.""" - plugin = {'slug': 'cryptography-issuer'} + + plugin = {"slug": "cryptography-issuer"} class DestinationFactory(BaseFactory): """Destination factory.""" - plugin_name = 'test-destination' - label = Sequence(lambda n: 'destination{0}'.format(n)) + + plugin_name = "test-destination" + label = Sequence(lambda n: "destination{0}".format(n)) class Meta: """Factory Configuration.""" + model = Destination class SourceFactory(BaseFactory): """Source factory.""" - plugin_name = 'test-source' - label = Sequence(lambda n: 'source{0}'.format(n)) + + plugin_name = "test-source" + label = Sequence(lambda n: "source{0}".format(n)) class Meta: """Factory Configuration.""" + model = Source class NotificationFactory(BaseFactory): """Notification factory.""" - plugin_name = 'test-notification' - label = Sequence(lambda n: 'notification{0}'.format(n)) + + plugin_name = "test-notification" + label = Sequence(lambda n: "notification{0}".format(n)) class Meta: """Factory Configuration.""" + model = Notification class RoleFactory(BaseFactory): """Role factory.""" - name = Sequence(lambda n: 'role{0}'.format(n)) + + name = Sequence(lambda n: "role{0}".format(n)) class Meta: """Factory Configuration.""" + model = Role @post_generation @@ -229,14 +253,16 @@ class RoleFactory(BaseFactory): class UserFactory(BaseFactory): """User Factory.""" - username = Sequence(lambda n: 'user{0}'.format(n)) - email = Sequence(lambda n: 'user{0}@example.com'.format(n)) + + username = Sequence(lambda n: "user{0}".format(n)) + email = Sequence(lambda n: "user{0}@example.com".format(n)) active = True password = FuzzyText(length=24) certificates = [] class Meta: """Factory Configuration.""" + model = User @post_generation @@ -269,39 +295,45 @@ class UserFactory(BaseFactory): class PolicyFactory(BaseFactory): """Policy Factory.""" - name = Sequence(lambda n: 'endpoint{0}'.format(n)) + + name = Sequence(lambda n: "endpoint{0}".format(n)) class Meta: """Factory Configuration.""" + model = Policy class EndpointFactory(BaseFactory): """Endpoint Factory.""" - owner = 'joe@example.com' - name = Sequence(lambda n: 'endpoint{0}'.format(n)) - type = FuzzyChoice(['elb']) + + owner = "joe@example.com" + name = Sequence(lambda n: "endpoint{0}".format(n)) + type = FuzzyChoice(["elb"]) active = True port = FuzzyInteger(0, high=65535) - dnsname = 'endpoint.example.com' + dnsname = "endpoint.example.com" policy = SubFactory(PolicyFactory) certificate = SubFactory(CertificateFactory) source = SubFactory(SourceFactory) class Meta: """Factory Configuration.""" + model = Endpoint class ApiKeyFactory(BaseFactory): """Api Key Factory.""" - name = Sequence(lambda n: 'api_key_{0}'.format(n)) + + name = Sequence(lambda n: "api_key_{0}".format(n)) revoked = False ttl = -1 issued_at = 1 class Meta: """Factory Configuration.""" + model = ApiKey @post_generation @@ -315,13 +347,14 @@ class ApiKeyFactory(BaseFactory): class PendingCertificateFactory(BaseFactory): """PendingCertificate factory.""" - name = Sequence(lambda n: 'pending_certificate{0}'.format(n)) + + name = Sequence(lambda n: "pending_certificate{0}".format(n)) external_id = 12345 csr = CSR_STR chain = INTERMEDIATE_CERT_STR private_key = WILDCARD_CERT_KEY - owner = 'joe@example.com' - status = FuzzyChoice(['valid', 'revoked', 'unknown']) + owner = "joe@example.com" + status = FuzzyChoice(["valid", "revoked", "unknown"]) deleted = False description = FuzzyText(length=128) date_created = FuzzyDate(date(2016, 1, 1), date(2020, 1, 1)) @@ -330,6 +363,7 @@ class PendingCertificateFactory(BaseFactory): class Meta: """Factory Configuration.""" + model = PendingCertificate @post_generation diff --git a/lemur/tests/plugins/destination_plugin.py b/lemur/tests/plugins/destination_plugin.py index f77085ec..d1eb6711 100644 --- a/lemur/tests/plugins/destination_plugin.py +++ b/lemur/tests/plugins/destination_plugin.py @@ -2,12 +2,12 @@ from lemur.plugins.bases import DestinationPlugin class TestDestinationPlugin(DestinationPlugin): - title = 'Test' - slug = 'test-destination' - description = 'Enables testing' + title = "Test" + slug = "test-destination" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestDestinationPlugin, self).__init__(*args, **kwargs) diff --git a/lemur/tests/plugins/issuer_plugin.py b/lemur/tests/plugins/issuer_plugin.py index 3fda83ae..5f5c732b 100644 --- a/lemur/tests/plugins/issuer_plugin.py +++ b/lemur/tests/plugins/issuer_plugin.py @@ -4,12 +4,12 @@ from lemur.tests.vectors import SAN_CERT_STR, INTERMEDIATE_CERT_STR class TestIssuerPlugin(IssuerPlugin): - title = 'Test' - slug = 'test-issuer' - description = 'Enables testing' + title = "Test" + slug = "test-issuer" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestIssuerPlugin, self).__init__(*args, **kwargs) @@ -20,17 +20,17 @@ class TestIssuerPlugin(IssuerPlugin): @staticmethod def create_authority(options): - role = {'username': '', 'password': '', 'name': 'test'} + role = {"username": "", "password": "", "name": "test"} return SAN_CERT_STR, "", [role] class TestAsyncIssuerPlugin(IssuerPlugin): - title = 'Test Async' - slug = 'test-issuer-async' - description = 'Enables testing with pending certificates' + title = "Test Async" + slug = "test-issuer-async" + description = "Enables testing with pending certificates" - author = 'James Chuong' - author_url = 'https://github.com/jchuong' + author = "James Chuong" + author_url = "https://github.com/jchuong" def __init__(self, *args, **kwargs): super(TestAsyncIssuerPlugin, self).__init__(*args, **kwargs) @@ -43,7 +43,7 @@ class TestAsyncIssuerPlugin(IssuerPlugin): @staticmethod def create_authority(options): - role = {'username': '', 'password': '', 'name': 'test'} + role = {"username": "", "password": "", "name": "test"} return SAN_CERT_STR, "", [role] def cancel_ordered_certificate(self, pending_certificate, **kwargs): diff --git a/lemur/tests/plugins/notification_plugin.py b/lemur/tests/plugins/notification_plugin.py index ad393d60..4ad79704 100644 --- a/lemur/tests/plugins/notification_plugin.py +++ b/lemur/tests/plugins/notification_plugin.py @@ -2,12 +2,12 @@ from lemur.plugins.bases import NotificationPlugin class TestNotificationPlugin(NotificationPlugin): - title = 'Test' - slug = 'test-notification' - description = 'Enables testing' + title = "Test" + slug = "test-notification" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestNotificationPlugin, self).__init__(*args, **kwargs) diff --git a/lemur/tests/plugins/source_plugin.py b/lemur/tests/plugins/source_plugin.py index 10402576..21ce245d 100644 --- a/lemur/tests/plugins/source_plugin.py +++ b/lemur/tests/plugins/source_plugin.py @@ -2,12 +2,12 @@ from lemur.plugins.bases import SourcePlugin class TestSourcePlugin(SourcePlugin): - title = 'Test' - slug = 'test-source' - description = 'Enables testing' + title = "Test" + slug = "test-source" + description = "Enables testing" - author = 'Kevin Glisson' - author_url = 'https://github.com/netflix/lemur.git' + author = "Kevin Glisson" + author_url = "https://github.com/netflix/lemur.git" def __init__(self, *args, **kwargs): super(TestSourcePlugin, self).__init__(*args, **kwargs) diff --git a/lemur/tests/test_api_keys.py b/lemur/tests/test_api_keys.py index e60773bf..9e293be2 100644 --- a/lemur/tests/test_api_keys.py +++ b/lemur/tests/test_api_keys.py @@ -4,219 +4,398 @@ import pytest from lemur.api_keys.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_api_key_list_get(client, token, status): assert client.get(api.url_for(ApiKeyList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_api_key_list_post_invalid(client, token, status): - assert client.post(api.url_for(ApiKeyList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(ApiKeyList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,user_id,status", [ - (VALID_USER_HEADER_TOKEN, 1, 200), - (VALID_ADMIN_HEADER_TOKEN, 2, 200), - (VALID_ADMIN_API_TOKEN, 2, 200), - ('', 0, 401) -]) +@pytest.mark.parametrize( + "token,user_id,status", + [ + (VALID_USER_HEADER_TOKEN, 1, 200), + (VALID_ADMIN_HEADER_TOKEN, 2, 200), + (VALID_ADMIN_API_TOKEN, 2, 200), + ("", 0, 401), + ], +) def test_api_key_list_post_valid_self(client, user_id, token, status): - assert client.post(api.url_for(ApiKeyList), data=json.dumps({'name': 'a test token', 'user': {'id': user_id, 'username': 'example', 'email': 'example@test.net'}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyList), + data=json.dumps( + { + "name": "a test token", + "user": { + "id": user_id, + "username": "example", + "email": "example@test.net", + }, + "ttl": -1, + } + ), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_api_key_list_post_valid_no_permission(client, token, status): - assert client.post(api.url_for(ApiKeyList), data=json.dumps({'name': 'a test token', 'user': {'id': 2, 'username': 'example', 'email': 'example@test.net'}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyList), + data=json.dumps( + { + "name": "a test token", + "user": { + "id": 2, + "username": "example", + "email": "example@test.net", + }, + "ttl": -1, + } + ), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_list_patch(client, token, status): - assert client.patch(api.url_for(ApiKeyList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(ApiKeyList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_list_delete(client, token, status): assert client.delete(api.url_for(ApiKeyList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_api_key_list_get(client, token, status): - assert client.get(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_user_api_key_list_post_invalid(client, token, status): - assert client.post(api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,user_id,status", [ - (VALID_USER_HEADER_TOKEN, 1, 200), - (VALID_ADMIN_HEADER_TOKEN, 2, 200), - (VALID_ADMIN_API_TOKEN, 2, 200), - ('', 0, 401) -]) +@pytest.mark.parametrize( + "token,user_id,status", + [ + (VALID_USER_HEADER_TOKEN, 1, 200), + (VALID_ADMIN_HEADER_TOKEN, 2, 200), + (VALID_ADMIN_API_TOKEN, 2, 200), + ("", 0, 401), + ], +) def test_user_api_key_list_post_valid_self(client, user_id, token, status): - assert client.post(api.url_for(ApiKeyUserList, user_id=1), data=json.dumps({'name': 'a test token', 'user': {'id': user_id}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyUserList, user_id=1), + data=json.dumps( + {"name": "a test token", "user": {"id": user_id}, "ttl": -1} + ), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_api_key_list_post_valid_no_permission(client, token, status): - assert client.post(api.url_for(ApiKeyUserList, user_id=2), data=json.dumps({'name': 'a test token', 'user': {'id': 2}, 'ttl': -1}), headers=token).status_code == status + assert ( + client.post( + api.url_for(ApiKeyUserList, user_id=2), + data=json.dumps({"name": "a test token", "user": {"id": 2}, "ttl": -1}), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_list_patch(client, token, status): - assert client.patch(api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(ApiKeyUserList, user_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_list_delete(client, token, status): - assert client.delete(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(ApiKeyUserList, user_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_api_key_get(client, token, status): assert client.get(api.url_for(ApiKeys, aid=1), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_post(client, token, status): assert client.post(api.url_for(ApiKeys, aid=1), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_api_key_patch(client, token, status): - assert client.patch(api.url_for(ApiKeys, aid=1), headers=token).status_code == status + assert ( + client.patch(api.url_for(ApiKeys, aid=1), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_api_key_put_permssions(client, token, status): - assert client.put(api.url_for(ApiKeys, aid=1), data=json.dumps({'name': 'Test', 'revoked': False, 'ttl': -1}), headers=token).status_code == status + assert ( + client.put( + api.url_for(ApiKeys, aid=1), + data=json.dumps({"name": "Test", "revoked": False, "ttl": -1}), + headers=token, + ).status_code + == status + ) # This test works while the other doesn't because the schema allows user id to be null. -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_api_key_described_get(client, token, status): - assert client.get(api.url_for(ApiKeysDescribed, aid=1), headers=token).status_code == status + assert ( + client.get(api.url_for(ApiKeysDescribed, aid=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_user_api_key_get(client, token, status): - assert client.get(api.url_for(UserApiKeys, uid=1, aid=1), headers=token).status_code == status + assert ( + client.get(api.url_for(UserApiKeys, uid=1, aid=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_post(client, token, status): - assert client.post(api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_api_key_patch(client, token, status): - assert client.patch(api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(UserApiKeys, uid=2, aid=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -@pytest.mark.skip(reason="no way of getting an actual user onto the access key to generate a jwt") +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +@pytest.mark.skip( + reason="no way of getting an actual user onto the access key to generate a jwt" +) def test_user_api_key_put_permssions(client, token, status): - assert client.put(api.url_for(UserApiKeys, uid=2, aid=1), data=json.dumps({'name': 'Test', 'revoked': False, 'ttl': -1}), headers=token).status_code == status + assert ( + client.put( + api.url_for(UserApiKeys, uid=2, aid=1), + data=json.dumps({"name": "Test", "revoked": False, "ttl": -1}), + headers=token, + ).status_code + == status + ) diff --git a/lemur/tests/test_authorities.py b/lemur/tests/test_authorities.py index e865ab41..9649e949 100644 --- a/lemur/tests/test_authorities.py +++ b/lemur/tests/test_authorities.py @@ -4,22 +4,29 @@ import pytest from lemur.authorities.views import * # noqa from lemur.tests.factories import AuthorityFactory, RoleFactory -from lemur.tests.vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from lemur.tests.vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_authority_input_schema(client, role, issuer_plugin, logged_in_user): from lemur.authorities.schemas import AuthorityInputSchema input_data = { - 'name': 'Example Authority', - 'owner': 'jim@example.com', - 'description': 'An example authority.', - 'commonName': 'An Example Authority', - 'plugin': {'slug': 'test-issuer', 'plugin_options': [{'name': 'test', 'value': 'blah'}]}, - 'type': 'root', - 'signingAlgorithm': 'sha256WithRSA', - 'keyType': 'RSA2048', - 'sensitivity': 'medium' + "name": "Example Authority", + "owner": "jim@example.com", + "description": "An example authority.", + "commonName": "An Example Authority", + "plugin": { + "slug": "test-issuer", + "plugin_options": [{"name": "test", "value": "blah"}], + }, + "type": "root", + "signingAlgorithm": "sha256WithRSA", + "keyType": "RSA2048", + "sensitivity": "medium", } data, errors = AuthorityInputSchema().load(input_data) @@ -28,179 +35,286 @@ def test_authority_input_schema(client, role, issuer_plugin, logged_in_user): def test_user_authority(session, client, authority, role, user, issuer_plugin): - u = user['user'] + u = user["user"] u.roles.append(role) authority.roles.append(role) session.commit() - assert client.get(api.url_for(AuthoritiesList), headers=user['token']).json['total'] == 1 + assert ( + client.get(api.url_for(AuthoritiesList), headers=user["token"]).json["total"] + == 1 + ) u.roles.remove(role) session.commit() - assert client.get(api.url_for(AuthoritiesList), headers=user['token']).json['total'] == 0 + assert ( + client.get(api.url_for(AuthoritiesList), headers=user["token"]).json["total"] + == 0 + ) def test_create_authority(issuer_plugin, user): from lemur.authorities.service import create - authority = create(plugin={'plugin_object': issuer_plugin, 'slug': issuer_plugin.slug}, owner='jim@example.com', type='root', creator=user['user']) + + authority = create( + plugin={"plugin_object": issuer_plugin, "slug": issuer_plugin.slug}, + owner="jim@example.com", + type="root", + creator=user["user"], + ) assert authority.authority_certificate -@pytest.mark.parametrize("token, count", [ - (VALID_USER_HEADER_TOKEN, 0), - (VALID_ADMIN_HEADER_TOKEN, 3), - (VALID_ADMIN_API_TOKEN, 3), -]) +@pytest.mark.parametrize( + "token, count", + [ + (VALID_USER_HEADER_TOKEN, 0), + (VALID_ADMIN_HEADER_TOKEN, 3), + (VALID_ADMIN_API_TOKEN, 3), + ], +) def test_admin_authority(client, authority, issuer_plugin, token, count): - assert client.get(api.url_for(AuthoritiesList), headers=token).json['total'] == count + assert ( + client.get(api.url_for(AuthoritiesList), headers=token).json["total"] == count + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_authority_get(client, token, status): - assert client.get(api.url_for(Authorities, authority_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Authorities, authority_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authority_post(client, token, status): - assert client.post(api.url_for(Authorities, authority_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Authorities, authority_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_authority_put(client, token, status): - assert client.put(api.url_for(Authorities, authority_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Authorities, authority_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authority_delete(client, token, status): - assert client.delete(api.url_for(Authorities, authority_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Authorities, authority_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authority_patch(client, token, status): - assert client.patch(api.url_for(Authorities, authority_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Authorities, authority_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_authorities_get(client, token, status): assert client.get(api.url_for(AuthoritiesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_authorities_post(client, token, status): - assert client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authorities_put(client, token, status): - assert client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authorities_delete(client, token, status): - assert client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + assert ( + client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_authorities_patch(client, token, status): - assert client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificate_authorities_get(client, token, status): assert client.get(api.url_for(AuthoritiesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificate_authorities_post(client, token, status): - assert client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_authorities_put(client, token, status): - assert client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_authorities_delete(client, token, status): - assert client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + assert ( + client.delete(api.url_for(AuthoritiesList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_authorities_patch(client, token, status): - assert client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(AuthoritiesList), data={}, headers=token).status_code + == status + ) def test_authority_roles(client, session, issuer_plugin): @@ -209,23 +323,29 @@ def test_authority_roles(client, session, issuer_plugin): session.flush() data = { - 'owner': auth.owner, - 'name': auth.name, - 'description': auth.description, - 'active': True, - 'roles': [ - {'id': role.id}, - ], + "owner": auth.owner, + "name": auth.name, + "description": auth.description, + "active": True, + "roles": [{"id": role.id}], } # Add role - resp = client.put(api.url_for(Authorities, authority_id=auth.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.put( + api.url_for(Authorities, authority_id=auth.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 1 + assert len(resp.json["roles"]) == 1 assert set(auth.roles) == {role} # Remove role - del data['roles'][0] - resp = client.put(api.url_for(Authorities, authority_id=auth.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + del data["roles"][0] + resp = client.put( + api.url_for(Authorities, authority_id=auth.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 0 + assert len(resp.json["roles"]) == 0 diff --git a/lemur/tests/test_certificates.py b/lemur/tests/test_certificates.py index cc8a5224..07b5ee4e 100644 --- a/lemur/tests/test_certificates.py +++ b/lemur/tests/test_certificates.py @@ -17,32 +17,53 @@ from lemur.common import utils from lemur.domains.models import Domain -from lemur.tests.vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN, CSR_STR, \ - INTERMEDIATE_CERT_STR, SAN_CERT_STR, SAN_CERT_CSR, SAN_CERT_KEY, ROOTCA_KEY, ROOTCA_CERT_STR +from lemur.tests.vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, + CSR_STR, + INTERMEDIATE_CERT_STR, + SAN_CERT_STR, + SAN_CERT_CSR, + SAN_CERT_KEY, + ROOTCA_KEY, + ROOTCA_CERT_STR, +) def test_get_or_increase_name(session, certificate): from lemur.certificates.models import get_or_increase_name from lemur.tests.factories import CertificateFactory - serial = 'AFF2DB4F8D2D4D8E80FA382AE27C2333' + serial = "AFF2DB4F8D2D4D8E80FA382AE27C2333" - assert get_or_increase_name(certificate.name, certificate.serial) == '{0}-{1}'.format(certificate.name, serial) + assert get_or_increase_name( + certificate.name, certificate.serial + ) == "{0}-{1}".format(certificate.name, serial) - certificate.name = 'test-cert-11111111' - assert get_or_increase_name(certificate.name, certificate.serial) == 'test-cert-11111111-' + serial + certificate.name = "test-cert-11111111" + assert ( + get_or_increase_name(certificate.name, certificate.serial) + == "test-cert-11111111-" + serial + ) - certificate.name = 'test-cert-11111111-1' - assert get_or_increase_name('test-cert-11111111-1', certificate.serial) == 'test-cert-11111111-1-' + serial + certificate.name = "test-cert-11111111-1" + assert ( + get_or_increase_name("test-cert-11111111-1", certificate.serial) + == "test-cert-11111111-1-" + serial + ) - cert2 = CertificateFactory(name='certificate1-' + serial) + cert2 = CertificateFactory(name="certificate1-" + serial) session.commit() - assert get_or_increase_name('certificate1', int(serial, 16)) == 'certificate1-{}-1'.format(serial) + assert get_or_increase_name( + "certificate1", int(serial, 16) + ) == "certificate1-{}-1".format(serial) def test_get_all_certs(session, certificate): from lemur.certificates.service import get_all_certs + assert len(get_all_certs()) > 1 @@ -66,7 +87,7 @@ def test_delete_cert(session): from lemur.certificates.service import delete, get from lemur.tests.factories import CertificateFactory - delete_this = CertificateFactory(name='DELETEME') + delete_this = CertificateFactory(name="DELETEME") session.commit() cert_exists = get(delete_this.id) @@ -85,21 +106,24 @@ def test_get_by_attributes(session, certificate): from lemur.certificates.service import get_by_attributes # Should get one cert - certificate1 = get_by_attributes({ - 'name': 'SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231' - }) + certificate1 = get_by_attributes( + { + "name": "SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231" + } + ) # Should get one cert using multiple attrs - certificate2 = get_by_attributes({ - 'name': 'test-cert-11111111-1', - 'cn': 'san.example.org' - }) + certificate2 = get_by_attributes( + {"name": "test-cert-11111111-1", "cn": "san.example.org"} + ) # Should get multiple certs - multiple = get_by_attributes({ - 'cn': 'LemurTrust Unittests Class 1 CA 2018', - 'issuer': 'LemurTrustUnittestsRootCA2018' - }) + multiple = get_by_attributes( + { + "cn": "LemurTrust Unittests Class 1 CA 2018", + "issuer": "LemurTrustUnittestsRootCA2018", + } + ) assert len(certificate1) == 1 assert len(certificate2) == 1 @@ -109,14 +133,11 @@ def test_get_by_attributes(session, certificate): def test_find_duplicates(session): from lemur.certificates.service import find_duplicates - cert = { - 'body': SAN_CERT_STR, - 'chain': INTERMEDIATE_CERT_STR - } + cert = {"body": SAN_CERT_STR, "chain": INTERMEDIATE_CERT_STR} dups1 = find_duplicates(cert) - cert['chain'] = '' + cert["chain"] = "" dups2 = find_duplicates(cert) @@ -138,13 +159,15 @@ def test_certificate_output_schema(session, certificate, issuer_plugin): from lemur.certificates.schemas import CertificateOutputSchema # Clear the cached attribute first - if 'parsed_cert' in certificate.__dict__: - del certificate.__dict__['parsed_cert'] + if "parsed_cert" in certificate.__dict__: + del certificate.__dict__["parsed_cert"] # Make sure serialization parses the cert only once (uses cached 'parsed_cert' attribute) - with patch('lemur.common.utils.parse_certificate', side_effect=utils.parse_certificate) as wrapper: + with patch( + "lemur.common.utils.parse_certificate", side_effect=utils.parse_certificate + ) as wrapper: data, errors = CertificateOutputSchema().dump(certificate) - assert data['issuer'] == 'LemurTrustUnittestsClass1CA2018' + assert data["issuer"] == "LemurTrustUnittestsClass1CA2018" assert wrapper.call_count == 1 @@ -152,24 +175,21 @@ def test_certificate_output_schema(session, certificate, issuer_plugin): def test_certificate_edit_schema(session): from lemur.certificates.schemas import CertificateEditInputSchema - input_data = {'owner': 'bob@example.com'} + input_data = {"owner": "bob@example.com"} data, errors = CertificateEditInputSchema().load(input_data) - assert len(data['notifications']) == 3 + assert len(data["notifications"]) == 3 def test_authority_key_identifier_schema(): from lemur.schemas import AuthorityKeyIdentifierSchema - input_data = { - 'useKeyIdentifier': True, - 'useAuthorityCert': True - } + + input_data = {"useKeyIdentifier": True, "useAuthorityCert": True} data, errors = AuthorityKeyIdentifierSchema().load(input_data) - assert sorted(data) == sorted({ - 'use_key_identifier': True, - 'use_authority_cert': True - }) + assert sorted(data) == sorted( + {"use_key_identifier": True, "use_authority_cert": True} + ) assert not errors data, errors = AuthorityKeyIdentifierSchema().dumps(data) @@ -179,11 +199,12 @@ def test_authority_key_identifier_schema(): def test_certificate_info_access_schema(): from lemur.schemas import CertificateInfoAccessSchema - input_data = {'includeAIA': True} + + input_data = {"includeAIA": True} data, errors = CertificateInfoAccessSchema().load(input_data) assert not errors - assert data == {'include_aia': True} + assert data == {"include_aia": True} data, errors = CertificateInfoAccessSchema().dump(data) assert not errors @@ -193,11 +214,11 @@ def test_certificate_info_access_schema(): def test_subject_key_identifier_schema(): from lemur.schemas import SubjectKeyIdentifierSchema - input_data = {'includeSKI': True} + input_data = {"includeSKI": True} data, errors = SubjectKeyIdentifierSchema().load(input_data) assert not errors - assert data == {'include_ski': True} + assert data == {"include_ski": True} data, errors = SubjectKeyIdentifierSchema().dump(data) assert not errors assert data == input_data @@ -207,16 +228,9 @@ def test_extension_schema(client): from lemur.certificates.schemas import ExtensionSchema input_data = { - 'keyUsage': { - 'useKeyEncipherment': True, - 'useDigitalSignature': True - }, - 'extendedKeyUsage': { - 'useServerAuthentication': True - }, - 'subjectKeyIdentifier': { - 'includeSKI': True - } + "keyUsage": {"useKeyEncipherment": True, "useDigitalSignature": True}, + "extendedKeyUsage": {"useServerAuthentication": True}, + "subjectKeyIdentifier": {"includeSKI": True}, } data, errors = ExtensionSchema().load(input_data) @@ -230,24 +244,24 @@ def test_certificate_input_schema(client, authority): from lemur.certificates.schemas import CertificateInputSchema input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': arrow.get(2018, 11, 9).isoformat(), - 'validityEnd': arrow.get(2019, 11, 9).isoformat(), - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": arrow.get(2018, 11, 9).isoformat(), + "validityEnd": arrow.get(2019, 11, 9).isoformat(), + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) assert not errors - assert data['authority'].id == authority.id + assert data["authority"].id == authority.id # make sure the defaults got set - assert data['common_name'] == 'test.example.com' - assert data['country'] == 'US' - assert data['location'] == 'Los Gatos' + assert data["common_name"] == "test.example.com" + assert data["country"] == "US" + assert data["location"] == "Los Gatos" assert len(data.keys()) == 19 @@ -256,28 +270,22 @@ def test_certificate_input_with_extensions(client, authority): from lemur.certificates.schemas import CertificateInputSchema input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'extensions': { - 'keyUsage': { - 'digital_signature': True + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "extensions": { + "keyUsage": {"digital_signature": True}, + "extendedKeyUsage": { + "useClientAuthentication": True, + "useServerAuthentication": True, }, - 'extendedKeyUsage': { - 'useClientAuthentication': True, - 'useServerAuthentication': True + "subjectKeyIdentifier": {"includeSKI": True}, + "subAltNames": { + "names": [{"nameType": "DNSName", "value": "test.example.com"}] }, - 'subjectKeyIdentifier': { - 'includeSKI': True - }, - 'subAltNames': { - 'names': [ - {'nameType': 'DNSName', 'value': 'test.example.com'} - ] - } }, - 'dnsProvider': None, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -287,48 +295,61 @@ def test_certificate_input_with_extensions(client, authority): def test_certificate_input_schema_parse_csr(authority): from lemur.certificates.schemas import CertificateInputSchema - test_san_dns = 'foobar.com' - extensions = {'sub_alt_names': {'names': x509.SubjectAlternativeName([x509.DNSName(test_san_dns)])}} - csr, private_key = create_csr(owner='joe@example.com', common_name='ACommonName', organization='test', - organizational_unit='Meters', country='NL', state='Noord-Holland', location='Amsterdam', - key_type='RSA2048', extensions=extensions) + test_san_dns = "foobar.com" + extensions = { + "sub_alt_names": { + "names": x509.SubjectAlternativeName([x509.DNSName(test_san_dns)]) + } + } + csr, private_key = create_csr( + owner="joe@example.com", + common_name="ACommonName", + organization="test", + organizational_unit="Meters", + country="NL", + state="Noord-Holland", + location="Amsterdam", + key_type="RSA2048", + extensions=extensions, + ) input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'csr': csr, - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "csr": csr, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) - for san in data['extensions']['sub_alt_names']['names']: + for san in data["extensions"]["sub_alt_names"]["names"]: assert san.value == test_san_dns assert not errors def test_certificate_out_of_range_date(client, authority): from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityYears': 100, - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityYears": 100, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) assert errors - input_data['validityStart'] = '2017-04-30T00:12:34.513631' + input_data["validityStart"] = "2017-04-30T00:12:34.513631" data, errors = CertificateInputSchema().load(input_data) assert errors - input_data['validityEnd'] = '2018-04-30T00:12:34.513631' + input_data["validityEnd"] = "2018-04-30T00:12:34.513631" data, errors = CertificateInputSchema().load(input_data) assert errors @@ -336,13 +357,14 @@ def test_certificate_out_of_range_date(client, authority): def test_certificate_valid_years(client, authority): from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityYears': 1, - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityYears": 1, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -351,14 +373,15 @@ def test_certificate_valid_years(client, authority): def test_certificate_valid_dates(client, authority): from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'test.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "test.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -368,14 +391,15 @@ def test_certificate_valid_dates(client, authority): def test_certificate_cn_admin(client, authority, logged_in_admin): """Admin is exempt from CN/SAN domain restrictions.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': '*.admin-overrides-whitelist.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "*.admin-overrides-whitelist.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -385,22 +409,23 @@ def test_certificate_cn_admin(client, authority, logged_in_admin): def test_certificate_allowed_names(client, authority, session, logged_in_user): """Test for allowed CN and SAN values.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'Names with spaces are not checked', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'extensions': { - 'subAltNames': { - 'names': [ - {'nameType': 'DNSName', 'value': 'allowed.example.com'}, - {'nameType': 'IPAddress', 'value': '127.0.0.1'}, + "commonName": "Names with spaces are not checked", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "extensions": { + "subAltNames": { + "names": [ + {"nameType": "DNSName", "value": "allowed.example.com"}, + {"nameType": "IPAddress", "value": "127.0.0.1"}, ] } }, - 'dnsProvider': None, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) @@ -415,74 +440,82 @@ def test_certificate_incative_authority(client, authority, session, logged_in_us session.add(authority) input_data = { - 'commonName': 'foo.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "foo.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) - assert errors['authority'][0] == "The authority is inactive." + assert errors["authority"][0] == "The authority is inactive." def test_certificate_disallowed_names(client, authority, session, logged_in_user): """The CN and SAN are disallowed by LEMUR_WHITELISTED_DOMAINS.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': '*.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'extensions': { - 'subAltNames': { - 'names': [ - {'nameType': 'DNSName', 'value': 'allowed.example.com'}, - {'nameType': 'DNSName', 'value': 'evilhacker.org'}, + "commonName": "*.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "extensions": { + "subAltNames": { + "names": [ + {"nameType": "DNSName", "value": "allowed.example.com"}, + {"nameType": "DNSName", "value": "evilhacker.org"}, ] } }, - 'dnsProvider': None, + "dnsProvider": None, } data, errors = CertificateInputSchema().load(input_data) - assert errors['common_name'][0].startswith("Domain *.example.com does not match whitelisted domain patterns") - assert (errors['extensions']['sub_alt_names']['names'][0] - .startswith("Domain evilhacker.org does not match whitelisted domain patterns")) + assert errors["common_name"][0].startswith( + "Domain *.example.com does not match whitelisted domain patterns" + ) + assert errors["extensions"]["sub_alt_names"]["names"][0].startswith( + "Domain evilhacker.org does not match whitelisted domain patterns" + ) def test_certificate_sensitive_name(client, authority, session, logged_in_user): """The CN is disallowed by 'sensitive' flag on Domain model.""" from lemur.certificates.schemas import CertificateInputSchema + input_data = { - 'commonName': 'sensitive.example.com', - 'owner': 'jim@example.com', - 'authority': {'id': authority.id}, - 'description': 'testtestest', - 'validityStart': '2020-01-01T00:00:00', - 'validityEnd': '2020-01-01T00:00:01', - 'dnsProvider': None, + "commonName": "sensitive.example.com", + "owner": "jim@example.com", + "authority": {"id": authority.id}, + "description": "testtestest", + "validityStart": "2020-01-01T00:00:00", + "validityEnd": "2020-01-01T00:00:01", + "dnsProvider": None, } - session.add(Domain(name='sensitive.example.com', sensitive=True)) + session.add(Domain(name="sensitive.example.com", sensitive=True)) data, errors = CertificateInputSchema().load(input_data) - assert errors['common_name'][0].startswith("Domain sensitive.example.com has been marked as sensitive") + assert errors["common_name"][0].startswith( + "Domain sensitive.example.com has been marked as sensitive" + ) def test_certificate_upload_schema_ok(client): from lemur.certificates.schemas import CertificateUploadInputSchema + data = { - 'name': 'Jane', - 'owner': 'pwner@example.com', - 'body': SAN_CERT_STR, - 'privateKey': SAN_CERT_KEY, - 'chain': INTERMEDIATE_CERT_STR, - 'csr': SAN_CERT_CSR, - 'external_id': '1234', + "name": "Jane", + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "privateKey": SAN_CERT_KEY, + "chain": INTERMEDIATE_CERT_STR, + "csr": SAN_CERT_CSR, + "external_id": "1234", } data, errors = CertificateUploadInputSchema().load(data) assert not errors @@ -490,20 +523,19 @@ def test_certificate_upload_schema_ok(client): def test_certificate_upload_schema_minimal(client): from lemur.certificates.schemas import CertificateUploadInputSchema - data = { - 'owner': 'pwner@example.com', - 'body': SAN_CERT_STR, - } + + data = {"owner": "pwner@example.com", "body": SAN_CERT_STR} data, errors = CertificateUploadInputSchema().load(data) assert not errors def test_certificate_upload_schema_long_chain(client): from lemur.certificates.schemas import CertificateUploadInputSchema + data = { - 'owner': 'pwner@example.com', - 'body': SAN_CERT_STR, - 'chain': INTERMEDIATE_CERT_STR + '\n' + ROOTCA_CERT_STR + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "chain": INTERMEDIATE_CERT_STR + "\n" + ROOTCA_CERT_STR, } data, errors = CertificateUploadInputSchema().load(data) assert not errors @@ -511,87 +543,106 @@ def test_certificate_upload_schema_long_chain(client): def test_certificate_upload_schema_invalid_body(client): from lemur.certificates.schemas import CertificateUploadInputSchema + data = { - 'owner': 'pwner@example.com', - 'body': 'Hereby I certify that this is a valid body', + "owner": "pwner@example.com", + "body": "Hereby I certify that this is a valid body", } data, errors = CertificateUploadInputSchema().load(data) - assert errors == {'body': ['Public certificate presented is not valid.']} + assert errors == {"body": ["Public certificate presented is not valid."]} def test_certificate_upload_schema_invalid_pkey(client): from lemur.certificates.schemas import CertificateUploadInputSchema + data = { - 'owner': 'pwner@example.com', - 'body': SAN_CERT_STR, - 'privateKey': 'Look at me Im a private key!!111', + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "privateKey": "Look at me Im a private key!!111", } data, errors = CertificateUploadInputSchema().load(data) - assert errors == {'private_key': ['Private key presented is not valid.']} + assert errors == {"private_key": ["Private key presented is not valid."]} def test_certificate_upload_schema_invalid_chain(client): from lemur.certificates.schemas import CertificateUploadInputSchema - data = { - 'body': SAN_CERT_STR, - 'chain': 'CHAINSAW', - 'owner': 'pwner@example.com', - } + + data = {"body": SAN_CERT_STR, "chain": "CHAINSAW", "owner": "pwner@example.com"} data, errors = CertificateUploadInputSchema().load(data) - assert errors == {'chain': ['Invalid certificate in certificate chain.']} + assert errors == {"chain": ["Invalid certificate in certificate chain."]} def test_certificate_upload_schema_wrong_pkey(client): from lemur.certificates.schemas import CertificateUploadInputSchema + data = { - 'body': SAN_CERT_STR, - 'privateKey': ROOTCA_KEY, - 'chain': INTERMEDIATE_CERT_STR, - 'owner': 'pwner@example.com', + "body": SAN_CERT_STR, + "privateKey": ROOTCA_KEY, + "chain": INTERMEDIATE_CERT_STR, + "owner": "pwner@example.com", } data, errors = CertificateUploadInputSchema().load(data) - assert errors == {'_schema': ['Private key does not match certificate.']} + assert errors == {"_schema": ["Private key does not match certificate."]} def test_certificate_upload_schema_wrong_chain(client): from lemur.certificates.schemas import CertificateUploadInputSchema + data = { - 'owner': 'pwner@example.com', - 'body': SAN_CERT_STR, - 'chain': ROOTCA_CERT_STR, + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "chain": ROOTCA_CERT_STR, } data, errors = CertificateUploadInputSchema().load(data) - assert errors == {'_schema': ["Incorrect chain certificate(s) provided: 'san.example.org' is not signed by " - "'LemurTrust Unittests Root CA 2018'"]} + assert errors == { + "_schema": [ + "Incorrect chain certificate(s) provided: 'san.example.org' is not signed by " + "'LemurTrust Unittests Root CA 2018'" + ] + } def test_certificate_upload_schema_wrong_chain_2nd(client): from lemur.certificates.schemas import CertificateUploadInputSchema + data = { - 'owner': 'pwner@example.com', - 'body': SAN_CERT_STR, - 'chain': INTERMEDIATE_CERT_STR + '\n' + SAN_CERT_STR, + "owner": "pwner@example.com", + "body": SAN_CERT_STR, + "chain": INTERMEDIATE_CERT_STR + "\n" + SAN_CERT_STR, } data, errors = CertificateUploadInputSchema().load(data) - assert errors == {'_schema': ["Incorrect chain certificate(s) provided: 'LemurTrust Unittests Class 1 CA 2018' is " - "not signed by 'san.example.org'"]} + assert errors == { + "_schema": [ + "Incorrect chain certificate(s) provided: 'LemurTrust Unittests Class 1 CA 2018' is " + "not signed by 'san.example.org'" + ] + } def test_create_basic_csr(client): csr_config = dict( - common_name='example.com', - organization='Example, Inc.', - organizational_unit='Operations', - country='US', - state='CA', - location='A place', - owner='joe@example.com', - key_type='RSA2048', - extensions=dict(names=dict(sub_alt_names=x509.SubjectAlternativeName([x509.DNSName('test.example.com'), x509.DNSName('test2.example.com')]))) + common_name="example.com", + organization="Example, Inc.", + organizational_unit="Operations", + country="US", + state="CA", + location="A place", + owner="joe@example.com", + key_type="RSA2048", + extensions=dict( + names=dict( + sub_alt_names=x509.SubjectAlternativeName( + [ + x509.DNSName("test.example.com"), + x509.DNSName("test2.example.com"), + ] + ) + ) + ), ) csr, pem = create_csr(**csr_config) - csr = x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend()) + csr = x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend()) for name in csr.subject: assert name.value in csr_config.values() @@ -603,13 +654,13 @@ def test_csr_empty_san(client): """ csr_text, pkey = create_csr( - common_name='daniel-san.example.com', - owner='daniel-san@example.com', - key_type='RSA2048', - extensions={'sub_alt_names': {'names': x509.SubjectAlternativeName([])}} + common_name="daniel-san.example.com", + owner="daniel-san@example.com", + key_type="RSA2048", + extensions={"sub_alt_names": {"names": x509.SubjectAlternativeName([])}}, ) - csr = x509.load_pem_x509_csr(csr_text.encode('utf-8'), default_backend()) + csr = x509.load_pem_x509_csr(csr_text.encode("utf-8"), default_backend()) with pytest.raises(x509.ExtensionNotFound): csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) @@ -620,13 +671,13 @@ def test_csr_disallowed_cn(client, logged_in_user): from lemur.common import validators request, pkey = create_csr( - common_name='evilhacker.org', - owner='joe@example.com', - key_type='RSA2048', + common_name="evilhacker.org", owner="joe@example.com", key_type="RSA2048" ) with pytest.raises(ValidationError) as err: validators.csr(request) - assert str(err.value).startswith('Domain evilhacker.org does not match whitelisted domain patterns') + assert str(err.value).startswith( + "Domain evilhacker.org does not match whitelisted domain patterns" + ) def test_csr_disallowed_san(client, logged_in_user): @@ -635,46 +686,71 @@ def test_csr_disallowed_san(client, logged_in_user): request, pkey = create_csr( common_name="CN with spaces isn't a domain and is thus allowed", - owner='joe@example.com', - key_type='RSA2048', - extensions={'sub_alt_names': {'names': x509.SubjectAlternativeName([x509.DNSName('evilhacker.org')])}} + owner="joe@example.com", + key_type="RSA2048", + extensions={ + "sub_alt_names": { + "names": x509.SubjectAlternativeName([x509.DNSName("evilhacker.org")]) + } + }, ) with pytest.raises(ValidationError) as err: validators.csr(request) - assert str(err.value).startswith('Domain evilhacker.org does not match whitelisted domain patterns') + assert str(err.value).startswith( + "Domain evilhacker.org does not match whitelisted domain patterns" + ) def test_get_name_from_arn(client): from lemur.certificates.service import get_name_from_arn - arn = 'arn:aws:iam::11111111:server-certificate/mycertificate' - assert get_name_from_arn(arn) == 'mycertificate' + + arn = "arn:aws:iam::11111111:server-certificate/mycertificate" + assert get_name_from_arn(arn) == "mycertificate" def test_get_account_number(client): from lemur.certificates.service import get_account_number - arn = 'arn:aws:iam::11111111:server-certificate/mycertificate' - assert get_account_number(arn) == '11111111' + + arn = "arn:aws:iam::11111111:server-certificate/mycertificate" + assert get_account_number(arn) == "11111111" def test_mint_certificate(issuer_plugin, authority): from lemur.certificates.service import mint - cert_body, private_key, chain, external_id, csr = mint(authority=authority, csr=CSR_STR) + + cert_body, private_key, chain, external_id, csr = mint( + authority=authority, csr=CSR_STR + ) assert cert_body == SAN_CERT_STR def test_create_certificate(issuer_plugin, authority, user): from lemur.certificates.service import create - cert = create(authority=authority, csr=CSR_STR, owner='joe@example.com', creator=user['user']) - assert str(cert.not_after) == '2047-12-31T22:00:00+00:00' - assert str(cert.not_before) == '2017-12-31T22:00:00+00:00' - assert cert.issuer == 'LemurTrustUnittestsClass1CA2018' - assert cert.name == 'SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231-AFF2DB4F8D2D4D8E80FA382AE27C2333' - cert = create(authority=authority, csr=CSR_STR, owner='joe@example.com', name='ACustomName1', creator=user['user']) - assert cert.name == 'ACustomName1' + cert = create( + authority=authority, csr=CSR_STR, owner="joe@example.com", creator=user["user"] + ) + assert str(cert.not_after) == "2047-12-31T22:00:00+00:00" + assert str(cert.not_before) == "2017-12-31T22:00:00+00:00" + assert cert.issuer == "LemurTrustUnittestsClass1CA2018" + assert ( + cert.name + == "SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231-AFF2DB4F8D2D4D8E80FA382AE27C2333" + ) + + cert = create( + authority=authority, + csr=CSR_STR, + owner="joe@example.com", + name="ACustomName1", + creator=user["user"], + ) + assert cert.name == "ACustomName1" -def test_reissue_certificate(issuer_plugin, crypto_authority, certificate, logged_in_user): +def test_reissue_certificate( + issuer_plugin, crypto_authority, certificate, logged_in_user +): from lemur.certificates.service import reissue_certificate # test-authority would return a mismatching private key, so use 'cryptography-issuer' plugin instead. @@ -684,286 +760,511 @@ def test_reissue_certificate(issuer_plugin, crypto_authority, certificate, logge def test_create_csr(): - csr, private_key = create_csr(owner='joe@example.com', common_name='ACommonName', organization='test', organizational_unit='Meters', country='US', - state='CA', location='Here', key_type='RSA2048') + csr, private_key = create_csr( + owner="joe@example.com", + common_name="ACommonName", + organization="test", + organizational_unit="Meters", + country="US", + state="CA", + location="Here", + key_type="RSA2048", + ) assert csr assert private_key - extensions = {'sub_alt_names': {'names': x509.SubjectAlternativeName([x509.DNSName('AnotherCommonName')])}} - csr, private_key = create_csr(owner='joe@example.com', common_name='ACommonName', organization='test', organizational_unit='Meters', country='US', - state='CA', location='Here', extensions=extensions, key_type='RSA2048') + extensions = { + "sub_alt_names": { + "names": x509.SubjectAlternativeName([x509.DNSName("AnotherCommonName")]) + } + } + csr, private_key = create_csr( + owner="joe@example.com", + common_name="ACommonName", + organization="test", + organizational_unit="Meters", + country="US", + state="CA", + location="Here", + extensions=extensions, + key_type="RSA2048", + ) assert csr assert private_key def test_import(user): from lemur.certificates.service import import_certificate - cert = import_certificate(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, creator=user['user']) - assert str(cert.not_after) == '2047-12-31T22:00:00+00:00' - assert str(cert.not_before) == '2017-12-31T22:00:00+00:00' - assert cert.issuer == 'LemurTrustUnittestsClass1CA2018' - assert cert.name.startswith('SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231') - cert = import_certificate(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', name='ACustomName2', creator=user['user']) - assert cert.name == 'ACustomName2' + cert = import_certificate( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + creator=user["user"], + ) + assert str(cert.not_after) == "2047-12-31T22:00:00+00:00" + assert str(cert.not_before) == "2017-12-31T22:00:00+00:00" + assert cert.issuer == "LemurTrustUnittestsClass1CA2018" + assert cert.name.startswith( + "SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231" + ) + + cert = import_certificate( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + name="ACustomName2", + creator=user["user"], + ) + assert cert.name == "ACustomName2" @pytest.mark.skip def test_upload(user): from lemur.certificates.service import upload - cert = upload(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', creator=user['user']) - assert str(cert.not_after) == '2040-01-01T20:30:52+00:00' - assert str(cert.not_before) == '2015-06-26T20:30:52+00:00' - assert cert.issuer == 'Example' - assert cert.name == 'long.lived.com-Example-20150626-20400101-3' - cert = upload(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', name='ACustomName', creator=user['user']) - assert 'ACustomName' in cert.name + cert = upload( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + creator=user["user"], + ) + assert str(cert.not_after) == "2040-01-01T20:30:52+00:00" + assert str(cert.not_before) == "2015-06-26T20:30:52+00:00" + assert cert.issuer == "Example" + assert cert.name == "long.lived.com-Example-20150626-20400101-3" + + cert = upload( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + name="ACustomName", + creator=user["user"], + ) + assert "ACustomName" in cert.name # verify upload with a private key as a str def test_upload_private_key_str(user): from lemur.certificates.service import upload - cert = upload(body=SAN_CERT_STR, chain=INTERMEDIATE_CERT_STR, private_key=SAN_CERT_KEY, owner='joe@example.com', name='ACustomName', creator=user['user']) + + cert = upload( + body=SAN_CERT_STR, + chain=INTERMEDIATE_CERT_STR, + private_key=SAN_CERT_KEY, + owner="joe@example.com", + name="ACustomName", + creator=user["user"], + ) assert cert -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificate_get_private_key(client, token, status): - assert client.get(api.url_for(Certificates, certificate_id=1), headers=token).status_code == status + assert ( + client.get( + api.url_for(Certificates, certificate_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificate_get(client, token, status): - assert client.get(api.url_for(Certificates, certificate_id=1), headers=token).status_code == status + assert ( + client.get( + api.url_for(Certificates, certificate_id=1), headers=token + ).status_code + == status + ) def test_certificate_get_body(client): - response_body = client.get(api.url_for(Certificates, certificate_id=1), headers=VALID_USER_HEADER_TOKEN).json - assert response_body['serial'] == '211983098819107449768450703123665283596' - assert response_body['serialHex'] == '9F7A75B39DAE4C3F9524C68B06DA6A0C' - assert response_body['distinguishedName'] == ('CN=LemurTrust Unittests Class 1 CA 2018,' - 'O=LemurTrust Enterprises Ltd,' - 'OU=Unittesting Operations Center,' - 'C=EE,' - 'ST=N/A,' - 'L=Earth') + response_body = client.get( + api.url_for(Certificates, certificate_id=1), headers=VALID_USER_HEADER_TOKEN + ).json + assert response_body["serial"] == "211983098819107449768450703123665283596" + assert response_body["serialHex"] == "9F7A75B39DAE4C3F9524C68B06DA6A0C" + assert response_body["distinguishedName"] == ( + "CN=LemurTrust Unittests Class 1 CA 2018," + "O=LemurTrust Enterprises Ltd," + "OU=Unittesting Operations Center," + "C=EE," + "ST=N/A," + "L=Earth" + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_post(client, token, status): - assert client.post(api.url_for(Certificates, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Certificates, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificate_put(client, token, status): - assert client.put(api.url_for(Certificates, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Certificates, certificate_id=1), data={}, headers=token + ).status_code + == status + ) def test_certificate_put_with_data(client, certificate, issuer_plugin): - resp = client.put(api.url_for(Certificates, certificate_id=certificate.id), data=json.dumps({'owner': 'bob@example.com', 'description': 'test', 'notify': True}), headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.put( + api.url_for(Certificates, certificate_id=certificate.id), + data=json.dumps( + {"owner": "bob@example.com", "description": "test", "notify": True} + ), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 204), - (VALID_ADMIN_API_TOKEN, 412), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 204), + (VALID_ADMIN_API_TOKEN, 412), + ("", 401), + ], +) def test_certificate_delete(client, token, status): - assert client.delete(api.url_for(Certificates, certificate_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Certificates, certificate_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 204), - (VALID_ADMIN_API_TOKEN, 204), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 204), + (VALID_ADMIN_API_TOKEN, 204), + ("", 401), + ], +) def test_invalid_certificate_delete(client, invalid_certificate, token, status): - assert client.delete( - api.url_for(Certificates, certificate_id=invalid_certificate.id), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Certificates, certificate_id=invalid_certificate.id), + headers=token, + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_patch(client, token, status): - assert client.patch(api.url_for(Certificates, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Certificates, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_certificates_get(client, token, status): - assert client.get(api.url_for(CertificatesList), headers=token).status_code == status + assert ( + client.get(api.url_for(CertificatesList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificates_post(client, token, status): - assert client.post(api.url_for(CertificatesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(CertificatesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_put(client, token, status): - assert client.put(api.url_for(CertificatesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(CertificatesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_delete(client, token, status): - assert client.delete(api.url_for(CertificatesList), headers=token).status_code == status + assert ( + client.delete(api.url_for(CertificatesList), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_patch(client, token, status): - assert client.patch(api.url_for(CertificatesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(CertificatesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_post(client, token, status): - assert client.post(api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_put(client, token, status): - assert client.put(api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_delete(client, token, status): - assert client.delete(api.url_for(CertificatePrivateKey, certificate_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(CertificatePrivateKey, certificate_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificate_credentials_patch(client, token, status): - assert client.patch(api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(CertificatePrivateKey, certificate_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_get(client, token, status): - assert client.get(api.url_for(CertificatesUpload), headers=token).status_code == status + assert ( + client.get(api.url_for(CertificatesUpload), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_certificates_upload_post(client, token, status): - assert client.post(api.url_for(CertificatesUpload), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(CertificatesUpload), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_put(client, token, status): - assert client.put(api.url_for(CertificatesUpload), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(CertificatesUpload), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_delete(client, token, status): - assert client.delete(api.url_for(CertificatesUpload), headers=token).status_code == status + assert ( + client.delete(api.url_for(CertificatesUpload), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_certificates_upload_patch(client, token, status): - assert client.patch(api.url_for(CertificatesUpload), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(CertificatesUpload), data={}, headers=token + ).status_code + == status + ) def test_sensitive_sort(client): - resp = client.get(api.url_for(CertificatesList) + '?sortBy=private_key&sortDir=asc', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'private_key' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(CertificatesList) + "?sortBy=private_key&sortDir=asc", + headers=VALID_ADMIN_HEADER_TOKEN, + ) + assert "'private_key' is not sortable or filterable" in resp.json["message"] def test_boolean_filter(client): - resp = client.get(api.url_for(CertificatesList) + '?filter=notify;true', headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.get( + api.url_for(CertificatesList) + "?filter=notify;true", + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 # Also don't crash with invalid input (we currently treat that as false) - resp = client.get(api.url_for(CertificatesList) + '?filter=notify;whatisthis', headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.get( + api.url_for(CertificatesList) + "?filter=notify;whatisthis", + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 diff --git a/lemur/tests/test_defaults.py b/lemur/tests/test_defaults.py index da9d6c79..b8daa575 100644 --- a/lemur/tests/test_defaults.py +++ b/lemur/tests/test_defaults.py @@ -8,14 +8,18 @@ from .vectors import SAN_CERT, WILDCARD_CERT, INTERMEDIATE_CERT def test_cert_get_cn(client): from lemur.common.defaults import common_name - assert common_name(SAN_CERT) == 'san.example.org' + assert common_name(SAN_CERT) == "san.example.org" def test_cert_sub_alt_domains(client): from lemur.common.defaults import domains assert domains(INTERMEDIATE_CERT) == [] - assert domains(SAN_CERT) == ['san.example.org', 'san2.example.org', 'daniel-san.example.org'] + assert domains(SAN_CERT) == [ + "san.example.org", + "san2.example.org", + "daniel-san.example.org", + ] def test_cert_is_san(client): @@ -28,94 +32,119 @@ def test_cert_is_san(client): def test_cert_is_wildcard(client): from lemur.common.defaults import is_wildcard + assert is_wildcard(WILDCARD_CERT) assert not is_wildcard(INTERMEDIATE_CERT) def test_cert_bitstrength(client): from lemur.common.defaults import bitstrength + assert bitstrength(INTERMEDIATE_CERT) == 2048 def test_cert_issuer(client): from lemur.common.defaults import issuer - assert issuer(INTERMEDIATE_CERT) == 'LemurTrustUnittestsRootCA2018' + + assert issuer(INTERMEDIATE_CERT) == "LemurTrustUnittestsRootCA2018" def test_text_to_slug(client): from lemur.common.defaults import text_to_slug - assert text_to_slug('test - string') == 'test-string' - assert text_to_slug('test - string', '') == 'teststring' + + assert text_to_slug("test - string") == "test-string" + assert text_to_slug("test - string", "") == "teststring" # Accented characters are decomposed - assert text_to_slug('föö bär') == 'foo-bar' + assert text_to_slug("föö bär") == "foo-bar" # Melt away the Unicode Snowman - assert text_to_slug('\u2603') == '' - assert text_to_slug('\u2603test\u2603') == 'test' - assert text_to_slug('snow\u2603man') == 'snow-man' - assert text_to_slug('snow\u2603man', '') == 'snowman' + assert text_to_slug("\u2603") == "" + assert text_to_slug("\u2603test\u2603") == "test" + assert text_to_slug("snow\u2603man") == "snow-man" + assert text_to_slug("snow\u2603man", "") == "snowman" # IDNA-encoded domain names should be kept as-is - assert text_to_slug('xn--i1b6eqas.xn--xmpl-loa9b3671b.com') == 'xn--i1b6eqas.xn--xmpl-loa9b3671b.com' + assert ( + text_to_slug("xn--i1b6eqas.xn--xmpl-loa9b3671b.com") + == "xn--i1b6eqas.xn--xmpl-loa9b3671b.com" + ) def test_create_name(client): from lemur.common.defaults import certificate_name from datetime import datetime - assert certificate_name( - 'example.com', - 'Example Inc,', - datetime(2015, 5, 7, 0, 0, 0), - datetime(2015, 5, 12, 0, 0, 0), - False - ) == 'example.com-ExampleInc-20150507-20150512' - assert certificate_name( - 'example.com', - 'Example Inc,', - datetime(2015, 5, 7, 0, 0, 0), - datetime(2015, 5, 12, 0, 0, 0), - True - ) == 'SAN-example.com-ExampleInc-20150507-20150512' - assert certificate_name( - 'xn--mnchen-3ya.de', - 'Vertrauenswürdig Autorität', - datetime(2015, 5, 7, 0, 0, 0), - datetime(2015, 5, 12, 0, 0, 0), - False - ) == 'xn--mnchen-3ya.de-VertrauenswurdigAutoritat-20150507-20150512' - assert certificate_name( - 'selfie.example.org', - '', - datetime(2015, 5, 7, 0, 0, 0), - datetime(2025, 5, 12, 13, 37, 0), - False - ) == 'selfie.example.org-selfsigned-20150507-20250512' + + assert ( + certificate_name( + "example.com", + "Example Inc,", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2015, 5, 12, 0, 0, 0), + False, + ) + == "example.com-ExampleInc-20150507-20150512" + ) + assert ( + certificate_name( + "example.com", + "Example Inc,", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2015, 5, 12, 0, 0, 0), + True, + ) + == "SAN-example.com-ExampleInc-20150507-20150512" + ) + assert ( + certificate_name( + "xn--mnchen-3ya.de", + "Vertrauenswürdig Autorität", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2015, 5, 12, 0, 0, 0), + False, + ) + == "xn--mnchen-3ya.de-VertrauenswurdigAutoritat-20150507-20150512" + ) + assert ( + certificate_name( + "selfie.example.org", + "", + datetime(2015, 5, 7, 0, 0, 0), + datetime(2025, 5, 12, 13, 37, 0), + False, + ) + == "selfie.example.org-selfsigned-20150507-20250512" + ) def test_issuer(client, cert_builder, issuer_private_key): from lemur.common.defaults import issuer - assert issuer(INTERMEDIATE_CERT) == 'LemurTrustUnittestsRootCA2018' + assert issuer(INTERMEDIATE_CERT) == "LemurTrustUnittestsRootCA2018" # We need to override builder's issuer name cert_builder._issuer_name = None # Unicode issuer name - cert = (cert_builder - .issuer_name(x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, 'Vertrauenswürdig Autorität')])) - .sign(issuer_private_key, hashes.SHA256(), default_backend())) - assert issuer(cert) == 'VertrauenswurdigAutoritat' + cert = cert_builder.issuer_name( + x509.Name( + [x509.NameAttribute(x509.NameOID.COMMON_NAME, "Vertrauenswürdig Autorität")] + ) + ).sign(issuer_private_key, hashes.SHA256(), default_backend()) + assert issuer(cert) == "VertrauenswurdigAutoritat" # Fallback to 'Organization' field when issuer CN is missing - cert = (cert_builder - .issuer_name(x509.Name([x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, 'No Such Organization')])) - .sign(issuer_private_key, hashes.SHA256(), default_backend())) - assert issuer(cert) == 'NoSuchOrganization' + cert = cert_builder.issuer_name( + x509.Name( + [x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "No Such Organization")] + ) + ).sign(issuer_private_key, hashes.SHA256(), default_backend()) + assert issuer(cert) == "NoSuchOrganization" # Missing issuer name - cert = (cert_builder - .issuer_name(x509.Name([])) - .sign(issuer_private_key, hashes.SHA256(), default_backend())) - assert issuer(cert) == '' + cert = cert_builder.issuer_name(x509.Name([])).sign( + issuer_private_key, hashes.SHA256(), default_backend() + ) + assert issuer(cert) == "" def test_issuer_selfsigned(selfsigned_cert): from lemur.common.defaults import issuer - assert issuer(selfsigned_cert) == '' + + assert issuer(selfsigned_cert) == "" diff --git a/lemur/tests/test_destinations.py b/lemur/tests/test_destinations.py index 11f03d9e..d17c703b 100644 --- a/lemur/tests/test_destinations.py +++ b/lemur/tests/test_destinations.py @@ -3,20 +3,22 @@ import pytest from lemur.destinations.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_destination_input_schema(client, destination_plugin, destination): from lemur.destinations.schemas import DestinationInputSchema input_data = { - 'label': 'destination1', - 'options': {}, - 'description': 'my destination', - 'active': True, - 'plugin': { - 'slug': 'test-destination' - } + "label": "destination1", + "options": {}, + "description": "my destination", + "active": True, + "plugin": {"slug": "test-destination"}, } data, errors = DestinationInputSchema().load(input_data) @@ -24,91 +26,154 @@ def test_destination_input_schema(client, destination_plugin, destination): assert not errors -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 404), - (VALID_ADMIN_HEADER_TOKEN, 404), - (VALID_ADMIN_API_TOKEN, 404), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 404), + (VALID_ADMIN_HEADER_TOKEN, 404), + (VALID_ADMIN_API_TOKEN, 404), + ("", 401), + ], +) def test_destination_get(client, token, status): - assert client.get(api.url_for(Destinations, destination_id=1), headers=token).status_code == status + assert ( + client.get( + api.url_for(Destinations, destination_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_post_(client, token, status): - assert client.post(api.url_for(Destinations, destination_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Destinations, destination_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_destination_put(client, token, status): - assert client.put(api.url_for(Destinations, destination_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Destinations, destination_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_destination_delete(client, token, status): - assert client.delete(api.url_for(Destinations, destination_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Destinations, destination_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_patch(client, token, status): - assert client.patch(api.url_for(Destinations, destination_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Destinations, destination_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_destination_list_post_(client, token, status): - assert client.post(api.url_for(DestinationsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(DestinationsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_destination_list_get(client, token, status): - assert client.get(api.url_for(DestinationsList), headers=token).status_code == status + assert ( + client.get(api.url_for(DestinationsList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_list_delete(client, token, status): - assert client.delete(api.url_for(DestinationsList), headers=token).status_code == status + assert ( + client.delete(api.url_for(DestinationsList), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_destination_list_patch(client, token, status): - assert client.patch(api.url_for(DestinationsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(DestinationsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_domains.py b/lemur/tests/test_domains.py index 873412b2..47023f8c 100644 --- a/lemur/tests/test_domains.py +++ b/lemur/tests/test_domains.py @@ -3,94 +3,152 @@ import pytest from lemur.domains.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_domain_get(client, token, status): - assert client.get(api.url_for(Domains, domain_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Domains, domain_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_post_(client, token, status): - assert client.post(api.url_for(Domains, domain_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Domains, domain_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_domain_put(client, token, status): - assert client.put(api.url_for(Domains, domain_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Domains, domain_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_delete(client, token, status): - assert client.delete(api.url_for(Domains, domain_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Domains, domain_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_patch(client, token, status): - assert client.patch(api.url_for(Domains, domain_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Domains, domain_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_domain_list_post_(client, token, status): - assert client.post(api.url_for(DomainsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(DomainsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_domain_list_get(client, token, status): assert client.get(api.url_for(DomainsList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_list_delete(client, token, status): assert client.delete(api.url_for(DomainsList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_domain_list_patch(client, token, status): - assert client.patch(api.url_for(DomainsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(DomainsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_endpoints.py b/lemur/tests/test_endpoints.py index 4ea0a4aa..af073e53 100644 --- a/lemur/tests/test_endpoints.py +++ b/lemur/tests/test_endpoints.py @@ -4,11 +4,16 @@ from lemur.endpoints.views import * # noqa from lemur.tests.factories import EndpointFactory, CertificateFactory -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_rotate_certificate(client, source_plugin): from lemur.deployment.service import rotate_certificate + new_certificate = CertificateFactory() endpoint = EndpointFactory() @@ -16,91 +21,147 @@ def test_rotate_certificate(client, source_plugin): assert endpoint.certificate == new_certificate -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 404), - (VALID_ADMIN_HEADER_TOKEN, 404), - (VALID_ADMIN_API_TOKEN, 404), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 404), + (VALID_ADMIN_HEADER_TOKEN, 404), + (VALID_ADMIN_API_TOKEN, 404), + ("", 401), + ], +) def test_endpoint_get(client, token, status): - assert client.get(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_post_(client, token, status): - assert client.post(api.url_for(Endpoints, endpoint_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Endpoints, endpoint_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_put(client, token, status): - assert client.put(api.url_for(Endpoints, endpoint_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Endpoints, endpoint_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_delete(client, token, status): - assert client.delete(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Endpoints, endpoint_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_patch(client, token, status): - assert client.patch(api.url_for(Endpoints, endpoint_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Endpoints, endpoint_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_list_post_(client, token, status): - assert client.post(api.url_for(EndpointsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(EndpointsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_endpoint_list_get(client, token, status): assert client.get(api.url_for(EndpointsList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_list_delete(client, token, status): - assert client.delete(api.url_for(EndpointsList), headers=token).status_code == status + assert ( + client.delete(api.url_for(EndpointsList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_endpoint_list_patch(client, token, status): - assert client.patch(api.url_for(EndpointsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(EndpointsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_ldap.py b/lemur/tests/test_ldap.py index a636afdc..8e4027a9 100644 --- a/lemur/tests/test_ldap.py +++ b/lemur/tests/test_ldap.py @@ -1,51 +1,69 @@ import pytest -from lemur.auth.ldap import * # noqa +from lemur.auth.ldap import * # noqa from mock import patch, MagicMock class LdapPrincipalTester(LdapPrincipal): - def __init__(self, args): super().__init__(args) - self.ldap_server = 'ldap://localhost' + self.ldap_server = "ldap://localhost" def bind_test(self): - groups = [('user', {'memberOf': ['CN=Lemur Access,OU=Groups,DC=example,DC=com'.encode('utf-8'), - 'CN=Pen Pushers,OU=Groups,DC=example,DC=com'.encode('utf-8')]})] + groups = [ + ( + "user", + { + "memberOf": [ + "CN=Lemur Access,OU=Groups,DC=example,DC=com".encode("utf-8"), + "CN=Pen Pushers,OU=Groups,DC=example,DC=com".encode("utf-8"), + ] + }, + ) + ] self.ldap_client = MagicMock() self.ldap_client.search_s.return_value = groups self._bind() def authorize_test_groups_to_roles_admin(self): - self.ldap_groups = ''.join(['CN=Pen Pushers,OU=Groups,DC=example,DC=com', - 'CN=Lemur Admins,OU=Groups,DC=example,DC=com', - 'CN=Lemur Read Only,OU=Groups,DC=example,DC=com']) + self.ldap_groups = "".join( + [ + "CN=Pen Pushers,OU=Groups,DC=example,DC=com", + "CN=Lemur Admins,OU=Groups,DC=example,DC=com", + "CN=Lemur Read Only,OU=Groups,DC=example,DC=com", + ] + ) self.ldap_required_group = None - self.ldap_groups_to_roles = {'Lemur Admins': 'admin', 'Lemur Read Only': 'read-only'} + self.ldap_groups_to_roles = { + "Lemur Admins": "admin", + "Lemur Read Only": "read-only", + } return self._authorize() def authorize_test_required_group(self, group): - self.ldap_groups = ''.join(['CN=Lemur Access,OU=Groups,DC=example,DC=com', - 'CN=Pen Pushers,OU=Groups,DC=example,DC=com']) + self.ldap_groups = "".join( + [ + "CN=Lemur Access,OU=Groups,DC=example,DC=com", + "CN=Pen Pushers,OU=Groups,DC=example,DC=com", + ] + ) self.ldap_required_group = group return self._authorize() @pytest.fixture() def principal(session): - args = {'username': 'user', 'password': 'p4ssw0rd'} + args = {"username": "user", "password": "p4ssw0rd"} yield LdapPrincipalTester(args) class TestLdapPrincipal: - - @patch('ldap.initialize') + @patch("ldap.initialize") def test_bind(self, app, principal): self.test_ldap_user = principal self.test_ldap_user.bind_test() - group = 'Pen Pushers' + group = "Pen Pushers" assert group in self.test_ldap_user.ldap_groups - assert self.test_ldap_user.ldap_principal == 'user@example.com' + assert self.test_ldap_user.ldap_principal == "user@example.com" def test_authorize_groups_to_roles_admin(self, app, principal): self.test_ldap_user = principal @@ -54,11 +72,11 @@ class TestLdapPrincipal: def test_authorize_required_group_missing(self, app, principal): self.test_ldap_user = principal - roles = self.test_ldap_user.authorize_test_required_group('Not Allowed') + roles = self.test_ldap_user.authorize_test_required_group("Not Allowed") assert not roles def test_authorize_required_group_access(self, session, principal): self.test_ldap_user = principal - roles = self.test_ldap_user.authorize_test_required_group('Lemur Access') + roles = self.test_ldap_user.authorize_test_required_group("Lemur Access") assert len(roles) >= 1 assert any(x.name == "user@example.com" for x in roles) diff --git a/lemur/tests/test_logs.py b/lemur/tests/test_logs.py index 516f5bb7..6705ffca 100644 --- a/lemur/tests/test_logs.py +++ b/lemur/tests/test_logs.py @@ -1,21 +1,32 @@ import pytest -from lemur.tests.vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from lemur.tests.vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) from lemur.logs.views import * # noqa def test_private_key_audit(client, certificate): from lemur.certificates.views import CertificatePrivateKey, api + assert len(certificate.logs) == 0 - client.get(api.url_for(CertificatePrivateKey, certificate_id=certificate.id), headers=VALID_ADMIN_HEADER_TOKEN) + client.get( + api.url_for(CertificatePrivateKey, certificate_id=certificate.id), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert len(certificate.logs) == 1 -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_get_logs(client, token, status): assert client.get(api.url_for(LogsList), headers=token).status_code == status diff --git a/lemur/tests/test_messaging.py b/lemur/tests/test_messaging.py index fc0e62da..98e9ebf3 100644 --- a/lemur/tests/test_messaging.py +++ b/lemur/tests/test_messaging.py @@ -8,14 +8,21 @@ from moto import mock_ses def test_needs_notification(app, certificate, notification): from lemur.notifications.messaging import needs_notification + assert not needs_notification(certificate) with pytest.raises(Exception): - notification.options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'min'}] + notification.options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "min"}, + ] certificate.notifications.append(notification) needs_notification(certificate) - certificate.notifications[0].options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + certificate.notifications[0].options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ] assert not needs_notification(certificate) delta = certificate.not_after - timedelta(days=10) @@ -30,7 +37,8 @@ def test_get_certificates(app, certificate, notification): delta = certificate.not_after - timedelta(days=2) notification.options = [ - {'name': 'interval', 'value': 2}, {'name': 'unit', 'value': 'days'} + {"name": "interval", "value": 2}, + {"name": "unit", "value": "days"}, ] with freeze_time(delta.datetime): @@ -55,11 +63,16 @@ def test_get_eligible_certificates(app, certificate, notification): from lemur.notifications.messaging import get_eligible_certificates certificate.notifications.append(notification) - certificate.notifications[0].options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + certificate.notifications[0].options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ] delta = certificate.not_after - timedelta(days=10) with freeze_time(delta.datetime): - assert get_eligible_certificates() == {certificate.owner: {notification.label: [(notification, certificate)]}} + assert get_eligible_certificates() == { + certificate.owner: {notification.label: [(notification, certificate)]} + } @mock_ses @@ -67,7 +80,10 @@ def test_send_expiration_notification(certificate, notification, notification_pl from lemur.notifications.messaging import send_expiration_notifications certificate.notifications.append(notification) - certificate.notifications[0].options = [{'name': 'interval', 'value': 10}, {'name': 'unit', 'value': 'days'}] + certificate.notifications[0].options = [ + {"name": "interval", "value": 10}, + {"name": "unit", "value": "days"}, + ] delta = certificate.not_after - timedelta(days=10) with freeze_time(delta.datetime): @@ -75,7 +91,9 @@ def test_send_expiration_notification(certificate, notification, notification_pl @mock_ses -def test_send_expiration_notification_with_no_notifications(certificate, notification, notification_plugin): +def test_send_expiration_notification_with_no_notifications( + certificate, notification, notification_plugin +): from lemur.notifications.messaging import send_expiration_notifications delta = certificate.not_after - timedelta(days=10) @@ -86,4 +104,5 @@ def test_send_expiration_notification_with_no_notifications(certificate, notific @mock_ses def test_send_rotation_notification(notification_plugin, certificate): from lemur.notifications.messaging import send_rotation_notification + send_rotation_notification(certificate, notification_plugin=notification_plugin) diff --git a/lemur/tests/test_missing.py b/lemur/tests/test_missing.py index 4f2c20c6..be615ced 100644 --- a/lemur/tests/test_missing.py +++ b/lemur/tests/test_missing.py @@ -9,9 +9,12 @@ def test_convert_validity_years(session): with freeze_time("2016-01-01"): data = convert_validity_years(dict(validity_years=2)) - assert data['validity_start'] == arrow.utcnow().isoformat() - assert data['validity_end'] == arrow.utcnow().replace(years=+2).isoformat() + assert data["validity_start"] == arrow.utcnow().isoformat() + assert data["validity_end"] == arrow.utcnow().replace(years=+2).isoformat() with freeze_time("2015-01-10"): data = convert_validity_years(dict(validity_years=1)) - assert data['validity_end'] == arrow.utcnow().replace(years=+1, days=-2).isoformat() + assert ( + data["validity_end"] + == arrow.utcnow().replace(years=+1, days=-2).isoformat() + ) diff --git a/lemur/tests/test_notifications.py b/lemur/tests/test_notifications.py index 6daee0a8..20079f97 100644 --- a/lemur/tests/test_notifications.py +++ b/lemur/tests/test_notifications.py @@ -3,20 +3,22 @@ import pytest from lemur.notifications.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_notification_input_schema(client, notification_plugin, notification): from lemur.notifications.schemas import NotificationInputSchema input_data = { - 'label': 'notification1', - 'options': {}, - 'description': 'my notification', - 'active': True, - 'plugin': { - 'slug': 'test-notification' - } + "label": "notification1", + "options": {}, + "description": "my notification", + "active": True, + "plugin": {"slug": "test-notification"}, } data, errors = NotificationInputSchema().load(input_data) @@ -24,91 +26,156 @@ def test_notification_input_schema(client, notification_plugin, notification): assert not errors -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_notification_get(client, notification_plugin, notification, token, status): - assert client.get(api.url_for(Notifications, notification_id=notification.id), headers=token).status_code == status + assert ( + client.get( + api.url_for(Notifications, notification_id=notification.id), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_post_(client, token, status): - assert client.post(api.url_for(Notifications, notification_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Notifications, notification_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_notification_put(client, token, status): - assert client.put(api.url_for(Notifications, notification_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Notifications, notification_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_notification_delete(client, token, status): - assert client.delete(api.url_for(Notifications, notification_id=1), headers=token).status_code == status + assert ( + client.delete( + api.url_for(Notifications, notification_id=1), headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_patch(client, token, status): - assert client.patch(api.url_for(Notifications, notification_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Notifications, notification_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_notification_list_post_(client, token, status): - assert client.post(api.url_for(NotificationsList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(NotificationsList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) -def test_notification_list_get(client, notification_plugin, notification, token, status): - assert client.get(api.url_for(NotificationsList), headers=token).status_code == status +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) +def test_notification_list_get( + client, notification_plugin, notification, token, status +): + assert ( + client.get(api.url_for(NotificationsList), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_list_delete(client, token, status): - assert client.delete(api.url_for(NotificationsList), headers=token).status_code == status + assert ( + client.delete(api.url_for(NotificationsList), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_notification_list_patch(client, token, status): - assert client.patch(api.url_for(NotificationsList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(NotificationsList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_pending_certificates.py b/lemur/tests/test_pending_certificates.py index 043002d3..3e755574 100644 --- a/lemur/tests/test_pending_certificates.py +++ b/lemur/tests/test_pending_certificates.py @@ -4,12 +4,19 @@ import pytest from marshmallow import ValidationError from lemur.pending_certificates.views import * # noqa -from .vectors import CSR_STR, INTERMEDIATE_CERT_STR, VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, \ - VALID_USER_HEADER_TOKEN, WILDCARD_CERT_STR +from .vectors import ( + CSR_STR, + INTERMEDIATE_CERT_STR, + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, + WILDCARD_CERT_STR, +) def test_increment_attempt(pending_certificate): from lemur.pending_certificates.service import increment_attempt + initial_attempt = pending_certificate.number_attempts attempts = increment_attempt(pending_certificate) assert attempts == initial_attempt + 1 @@ -17,50 +24,66 @@ def test_increment_attempt(pending_certificate): def test_create_pending_certificate(async_issuer_plugin, async_authority, user): from lemur.certificates.service import create - pending_cert = create(authority=async_authority, csr=CSR_STR, owner='joe@example.com', creator=user['user'], - common_name='ACommonName') - assert pending_cert.external_id == '12345' + + pending_cert = create( + authority=async_authority, + csr=CSR_STR, + owner="joe@example.com", + creator=user["user"], + common_name="ACommonName", + ) + assert pending_cert.external_id == "12345" def test_create_pending(pending_certificate, user, session): import copy from lemur.pending_certificates.service import create_certificate, get - cert = {'body': WILDCARD_CERT_STR, - 'chain': INTERMEDIATE_CERT_STR, - 'external_id': '54321'} + + cert = { + "body": WILDCARD_CERT_STR, + "chain": INTERMEDIATE_CERT_STR, + "external_id": "54321", + } # Weird copy because the session behavior. pending_certificate is a valid object but the # return of vars(pending_certificate) is a sessionobject, and so nothing from the pending_cert # is used to create the certificate. Maybe a bug due to using vars(), and should copy every # field explicitly. pending_certificate = copy.copy(get(pending_certificate.id)) - real_cert = create_certificate(pending_certificate, cert, user['user']) + real_cert = create_certificate(pending_certificate, cert, user["user"]) assert real_cert.owner == pending_certificate.owner assert real_cert.notify == pending_certificate.notify assert real_cert.private_key == pending_certificate.private_key - assert real_cert.external_id == '54321' + assert real_cert.external_id == "54321" -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 204), - (VALID_ADMIN_API_TOKEN, 204), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 204), + (VALID_ADMIN_API_TOKEN, 204), + ("", 401), + ], +) def test_pending_cancel(client, pending_certificate, token, status): - assert client.delete(api.url_for(PendingCertificates, pending_certificate_id=pending_certificate.id), - data=json.dumps({'note': "unit test", 'send_email': False}), - headers=token).status_code == status + assert ( + client.delete( + api.url_for( + PendingCertificates, pending_certificate_id=pending_certificate.id + ), + data=json.dumps({"note": "unit test", "send_email": False}), + headers=token, + ).status_code + == status + ) def test_pending_upload(pending_certificate_from_full_chain_ca): from lemur.pending_certificates.service import upload from lemur.certificates.service import get - cert = {'body': WILDCARD_CERT_STR, - 'chain': None, - 'external_id': None - } + cert = {"body": WILDCARD_CERT_STR, "chain": None, "external_id": None} pending_cert = upload(pending_certificate_from_full_chain_ca.id, **cert) assert pending_cert.resolved @@ -71,9 +94,10 @@ def test_pending_upload_with_chain(pending_certificate_from_partial_chain_ca): from lemur.pending_certificates.service import upload from lemur.certificates.service import get - cert = {'body': WILDCARD_CERT_STR, - 'chain': INTERMEDIATE_CERT_STR, - 'external_id': None + cert = { + "body": WILDCARD_CERT_STR, + "chain": INTERMEDIATE_CERT_STR, + "external_id": None, } pending_cert = upload(pending_certificate_from_partial_chain_ca.id, **cert) @@ -84,11 +108,9 @@ def test_pending_upload_with_chain(pending_certificate_from_partial_chain_ca): def test_invalid_pending_upload_with_chain(pending_certificate_from_partial_chain_ca): from lemur.pending_certificates.service import upload - cert = {'body': WILDCARD_CERT_STR, - 'chain': None, - 'external_id': None - } + cert = {"body": WILDCARD_CERT_STR, "chain": None, "external_id": None} with pytest.raises(ValidationError) as err: upload(pending_certificate_from_partial_chain_ca.id, **cert) assert str(err.value).startswith( - 'Incorrect chain certificate(s) provided: \'*.wild.example.org\' is not signed by \'LemurTrust Unittests Root CA 2018') + "Incorrect chain certificate(s) provided: '*.wild.example.org' is not signed by 'LemurTrust Unittests Root CA 2018" + ) diff --git a/lemur/tests/test_roles.py b/lemur/tests/test_roles.py index e5483e00..6e612062 100644 --- a/lemur/tests/test_roles.py +++ b/lemur/tests/test_roles.py @@ -3,16 +3,23 @@ import json import pytest from lemur.roles.views import * # noqa -from lemur.tests.factories import RoleFactory, AuthorityFactory, CertificateFactory, UserFactory -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from lemur.tests.factories import ( + RoleFactory, + AuthorityFactory, + CertificateFactory, + UserFactory, +) +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_role_input_schema(client): from lemur.roles.schemas import RoleInputSchema - input_data = { - 'name': 'myRole' - } + input_data = {"name": "myRole"} data, errors = RoleInputSchema().load(input_data) @@ -38,60 +45,80 @@ def test_multiple_authority_certificate_association(session, client): assert role.certificates[1].name == certificate1.name -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_get(client, token, status): - assert client.get(api.url_for(Roles, role_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Roles, role_id=1), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_post_(client, token, status): - assert client.post(api.url_for(Roles, role_id=1), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(Roles, role_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 400), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 400), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_role_put(client, token, status): - assert client.put(api.url_for(Roles, role_id=1), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(Roles, role_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_put_with_data(client, session, token, status): user = UserFactory() role = RoleFactory() session.commit() - data = { - 'users': [ - {'id': user.id} - ], - 'id': role.id, - 'name': role.name - } + data = {"users": [{"id": user.id}], "id": role.id, "name": role.name} - assert client.put(api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=token).status_code == status + assert ( + client.put( + api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=token + ).status_code + == status + ) def test_role_put_with_data_and_user(client, session): from lemur.auth.service import create_token + user = UserFactory() role = RoleFactory(users=[user]) role1 = RoleFactory() @@ -99,83 +126,119 @@ def test_role_put_with_data_and_user(client, session): session.commit() headers = { - 'Authorization': 'Basic ' + create_token(user), - 'Content-Type': 'application/json' + "Authorization": "Basic " + create_token(user), + "Content-Type": "application/json", } data = { - 'users': [ - {'id': user1.id}, - {'id': user.id} - ], - 'id': role.id, - 'name': role.name + "users": [{"id": user1.id}, {"id": user.id}], + "id": role.id, + "name": role.name, } - assert client.put(api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=headers).status_code == 200 - assert client.get(api.url_for(RolesList), data={}, headers=headers).json['total'] > 1 + assert ( + client.put( + api.url_for(Roles, role_id=role.id), data=json.dumps(data), headers=headers + ).status_code + == 200 + ) + assert ( + client.get(api.url_for(RolesList), data={}, headers=headers).json["total"] > 1 + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_delete(client, token, status, role): - assert client.delete(api.url_for(Roles, role_id=role.id), headers=token).status_code == status + assert ( + client.delete(api.url_for(Roles, role_id=role.id), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_patch(client, token, status): - assert client.patch(api.url_for(Roles, role_id=1), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(Roles, role_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_role_list_post_(client, token, status): - assert client.post(api.url_for(RolesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(RolesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_role_list_get(client, token, status): assert client.get(api.url_for(RolesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_list_delete(client, token, status): assert client.delete(api.url_for(RolesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_role_list_patch(client, token, status): - assert client.patch(api.url_for(RolesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(RolesList), data={}, headers=token).status_code + == status + ) def test_sensitive_filter(client): - resp = client.get(api.url_for(RolesList) + '?filter=password;a', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'password' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(RolesList) + "?filter=password;a", headers=VALID_ADMIN_HEADER_TOKEN + ) + assert "'password' is not sortable or filterable" in resp.json["message"] diff --git a/lemur/tests/test_schemas.py b/lemur/tests/test_schemas.py index e2a05213..2c085849 100644 --- a/lemur/tests/test_schemas.py +++ b/lemur/tests/test_schemas.py @@ -14,15 +14,15 @@ def test_get_object_attribute(): get_object_attribute([{}], many=True) with pytest.raises(ValidationError): - get_object_attribute([{}, {'id': 1}], many=True) + get_object_attribute([{}, {"id": 1}], many=True) with pytest.raises(ValidationError): - get_object_attribute([{}, {'name': 'test'}], many=True) + get_object_attribute([{}, {"name": "test"}], many=True) - assert get_object_attribute({'name': 'test'}) == 'name' - assert get_object_attribute({'id': 1}) == 'id' - assert get_object_attribute([{'name': 'test'}], many=True) == 'name' - assert get_object_attribute([{'id': 1}], many=True) == 'id' + assert get_object_attribute({"name": "test"}) == "name" + assert get_object_attribute({"id": 1}) == "id" + assert get_object_attribute([{"name": "test"}], many=True) == "name" + assert get_object_attribute([{"id": 1}], many=True) == "id" def test_fetch_objects(session): @@ -33,26 +33,26 @@ def test_fetch_objects(session): role1 = RoleFactory() session.commit() - data = {'id': role.id} + data = {"id": role.id} found_role = fetch_objects(Role, data) assert found_role == role - data = {'name': role.name} + data = {"name": role.name} found_role = fetch_objects(Role, data) assert found_role == role - data = [{'id': role.id}, {'id': role1.id}] + data = [{"id": role.id}, {"id": role1.id}] found_roles = fetch_objects(Role, data, many=True) assert found_roles == [role, role1] - data = [{'name': role.name}, {'name': role1.name}] + data = [{"name": role.name}, {"name": role1.name}] found_roles = fetch_objects(Role, data, many=True) assert found_roles == [role, role1] with pytest.raises(ValidationError): - data = [{'name': 'blah'}, {'name': role1.name}] + data = [{"name": "blah"}, {"name": role1.name}] fetch_objects(Role, data, many=True) with pytest.raises(ValidationError): - data = {'name': 'nah'} + data = {"name": "nah"} fetch_objects(Role, data) diff --git a/lemur/tests/test_sources.py b/lemur/tests/test_sources.py index 1ce0d9ba..312c008f 100644 --- a/lemur/tests/test_sources.py +++ b/lemur/tests/test_sources.py @@ -2,17 +2,22 @@ import pytest from lemur.sources.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN, WILDCARD_CERT_STR, \ - WILDCARD_CERT_KEY +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, + WILDCARD_CERT_STR, + WILDCARD_CERT_KEY, +) def validate_source_schema(client): from lemur.sources.schemas import SourceInputSchema input_data = { - 'label': 'exampleSource', - 'options': {}, - 'plugin': {'slug': 'aws-source'} + "label": "exampleSource", + "options": {}, + "plugin": {"slug": "aws-source"}, } data, errors = SourceInputSchema().load(input_data) @@ -26,111 +31,171 @@ def test_create_certificate(user, source): certificate_create({}, source) data = { - 'body': WILDCARD_CERT_STR, - 'private_key': WILDCARD_CERT_KEY, - 'owner': 'bob@example.com', - 'creator': user['user'] + "body": WILDCARD_CERT_STR, + "private_key": WILDCARD_CERT_KEY, + "owner": "bob@example.com", + "creator": user["user"], } cert = certificate_create(data, source) assert cert.notifications -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 404), - (VALID_ADMIN_HEADER_TOKEN, 404), - (VALID_ADMIN_API_TOKEN, 404), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 404), + (VALID_ADMIN_HEADER_TOKEN, 404), + (VALID_ADMIN_API_TOKEN, 404), + ("", 401), + ], +) def test_source_get(client, source_plugin, token, status): - assert client.get(api.url_for(Sources, source_id=43543), headers=token).status_code == status + assert ( + client.get(api.url_for(Sources, source_id=43543), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_source_post_(client, token, status): - assert client.post(api.url_for(Sources, source_id=1), data={}, headers=token).status_code == status + assert ( + client.post( + api.url_for(Sources, source_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_source_put(client, token, status): - assert client.put(api.url_for(Sources, source_id=1), data={}, headers=token).status_code == status + assert ( + client.put( + api.url_for(Sources, source_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_source_delete(client, token, status): - assert client.delete(api.url_for(Sources, source_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Sources, source_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_source_patch(client, token, status): - assert client.patch(api.url_for(Sources, source_id=1), data={}, headers=token).status_code == status + assert ( + client.patch( + api.url_for(Sources, source_id=1), data={}, headers=token + ).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_sources_list_get(client, source_plugin, token, status): assert client.get(api.url_for(SourcesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_sources_list_post(client, token, status): - assert client.post(api.url_for(SourcesList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(SourcesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_sources_list_put(client, token, status): - assert client.put(api.url_for(SourcesList), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(SourcesList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_sources_list_delete(client, token, status): assert client.delete(api.url_for(SourcesList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_sources_list_patch(client, token, status): - assert client.patch(api.url_for(SourcesList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(SourcesList), data={}, headers=token).status_code + == status + ) diff --git a/lemur/tests/test_users.py b/lemur/tests/test_users.py index 61db93bf..9e67f868 100644 --- a/lemur/tests/test_users.py +++ b/lemur/tests/test_users.py @@ -4,16 +4,20 @@ import pytest from lemur.tests.factories import UserFactory, RoleFactory from lemur.users.views import * # noqa -from .vectors import VALID_ADMIN_API_TOKEN, VALID_ADMIN_HEADER_TOKEN, VALID_USER_HEADER_TOKEN +from .vectors import ( + VALID_ADMIN_API_TOKEN, + VALID_ADMIN_HEADER_TOKEN, + VALID_USER_HEADER_TOKEN, +) def test_user_input_schema(client): from lemur.users.schemas import UserInputSchema input_data = { - 'username': 'example', - 'password': '1233432', - 'email': 'example@example.com' + "username": "example", + "password": "1233432", + "email": "example@example.com", } data, errors = UserInputSchema().load(input_data) @@ -21,104 +25,156 @@ def test_user_input_schema(client): assert not errors -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_get(client, token, status): - assert client.get(api.url_for(Users, user_id=1), headers=token).status_code == status + assert ( + client.get(api.url_for(Users, user_id=1), headers=token).status_code == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_post_(client, token, status): - assert client.post(api.url_for(Users, user_id=1), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(Users, user_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_user_put(client, token, status): - assert client.put(api.url_for(Users, user_id=1), data={}, headers=token).status_code == status + assert ( + client.put(api.url_for(Users, user_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_delete(client, token, status): - assert client.delete(api.url_for(Users, user_id=1), headers=token).status_code == status + assert ( + client.delete(api.url_for(Users, user_id=1), headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_patch(client, token, status): - assert client.patch(api.url_for(Users, user_id=1), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(Users, user_id=1), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 403), - (VALID_ADMIN_HEADER_TOKEN, 400), - (VALID_ADMIN_API_TOKEN, 400), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 403), + (VALID_ADMIN_HEADER_TOKEN, 400), + (VALID_ADMIN_API_TOKEN, 400), + ("", 401), + ], +) def test_user_list_post_(client, token, status): - assert client.post(api.url_for(UsersList), data={}, headers=token).status_code == status + assert ( + client.post(api.url_for(UsersList), data={}, headers=token).status_code + == status + ) -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 200), - (VALID_ADMIN_HEADER_TOKEN, 200), - (VALID_ADMIN_API_TOKEN, 200), - ('', 401) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 200), + (VALID_ADMIN_HEADER_TOKEN, 200), + (VALID_ADMIN_API_TOKEN, 200), + ("", 401), + ], +) def test_user_list_get(client, token, status): assert client.get(api.url_for(UsersList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_list_delete(client, token, status): assert client.delete(api.url_for(UsersList), headers=token).status_code == status -@pytest.mark.parametrize("token,status", [ - (VALID_USER_HEADER_TOKEN, 405), - (VALID_ADMIN_HEADER_TOKEN, 405), - (VALID_ADMIN_API_TOKEN, 405), - ('', 405) -]) +@pytest.mark.parametrize( + "token,status", + [ + (VALID_USER_HEADER_TOKEN, 405), + (VALID_ADMIN_HEADER_TOKEN, 405), + (VALID_ADMIN_API_TOKEN, 405), + ("", 405), + ], +) def test_user_list_patch(client, token, status): - assert client.patch(api.url_for(UsersList), data={}, headers=token).status_code == status + assert ( + client.patch(api.url_for(UsersList), data={}, headers=token).status_code + == status + ) def test_sensitive_filter(client): - resp = client.get(api.url_for(UsersList) + '?filter=password;a', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'password' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(UsersList) + "?filter=password;a", headers=VALID_ADMIN_HEADER_TOKEN + ) + assert "'password' is not sortable or filterable" in resp.json["message"] def test_sensitive_sort(client): - resp = client.get(api.url_for(UsersList) + '?sortBy=password&sortDir=asc', headers=VALID_ADMIN_HEADER_TOKEN) - assert "'password' is not sortable or filterable" in resp.json['message'] + resp = client.get( + api.url_for(UsersList) + "?sortBy=password&sortDir=asc", + headers=VALID_ADMIN_HEADER_TOKEN, + ) + assert "'password' is not sortable or filterable" in resp.json["message"] def test_user_role_changes(client, session): @@ -128,25 +184,30 @@ def test_user_role_changes(client, session): session.flush() data = { - 'active': True, - 'id': user.id, - 'username': user.username, - 'email': user.email, - 'roles': [ - {'id': role1.id}, - {'id': role2.id}, - ], + "active": True, + "id": user.id, + "username": user.username, + "email": user.email, + "roles": [{"id": role1.id}, {"id": role2.id}], } # PUT two roles - resp = client.put(api.url_for(Users, user_id=user.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + resp = client.put( + api.url_for(Users, user_id=user.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 2 + assert len(resp.json["roles"]) == 2 assert set(user.roles) == {role1, role2} # Remove one role and PUT again - del data['roles'][1] - resp = client.put(api.url_for(Users, user_id=user.id), data=json.dumps(data), headers=VALID_ADMIN_HEADER_TOKEN) + del data["roles"][1] + resp = client.put( + api.url_for(Users, user_id=user.id), + data=json.dumps(data), + headers=VALID_ADMIN_HEADER_TOKEN, + ) assert resp.status_code == 200 - assert len(resp.json['roles']) == 1 + assert len(resp.json["roles"]) == 1 assert set(user.roles) == {role1} diff --git a/lemur/tests/test_utils.py b/lemur/tests/test_utils.py index 74c11643..2e117d25 100644 --- a/lemur/tests/test_utils.py +++ b/lemur/tests/test_utils.py @@ -1,40 +1,49 @@ import pytest -from lemur.tests.vectors import SAN_CERT, INTERMEDIATE_CERT, ROOTCA_CERT, EC_CERT_EXAMPLE, ECDSA_PRIME256V1_CERT, ECDSA_SECP384r1_CERT, DSA_CERT +from lemur.tests.vectors import ( + SAN_CERT, + INTERMEDIATE_CERT, + ROOTCA_CERT, + EC_CERT_EXAMPLE, + ECDSA_PRIME256V1_CERT, + ECDSA_SECP384r1_CERT, + DSA_CERT, +) def test_generate_private_key(): from lemur.common.utils import generate_private_key - assert generate_private_key('RSA2048') - assert generate_private_key('RSA4096') - assert generate_private_key('ECCPRIME192V1') - assert generate_private_key('ECCPRIME256V1') - assert generate_private_key('ECCSECP192R1') - assert generate_private_key('ECCSECP224R1') - assert generate_private_key('ECCSECP256R1') - assert generate_private_key('ECCSECP384R1') - assert generate_private_key('ECCSECP521R1') - assert generate_private_key('ECCSECP256K1') - assert generate_private_key('ECCSECT163K1') - assert generate_private_key('ECCSECT233K1') - assert generate_private_key('ECCSECT283K1') - assert generate_private_key('ECCSECT409K1') - assert generate_private_key('ECCSECT571K1') - assert generate_private_key('ECCSECT163R2') - assert generate_private_key('ECCSECT233R1') - assert generate_private_key('ECCSECT283R1') - assert generate_private_key('ECCSECT409R1') - assert generate_private_key('ECCSECT571R2') + assert generate_private_key("RSA2048") + assert generate_private_key("RSA4096") + assert generate_private_key("ECCPRIME192V1") + assert generate_private_key("ECCPRIME256V1") + assert generate_private_key("ECCSECP192R1") + assert generate_private_key("ECCSECP224R1") + assert generate_private_key("ECCSECP256R1") + assert generate_private_key("ECCSECP384R1") + assert generate_private_key("ECCSECP521R1") + assert generate_private_key("ECCSECP256K1") + assert generate_private_key("ECCSECT163K1") + assert generate_private_key("ECCSECT233K1") + assert generate_private_key("ECCSECT283K1") + assert generate_private_key("ECCSECT409K1") + assert generate_private_key("ECCSECT571K1") + assert generate_private_key("ECCSECT163R2") + assert generate_private_key("ECCSECT233R1") + assert generate_private_key("ECCSECT283R1") + assert generate_private_key("ECCSECT409R1") + assert generate_private_key("ECCSECT571R2") with pytest.raises(Exception): - generate_private_key('LEMUR') + generate_private_key("LEMUR") def test_get_authority_key(): - '''test get authority key function''' + """test get authority key function""" from lemur.common.utils import get_authority_key - test_cert = '''-----BEGIN CERTIFICATE----- + + test_cert = """-----BEGIN CERTIFICATE----- MIIGYjCCBEqgAwIBAgIUVS7mn6LR5XlQyEGxQ4w9YAWL/XIwDQYJKoZIhvcNAQEN BQAweTELMAkGA1UEBhMCREUxDTALBgNVBAgTBEJvbm4xEDAOBgNVBAcTB0dlcm1h bnkxITAfBgNVBAoTGFRlbGVrb20gRGV1dHNjaGxhbmQgR21iSDELMAkGA1UECxMC @@ -70,9 +79,9 @@ zc75IDsn5wP6A3KflduWW7ri0bYUiKe5higMcbUM0aXzTEAVxsxPk8aEsR9dazF7 y4L/msew3UjFE3ovDHgStjWM1NBMxuIvJEbWOsiB2WA2l3FiT8HvFi0eX/0hbkGi 5LL+oz7nvm9Of7te/BV6Rq0rXWN4d6asO+QlLkTqbmAH6rwunmPCY7MbLXXtP/qM KFfxwrO1 ------END CERTIFICATE-----''' +-----END CERTIFICATE-----""" authority_key = get_authority_key(test_cert) - assert authority_key == 'feacb541be81771293affa412d8dc9f66a3ebb80' + assert authority_key == "feacb541be81771293affa412d8dc9f66a3ebb80" def test_is_selfsigned(selfsigned_cert): diff --git a/lemur/tests/test_validators.py b/lemur/tests/test_validators.py index c3d5357d..77148079 100644 --- a/lemur/tests/test_validators.py +++ b/lemur/tests/test_validators.py @@ -12,7 +12,7 @@ def test_private_key(session): parse_private_key(SAN_CERT_KEY) with pytest.raises(ValueError): - parse_private_key('invalid_private_key') + parse_private_key("invalid_private_key") def test_validate_private_key(session): @@ -29,7 +29,7 @@ def test_sub_alt_type(session): from lemur.common.validators import sub_alt_type with pytest.raises(ValidationError): - sub_alt_type('CNAME') + sub_alt_type("CNAME") def test_dates(session): @@ -44,7 +44,13 @@ def test_dates(session): dates(dict(validity_end=datetime(2016, 1, 1))) with pytest.raises(ValidationError): - dates(dict(validity_start=datetime(2016, 1, 5), validity_end=datetime(2016, 1, 1))) + dates( + dict(validity_start=datetime(2016, 1, 5), validity_end=datetime(2016, 1, 1)) + ) with pytest.raises(ValidationError): - dates(dict(validity_start=datetime(2016, 1, 1), validity_end=datetime(2016, 1, 10))) + dates( + dict( + validity_start=datetime(2016, 1, 1), validity_end=datetime(2016, 1, 10) + ) + ) diff --git a/lemur/tests/test_verify.py b/lemur/tests/test_verify.py index a1f0f5eb..348f6559 100644 --- a/lemur/tests/test_verify.py +++ b/lemur/tests/test_verify.py @@ -13,20 +13,24 @@ from .vectors import INTERMEDIATE_CERT_STR def test_verify_simple_cert(): """Simple certificate without CRL or OCSP.""" # Verification returns None if there are no means to verify a cert - assert verify_string(INTERMEDIATE_CERT_STR, '') is None + assert verify_string(INTERMEDIATE_CERT_STR, "") is None def test_verify_crl_unknown_scheme(cert_builder, private_key): """Unknown distribution point URI schemes should be ignored.""" - ldap_uri = 'ldap://ldap.example.org/cn=Example%20Certificate%20Authority?certificateRevocationList;binary' - crl_dp = x509.DistributionPoint([UniformResourceIdentifier(ldap_uri)], - relative_name=None, reasons=None, crl_issuer=None) - cert = (cert_builder - .add_extension(x509.CRLDistributionPoints([crl_dp]), critical=False) - .sign(private_key, hashes.SHA256(), default_backend())) + ldap_uri = "ldap://ldap.example.org/cn=Example%20Certificate%20Authority?certificateRevocationList;binary" + crl_dp = x509.DistributionPoint( + [UniformResourceIdentifier(ldap_uri)], + relative_name=None, + reasons=None, + crl_issuer=None, + ) + cert = cert_builder.add_extension( + x509.CRLDistributionPoints([crl_dp]), critical=False + ).sign(private_key, hashes.SHA256(), default_backend()) with mktempfile() as cert_tmp: - with open(cert_tmp, 'wb') as f: + with open(cert_tmp, "wb") as f: f.write(cert.public_bytes(serialization.Encoding.PEM)) # Must not raise exception @@ -35,15 +39,19 @@ def test_verify_crl_unknown_scheme(cert_builder, private_key): def test_verify_crl_unreachable(cert_builder, private_key): """Unreachable CRL distribution point results in error.""" - ldap_uri = 'http://invalid.example.org/crl/foobar.crl' - crl_dp = x509.DistributionPoint([UniformResourceIdentifier(ldap_uri)], - relative_name=None, reasons=None, crl_issuer=None) - cert = (cert_builder - .add_extension(x509.CRLDistributionPoints([crl_dp]), critical=False) - .sign(private_key, hashes.SHA256(), default_backend())) + ldap_uri = "http://invalid.example.org/crl/foobar.crl" + crl_dp = x509.DistributionPoint( + [UniformResourceIdentifier(ldap_uri)], + relative_name=None, + reasons=None, + crl_issuer=None, + ) + cert = cert_builder.add_extension( + x509.CRLDistributionPoints([crl_dp]), critical=False + ).sign(private_key, hashes.SHA256(), default_backend()) with mktempfile() as cert_tmp: - with open(cert_tmp, 'wb') as f: + with open(cert_tmp, "wb") as f: f.write(cert.public_bytes(serialization.Encoding.PEM)) with pytest.raises(Exception, match="Unable to retrieve CRL:"): diff --git a/lemur/tests/vectors.py b/lemur/tests/vectors.py index cb5800a1..0768cdac 100644 --- a/lemur/tests/vectors.py +++ b/lemur/tests/vectors.py @@ -1,20 +1,23 @@ from lemur.common.utils import parse_certificate VALID_USER_HEADER_TOKEN = { - 'Authorization': 'Basic ' + 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE1MjE2NTIwMjIsImV4cCI6MjM4NTY1MjAyMiwic3ViIjoxfQ.uK4PZjVAs0gt6_9h2EkYkKd64nFXdOq-rHsJZzeQicc', - 'Content-Type': 'application/json' + "Authorization": "Basic " + + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE1MjE2NTIwMjIsImV4cCI6MjM4NTY1MjAyMiwic3ViIjoxfQ.uK4PZjVAs0gt6_9h2EkYkKd64nFXdOq-rHsJZzeQicc", + "Content-Type": "application/json", } VALID_ADMIN_HEADER_TOKEN = { - 'Authorization': 'Basic ' + 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE1MjE2NTE2NjMsInN1YiI6MiwiYWlkIjoxfQ.wyf5PkQNcggLrMFqxDfzjY-GWPw_XsuWvU2GmQaC5sg', - 'Content-Type': 'application/json' + "Authorization": "Basic " + + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpYXQiOjE1MjE2NTE2NjMsInN1YiI6MiwiYWlkIjoxfQ.wyf5PkQNcggLrMFqxDfzjY-GWPw_XsuWvU2GmQaC5sg", + "Content-Type": "application/json", } VALID_ADMIN_API_TOKEN = { - 'Authorization': 'Basic ' + 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjIsImFpZCI6MSwiaWF0IjoxNDM1MjMzMzY5fQ.umW0I_oh4MVZ2qrClzj9SfYnQl6cd0HGzh9EwkDW60I', - 'Content-Type': 'application/json' + "Authorization": "Basic " + + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjIsImFpZCI6MSwiaWF0IjoxNDM1MjMzMzY5fQ.umW0I_oh4MVZ2qrClzj9SfYnQl6cd0HGzh9EwkDW60I", + "Content-Type": "application/json", } diff --git a/lemur/users/models.py b/lemur/users/models.py index 79125b9c..d7b900dc 100644 --- a/lemur/users/models.py +++ b/lemur/users/models.py @@ -33,7 +33,7 @@ def hash_password(mapper, connect, target): class User(db.Model): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) password = Column(String(128)) active = Column(Boolean()) @@ -41,14 +41,24 @@ class User(db.Model): username = Column(String(255), nullable=False, unique=True) email = Column(String(128), unique=True) profile_picture = Column(String(255)) - roles = relationship('Role', secondary=roles_users, passive_deletes=True, backref=db.backref('user'), lazy='dynamic') - certificates = relationship('Certificate', backref=db.backref('user'), lazy='dynamic') - pending_certificates = relationship('PendingCertificate', backref=db.backref('user'), lazy='dynamic') - authorities = relationship('Authority', backref=db.backref('user'), lazy='dynamic') - keys = relationship('ApiKey', backref=db.backref('user'), lazy='dynamic') - logs = relationship('Log', backref=db.backref('user'), lazy='dynamic') + roles = relationship( + "Role", + secondary=roles_users, + passive_deletes=True, + backref=db.backref("user"), + lazy="dynamic", + ) + certificates = relationship( + "Certificate", backref=db.backref("user"), lazy="dynamic" + ) + pending_certificates = relationship( + "PendingCertificate", backref=db.backref("user"), lazy="dynamic" + ) + authorities = relationship("Authority", backref=db.backref("user"), lazy="dynamic") + keys = relationship("ApiKey", backref=db.backref("user"), lazy="dynamic") + logs = relationship("Log", backref=db.backref("user"), lazy="dynamic") - sensitive_fields = ('password',) + sensitive_fields = ("password",) def check_password(self, password): """ @@ -68,7 +78,7 @@ class User(db.Model): :return: """ if self.password: - self.password = bcrypt.generate_password_hash(self.password).decode('utf-8') + self.password = bcrypt.generate_password_hash(self.password).decode("utf-8") @property def is_admin(self): @@ -79,11 +89,11 @@ class User(db.Model): :return: """ for role in self.roles: - if role.name == 'admin': + if role.name == "admin": return True def __repr__(self): return "User(username={username})".format(username=self.username) -listen(User, 'before_insert', hash_password) +listen(User, "before_insert", hash_password) diff --git a/lemur/users/schemas.py b/lemur/users/schemas.py index b5a21127..74bd93e9 100644 --- a/lemur/users/schemas.py +++ b/lemur/users/schemas.py @@ -8,7 +8,11 @@ from marshmallow import fields from lemur.common.schema import LemurInputSchema, LemurOutputSchema -from lemur.schemas import AssociatedRoleSchema, AssociatedCertificateSchema, AssociatedAuthoritySchema +from lemur.schemas import ( + AssociatedRoleSchema, + AssociatedCertificateSchema, + AssociatedAuthoritySchema, +) class UserInputSchema(LemurInputSchema): diff --git a/lemur/users/service.py b/lemur/users/service.py index c6557cb9..8fb91aa3 100644 --- a/lemur/users/service.py +++ b/lemur/users/service.py @@ -96,7 +96,7 @@ def get_by_email(email): :param email: :return: """ - return database.get(User, email, field='email') + return database.get(User, email, field="email") def get_by_username(username): @@ -106,7 +106,7 @@ def get_by_username(username): :param username: :return: """ - return database.get(User, username, field='username') + return database.get(User, username, field="username") def get_all(): @@ -129,10 +129,10 @@ def render(args): """ query = database.session_query(User) - filt = args.pop('filter') + filt = args.pop("filter") if filt: - terms = filt.split(';') + terms = filt.split(";") query = database.filter(query, User, terms) return database.sort_and_page(query, User, args) diff --git a/lemur/users/views.py b/lemur/users/views.py index eb67f014..06729177 100644 --- a/lemur/users/views.py +++ b/lemur/users/views.py @@ -18,15 +18,20 @@ from lemur.users import service from lemur.certificates import service as certificate_service from lemur.roles import service as role_service -from lemur.users.schemas import user_input_schema, user_output_schema, users_output_schema +from lemur.users.schemas import ( + user_input_schema, + user_output_schema, + users_output_schema, +) -mod = Blueprint('users', __name__) +mod = Blueprint("users", __name__) api = Api(mod) class UsersList(AuthenticatedResource): """ Defines the 'users' endpoint """ + def __init__(self): self.reqparse = reqparse.RequestParser() super(UsersList, self).__init__() @@ -83,8 +88,8 @@ class UsersList(AuthenticatedResource): :statuscode 200: no error """ parser = paginated_parser.copy() - parser.add_argument('owner', type=str, location='args') - parser.add_argument('id', type=str, location='args') + parser.add_argument("owner", type=str, location="args") + parser.add_argument("id", type=str, location="args") args = parser.parse_args() return service.render(args) @@ -137,7 +142,14 @@ class UsersList(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.create(data['username'], data['password'], data['email'], data['active'], None, data['roles']) + return service.create( + data["username"], + data["password"], + data["email"], + data["active"], + None, + data["roles"], + ) class Users(AuthenticatedResource): @@ -225,7 +237,14 @@ class Users(AuthenticatedResource): :reqheader Authorization: OAuth token to authenticate :statuscode 200: no error """ - return service.update(user_id, data['username'], data['email'], data['active'], None, data['roles']) + return service.update( + user_id, + data["username"], + data["email"], + data["active"], + None, + data["roles"], + ) class CertificateUsers(AuthenticatedResource): @@ -365,8 +384,12 @@ class Me(AuthenticatedResource): return g.current_user -api.add_resource(Me, '/auth/me', endpoint='me') -api.add_resource(UsersList, '/users', endpoint='users') -api.add_resource(Users, '/users/', endpoint='user') -api.add_resource(CertificateUsers, '/certificates//creator', endpoint='certificateCreator') -api.add_resource(RoleUsers, '/roles//users', endpoint='roleUsers') +api.add_resource(Me, "/auth/me", endpoint="me") +api.add_resource(UsersList, "/users", endpoint="users") +api.add_resource(Users, "/users/", endpoint="user") +api.add_resource( + CertificateUsers, + "/certificates//creator", + endpoint="certificateCreator", +) +api.add_resource(RoleUsers, "/roles//users", endpoint="roleUsers") diff --git a/lemur/utils.py b/lemur/utils.py index 1661e3f7..909d959a 100644 --- a/lemur/utils.py +++ b/lemur/utils.py @@ -31,7 +31,9 @@ def mktempfile(): @contextmanager def mktemppath(): try: - path = os.path.join(tempfile._get_default_tempdir(), next(tempfile._get_candidate_names())) + path = os.path.join( + tempfile._get_default_tempdir(), next(tempfile._get_candidate_names()) + ) yield path finally: try: @@ -53,7 +55,7 @@ def get_keys(): # when running lemur create_config, this code needs to work despite # the fact that there is not a current_app with a config at that point - keys = current_app.config.get('LEMUR_ENCRYPTION_KEYS', []) + keys = current_app.config.get("LEMUR_ENCRYPTION_KEYS", []) # this function is expected to return a list of keys, but we want # to let people just specify a single key @@ -97,7 +99,7 @@ class Vault(types.TypeDecorator): # ensure bytes for fernet if isinstance(value, str): - value = value.encode('utf-8') + value = value.encode("utf-8") return MultiFernet(self.keys).encrypt(value) @@ -117,4 +119,4 @@ class Vault(types.TypeDecorator): if not value: return - return MultiFernet(self.keys).decrypt(value).decode('utf8') + return MultiFernet(self.keys).decrypt(value).decode("utf8") diff --git a/requirements-dev.txt b/requirements-dev.txt index 1a5b5f9d..bfbadc8a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -25,12 +25,12 @@ pygments==2.4.0 # via readme-renderer pyyaml==5.1 readme-renderer==24.0 # via twine requests-toolbelt==0.9.1 # via twine -requests==2.21.0 # via requests-toolbelt, twine +requests==2.22.0 # via requests-toolbelt, twine six==1.12.0 # via bleach, cfgv, pre-commit, readme-renderer toml==0.10.0 # via pre-commit tqdm==4.32.1 # via twine twine==1.13.0 -urllib3==1.24.3 # via requests -virtualenv==16.5.0 # via pre-commit +urllib3==1.25.2 # via requests +virtualenv==16.6.0 # via pre-commit webencodings==0.5.1 # via bleach zipp==0.5.0 # via importlib-metadata diff --git a/requirements-docs.txt b/requirements-docs.txt index f23de8f4..bf60d82f 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -17,8 +17,8 @@ babel==2.6.0 # via sphinx bcrypt==3.1.6 billiard==3.6.0.0 blinker==1.4 -boto3==1.9.147 -botocore==1.12.147 +boto3==1.9.149 +botocore==1.12.149 celery[redis]==4.3.0 certifi==2019.3.9 certsrv==2.1.1 @@ -102,5 +102,5 @@ tabulate==0.8.3 twofish==0.3.0 urllib3==1.24.3 vine==1.3.0 -werkzeug==0.15.2 +werkzeug==0.15.4 xmltodict==0.12.0 diff --git a/requirements-tests.in b/requirements-tests.in index dcd3d0c7..d624d4f7 100644 --- a/requirements-tests.in +++ b/requirements-tests.in @@ -1,5 +1,6 @@ # Run `make up-reqs` to update pinned dependencies in requirement text files +black coverage factory-boy Faker diff --git a/requirements-tests.txt b/requirements-tests.txt index 27837359..95ceb652 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -4,19 +4,21 @@ # # pip-compile --no-index --output-file=requirements-tests.txt requirements-tests.in # +appdirs==1.4.3 # via black asn1crypto==0.24.0 # via cryptography atomicwrites==1.3.0 # via pytest -attrs==19.1.0 # via pytest +attrs==19.1.0 # via black, pytest aws-sam-translator==1.11.0 # via cfn-lint aws-xray-sdk==2.4.2 # via moto -boto3==1.9.147 # via aws-sam-translator, moto +black==19.3b0 +boto3==1.9.149 # via aws-sam-translator, moto boto==2.49.0 # via moto -botocore==1.12.147 # via aws-xray-sdk, boto3, moto, s3transfer +botocore==1.12.149 # via aws-xray-sdk, boto3, moto, s3transfer certifi==2019.3.9 # via requests cffi==1.12.3 # via cryptography -cfn-lint==0.20.1 # via moto +cfn-lint==0.20.2 # via moto chardet==3.0.4 # via requests -click==7.0 # via flask +click==7.0 # via black, flask coverage==4.5.3 cryptography==2.6.1 # via moto docker-pycreds==0.4.0 # via docker @@ -55,15 +57,16 @@ python-jose==3.0.1 # via moto pytz==2019.1 # via moto pyyaml==5.1 requests-mock==1.6.0 -requests==2.21.0 # via cfn-lint, docker, moto, requests-mock, responses +requests==2.22.0 # via cfn-lint, docker, moto, requests-mock, responses responses==0.10.6 # via moto rsa==4.0 # via python-jose s3transfer==0.2.0 # via boto3 six==1.12.0 # via aws-sam-translator, cfn-lint, cryptography, docker, docker-pycreds, faker, freezegun, mock, moto, pytest, python-dateutil, python-jose, requests-mock, responses, websocket-client text-unidecode==1.2 # via faker +toml==0.10.0 # via black urllib3==1.24.3 # via botocore, requests wcwidth==0.1.7 # via pytest websocket-client==0.56.0 # via docker -werkzeug==0.15.2 # via flask, moto, pytest-flask +werkzeug==0.15.4 # via flask, moto, pytest-flask wrapt==1.11.1 # via aws-xray-sdk xmltodict==0.12.0 # via moto diff --git a/requirements.txt b/requirements.txt index 935e85ca..66f4fd40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,8 +15,8 @@ asyncpool==1.0 bcrypt==3.1.6 # via flask-bcrypt, paramiko billiard==3.6.0.0 # via celery blinker==1.4 # via flask-mail, flask-principal, raven -boto3==1.9.147 -botocore==1.12.147 +boto3==1.9.149 +botocore==1.12.149 celery[redis]==4.3.0 certifi==2019.3.9 certsrv==2.1.1 @@ -87,5 +87,5 @@ tabulate==0.8.3 twofish==0.3.0 # via pyjks urllib3==1.24.3 # via botocore, requests vine==1.3.0 # via amqp, celery -werkzeug==0.15.2 # via flask +werkzeug==0.15.4 # via flask xmltodict==0.12.0