Merge branch 'master' of github.com:Netflix/lemur into improving-cert-lookup-time

This commit is contained in:
Hossein Shafagh 2019-05-30 08:55:49 -07:00
commit b4d9ab9f0c
227 changed files with 9420 additions and 5972 deletions

View File

@ -8,3 +8,17 @@
sha: v2.9.5 sha: v2.9.5
hooks: hooks:
- id: jshint - id: jshint
- repo: https://github.com/ambv/black
rev: stable
hooks:
- id: black
language_version: python3.7
- repo: local
hooks:
- id: python-bandit-vulnerability-check
name: bandit
entry: bandit
args: ['--ini', 'tox.ini', '-r', 'consoleme']
language: system
pass_filenames: false

View File

@ -1,12 +1,18 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
__all__ = [ __all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__", "__title__",
"__email__", "__license__", "__copyright__", "__summary__",
"__uri__",
"__version__",
"__author__",
"__email__",
"__license__",
"__copyright__",
] ]
__title__ = "lemur" __title__ = "lemur"
__summary__ = ("Certificate management and orchestration service") __summary__ = "Certificate management and orchestration service"
__uri__ = "https://github.com/Netflix/lemur" __uri__ = "https://github.com/Netflix/lemur"
__version__ = "0.7.0" __version__ = "0.7.0"

View File

@ -32,14 +32,26 @@ from lemur.pending_certificates.views import mod as pending_certificates_bp
from lemur.dns_providers.views import mod as dns_providers_bp from lemur.dns_providers.views import mod as dns_providers_bp
from lemur.__about__ import ( from lemur.__about__ import (
__author__, __copyright__, __email__, __license__, __summary__, __title__, __author__,
__uri__, __version__ __copyright__,
__email__,
__license__,
__summary__,
__title__,
__uri__,
__version__,
) )
__all__ = [ __all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__", "__title__",
"__email__", "__license__", "__copyright__", "__summary__",
"__uri__",
"__version__",
"__author__",
"__email__",
"__license__",
"__copyright__",
] ]
LEMUR_BLUEPRINTS = ( LEMUR_BLUEPRINTS = (
@ -63,7 +75,9 @@ LEMUR_BLUEPRINTS = (
def create_app(config_path=None): def create_app(config_path=None):
app = factory.create_app(app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config_path) app = factory.create_app(
app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config_path
)
configure_hook(app) configure_hook(app)
return app return app
@ -93,7 +107,7 @@ def configure_hook(app):
@app.after_request @app.after_request
def after_request(response): def after_request(response):
# Return early if we don't have the start time # Return early if we don't have the start time
if not hasattr(g, 'request_start_time'): if not hasattr(g, "request_start_time"):
return response return response
# Get elapsed time in milliseconds # Get elapsed time in milliseconds
@ -102,12 +116,12 @@ def configure_hook(app):
# Collect request/response tags # Collect request/response tags
tags = { tags = {
'endpoint': request.endpoint, "endpoint": request.endpoint,
'request_method': request.method.lower(), "request_method": request.method.lower(),
'status_code': response.status_code "status_code": response.status_code,
} }
# Record our response time metric # Record our response time metric
metrics.send('response_time', 'TIMER', elapsed, metric_tags=tags) metrics.send("response_time", "TIMER", elapsed, metric_tags=tags)
metrics.send('status_code_{}'.format(response.status_code), 'counter', 1) metrics.send("status_code_{}".format(response.status_code), "counter", 1)
return response return response

View File

@ -14,23 +14,32 @@ from datetime import datetime
manager = Manager(usage="Handles all api key related tasks.") manager = Manager(usage="Handles all api key related tasks.")
@manager.option('-u', '--user-id', dest='uid', help='The User ID this access key belongs too.') @manager.option(
@manager.option('-n', '--name', dest='name', help='The name of this API Key.') "-u", "--user-id", dest="uid", help="The User ID this access key belongs too."
@manager.option('-t', '--ttl', dest='ttl', help='The TTL of this API Key. -1 for forever.') )
@manager.option("-n", "--name", dest="name", help="The name of this API Key.")
@manager.option(
"-t", "--ttl", dest="ttl", help="The TTL of this API Key. -1 for forever."
)
def create(uid, name, ttl): def create(uid, name, ttl):
""" """
Create a new api key for a user. Create a new api key for a user.
:return: :return:
""" """
print("[+] Creating a new api key.") print("[+] Creating a new api key.")
key = api_key_service.create(user_id=uid, name=name, key = api_key_service.create(
ttl=ttl, issued_at=int(datetime.utcnow().timestamp()), revoked=False) user_id=uid,
name=name,
ttl=ttl,
issued_at=int(datetime.utcnow().timestamp()),
revoked=False,
)
print("[+] Successfully created a new api key. Generating a JWT...") print("[+] Successfully created a new api key. Generating a JWT...")
jwt = create_token(uid, key.id, key.ttl) jwt = create_token(uid, key.id, key.ttl)
print("[+] Your JWT is: {jwt}".format(jwt=jwt)) print("[+] Your JWT is: {jwt}".format(jwt=jwt))
@manager.option('-a', '--api-key-id', dest='aid', help='The API Key ID to revoke.') @manager.option("-a", "--api-key-id", dest="aid", help="The API Key ID to revoke.")
def revoke(aid): def revoke(aid):
""" """
Revokes an api key for a user. Revokes an api key for a user.

View File

@ -12,14 +12,19 @@ from lemur.database import db
class ApiKey(db.Model): class ApiKey(db.Model):
__tablename__ = 'api_keys' __tablename__ = "api_keys"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String) name = Column(String)
user_id = Column(Integer, ForeignKey('users.id')) user_id = Column(Integer, ForeignKey("users.id"))
ttl = Column(BigInteger) ttl = Column(BigInteger)
issued_at = Column(BigInteger) issued_at = Column(BigInteger)
revoked = Column(Boolean) revoked = Column(Boolean)
def __repr__(self): def __repr__(self):
return "ApiKey(name={name}, user_id={user_id}, ttl={ttl}, issued_at={iat}, revoked={revoked})".format( return "ApiKey(name={name}, user_id={user_id}, ttl={ttl}, issued_at={iat}, revoked={revoked})".format(
user_id=self.user_id, name=self.name, ttl=self.ttl, iat=self.issued_at, revoked=self.revoked) user_id=self.user_id,
name=self.name,
ttl=self.ttl,
iat=self.issued_at,
revoked=self.revoked,
)

View File

@ -13,12 +13,18 @@ from lemur.users.schemas import UserNestedOutputSchema, UserInputSchema
def current_user_id(): def current_user_id():
return {'id': g.current_user.id, 'email': g.current_user.email, 'username': g.current_user.username} return {
"id": g.current_user.id,
"email": g.current_user.email,
"username": g.current_user.username,
}
class ApiKeyInputSchema(LemurInputSchema): class ApiKeyInputSchema(LemurInputSchema):
name = fields.String(required=False) name = fields.String(required=False)
user = fields.Nested(UserInputSchema, missing=current_user_id, default=current_user_id) user = fields.Nested(
UserInputSchema, missing=current_user_id, default=current_user_id
)
ttl = fields.Integer() ttl = fields.Integer()

View File

@ -34,7 +34,7 @@ def revoke(aid):
:return: :return:
""" """
api_key = get(aid) api_key = get(aid)
setattr(api_key, 'revoked', False) setattr(api_key, "revoked", False)
return database.update(api_key) return database.update(api_key)
@ -80,10 +80,10 @@ def render(args):
:return: :return:
""" """
query = database.session_query(ApiKey) query = database.session_query(ApiKey)
user_id = args.pop('user_id', None) user_id = args.pop("user_id", None)
aid = args.pop('id', None) aid = args.pop("id", None)
has_permission = args.pop('has_permission', False) has_permission = args.pop("has_permission", False)
requesting_user_id = args.pop('requesting_user_id') requesting_user_id = args.pop("requesting_user_id")
if user_id: if user_id:
query = query.filter(ApiKey.user_id == user_id) query = query.filter(ApiKey.user_id == user_id)

View File

@ -19,10 +19,16 @@ from lemur.auth.permissions import ApiKeyCreatorPermission
from lemur.common.schema import validate_schema from lemur.common.schema import validate_schema
from lemur.common.utils import paginated_parser from lemur.common.utils import paginated_parser
from lemur.api_keys.schemas import api_key_input_schema, api_key_revoke_schema, api_key_output_schema, \ from lemur.api_keys.schemas import (
api_keys_output_schema, api_key_described_output_schema, user_api_key_input_schema api_key_input_schema,
api_key_revoke_schema,
api_key_output_schema,
api_keys_output_schema,
api_key_described_output_schema,
user_api_key_input_schema,
)
mod = Blueprint('api_keys', __name__) mod = Blueprint("api_keys", __name__)
api = Api(mod) api = Api(mod)
@ -81,8 +87,8 @@ class ApiKeyList(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
args = parser.parse_args() args = parser.parse_args()
args['has_permission'] = ApiKeyCreatorPermission().can() args["has_permission"] = ApiKeyCreatorPermission().can()
args['requesting_user_id'] = g.current_user.id args["requesting_user_id"] = g.current_user.id
return service.render(args) return service.render(args)
@validate_schema(api_key_input_schema, api_key_output_schema) @validate_schema(api_key_input_schema, api_key_output_schema)
@ -124,12 +130,26 @@ class ApiKeyList(AuthenticatedResource):
:statuscode 403: unauthenticated :statuscode 403: unauthenticated
""" """
if not ApiKeyCreatorPermission().can(): if not ApiKeyCreatorPermission().can():
if data['user']['id'] != g.current_user.id: if data["user"]["id"] != g.current_user.id:
return dict(message="You are not authorized to create tokens for: {0}".format(data['user']['username'])), 403 return (
dict(
message="You are not authorized to create tokens for: {0}".format(
data["user"]["username"]
)
),
403,
)
access_token = service.create(name=data['name'], user_id=data['user']['id'], ttl=data['ttl'], access_token = service.create(
revoked=False, issued_at=int(datetime.utcnow().timestamp())) name=data["name"],
return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)) user_id=data["user"]["id"],
ttl=data["ttl"],
revoked=False,
issued_at=int(datetime.utcnow().timestamp()),
)
return dict(
jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)
)
class ApiKeyUserList(AuthenticatedResource): class ApiKeyUserList(AuthenticatedResource):
@ -186,9 +206,9 @@ class ApiKeyUserList(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
args = parser.parse_args() args = parser.parse_args()
args['has_permission'] = ApiKeyCreatorPermission().can() args["has_permission"] = ApiKeyCreatorPermission().can()
args['requesting_user_id'] = g.current_user.id args["requesting_user_id"] = g.current_user.id
args['user_id'] = user_id args["user_id"] = user_id
return service.render(args) return service.render(args)
@validate_schema(user_api_key_input_schema, api_key_output_schema) @validate_schema(user_api_key_input_schema, api_key_output_schema)
@ -230,11 +250,25 @@ class ApiKeyUserList(AuthenticatedResource):
""" """
if not ApiKeyCreatorPermission().can(): if not ApiKeyCreatorPermission().can():
if user_id != g.current_user.id: if user_id != g.current_user.id:
return dict(message="You are not authorized to create tokens for: {0}".format(user_id)), 403 return (
dict(
message="You are not authorized to create tokens for: {0}".format(
user_id
)
),
403,
)
access_token = service.create(name=data['name'], user_id=user_id, ttl=data['ttl'], access_token = service.create(
revoked=False, issued_at=int(datetime.utcnow().timestamp())) name=data["name"],
return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)) user_id=user_id,
ttl=data["ttl"],
revoked=False,
issued_at=int(datetime.utcnow().timestamp()),
)
return dict(
jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)
)
class ApiKeys(AuthenticatedResource): class ApiKeys(AuthenticatedResource):
@ -329,7 +363,9 @@ class ApiKeys(AuthenticatedResource):
if not ApiKeyCreatorPermission().can(): if not ApiKeyCreatorPermission().can():
return dict(message="You are not authorized to update this token!"), 403 return dict(message="You are not authorized to update this token!"), 403
service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl']) service.update(
access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"]
)
return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl)) return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl))
def delete(self, aid): def delete(self, aid):
@ -371,7 +407,7 @@ class ApiKeys(AuthenticatedResource):
return dict(message="You are not authorized to delete this token!"), 403 return dict(message="You are not authorized to delete this token!"), 403
service.delete(access_key) service.delete(access_key)
return {'result': True} return {"result": True}
class UserApiKeys(AuthenticatedResource): class UserApiKeys(AuthenticatedResource):
@ -472,7 +508,9 @@ class UserApiKeys(AuthenticatedResource):
if access_key.user_id != uid: if access_key.user_id != uid:
return dict(message="You are not authorized to update this token!"), 403 return dict(message="You are not authorized to update this token!"), 403
service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl']) service.update(
access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"]
)
return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl)) return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl))
def delete(self, uid, aid): def delete(self, uid, aid):
@ -517,7 +555,7 @@ class UserApiKeys(AuthenticatedResource):
return dict(message="You are not authorized to delete this token!"), 403 return dict(message="You are not authorized to delete this token!"), 403
service.delete(access_key) service.delete(access_key)
return {'result': True} return {"result": True}
class ApiKeysDescribed(AuthenticatedResource): class ApiKeysDescribed(AuthenticatedResource):
@ -572,8 +610,12 @@ class ApiKeysDescribed(AuthenticatedResource):
return access_key return access_key
api.add_resource(ApiKeyList, '/keys', endpoint='api_keys') api.add_resource(ApiKeyList, "/keys", endpoint="api_keys")
api.add_resource(ApiKeys, '/keys/<int:aid>', endpoint='api_key') api.add_resource(ApiKeys, "/keys/<int:aid>", endpoint="api_key")
api.add_resource(ApiKeysDescribed, '/keys/<int:aid>/described', endpoint='api_key_described') api.add_resource(
api.add_resource(ApiKeyUserList, '/users/<int:user_id>/keys', endpoint='user_api_keys') ApiKeysDescribed, "/keys/<int:aid>/described", endpoint="api_key_described"
api.add_resource(UserApiKeys, '/users/<int:uid>/keys/<int:aid>', endpoint='user_api_key') )
api.add_resource(ApiKeyUserList, "/users/<int:user_id>/keys", endpoint="user_api_keys")
api.add_resource(
UserApiKeys, "/users/<int:uid>/keys/<int:aid>", endpoint="user_api_key"
)

View File

@ -14,35 +14,41 @@ from lemur.roles import service as role_service
from lemur.common.utils import validate_conf, get_psuedo_random_string from lemur.common.utils import validate_conf, get_psuedo_random_string
class LdapPrincipal(): class LdapPrincipal:
""" """
Provides methods for authenticating against an LDAP server. Provides methods for authenticating against an LDAP server.
""" """
def __init__(self, args): def __init__(self, args):
self._ldap_validate_conf() self._ldap_validate_conf()
# setup ldap config # setup ldap config
if not args['username']: if not args["username"]:
raise Exception("missing ldap username") raise Exception("missing ldap username")
if not args['password']: if not args["password"]:
self.error_message = "missing ldap password" self.error_message = "missing ldap password"
raise Exception("missing ldap password") raise Exception("missing ldap password")
self.ldap_principal = args['username'] self.ldap_principal = args["username"]
self.ldap_email_domain = current_app.config.get("LDAP_EMAIL_DOMAIN", None) self.ldap_email_domain = current_app.config.get("LDAP_EMAIL_DOMAIN", None)
if '@' not in self.ldap_principal: if "@" not in self.ldap_principal:
self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain) self.ldap_principal = "%s@%s" % (
self.ldap_username = args['username'] self.ldap_principal,
if '@' in self.ldap_username: self.ldap_email_domain,
self.ldap_username = args['username'].split("@")[0] )
self.ldap_password = args['password'] self.ldap_username = args["username"]
self.ldap_server = current_app.config.get('LDAP_BIND_URI', None) if "@" in self.ldap_username:
self.ldap_username = args["username"].split("@")[0]
self.ldap_password = args["password"]
self.ldap_server = current_app.config.get("LDAP_BIND_URI", None)
self.ldap_base_dn = current_app.config.get("LDAP_BASE_DN", None) self.ldap_base_dn = current_app.config.get("LDAP_BASE_DN", None)
self.ldap_use_tls = current_app.config.get("LDAP_USE_TLS", False) self.ldap_use_tls = current_app.config.get("LDAP_USE_TLS", False)
self.ldap_cacert_file = current_app.config.get("LDAP_CACERT_FILE", None) self.ldap_cacert_file = current_app.config.get("LDAP_CACERT_FILE", None)
self.ldap_default_role = current_app.config.get("LEMUR_DEFAULT_ROLE", None) self.ldap_default_role = current_app.config.get("LEMUR_DEFAULT_ROLE", None)
self.ldap_required_group = current_app.config.get("LDAP_REQUIRED_GROUP", None) self.ldap_required_group = current_app.config.get("LDAP_REQUIRED_GROUP", None)
self.ldap_groups_to_roles = current_app.config.get("LDAP_GROUPS_TO_ROLES", None) self.ldap_groups_to_roles = current_app.config.get("LDAP_GROUPS_TO_ROLES", None)
self.ldap_is_active_directory = current_app.config.get("LDAP_IS_ACTIVE_DIRECTORY", False) self.ldap_is_active_directory = current_app.config.get(
self.ldap_attrs = ['memberOf'] "LDAP_IS_ACTIVE_DIRECTORY", False
)
self.ldap_attrs = ["memberOf"]
self.ldap_client = None self.ldap_client = None
self.ldap_groups = None self.ldap_groups = None
@ -60,8 +66,8 @@ class LdapPrincipal():
get_psuedo_random_string(), get_psuedo_random_string(),
self.ldap_principal, self.ldap_principal,
True, True,
'', # thumbnailPhotoUrl "", # thumbnailPhotoUrl
list(roles) list(roles),
) )
else: else:
# we add 'lemur' specific roles, so they do not get marked as removed # we add 'lemur' specific roles, so they do not get marked as removed
@ -76,7 +82,7 @@ class LdapPrincipal():
self.ldap_principal, self.ldap_principal,
user.active, user.active,
user.profile_picture, user.profile_picture,
list(roles) list(roles),
) )
return user return user
@ -105,9 +111,12 @@ class LdapPrincipal():
# update their 'roles' # update their 'roles'
role = role_service.get_by_name(self.ldap_principal) role = role_service.get_by_name(self.ldap_principal)
if not role: if not role:
description = "auto generated role based on owner: {0}".format(self.ldap_principal) description = "auto generated role based on owner: {0}".format(
role = role_service.create(self.ldap_principal, description=description, self.ldap_principal
third_party=True) )
role = role_service.create(
self.ldap_principal, description=description, third_party=True
)
if not role.third_party: if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True) role = role_service.set_third_party(role.id, third_party_status=True)
roles.add(role) roles.add(role)
@ -118,9 +127,15 @@ class LdapPrincipal():
role = role_service.get_by_name(role_name) role = role_service.get_by_name(role_name)
if role: if role:
if ldap_group_name in self.ldap_groups: if ldap_group_name in self.ldap_groups:
current_app.logger.debug("assigning role {0} to ldap user {1}".format(self.ldap_principal, role)) current_app.logger.debug(
"assigning role {0} to ldap user {1}".format(
self.ldap_principal, role
)
)
if not role.third_party: if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True) role = role_service.set_third_party(
role.id, third_party_status=True
)
roles.add(role) roles.add(role)
return roles return roles
@ -132,7 +147,7 @@ class LdapPrincipal():
self._bind() self._bind()
roles = self._authorize() roles = self._authorize()
if not roles: if not roles:
raise Exception('ldap authorization failed') raise Exception("ldap authorization failed")
return self._update_user(roles) return self._update_user(roles)
def _bind(self): def _bind(self):
@ -141,9 +156,12 @@ class LdapPrincipal():
list groups for a user. list groups for a user.
raise an exception on error. raise an exception on error.
""" """
if '@' not in self.ldap_principal: if "@" not in self.ldap_principal:
self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain) self.ldap_principal = "%s@%s" % (
ldap_filter = 'userPrincipalName=%s' % self.ldap_principal self.ldap_principal,
self.ldap_email_domain,
)
ldap_filter = "userPrincipalName=%s" % self.ldap_principal
# query ldap for auth # query ldap for auth
try: try:
@ -159,37 +177,47 @@ class LdapPrincipal():
self.ldap_client.set_option(ldap.OPT_X_TLS_DEMAND, True) self.ldap_client.set_option(ldap.OPT_X_TLS_DEMAND, True)
self.ldap_client.set_option(ldap.OPT_DEBUG_LEVEL, 255) self.ldap_client.set_option(ldap.OPT_DEBUG_LEVEL, 255)
if self.ldap_cacert_file: if self.ldap_cacert_file:
self.ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file) self.ldap_client.set_option(
ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file
)
self.ldap_client.simple_bind_s(self.ldap_principal, self.ldap_password) self.ldap_client.simple_bind_s(self.ldap_principal, self.ldap_password)
except ldap.INVALID_CREDENTIALS: except ldap.INVALID_CREDENTIALS:
self.ldap_client.unbind() self.ldap_client.unbind()
raise Exception('The supplied ldap credentials are invalid') raise Exception("The supplied ldap credentials are invalid")
except ldap.SERVER_DOWN: except ldap.SERVER_DOWN:
raise Exception('ldap server unavailable') raise Exception("ldap server unavailable")
except ldap.LDAPError as e: except ldap.LDAPError as e:
raise Exception("ldap error: {0}".format(e)) raise Exception("ldap error: {0}".format(e))
if self.ldap_is_active_directory: if self.ldap_is_active_directory:
# Lookup user DN, needed to search for group membership # Lookup user DN, needed to search for group membership
userdn = self.ldap_client.search_s(self.ldap_base_dn, userdn = self.ldap_client.search_s(
ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_base_dn,
['distinguishedName'])[0][1]['distinguishedName'][0] ldap.SCOPE_SUBTREE,
userdn = userdn.decode('utf-8') ldap_filter,
["distinguishedName"],
)[0][1]["distinguishedName"][0]
userdn = userdn.decode("utf-8")
# Search all groups that have the userDN as a member # Search all groups that have the userDN as a member
groupfilter = '(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))'.format(userdn) groupfilter = "(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))".format(
lgroups = self.ldap_client.search_s(self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ['cn']) userdn
)
lgroups = self.ldap_client.search_s(
self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ["cn"]
)
# Create a list of group CN's from the result # Create a list of group CN's from the result
self.ldap_groups = [] self.ldap_groups = []
for group in lgroups: for group in lgroups:
(dn, values) = group (dn, values) = group
self.ldap_groups.append(values['cn'][0].decode('ascii')) self.ldap_groups.append(values["cn"][0].decode("ascii"))
else: else:
lgroups = self.ldap_client.search_s(self.ldap_base_dn, lgroups = self.ldap_client.search_s(
ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs)[0][1]['memberOf'] self.ldap_base_dn, ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs
)[0][1]["memberOf"]
# lgroups is a list of utf-8 encoded strings # lgroups is a list of utf-8 encoded strings
# convert to a single string of groups to allow matching # convert to a single string of groups to allow matching
self.ldap_groups = b''.join(lgroups).decode('ascii') self.ldap_groups = b"".join(lgroups).decode("ascii")
self.ldap_client.unbind() self.ldap_client.unbind()
@ -197,9 +225,5 @@ class LdapPrincipal():
""" """
Confirms required ldap config settings exist. Confirms required ldap config settings exist.
""" """
required_vars = [ required_vars = ["LDAP_BIND_URI", "LDAP_BASE_DN", "LDAP_EMAIL_DOMAIN"]
'LDAP_BIND_URI',
'LDAP_BASE_DN',
'LDAP_EMAIL_DOMAIN',
]
validate_conf(current_app, required_vars) validate_conf(current_app, required_vars)

View File

@ -12,21 +12,21 @@ from collections import namedtuple
from flask_principal import Permission, RoleNeed from flask_principal import Permission, RoleNeed
# Permissions # Permissions
operator_permission = Permission(RoleNeed('operator')) operator_permission = Permission(RoleNeed("operator"))
admin_permission = Permission(RoleNeed('admin')) admin_permission = Permission(RoleNeed("admin"))
CertificateOwner = namedtuple('certificate', ['method', 'value']) CertificateOwner = namedtuple("certificate", ["method", "value"])
CertificateOwnerNeed = partial(CertificateOwner, 'role') CertificateOwnerNeed = partial(CertificateOwner, "role")
class SensitiveDomainPermission(Permission): class SensitiveDomainPermission(Permission):
def __init__(self): def __init__(self):
super(SensitiveDomainPermission, self).__init__(RoleNeed('admin')) super(SensitiveDomainPermission, self).__init__(RoleNeed("admin"))
class CertificatePermission(Permission): class CertificatePermission(Permission):
def __init__(self, owner, roles): def __init__(self, owner, roles):
needs = [RoleNeed('admin'), RoleNeed(owner), RoleNeed('creator')] needs = [RoleNeed("admin"), RoleNeed(owner), RoleNeed("creator")]
for r in roles: for r in roles:
needs.append(CertificateOwnerNeed(str(r))) needs.append(CertificateOwnerNeed(str(r)))
# Backwards compatibility with mixed-case role names # Backwards compatibility with mixed-case role names
@ -38,29 +38,29 @@ class CertificatePermission(Permission):
class ApiKeyCreatorPermission(Permission): class ApiKeyCreatorPermission(Permission):
def __init__(self): def __init__(self):
super(ApiKeyCreatorPermission, self).__init__(RoleNeed('admin')) super(ApiKeyCreatorPermission, self).__init__(RoleNeed("admin"))
RoleMember = namedtuple('role', ['method', 'value']) RoleMember = namedtuple("role", ["method", "value"])
RoleMemberNeed = partial(RoleMember, 'member') RoleMemberNeed = partial(RoleMember, "member")
class RoleMemberPermission(Permission): class RoleMemberPermission(Permission):
def __init__(self, role_id): def __init__(self, role_id):
needs = [RoleNeed('admin'), RoleMemberNeed(role_id)] needs = [RoleNeed("admin"), RoleMemberNeed(role_id)]
super(RoleMemberPermission, self).__init__(*needs) super(RoleMemberPermission, self).__init__(*needs)
AuthorityCreator = namedtuple('authority', ['method', 'value']) AuthorityCreator = namedtuple("authority", ["method", "value"])
AuthorityCreatorNeed = partial(AuthorityCreator, 'authorityUse') AuthorityCreatorNeed = partial(AuthorityCreator, "authorityUse")
AuthorityOwner = namedtuple('authority', ['method', 'value']) AuthorityOwner = namedtuple("authority", ["method", "value"])
AuthorityOwnerNeed = partial(AuthorityOwner, 'role') AuthorityOwnerNeed = partial(AuthorityOwner, "role")
class AuthorityPermission(Permission): class AuthorityPermission(Permission):
def __init__(self, authority_id, roles): def __init__(self, authority_id, roles):
needs = [RoleNeed('admin'), AuthorityCreatorNeed(str(authority_id))] needs = [RoleNeed("admin"), AuthorityCreatorNeed(str(authority_id))]
for r in roles: for r in roles:
needs.append(AuthorityOwnerNeed(str(r))) needs.append(AuthorityOwnerNeed(str(r)))

View File

@ -39,13 +39,13 @@ def get_rsa_public_key(n, e):
:param e: :param e:
:return: a RSA Public Key in PEM format :return: a RSA Public Key in PEM format
""" """
n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, 'utf-8'))), 16) n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, "utf-8"))), 16)
e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, 'utf-8'))), 16) e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, "utf-8"))), 16)
pub = RSAPublicNumbers(e, n).public_key(default_backend()) pub = RSAPublicNumbers(e, n).public_key(default_backend())
return pub.public_bytes( return pub.public_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo format=serialization.PublicFormat.SubjectPublicKeyInfo,
) )
@ -57,28 +57,27 @@ def create_token(user, aid=None, ttl=None):
:param user: :param user:
:return: :return:
""" """
expiration_delta = timedelta(days=int(current_app.config.get('LEMUR_TOKEN_EXPIRATION', 1))) expiration_delta = timedelta(
payload = { days=int(current_app.config.get("LEMUR_TOKEN_EXPIRATION", 1))
'iat': datetime.utcnow(), )
'exp': datetime.utcnow() + expiration_delta payload = {"iat": datetime.utcnow(), "exp": datetime.utcnow() + expiration_delta}
}
# Handle Just a User ID & User Object. # Handle Just a User ID & User Object.
if isinstance(user, int): if isinstance(user, int):
payload['sub'] = user payload["sub"] = user
else: else:
payload['sub'] = user.id payload["sub"] = user.id
if aid is not None: if aid is not None:
payload['aid'] = aid payload["aid"] = aid
# Custom TTLs are only supported on Access Keys. # Custom TTLs are only supported on Access Keys.
if ttl is not None and aid is not None: if ttl is not None and aid is not None:
# Tokens that are forever until revoked. # Tokens that are forever until revoked.
if ttl == -1: if ttl == -1:
del payload['exp'] del payload["exp"]
else: else:
payload['exp'] = ttl payload["exp"] = ttl
token = jwt.encode(payload, current_app.config['LEMUR_TOKEN_SECRET']) token = jwt.encode(payload, current_app.config["LEMUR_TOKEN_SECRET"])
return token.decode('unicode_escape') return token.decode("unicode_escape")
def login_required(f): def login_required(f):
@ -88,49 +87,54 @@ def login_required(f):
:param f: :param f:
:return: :return:
""" """
@wraps(f) @wraps(f)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
if not request.headers.get('Authorization'): if not request.headers.get("Authorization"):
response = jsonify(message='Missing authorization header') response = jsonify(message="Missing authorization header")
response.status_code = 401 response.status_code = 401
return response return response
try: try:
token = request.headers.get('Authorization').split()[1] token = request.headers.get("Authorization").split()[1]
except Exception as e: except Exception as e:
return dict(message='Token is invalid'), 403 return dict(message="Token is invalid"), 403
try: try:
payload = jwt.decode(token, current_app.config['LEMUR_TOKEN_SECRET']) payload = jwt.decode(token, current_app.config["LEMUR_TOKEN_SECRET"])
except jwt.DecodeError: except jwt.DecodeError:
return dict(message='Token is invalid'), 403 return dict(message="Token is invalid"), 403
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
return dict(message='Token has expired'), 403 return dict(message="Token has expired"), 403
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return dict(message='Token is invalid'), 403 return dict(message="Token is invalid"), 403
if 'aid' in payload: if "aid" in payload:
access_key = api_key_service.get(payload['aid']) access_key = api_key_service.get(payload["aid"])
if access_key.revoked: if access_key.revoked:
return dict(message='Token has been revoked'), 403 return dict(message="Token has been revoked"), 403
if access_key.ttl != -1: if access_key.ttl != -1:
current_time = datetime.utcnow() current_time = datetime.utcnow()
expired_time = datetime.fromtimestamp(access_key.issued_at + access_key.ttl) expired_time = datetime.fromtimestamp(
access_key.issued_at + access_key.ttl
)
if current_time >= expired_time: if current_time >= expired_time:
return dict(message='Token has expired'), 403 return dict(message="Token has expired"), 403
user = user_service.get(payload['sub']) user = user_service.get(payload["sub"])
if not user.active: if not user.active:
return dict(message='User is not currently active'), 403 return dict(message="User is not currently active"), 403
g.current_user = user g.current_user = user
if not g.current_user: if not g.current_user:
return dict(message='You are not logged in'), 403 return dict(message="You are not logged in"), 403
# Tell Flask-Principal the identity changed # Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity=Identity(g.current_user.id)) identity_changed.send(
current_app._get_current_object(), identity=Identity(g.current_user.id)
)
return f(*args, **kwargs) return f(*args, **kwargs)
@ -144,18 +148,18 @@ def fetch_token_header(token):
:param token: :param token:
:return: :raise jwt.DecodeError: :return: :raise jwt.DecodeError:
""" """
token = token.encode('utf-8') token = token.encode("utf-8")
try: try:
signing_input, crypto_segment = token.rsplit(b'.', 1) signing_input, crypto_segment = token.rsplit(b".", 1)
header_segment, payload_segment = signing_input.split(b'.', 1) header_segment, payload_segment = signing_input.split(b".", 1)
except ValueError: except ValueError:
raise jwt.DecodeError('Not enough segments') raise jwt.DecodeError("Not enough segments")
try: try:
return json.loads(jwt.utils.base64url_decode(header_segment).decode('utf-8')) return json.loads(jwt.utils.base64url_decode(header_segment).decode("utf-8"))
except TypeError as e: except TypeError as e:
current_app.logger.exception(e) current_app.logger.exception(e)
raise jwt.DecodeError('Invalid header padding') raise jwt.DecodeError("Invalid header padding")
@identity_loaded.connect @identity_loaded.connect
@ -174,13 +178,13 @@ def on_identity_loaded(sender, identity):
identity.provides.add(UserNeed(identity.id)) identity.provides.add(UserNeed(identity.id))
# identity with the roles that the user provides # identity with the roles that the user provides
if hasattr(user, 'roles'): if hasattr(user, "roles"):
for role in user.roles: for role in user.roles:
identity.provides.add(RoleNeed(role.name)) identity.provides.add(RoleNeed(role.name))
identity.provides.add(RoleMemberNeed(role.id)) identity.provides.add(RoleMemberNeed(role.id))
# apply ownership for authorities # apply ownership for authorities
if hasattr(user, 'authorities'): if hasattr(user, "authorities"):
for authority in user.authorities: for authority in user.authorities:
identity.provides.add(AuthorityCreatorNeed(authority.id)) identity.provides.add(AuthorityCreatorNeed(authority.id))
@ -191,6 +195,7 @@ class AuthenticatedResource(Resource):
""" """
Inherited by all resources that need to be protected by authentication. Inherited by all resources that need to be protected by authentication.
""" """
method_decorators = [login_required] method_decorators = [login_required]
def __init__(self): def __init__(self):

View File

@ -24,11 +24,13 @@ from lemur.auth.service import create_token, fetch_token_header, get_rsa_public_
from lemur.auth import ldap from lemur.auth import ldap
mod = Blueprint('auth', __name__) mod = Blueprint("auth", __name__)
api = Api(mod) api = Api(mod)
def exchange_for_access_token(code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True): def exchange_for_access_token(
code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True
):
""" """
Exchanges authorization code for access token. Exchanges authorization code for access token.
@ -43,28 +45,32 @@ def exchange_for_access_token(code, redirect_uri, client_id, secret, access_toke
""" """
# take the information we have received from the provider to create a new request # take the information we have received from the provider to create a new request
params = { params = {
'grant_type': 'authorization_code', "grant_type": "authorization_code",
'scope': 'openid email profile address', "scope": "openid email profile address",
'code': code, "code": code,
'redirect_uri': redirect_uri, "redirect_uri": redirect_uri,
'client_id': client_id "client_id": client_id,
} }
# the secret and cliendId will be given to you when you signup for the provider # the secret and cliendId will be given to you when you signup for the provider
token = '{0}:{1}'.format(client_id, secret) token = "{0}:{1}".format(client_id, secret)
basic = base64.b64encode(bytes(token, 'utf-8')) basic = base64.b64encode(bytes(token, "utf-8"))
headers = { headers = {
'Content-Type': 'application/x-www-form-urlencoded', "Content-Type": "application/x-www-form-urlencoded",
'authorization': 'basic {0}'.format(basic.decode('utf-8')) "authorization": "basic {0}".format(basic.decode("utf-8")),
} }
# exchange authorization code for access token. # exchange authorization code for access token.
r = requests.post(access_token_url, headers=headers, params=params, verify=verify_cert) r = requests.post(
access_token_url, headers=headers, params=params, verify=verify_cert
)
if r.status_code == 400: if r.status_code == 400:
r = requests.post(access_token_url, headers=headers, data=params, verify=verify_cert) r = requests.post(
id_token = r.json()['id_token'] access_token_url, headers=headers, data=params, verify=verify_cert
access_token = r.json()['access_token'] )
id_token = r.json()["id_token"]
access_token = r.json()["access_token"]
return id_token, access_token return id_token, access_token
@ -83,23 +89,25 @@ def validate_id_token(id_token, client_id, jwks_url):
# retrieve the key material as specified by the token header # retrieve the key material as specified by the token header
r = requests.get(jwks_url) r = requests.get(jwks_url)
for key in r.json()['keys']: for key in r.json()["keys"]:
if key['kid'] == header_data['kid']: if key["kid"] == header_data["kid"]:
secret = get_rsa_public_key(key['n'], key['e']) secret = get_rsa_public_key(key["n"], key["e"])
algo = header_data['alg'] algo = header_data["alg"]
break break
else: else:
return dict(message='Key not found'), 401 return dict(message="Key not found"), 401
# validate your token based on the key it was signed with # validate your token based on the key it was signed with
try: try:
jwt.decode(id_token, secret.decode('utf-8'), algorithms=[algo], audience=client_id) jwt.decode(
id_token, secret.decode("utf-8"), algorithms=[algo], audience=client_id
)
except jwt.DecodeError: except jwt.DecodeError:
return dict(message='Token is invalid'), 401 return dict(message="Token is invalid"), 401
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
return dict(message='Token has expired'), 401 return dict(message="Token has expired"), 401
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return dict(message='Token is invalid'), 401 return dict(message="Token is invalid"), 401
def retrieve_user(user_api_url, access_token): def retrieve_user(user_api_url, access_token):
@ -110,22 +118,18 @@ def retrieve_user(user_api_url, access_token):
:param access_token: :param access_token:
:return: :return:
""" """
user_params = dict(access_token=access_token, schema='profile') user_params = dict(access_token=access_token, schema="profile")
headers = {} headers = {}
if current_app.config.get('PING_INCLUDE_BEARER_TOKEN'): if current_app.config.get("PING_INCLUDE_BEARER_TOKEN"):
headers = {'Authorization': f'Bearer {access_token}'} headers = {"Authorization": f"Bearer {access_token}"}
# retrieve information about the current user. # retrieve information about the current user.
r = requests.get( r = requests.get(user_api_url, params=user_params, headers=headers)
user_api_url,
params=user_params,
headers=headers,
)
profile = r.json() profile = r.json()
user = user_service.get_by_email(profile['email']) user = user_service.get_by_email(profile["email"])
return user, profile return user, profile
@ -138,31 +142,44 @@ def create_user_roles(profile):
roles = [] roles = []
# update their google 'roles' # update their google 'roles'
if 'googleGroups' in profile: if "googleGroups" in profile:
for group in profile['googleGroups']: for group in profile["googleGroups"]:
role = role_service.get_by_name(group) role = role_service.get_by_name(group)
if not role: if not role:
role = role_service.create(group, description='This is a google group based role created by Lemur', third_party=True) role = role_service.create(
group,
description="This is a google group based role created by Lemur",
third_party=True,
)
if not role.third_party: if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True) role = role_service.set_third_party(role.id, third_party_status=True)
roles.append(role) roles.append(role)
else: else:
current_app.logger.warning("'googleGroups' not sent by identity provider, no specific roles will assigned to the user.") current_app.logger.warning(
"'googleGroups' not sent by identity provider, no specific roles will assigned to the user."
)
role = role_service.get_by_name(profile['email']) role = role_service.get_by_name(profile["email"])
if not role: if not role:
role = role_service.create(profile['email'], description='This is a user specific role', third_party=True) role = role_service.create(
profile["email"],
description="This is a user specific role",
third_party=True,
)
if not role.third_party: if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True) role = role_service.set_third_party(role.id, third_party_status=True)
roles.append(role) roles.append(role)
# every user is an operator (tied to a default role) # every user is an operator (tied to a default role)
if current_app.config.get('LEMUR_DEFAULT_ROLE'): if current_app.config.get("LEMUR_DEFAULT_ROLE"):
default = role_service.get_by_name(current_app.config['LEMUR_DEFAULT_ROLE']) default = role_service.get_by_name(current_app.config["LEMUR_DEFAULT_ROLE"])
if not default: if not default:
default = role_service.create(current_app.config['LEMUR_DEFAULT_ROLE'], description='This is the default Lemur role.') default = role_service.create(
current_app.config["LEMUR_DEFAULT_ROLE"],
description="This is the default Lemur role.",
)
if not default.third_party: if not default.third_party:
role_service.set_third_party(default.id, third_party_status=True) role_service.set_third_party(default.id, third_party_status=True)
roles.append(default) roles.append(default)
@ -181,12 +198,12 @@ def update_user(user, profile, roles):
# if we get an sso user create them an account # if we get an sso user create them an account
if not user: if not user:
user = user_service.create( user = user_service.create(
profile['email'], profile["email"],
get_psuedo_random_string(), get_psuedo_random_string(),
profile['email'], profile["email"],
True, True,
profile.get('thumbnailPhotoUrl'), profile.get("thumbnailPhotoUrl"),
roles roles,
) )
else: else:
@ -198,11 +215,11 @@ def update_user(user, profile, roles):
# update any changes to the user # update any changes to the user
user_service.update( user_service.update(
user.id, user.id,
profile['email'], profile["email"],
profile['email'], profile["email"],
True, True,
profile.get('thumbnailPhotoUrl'), # profile isn't google+ enabled profile.get("thumbnailPhotoUrl"), # profile isn't google+ enabled
roles roles,
) )
@ -223,6 +240,7 @@ class Login(Resource):
on your uses cases but. It is important to not that there is currently no build in method to revoke a users token \ on your uses cases but. It is important to not that there is currently no build in method to revoke a users token \
and force re-authentication. and force re-authentication.
""" """
def __init__(self): def __init__(self):
self.reqparse = reqparse.RequestParser() self.reqparse = reqparse.RequestParser()
super(Login, self).__init__() super(Login, self).__init__()
@ -263,23 +281,26 @@ class Login(Resource):
:statuscode 401: invalid credentials :statuscode 401: invalid credentials
:statuscode 200: no error :statuscode 200: no error
""" """
self.reqparse.add_argument('username', type=str, required=True, location='json') self.reqparse.add_argument("username", type=str, required=True, location="json")
self.reqparse.add_argument('password', type=str, required=True, location='json') self.reqparse.add_argument("password", type=str, required=True, location="json")
args = self.reqparse.parse_args() args = self.reqparse.parse_args()
if '@' in args['username']: if "@" in args["username"]:
user = user_service.get_by_email(args['username']) user = user_service.get_by_email(args["username"])
else: else:
user = user_service.get_by_username(args['username']) user = user_service.get_by_username(args["username"])
# default to local authentication # default to local authentication
if user and user.check_password(args['password']) and user.active: if user and user.check_password(args["password"]) and user.active:
# Tell Flask-Principal the identity changed # Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity_changed.send(
identity=Identity(user.id)) current_app._get_current_object(), identity=Identity(user.id)
)
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user)) return dict(token=create_token(user))
# try ldap login # try ldap login
@ -289,19 +310,29 @@ class Login(Resource):
user = ldap_principal.authenticate() user = ldap_principal.authenticate()
if user and user.active: if user and user.active:
# Tell Flask-Principal the identity changed # Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity_changed.send(
identity=Identity(user.id)) current_app._get_current_object(), identity=Identity(user.id)
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) )
metrics.send(
"login",
"counter",
1,
metric_tags={"status": SUCCESS_METRIC_STATUS},
)
return dict(token=create_token(user)) return dict(token=create_token(user))
except Exception as e: except Exception as e:
current_app.logger.error("ldap error: {0}".format(e)) current_app.logger.error("ldap error: {0}".format(e))
ldap_message = 'ldap error: %s' % e ldap_message = "ldap error: %s" % e
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message=ldap_message), 403 return dict(message=ldap_message), 403
# if not valid user - no certificates for you # if not valid user - no certificates for you
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) metrics.send(
return dict(message='The supplied credentials are invalid'), 403 "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid"), 403
class Ping(Resource): class Ping(Resource):
@ -314,36 +345,39 @@ class Ping(Resource):
provider uses for its callbacks. provider uses for its callbacks.
2. Add or change the Lemur AngularJS Configuration to point to your new provider 2. Add or change the Lemur AngularJS Configuration to point to your new provider
""" """
def __init__(self): def __init__(self):
self.reqparse = reqparse.RequestParser() self.reqparse = reqparse.RequestParser()
super(Ping, self).__init__() super(Ping, self).__init__()
def get(self): def get(self):
return 'Redirecting...' return "Redirecting..."
def post(self): def post(self):
self.reqparse.add_argument('clientId', type=str, required=True, location='json') self.reqparse.add_argument("clientId", type=str, required=True, location="json")
self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') self.reqparse.add_argument(
self.reqparse.add_argument('code', type=str, required=True, location='json') "redirectUri", type=str, required=True, location="json"
)
self.reqparse.add_argument("code", type=str, required=True, location="json")
args = self.reqparse.parse_args() args = self.reqparse.parse_args()
# you can either discover these dynamically or simply configure them # you can either discover these dynamically or simply configure them
access_token_url = current_app.config.get('PING_ACCESS_TOKEN_URL') access_token_url = current_app.config.get("PING_ACCESS_TOKEN_URL")
user_api_url = current_app.config.get('PING_USER_API_URL') user_api_url = current_app.config.get("PING_USER_API_URL")
secret = current_app.config.get('PING_SECRET') secret = current_app.config.get("PING_SECRET")
id_token, access_token = exchange_for_access_token( id_token, access_token = exchange_for_access_token(
args['code'], args["code"],
args['redirectUri'], args["redirectUri"],
args['clientId'], args["clientId"],
secret, secret,
access_token_url=access_token_url access_token_url=access_token_url,
) )
jwks_url = current_app.config.get('PING_JWKS_URL') jwks_url = current_app.config.get("PING_JWKS_URL")
error_code = validate_id_token(id_token, args['clientId'], jwks_url) error_code = validate_id_token(id_token, args["clientId"], jwks_url)
if error_code: if error_code:
return error_code return error_code
user, profile = retrieve_user(user_api_url, access_token) user, profile = retrieve_user(user_api_url, access_token)
@ -351,13 +385,19 @@ class Ping(Resource):
update_user(user, profile, roles) update_user(user, profile, roles)
if not user or not user.active: if not user or not user.active:
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) metrics.send(
return dict(message='The supplied credentials are invalid'), 403 "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid"), 403
# Tell Flask-Principal the identity changed # Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity=Identity(user.id)) identity_changed.send(
current_app._get_current_object(), identity=Identity(user.id)
)
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user)) return dict(token=create_token(user))
@ -367,33 +407,35 @@ class OAuth2(Resource):
super(OAuth2, self).__init__() super(OAuth2, self).__init__()
def get(self): def get(self):
return 'Redirecting...' return "Redirecting..."
def post(self): def post(self):
self.reqparse.add_argument('clientId', type=str, required=True, location='json') self.reqparse.add_argument("clientId", type=str, required=True, location="json")
self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') self.reqparse.add_argument(
self.reqparse.add_argument('code', type=str, required=True, location='json') "redirectUri", type=str, required=True, location="json"
)
self.reqparse.add_argument("code", type=str, required=True, location="json")
args = self.reqparse.parse_args() args = self.reqparse.parse_args()
# you can either discover these dynamically or simply configure them # you can either discover these dynamically or simply configure them
access_token_url = current_app.config.get('OAUTH2_ACCESS_TOKEN_URL') access_token_url = current_app.config.get("OAUTH2_ACCESS_TOKEN_URL")
user_api_url = current_app.config.get('OAUTH2_USER_API_URL') user_api_url = current_app.config.get("OAUTH2_USER_API_URL")
verify_cert = current_app.config.get('OAUTH2_VERIFY_CERT') verify_cert = current_app.config.get("OAUTH2_VERIFY_CERT")
secret = current_app.config.get('OAUTH2_SECRET') secret = current_app.config.get("OAUTH2_SECRET")
id_token, access_token = exchange_for_access_token( id_token, access_token = exchange_for_access_token(
args['code'], args["code"],
args['redirectUri'], args["redirectUri"],
args['clientId'], args["clientId"],
secret, secret,
access_token_url=access_token_url, access_token_url=access_token_url,
verify_cert=verify_cert verify_cert=verify_cert,
) )
jwks_url = current_app.config.get('PING_JWKS_URL') jwks_url = current_app.config.get("PING_JWKS_URL")
error_code = validate_id_token(id_token, args['clientId'], jwks_url) error_code = validate_id_token(id_token, args["clientId"], jwks_url)
if error_code: if error_code:
return error_code return error_code
@ -402,13 +444,19 @@ class OAuth2(Resource):
update_user(user, profile, roles) update_user(user, profile, roles)
if not user.active: if not user.active:
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) metrics.send(
return dict(message='The supplied credentials are invalid'), 403 "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid"), 403
# Tell Flask-Principal the identity changed # Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity=Identity(user.id)) identity_changed.send(
current_app._get_current_object(), identity=Identity(user.id)
)
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user)) return dict(token=create_token(user))
@ -419,44 +467,52 @@ class Google(Resource):
super(Google, self).__init__() super(Google, self).__init__()
def post(self): def post(self):
access_token_url = 'https://accounts.google.com/o/oauth2/token' access_token_url = "https://accounts.google.com/o/oauth2/token"
people_api_url = 'https://www.googleapis.com/plus/v1/people/me/openIdConnect' people_api_url = "https://www.googleapis.com/plus/v1/people/me/openIdConnect"
self.reqparse.add_argument('clientId', type=str, required=True, location='json') self.reqparse.add_argument("clientId", type=str, required=True, location="json")
self.reqparse.add_argument('redirectUri', type=str, required=True, location='json') self.reqparse.add_argument(
self.reqparse.add_argument('code', type=str, required=True, location='json') "redirectUri", type=str, required=True, location="json"
)
self.reqparse.add_argument("code", type=str, required=True, location="json")
args = self.reqparse.parse_args() args = self.reqparse.parse_args()
# Step 1. Exchange authorization code for access token # Step 1. Exchange authorization code for access token
payload = { payload = {
'client_id': args['clientId'], "client_id": args["clientId"],
'grant_type': 'authorization_code', "grant_type": "authorization_code",
'redirect_uri': args['redirectUri'], "redirect_uri": args["redirectUri"],
'code': args['code'], "code": args["code"],
'client_secret': current_app.config.get('GOOGLE_SECRET') "client_secret": current_app.config.get("GOOGLE_SECRET"),
} }
r = requests.post(access_token_url, data=payload) r = requests.post(access_token_url, data=payload)
token = r.json() token = r.json()
# Step 2. Retrieve information about the current user # Step 2. Retrieve information about the current user
headers = {'Authorization': 'Bearer {0}'.format(token['access_token'])} headers = {"Authorization": "Bearer {0}".format(token["access_token"])}
r = requests.get(people_api_url, headers=headers) r = requests.get(people_api_url, headers=headers)
profile = r.json() profile = r.json()
user = user_service.get_by_email(profile['email']) user = user_service.get_by_email(profile["email"])
if not (user and user.active): if not (user and user.active):
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) metrics.send(
return dict(message='The supplied credentials are invalid.'), 403 "login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid."), 403
if user: if user:
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user)) return dict(token=create_token(user))
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
class Providers(Resource): class Providers(Resource):
@ -467,47 +523,57 @@ class Providers(Resource):
provider = provider.lower() provider = provider.lower()
if provider == "google": if provider == "google":
active_providers.append({ active_providers.append(
'name': 'google', {
'clientId': current_app.config.get("GOOGLE_CLIENT_ID"), "name": "google",
'url': api.url_for(Google) "clientId": current_app.config.get("GOOGLE_CLIENT_ID"),
}) "url": api.url_for(Google),
}
)
elif provider == "ping": elif provider == "ping":
active_providers.append({ active_providers.append(
'name': current_app.config.get("PING_NAME"), {
'url': current_app.config.get('PING_REDIRECT_URI'), "name": current_app.config.get("PING_NAME"),
'redirectUri': current_app.config.get("PING_REDIRECT_URI"), "url": current_app.config.get("PING_REDIRECT_URI"),
'clientId': current_app.config.get("PING_CLIENT_ID"), "redirectUri": current_app.config.get("PING_REDIRECT_URI"),
'responseType': 'code', "clientId": current_app.config.get("PING_CLIENT_ID"),
'scope': ['openid', 'email', 'profile', 'address'], "responseType": "code",
'scopeDelimiter': ' ', "scope": ["openid", "email", "profile", "address"],
'authorizationEndpoint': current_app.config.get("PING_AUTH_ENDPOINT"), "scopeDelimiter": " ",
'requiredUrlParams': ['scope'], "authorizationEndpoint": current_app.config.get(
'type': '2.0' "PING_AUTH_ENDPOINT"
}) ),
"requiredUrlParams": ["scope"],
"type": "2.0",
}
)
elif provider == "oauth2": elif provider == "oauth2":
active_providers.append({ active_providers.append(
'name': current_app.config.get("OAUTH2_NAME"), {
'url': current_app.config.get('OAUTH2_REDIRECT_URI'), "name": current_app.config.get("OAUTH2_NAME"),
'redirectUri': current_app.config.get("OAUTH2_REDIRECT_URI"), "url": current_app.config.get("OAUTH2_REDIRECT_URI"),
'clientId': current_app.config.get("OAUTH2_CLIENT_ID"), "redirectUri": current_app.config.get("OAUTH2_REDIRECT_URI"),
'responseType': 'code', "clientId": current_app.config.get("OAUTH2_CLIENT_ID"),
'scope': ['openid', 'email', 'profile', 'groups'], "responseType": "code",
'scopeDelimiter': ' ', "scope": ["openid", "email", "profile", "groups"],
'authorizationEndpoint': current_app.config.get("OAUTH2_AUTH_ENDPOINT"), "scopeDelimiter": " ",
'requiredUrlParams': ['scope', 'state', 'nonce'], "authorizationEndpoint": current_app.config.get(
'state': 'STATE', "OAUTH2_AUTH_ENDPOINT"
'nonce': get_psuedo_random_string(), ),
'type': '2.0' "requiredUrlParams": ["scope", "state", "nonce"],
}) "state": "STATE",
"nonce": get_psuedo_random_string(),
"type": "2.0",
}
)
return active_providers return active_providers
api.add_resource(Login, '/auth/login', endpoint='login') api.add_resource(Login, "/auth/login", endpoint="login")
api.add_resource(Ping, '/auth/ping', endpoint='ping') api.add_resource(Ping, "/auth/ping", endpoint="ping")
api.add_resource(Google, '/auth/google', endpoint='google') api.add_resource(Google, "/auth/google", endpoint="google")
api.add_resource(OAuth2, '/auth/oauth2', endpoint='oauth2') api.add_resource(OAuth2, "/auth/oauth2", endpoint="oauth2")
api.add_resource(Providers, '/auth/providers', endpoint='providers') api.add_resource(Providers, "/auth/providers", endpoint="providers")

View File

@ -7,7 +7,17 @@
.. moduleauthor:: Kevin Glisson <kglisson@netflix.com> .. moduleauthor:: Kevin Glisson <kglisson@netflix.com>
""" """
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy import Column, Integer, String, Text, func, ForeignKey, DateTime, PassiveDefault, Boolean from sqlalchemy import (
Column,
Integer,
String,
Text,
func,
ForeignKey,
DateTime,
PassiveDefault,
Boolean,
)
from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.dialects.postgresql import JSON
from lemur.database import db from lemur.database import db
@ -16,7 +26,7 @@ from lemur.models import roles_authorities
class Authority(db.Model): class Authority(db.Model):
__tablename__ = 'authorities' __tablename__ = "authorities"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
owner = Column(String(128), nullable=False) owner = Column(String(128), nullable=False)
name = Column(String(128), unique=True) name = Column(String(128), unique=True)
@ -27,22 +37,44 @@ class Authority(db.Model):
description = Column(Text) description = Column(Text)
options = Column(JSON) options = Column(JSON)
date_created = Column(DateTime, PassiveDefault(func.now()), nullable=False) date_created = Column(DateTime, PassiveDefault(func.now()), nullable=False)
roles = relationship('Role', secondary=roles_authorities, passive_deletes=True, backref=db.backref('authority'), lazy='dynamic') roles = relationship(
user_id = Column(Integer, ForeignKey('users.id')) "Role",
authority_certificate = relationship("Certificate", backref='root_authority', uselist=False, foreign_keys='Certificate.root_authority_id') secondary=roles_authorities,
certificates = relationship("Certificate", backref='authority', foreign_keys='Certificate.authority_id') passive_deletes=True,
backref=db.backref("authority"),
lazy="dynamic",
)
user_id = Column(Integer, ForeignKey("users.id"))
authority_certificate = relationship(
"Certificate",
backref="root_authority",
uselist=False,
foreign_keys="Certificate.root_authority_id",
)
certificates = relationship(
"Certificate", backref="authority", foreign_keys="Certificate.authority_id"
)
authority_pending_certificate = relationship("PendingCertificate", backref='root_authority', uselist=False, foreign_keys='PendingCertificate.root_authority_id') authority_pending_certificate = relationship(
pending_certificates = relationship('PendingCertificate', backref='authority', foreign_keys='PendingCertificate.authority_id') "PendingCertificate",
backref="root_authority",
uselist=False,
foreign_keys="PendingCertificate.root_authority_id",
)
pending_certificates = relationship(
"PendingCertificate",
backref="authority",
foreign_keys="PendingCertificate.authority_id",
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.owner = kwargs['owner'] self.owner = kwargs["owner"]
self.roles = kwargs.get('roles', []) self.roles = kwargs.get("roles", [])
self.name = kwargs.get('name') self.name = kwargs.get("name")
self.description = kwargs.get('description') self.description = kwargs.get("description")
self.authority_certificate = kwargs['authority_certificate'] self.authority_certificate = kwargs["authority_certificate"]
self.plugin_name = kwargs['plugin']['slug'] self.plugin_name = kwargs["plugin"]["slug"]
self.options = kwargs.get('options') self.options = kwargs.get("options")
@property @property
def plugin(self): def plugin(self):

View File

@ -11,7 +11,13 @@ from marshmallow import fields, validates_schema, pre_load
from marshmallow import validate from marshmallow import validate
from marshmallow.exceptions import ValidationError from marshmallow.exceptions import ValidationError
from lemur.schemas import PluginInputSchema, PluginOutputSchema, ExtensionSchema, AssociatedAuthoritySchema, AssociatedRoleSchema from lemur.schemas import (
PluginInputSchema,
PluginOutputSchema,
ExtensionSchema,
AssociatedAuthoritySchema,
AssociatedRoleSchema,
)
from lemur.users.schemas import UserNestedOutputSchema from lemur.users.schemas import UserNestedOutputSchema
from lemur.common.schema import LemurInputSchema, LemurOutputSchema from lemur.common.schema import LemurInputSchema, LemurOutputSchema
from lemur.common import validators, missing from lemur.common import validators, missing
@ -30,21 +36,36 @@ class AuthorityInputSchema(LemurInputSchema):
validity_years = fields.Integer() validity_years = fields.Integer()
# certificate body fields # certificate body fields
organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT')) organizational_unit = fields.String(
organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION')) missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT")
location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION')) )
country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY')) organization = fields.String(
state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE')) missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION")
)
location = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION")
)
country = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY")
)
state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE"))
plugin = fields.Nested(PluginInputSchema) plugin = fields.Nested(PluginInputSchema)
# signing related options # signing related options
type = fields.String(validate=validate.OneOf(['root', 'subca']), missing='root') type = fields.String(validate=validate.OneOf(["root", "subca"]), missing="root")
parent = fields.Nested(AssociatedAuthoritySchema) parent = fields.Nested(AssociatedAuthoritySchema)
signing_algorithm = fields.String(validate=validate.OneOf(['sha256WithRSA', 'sha1WithRSA']), missing='sha256WithRSA') signing_algorithm = fields.String(
key_type = fields.String(validate=validate.OneOf(['RSA2048', 'RSA4096']), missing='RSA2048') validate=validate.OneOf(["sha256WithRSA", "sha1WithRSA"]),
missing="sha256WithRSA",
)
key_type = fields.String(
validate=validate.OneOf(["RSA2048", "RSA4096"]), missing="RSA2048"
)
key_name = fields.String() key_name = fields.String()
sensitivity = fields.String(validate=validate.OneOf(['medium', 'high']), missing='medium') sensitivity = fields.String(
validate=validate.OneOf(["medium", "high"]), missing="medium"
)
serial_number = fields.Integer() serial_number = fields.Integer()
first_serial = fields.Integer(missing=1) first_serial = fields.Integer(missing=1)
@ -58,9 +79,11 @@ class AuthorityInputSchema(LemurInputSchema):
@validates_schema @validates_schema
def validate_subca(self, data): def validate_subca(self, data):
if data['type'] == 'subca': if data["type"] == "subca":
if not data.get('parent'): if not data.get("parent"):
raise ValidationError("If generating a subca, parent 'authority' must be specified.") raise ValidationError(
"If generating a subca, parent 'authority' must be specified."
)
@pre_load @pre_load
def ensure_dates(self, data): def ensure_dates(self, data):

View File

@ -43,7 +43,7 @@ def mint(**kwargs):
""" """
Creates the authority based on the plugin provided. Creates the authority based on the plugin provided.
""" """
issuer = kwargs['plugin']['plugin_object'] issuer = kwargs["plugin"]["plugin_object"]
values = issuer.create_authority(kwargs) values = issuer.create_authority(kwargs)
# support older plugins # support older plugins
@ -53,7 +53,12 @@ def mint(**kwargs):
elif len(values) == 4: elif len(values) == 4:
body, private_key, chain, roles = values body, private_key, chain, roles = values
roles = create_authority_roles(roles, kwargs['owner'], kwargs['plugin']['plugin_object'].title, kwargs['creator']) roles = create_authority_roles(
roles,
kwargs["owner"],
kwargs["plugin"]["plugin_object"].title,
kwargs["creator"],
)
return body, private_key, chain, roles return body, private_key, chain, roles
@ -66,16 +71,17 @@ def create_authority_roles(roles, owner, plugin_title, creator):
""" """
role_objs = [] role_objs = []
for r in roles: for r in roles:
role = role_service.get_by_name(r['name']) role = role_service.get_by_name(r["name"])
if not role: if not role:
role = role_service.create( role = role_service.create(
r['name'], r["name"],
password=r['password'], password=r["password"],
description="Auto generated role for {0}".format(plugin_title), description="Auto generated role for {0}".format(plugin_title),
username=r['username']) username=r["username"],
)
# the user creating the authority should be able to administer it # the user creating the authority should be able to administer it
if role.username == 'admin': if role.username == "admin":
creator.roles.append(role) creator.roles.append(role)
role_objs.append(role) role_objs.append(role)
@ -84,8 +90,7 @@ def create_authority_roles(roles, owner, plugin_title, creator):
owner_role = role_service.get_by_name(owner) owner_role = role_service.get_by_name(owner)
if not owner_role: if not owner_role:
owner_role = role_service.create( owner_role = role_service.create(
owner, owner, description="Auto generated role based on owner: {0}".format(owner)
description="Auto generated role based on owner: {0}".format(owner)
) )
role_objs.append(owner_role) role_objs.append(owner_role)
@ -98,27 +103,29 @@ def create(**kwargs):
""" """
body, private_key, chain, roles = mint(**kwargs) body, private_key, chain, roles = mint(**kwargs)
kwargs['creator'].roles = list(set(list(kwargs['creator'].roles) + roles)) kwargs["creator"].roles = list(set(list(kwargs["creator"].roles) + roles))
kwargs['body'] = body kwargs["body"] = body
kwargs['private_key'] = private_key kwargs["private_key"] = private_key
kwargs['chain'] = chain kwargs["chain"] = chain
if kwargs.get('roles'): if kwargs.get("roles"):
kwargs['roles'] += roles kwargs["roles"] += roles
else: else:
kwargs['roles'] = roles kwargs["roles"] = roles
cert = upload(**kwargs) cert = upload(**kwargs)
kwargs['authority_certificate'] = cert kwargs["authority_certificate"] = cert
if kwargs.get('plugin', {}).get('plugin_options', []): if kwargs.get("plugin", {}).get("plugin_options", []):
kwargs['options'] = json.dumps(kwargs['plugin']['plugin_options']) kwargs["options"] = json.dumps(kwargs["plugin"]["plugin_options"])
authority = Authority(**kwargs) authority = Authority(**kwargs)
authority = database.create(authority) authority = database.create(authority)
kwargs['creator'].authorities.append(authority) kwargs["creator"].authorities.append(authority)
metrics.send('authority_created', 'counter', 1, metric_tags=dict(owner=authority.owner)) metrics.send(
"authority_created", "counter", 1, metric_tags=dict(owner=authority.owner)
)
return authority return authority
@ -150,7 +157,7 @@ def get_by_name(authority_name):
:param authority_name: :param authority_name:
:return: :return:
""" """
return database.get(Authority, authority_name, field='name') return database.get(Authority, authority_name, field="name")
def get_authority_role(ca_name, creator=None): def get_authority_role(ca_name, creator=None):
@ -173,29 +180,31 @@ def render(args):
:return: :return:
""" """
query = database.session_query(Authority) query = database.session_query(Authority)
filt = args.pop('filter') filt = args.pop("filter")
if filt: if filt:
terms = filt.split(';') terms = filt.split(";")
if 'active' in filt: if "active" in filt:
query = query.filter(Authority.active == truthiness(terms[1])) query = query.filter(Authority.active == truthiness(terms[1]))
elif 'cn' in filt: elif "cn" in filt:
term = '%{0}%'.format(terms[1]) term = "%{0}%".format(terms[1])
sub_query = database.session_query(Certificate.root_authority_id) \ sub_query = (
.filter(Certificate.cn.ilike(term)) \ database.session_query(Certificate.root_authority_id)
.filter(Certificate.cn.ilike(term))
.subquery() .subquery()
)
query = query.filter(Authority.id.in_(sub_query)) query = query.filter(Authority.id.in_(sub_query))
else: else:
query = database.filter(query, Authority, terms) query = database.filter(query, Authority, terms)
# we make sure that a user can only use an authority they either own are a member of - admins can see all # we make sure that a user can only use an authority they either own are a member of - admins can see all
if not args['user'].is_admin: if not args["user"].is_admin:
authority_ids = [] authority_ids = []
for authority in args['user'].authorities: for authority in args["user"].authorities:
authority_ids.append(authority.id) authority_ids.append(authority.id)
for role in args['user'].roles: for role in args["user"].roles:
for authority in role.authorities: for authority in role.authorities:
authority_ids.append(authority.id) authority_ids.append(authority.id)
query = query.filter(Authority.id.in_(authority_ids)) query = query.filter(Authority.id.in_(authority_ids))

View File

@ -16,15 +16,21 @@ from lemur.auth.permissions import AuthorityPermission
from lemur.certificates import service as certificate_service from lemur.certificates import service as certificate_service
from lemur.authorities import service from lemur.authorities import service
from lemur.authorities.schemas import authority_input_schema, authority_output_schema, authorities_output_schema, authority_update_schema from lemur.authorities.schemas import (
authority_input_schema,
authority_output_schema,
authorities_output_schema,
authority_update_schema,
)
mod = Blueprint('authorities', __name__) mod = Blueprint("authorities", __name__)
api = Api(mod) api = Api(mod)
class AuthoritiesList(AuthenticatedResource): class AuthoritiesList(AuthenticatedResource):
""" Defines the 'authorities' endpoint """ """ Defines the 'authorities' endpoint """
def __init__(self): def __init__(self):
self.reqparse = reqparse.RequestParser() self.reqparse = reqparse.RequestParser()
super(AuthoritiesList, self).__init__() super(AuthoritiesList, self).__init__()
@ -107,7 +113,7 @@ class AuthoritiesList(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
args = parser.parse_args() args = parser.parse_args()
args['user'] = g.current_user args["user"] = g.current_user
return service.render(args) return service.render(args)
@validate_schema(authority_input_schema, authority_output_schema) @validate_schema(authority_input_schema, authority_output_schema)
@ -220,7 +226,7 @@ class AuthoritiesList(AuthenticatedResource):
:statuscode 403: unauthenticated :statuscode 403: unauthenticated
:statuscode 200: no error :statuscode 200: no error
""" """
data['creator'] = g.current_user data["creator"] = g.current_user
return service.create(**data) return service.create(**data)
@ -388,7 +394,7 @@ class Authorities(AuthenticatedResource):
authority = service.get(authority_id) authority = service.get(authority_id)
if not authority: if not authority:
return dict(message='Not Found'), 404 return dict(message="Not Found"), 404
# all the authority role members should be allowed # all the authority role members should be allowed
roles = [x.name for x in authority.roles] roles = [x.name for x in authority.roles]
@ -397,10 +403,10 @@ class Authorities(AuthenticatedResource):
if permission.can(): if permission.can():
return service.update( return service.update(
authority_id, authority_id,
owner=data['owner'], owner=data["owner"],
description=data['description'], description=data["description"],
active=data['active'], active=data["active"],
roles=data['roles'] roles=data["roles"],
) )
return dict(message="You are not authorized to update this authority."), 403 return dict(message="You are not authorized to update this authority."), 403
@ -505,10 +511,21 @@ class AuthorityVisualizations(AuthenticatedResource):
]} ]}
""" """
authority = service.get(authority_id) authority = service.get(authority_id)
return dict(name=authority.name, children=[{"name": c.name} for c in authority.certificates]) return dict(
name=authority.name,
children=[{"name": c.name} for c in authority.certificates],
)
api.add_resource(AuthoritiesList, '/authorities', endpoint='authorities') api.add_resource(AuthoritiesList, "/authorities", endpoint="authorities")
api.add_resource(Authorities, '/authorities/<int:authority_id>', endpoint='authority') api.add_resource(Authorities, "/authorities/<int:authority_id>", endpoint="authority")
api.add_resource(AuthorityVisualizations, '/authorities/<int:authority_id>/visualize', endpoint='authority_visualizations') api.add_resource(
api.add_resource(CertificateAuthority, '/certificates/<int:certificate_id>/authority', endpoint='certificateAuthority') AuthorityVisualizations,
"/authorities/<int:authority_id>/visualize",
endpoint="authority_visualizations",
)
api.add_resource(
CertificateAuthority,
"/certificates/<int:certificate_id>/authority",
endpoint="certificateAuthority",
)

View File

@ -13,7 +13,7 @@ from lemur.plugins.base import plugins
class Authorization(db.Model): class Authorization(db.Model):
__tablename__ = 'pending_dns_authorizations' __tablename__ = "pending_dns_authorizations"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
account_number = Column(String(128)) account_number = Column(String(128))
domains = Column(JSONType) domains = Column(JSONType)

View File

@ -34,7 +34,7 @@ from lemur.certificates.service import (
get_all_pending_reissue, get_all_pending_reissue,
get_by_name, get_by_name,
get_all_certs, get_all_certs,
get get,
) )
from lemur.certificates.verify import verify_string from lemur.certificates.verify import verify_string
@ -56,11 +56,14 @@ def print_certificate_details(details):
"\t[+] Authority: {authority_name}\n" "\t[+] Authority: {authority_name}\n"
"\t[+] Validity Start: {validity_start}\n" "\t[+] Validity Start: {validity_start}\n"
"\t[+] Validity End: {validity_end}\n".format( "\t[+] Validity End: {validity_end}\n".format(
common_name=details['commonName'], common_name=details["commonName"],
sans=",".join(x['value'] for x in details['extensions']['subAltNames']['names']) or None, sans=",".join(
authority_name=details['authority']['name'], x["value"] for x in details["extensions"]["subAltNames"]["names"]
validity_start=details['validityStart'], )
validity_end=details['validityEnd'] or None,
authority_name=details["authority"]["name"],
validity_start=details["validityStart"],
validity_end=details["validityEnd"],
) )
) )
@ -120,13 +123,11 @@ def request_rotation(endpoint, certificate, message, commit):
except Exception as e: except Exception as e:
print( print(
"[!] Failed to rotate endpoint {0} to certificate {1} reason: {2}".format( "[!] Failed to rotate endpoint {0} to certificate {1} reason: {2}".format(
endpoint.name, endpoint.name, certificate.name, e
certificate.name,
e
) )
) )
metrics.send('endpoint_rotation', 'counter', 1, metric_tags={'status': status}) metrics.send("endpoint_rotation", "counter", 1, metric_tags={"status": status})
def request_reissue(certificate, commit): def request_reissue(certificate, commit):
@ -154,17 +155,52 @@ def request_reissue(certificate, commit):
except Exception as e: except Exception as e:
sentry.captureException(extra={"certificate_name": str(certificate.name)}) sentry.captureException(extra={"certificate_name": str(certificate.name)})
current_app.logger.exception(f"Error reissuing certificate: {certificate.name}", exc_info=True) current_app.logger.exception(
f"Error reissuing certificate: {certificate.name}", exc_info=True
)
print(f"[!] Failed to reissue certificate: {certificate.name}. Reason: {e}") print(f"[!] Failed to reissue certificate: {certificate.name}. Reason: {e}")
metrics.send('certificate_reissue', 'counter', 1, metric_tags={'status': status, 'certificate': certificate.name}) metrics.send(
"certificate_reissue",
"counter",
1,
metric_tags={"status": status, "certificate": certificate.name},
)
@manager.option('-e', '--endpoint', dest='endpoint_name', help='Name of the endpoint you wish to rotate.') @manager.option(
@manager.option('-n', '--new-certificate', dest='new_certificate_name', help='Name of the certificate you wish to rotate to.') "-e",
@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to rotate.') "--endpoint",
@manager.option('-a', '--notify', dest='message', action='store_true', help='Send a rotation notification to the certificates owner.') dest="endpoint_name",
@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') help="Name of the endpoint you wish to rotate.",
)
@manager.option(
"-n",
"--new-certificate",
dest="new_certificate_name",
help="Name of the certificate you wish to rotate to.",
)
@manager.option(
"-o",
"--old-certificate",
dest="old_certificate_name",
help="Name of the certificate you wish to rotate.",
)
@manager.option(
"-a",
"--notify",
dest="message",
action="store_true",
help="Send a rotation notification to the certificates owner.",
)
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, commit): def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, commit):
""" """
Rotates an endpoint and reissues it if it has not already been replaced. If it has Rotates an endpoint and reissues it if it has not already been replaced. If it has
@ -183,7 +219,9 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c
endpoint = validate_endpoint(endpoint_name) endpoint = validate_endpoint(endpoint_name)
if endpoint and new_cert: if endpoint and new_cert:
print(f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}") print(
f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}"
)
request_rotation(endpoint, new_cert, message, commit) request_rotation(endpoint, new_cert, message, commit)
elif old_cert and new_cert: elif old_cert and new_cert:
@ -197,16 +235,27 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c
print("[+] Rotating all endpoints that have new certificates available") print("[+] Rotating all endpoints that have new certificates available")
for endpoint in endpoint_service.get_all_pending_rotation(): for endpoint in endpoint_service.get_all_pending_rotation():
if len(endpoint.certificate.replaced) == 1: if len(endpoint.certificate.replaced) == 1:
print(f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}") print(
request_rotation(endpoint, endpoint.certificate.replaced[0], message, commit) f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}"
)
request_rotation(
endpoint, endpoint.certificate.replaced[0], message, commit
)
else: else:
metrics.send('endpoint_rotation', 'counter', 1, metric_tags={ metrics.send(
'status': FAILURE_METRIC_STATUS, "endpoint_rotation",
"counter",
1,
metric_tags={
"status": FAILURE_METRIC_STATUS,
"old_certificate_name": str(old_cert), "old_certificate_name": str(old_cert),
"new_certificate_name": str(endpoint.certificate.replaced[0].name), "new_certificate_name": str(
endpoint.certificate.replaced[0].name
),
"endpoint_name": str(endpoint.name), "endpoint_name": str(endpoint.name),
"message": str(message), "message": str(message),
}) },
)
print( print(
f"[!] Failed to rotate endpoint {endpoint.name} reason: " f"[!] Failed to rotate endpoint {endpoint.name} reason: "
"Multiple replacement certificates found." "Multiple replacement certificates found."
@ -222,20 +271,38 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c
"new_certificate_name": str(new_certificate_name), "new_certificate_name": str(new_certificate_name),
"endpoint_name": str(endpoint_name), "endpoint_name": str(endpoint_name),
"message": str(message), "message": str(message),
}) }
)
metrics.send('endpoint_rotation_job', 'counter', 1, metric_tags={ metrics.send(
"endpoint_rotation_job",
"counter",
1,
metric_tags={
"status": status, "status": status,
"old_certificate_name": str(old_certificate_name), "old_certificate_name": str(old_certificate_name),
"new_certificate_name": str(new_certificate_name), "new_certificate_name": str(new_certificate_name),
"endpoint_name": str(endpoint_name), "endpoint_name": str(endpoint_name),
"message": str(message), "message": str(message),
"endpoint": str(globals().get("endpoint")) "endpoint": str(globals().get("endpoint")),
}) },
)
@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to reissue.') @manager.option(
@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') "-o",
"--old-certificate",
dest="old_certificate_name",
help="Name of the certificate you wish to reissue.",
)
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def reissue(old_certificate_name, commit): def reissue(old_certificate_name, commit):
""" """
Reissues certificate with the same parameters as it was originally issued with. Reissues certificate with the same parameters as it was originally issued with.
@ -263,76 +330,94 @@ def reissue(old_certificate_name, commit):
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.exception("Error reissuing certificate.", exc_info=True) current_app.logger.exception("Error reissuing certificate.", exc_info=True)
print( print("[!] Failed to reissue certificates. Reason: {}".format(e))
"[!] Failed to reissue certificates. Reason: {}".format(
e metrics.send(
) "certificate_reissue_job", "counter", 1, metric_tags={"status": status}
) )
metrics.send('certificate_reissue_job', 'counter', 1, metric_tags={'status': status})
@manager.option(
@manager.option('-f', '--fqdns', dest='fqdns', help='FQDNs to query. Multiple fqdns specified via comma.') "-f",
@manager.option('-i', '--issuer', dest='issuer', help='Issuer to query for.') "--fqdns",
@manager.option('-o', '--owner', dest='owner', help='Owner to query for.') dest="fqdns",
@manager.option('-e', '--expired', dest='expired', type=bool, default=False, help='Include expired certificates.') help="FQDNs to query. Multiple fqdns specified via comma.",
)
@manager.option("-i", "--issuer", dest="issuer", help="Issuer to query for.")
@manager.option("-o", "--owner", dest="owner", help="Owner to query for.")
@manager.option(
"-e",
"--expired",
dest="expired",
type=bool,
default=False,
help="Include expired certificates.",
)
def query(fqdns, issuer, owner, expired): def query(fqdns, issuer, owner, expired):
"""Prints certificates that match the query params.""" """Prints certificates that match the query params."""
table = [] table = []
q = database.session_query(Certificate) q = database.session_query(Certificate)
if issuer: if issuer:
sub_query = database.session_query(Authority.id) \ sub_query = (
.filter(Authority.name.ilike('%{0}%'.format(issuer))) \ database.session_query(Authority.id)
.filter(Authority.name.ilike("%{0}%".format(issuer)))
.subquery() .subquery()
)
q = q.filter( q = q.filter(
or_( or_(
Certificate.issuer.ilike('%{0}%'.format(issuer)), Certificate.issuer.ilike("%{0}%".format(issuer)),
Certificate.authority_id.in_(sub_query) Certificate.authority_id.in_(sub_query),
) )
) )
if owner: if owner:
q = q.filter(Certificate.owner.ilike('%{0}%'.format(owner))) q = q.filter(Certificate.owner.ilike("%{0}%".format(owner)))
if not expired: if not expired:
q = q.filter(Certificate.expired == False) # noqa q = q.filter(Certificate.expired == False) # noqa
if fqdns: if fqdns:
for f in fqdns.split(','): for f in fqdns.split(","):
q = q.filter( q = q.filter(
or_( or_(
Certificate.cn.ilike('%{0}%'.format(f)), Certificate.cn.ilike("%{0}%".format(f)),
Certificate.domains.any(Domain.name.ilike('%{0}%'.format(f))) Certificate.domains.any(Domain.name.ilike("%{0}%".format(f))),
) )
) )
for c in q.all(): for c in q.all():
table.append([c.id, c.name, c.owner, c.issuer]) table.append([c.id, c.name, c.owner, c.issuer])
print(tabulate(table, headers=['Id', 'Name', 'Owner', 'Issuer'], tablefmt='csv')) print(tabulate(table, headers=["Id", "Name", "Owner", "Issuer"], tablefmt="csv"))
def worker(data, commit, reason): def worker(data, commit, reason):
parts = [x for x in data.split(' ') if x] parts = [x for x in data.split(" ") if x]
try: try:
cert = get(int(parts[0].strip())) cert = get(int(parts[0].strip()))
plugin = plugins.get(cert.authority.plugin_name) plugin = plugins.get(cert.authority.plugin_name)
print('[+] Revoking certificate. Id: {0} Name: {1}'.format(cert.id, cert.name)) print("[+] Revoking certificate. Id: {0} Name: {1}".format(cert.id, cert.name))
if commit: if commit:
plugin.revoke_certificate(cert, reason) plugin.revoke_certificate(cert, reason)
metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS}) metrics.send(
"certificate_revoke",
"counter",
1,
metric_tags={"status": SUCCESS_METRIC_STATUS},
)
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS}) metrics.send(
print( "certificate_revoke",
"[!] Failed to revoke certificates. Reason: {}".format( "counter",
e 1,
) metric_tags={"status": FAILURE_METRIC_STATUS},
) )
print("[!] Failed to revoke certificates. Reason: {}".format(e))
@manager.command @manager.command
@ -341,13 +426,22 @@ def clear_pending():
Function clears all pending certificates. Function clears all pending certificates.
:return: :return:
""" """
v = plugins.get('verisign-issuer') v = plugins.get("verisign-issuer")
v.clear_pending_certificates() v.clear_pending_certificates()
@manager.option('-p', '--path', dest='path', help='Absolute file path to a Lemur query csv.') @manager.option(
@manager.option('-r', '--reason', dest='reason', help='Reason to revoke certificate.') "-p", "--path", dest="path", help="Absolute file path to a Lemur query csv."
@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.') )
@manager.option("-r", "--reason", dest="reason", help="Reason to revoke certificate.")
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def revoke(path, reason, commit): def revoke(path, reason, commit):
""" """
Revokes given certificate. Revokes given certificate.
@ -357,7 +451,7 @@ def revoke(path, reason, commit):
print("[+] Starting certificate revocation.") print("[+] Starting certificate revocation.")
with open(path, 'r') as f: with open(path, "r") as f:
args = [[x, commit, reason] for x in f.readlines()[2:]] args = [[x, commit, reason] for x in f.readlines()[2:]]
with multiprocessing.Pool(processes=3) as pool: with multiprocessing.Pool(processes=3) as pool:
@ -380,11 +474,11 @@ def check_revoked():
else: else:
status = verify_string(cert.body, "") status = verify_string(cert.body, "")
cert.status = 'valid' if status else 'revoked' cert.status = "valid" if status else "revoked"
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.exception(e) current_app.logger.exception(e)
cert.status = 'unknown' cert.status = "unknown"
database.update(cert) database.update(cert)

View File

@ -12,21 +12,30 @@ import subprocess
from flask import current_app from flask import current_app
from lemur.certificates.service import csr_created, csr_imported, certificate_issued, certificate_imported from lemur.certificates.service import (
csr_created,
csr_imported,
certificate_issued,
certificate_imported,
)
def csr_dump_handler(sender, csr, **kwargs): def csr_dump_handler(sender, csr, **kwargs):
try: try:
subprocess.run(['openssl', 'req', '-text', '-noout', '-reqopt', 'no_sigdump,no_pubkey'], subprocess.run(
input=csr.encode('utf8')) ["openssl", "req", "-text", "-noout", "-reqopt", "no_sigdump,no_pubkey"],
input=csr.encode("utf8"),
)
except Exception as err: except Exception as err:
current_app.logger.warning("Error inspecting CSR: %s", err) current_app.logger.warning("Error inspecting CSR: %s", err)
def cert_dump_handler(sender, certificate, **kwargs): def cert_dump_handler(sender, certificate, **kwargs):
try: try:
subprocess.run(['openssl', 'x509', '-text', '-noout', '-certopt', 'no_sigdump,no_pubkey'], subprocess.run(
input=certificate.body.encode('utf8')) ["openssl", "x509", "-text", "-noout", "-certopt", "no_sigdump,no_pubkey"],
input=certificate.body.encode("utf8"),
)
except Exception as err: except Exception as err:
current_app.logger.warning("Error inspecting certificate: %s", err) current_app.logger.warning("Error inspecting certificate: %s", err)

View File

@ -12,7 +12,18 @@ from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from flask import current_app from flask import current_app
from idna.core import InvalidCodepoint from idna.core import InvalidCodepoint
from sqlalchemy import event, Integer, ForeignKey, String, PassiveDefault, func, Column, Text, Boolean, Index from sqlalchemy import (
event,
Integer,
ForeignKey,
String,
PassiveDefault,
func,
Column,
Text,
Boolean,
Index,
)
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.sql.expression import case, extract from sqlalchemy.sql.expression import case, extract
@ -25,19 +36,25 @@ from lemur.database import db
from lemur.domains.models import Domain from lemur.domains.models import Domain
from lemur.extensions import metrics from lemur.extensions import metrics
from lemur.extensions import sentry from lemur.extensions import sentry
from lemur.models import certificate_associations, certificate_source_associations, \ from lemur.models import (
certificate_destination_associations, certificate_notification_associations, \ certificate_associations,
certificate_replacement_associations, roles_certificates, pending_cert_replacement_associations certificate_source_associations,
certificate_destination_associations,
certificate_notification_associations,
certificate_replacement_associations,
roles_certificates,
pending_cert_replacement_associations,
)
from lemur.plugins.base import plugins from lemur.plugins.base import plugins
from lemur.policies.models import RotationPolicy from lemur.policies.models import RotationPolicy
from lemur.utils import Vault from lemur.utils import Vault
def get_sequence(name): def get_sequence(name):
if '-' not in name: if "-" not in name:
return name, None return name, None
parts = name.split('-') parts = name.split("-")
# see if we have an int at the end of our name # see if we have an int at the end of our name
try: try:
@ -49,22 +66,26 @@ def get_sequence(name):
if len(parts[-1]) == 8: if len(parts[-1]) == 8:
return name, None return name, None
root = '-'.join(parts[:-1]) root = "-".join(parts[:-1])
return root, seq return root, seq
def get_or_increase_name(name, serial): def get_or_increase_name(name, serial):
certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(name))).all() certificates = Certificate.query.filter(Certificate.name == name).all()
if not certificates: if not certificates:
return name return name
serial_name = '{0}-{1}'.format(name, hex(int(serial))[2:].upper()) serial_name = "{0}-{1}".format(name, hex(int(serial))[2:].upper())
certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(serial_name))).all() certificates = Certificate.query.filter(Certificate.name == serial_name).all()
if not certificates: if not certificates:
return serial_name return serial_name
certificates = Certificate.query.filter(
Certificate.name.ilike("{0}%".format(serial_name))
).all()
ends = [0] ends = [0]
root, end = get_sequence(serial_name) root, end = get_sequence(serial_name)
for cert in certificates: for cert in certificates:
@ -72,21 +93,29 @@ def get_or_increase_name(name, serial):
if end: if end:
ends.append(end) ends.append(end)
return '{0}-{1}'.format(root, max(ends) + 1) return "{0}-{1}".format(root, max(ends) + 1)
class Certificate(db.Model): class Certificate(db.Model):
__tablename__ = 'certificates' __tablename__ = "certificates"
__table_args__ = ( __table_args__ = (
Index('ix_certificates_cn', "cn", Index(
"ix_certificates_cn",
"cn",
postgresql_ops={"cn": "gin_trgm_ops"}, postgresql_ops={"cn": "gin_trgm_ops"},
postgresql_using='gin'), postgresql_using="gin",
Index('ix_certificates_name', "name", ),
Index(
"ix_certificates_name",
"name",
postgresql_ops={"name": "gin_trgm_ops"}, postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using='gin'), postgresql_using="gin",
),
) )
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
ix = Index('ix_certificates_id_desc', id.desc(), postgresql_using='btree', unique=True) ix = Index(
"ix_certificates_id_desc", id.desc(), postgresql_using="btree", unique=True
)
external_id = Column(String(128)) external_id = Column(String(128))
owner = Column(String(128), nullable=False) owner = Column(String(128), nullable=False)
name = Column(String(256), unique=True) name = Column(String(256), unique=True)
@ -102,7 +131,9 @@ class Certificate(db.Model):
serial = Column(String(128)) serial = Column(String(128))
cn = Column(String(128)) cn = Column(String(128))
deleted = Column(Boolean, index=True, default=False) deleted = Column(Boolean, index=True, default=False)
dns_provider_id = Column(Integer(), ForeignKey('dns_providers.id', ondelete='CASCADE'), nullable=True) dns_provider_id = Column(
Integer(), ForeignKey("dns_providers.id", ondelete="CASCADE"), nullable=True
)
not_before = Column(ArrowType) not_before = Column(ArrowType)
not_after = Column(ArrowType) not_after = Column(ArrowType)
@ -116,34 +147,53 @@ class Certificate(db.Model):
san = Column(String(1024)) # TODO this should be migrated to boolean san = Column(String(1024)) # TODO this should be migrated to boolean
rotation = Column(Boolean, default=False) rotation = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('users.id')) user_id = Column(Integer, ForeignKey("users.id"))
authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) authority_id = Column(Integer, ForeignKey("authorities.id", ondelete="CASCADE"))
root_authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE")) root_authority_id = Column(
rotation_policy_id = Column(Integer, ForeignKey('rotation_policies.id')) Integer, ForeignKey("authorities.id", ondelete="CASCADE")
)
rotation_policy_id = Column(Integer, ForeignKey("rotation_policies.id"))
notifications = relationship('Notification', secondary=certificate_notification_associations, backref='certificate') notifications = relationship(
destinations = relationship('Destination', secondary=certificate_destination_associations, backref='certificate') "Notification",
sources = relationship('Source', secondary=certificate_source_associations, backref='certificate') secondary=certificate_notification_associations,
domains = relationship('Domain', secondary=certificate_associations, backref='certificate') backref="certificate",
roles = relationship('Role', secondary=roles_certificates, backref='certificate') )
replaces = relationship('Certificate', destinations = relationship(
"Destination",
secondary=certificate_destination_associations,
backref="certificate",
)
sources = relationship(
"Source", secondary=certificate_source_associations, backref="certificate"
)
domains = relationship(
"Domain", secondary=certificate_associations, backref="certificate"
)
roles = relationship("Role", secondary=roles_certificates, backref="certificate")
replaces = relationship(
"Certificate",
secondary=certificate_replacement_associations, secondary=certificate_replacement_associations,
primaryjoin=id == certificate_replacement_associations.c.certificate_id, # noqa primaryjoin=id == certificate_replacement_associations.c.certificate_id, # noqa
secondaryjoin=id == certificate_replacement_associations.c.replaced_certificate_id, # noqa secondaryjoin=id
backref='replaced') == certificate_replacement_associations.c.replaced_certificate_id, # noqa
backref="replaced",
)
replaced_by_pending = relationship('PendingCertificate', replaced_by_pending = relationship(
"PendingCertificate",
secondary=pending_cert_replacement_associations, secondary=pending_cert_replacement_associations,
backref='pending_replace', backref="pending_replace",
viewonly=True) viewonly=True,
)
logs = relationship('Log', backref='certificate') logs = relationship("Log", backref="certificate")
endpoints = relationship('Endpoint', backref='certificate') endpoints = relationship("Endpoint", backref="certificate")
rotation_policy = relationship("RotationPolicy") rotation_policy = relationship("RotationPolicy")
sensitive_fields = ('private_key',) sensitive_fields = ("private_key",)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.body = kwargs['body'].strip() self.body = kwargs["body"].strip()
cert = self.parsed_cert cert = self.parsed_cert
self.issuer = defaults.issuer(cert) self.issuer = defaults.issuer(cert)
@ -154,36 +204,42 @@ class Certificate(db.Model):
self.serial = defaults.serial(cert) self.serial = defaults.serial(cert)
# when destinations are appended they require a valid name. # when destinations are appended they require a valid name.
if kwargs.get('name'): if kwargs.get("name"):
self.name = get_or_increase_name(defaults.text_to_slug(kwargs['name']), self.serial) self.name = get_or_increase_name(
defaults.text_to_slug(kwargs["name"]), self.serial
)
else: else:
self.name = get_or_increase_name( self.name = get_or_increase_name(
defaults.certificate_name(self.cn, self.issuer, self.not_before, self.not_after, self.san), self.serial) defaults.certificate_name(
self.cn, self.issuer, self.not_before, self.not_after, self.san
),
self.serial,
)
self.owner = kwargs['owner'] self.owner = kwargs["owner"]
if kwargs.get('private_key'): if kwargs.get("private_key"):
self.private_key = kwargs['private_key'].strip() self.private_key = kwargs["private_key"].strip()
if kwargs.get('chain'): if kwargs.get("chain"):
self.chain = kwargs['chain'].strip() self.chain = kwargs["chain"].strip()
if kwargs.get('csr'): if kwargs.get("csr"):
self.csr = kwargs['csr'].strip() self.csr = kwargs["csr"].strip()
self.notify = kwargs.get('notify', True) self.notify = kwargs.get("notify", True)
self.destinations = kwargs.get('destinations', []) self.destinations = kwargs.get("destinations", [])
self.notifications = kwargs.get('notifications', []) self.notifications = kwargs.get("notifications", [])
self.description = kwargs.get('description') self.description = kwargs.get("description")
self.roles = list(set(kwargs.get('roles', []))) self.roles = list(set(kwargs.get("roles", [])))
self.replaces = kwargs.get('replaces', []) self.replaces = kwargs.get("replaces", [])
self.rotation = kwargs.get('rotation') self.rotation = kwargs.get("rotation")
self.rotation_policy = kwargs.get('rotation_policy') self.rotation_policy = kwargs.get("rotation_policy")
self.signing_algorithm = defaults.signing_algorithm(cert) self.signing_algorithm = defaults.signing_algorithm(cert)
self.bits = defaults.bitstrength(cert) self.bits = defaults.bitstrength(cert)
self.external_id = kwargs.get('external_id') self.external_id = kwargs.get("external_id")
self.authority_id = kwargs.get('authority_id') self.authority_id = kwargs.get("authority_id")
self.dns_provider_id = kwargs.get('dns_provider_id') self.dns_provider_id = kwargs.get("dns_provider_id")
for domain in defaults.domains(cert): for domain in defaults.domains(cert):
self.domains.append(Domain(name=domain)) self.domains.append(Domain(name=domain))
@ -197,8 +253,11 @@ class Certificate(db.Model):
Integrity checks: Does the cert have a valid chain and matching private key? Integrity checks: Does the cert have a valid chain and matching private key?
""" """
if self.private_key: if self.private_key:
validators.verify_private_key_match(utils.parse_private_key(self.private_key), self.parsed_cert, validators.verify_private_key_match(
error_class=AssertionError) utils.parse_private_key(self.private_key),
self.parsed_cert,
error_class=AssertionError,
)
if self.chain: if self.chain:
chain = [self.parsed_cert] + utils.parse_cert_chain(self.chain) chain = [self.parsed_cert] + utils.parse_cert_chain(self.chain)
@ -240,7 +299,9 @@ class Certificate(db.Model):
@property @property
def key_type(self): def key_type(self):
if isinstance(self.parsed_cert.public_key(), rsa.RSAPublicKey): if isinstance(self.parsed_cert.public_key(), rsa.RSAPublicKey):
return 'RSA{key_size}'.format(key_size=self.parsed_cert.public_key().key_size) return "RSA{key_size}".format(
key_size=self.parsed_cert.public_key().key_size
)
@property @property
def validity_remaining(self): def validity_remaining(self):
@ -265,26 +326,16 @@ class Certificate(db.Model):
@expired.expression @expired.expression
def expired(cls): def expired(cls):
return case( return case([(cls.not_after <= arrow.utcnow(), True)], else_=False)
[
(cls.not_after <= arrow.utcnow(), True)
],
else_=False
)
@hybrid_property @hybrid_property
def revoked(self): def revoked(self):
if 'revoked' == self.status: if "revoked" == self.status:
return True return True
@revoked.expression @revoked.expression
def revoked(cls): def revoked(cls):
return case( return case([(cls.status == "revoked", True)], else_=False)
[
(cls.status == 'revoked', True)
],
else_=False
)
@hybrid_property @hybrid_property
def in_rotation_window(self): def in_rotation_window(self):
@ -307,66 +358,65 @@ class Certificate(db.Model):
:return: :return:
""" """
return case( return case(
[ [(extract("day", cls.not_after - func.now()) <= RotationPolicy.days, True)],
(extract('day', cls.not_after - func.now()) <= RotationPolicy.days, True) else_=False,
],
else_=False
) )
@property @property
def extensions(self): def extensions(self):
# setup default values # setup default values
return_extensions = { return_extensions = {"sub_alt_names": {"names": []}}
'sub_alt_names': {'names': []}
}
try: try:
for extension in self.parsed_cert.extensions: for extension in self.parsed_cert.extensions:
value = extension.value value = extension.value
if isinstance(value, x509.BasicConstraints): if isinstance(value, x509.BasicConstraints):
return_extensions['basic_constraints'] = value return_extensions["basic_constraints"] = value
elif isinstance(value, x509.SubjectAlternativeName): elif isinstance(value, x509.SubjectAlternativeName):
return_extensions['sub_alt_names']['names'] = value return_extensions["sub_alt_names"]["names"] = value
elif isinstance(value, x509.ExtendedKeyUsage): elif isinstance(value, x509.ExtendedKeyUsage):
return_extensions['extended_key_usage'] = value return_extensions["extended_key_usage"] = value
elif isinstance(value, x509.KeyUsage): elif isinstance(value, x509.KeyUsage):
return_extensions['key_usage'] = value return_extensions["key_usage"] = value
elif isinstance(value, x509.SubjectKeyIdentifier): elif isinstance(value, x509.SubjectKeyIdentifier):
return_extensions['subject_key_identifier'] = {'include_ski': True} return_extensions["subject_key_identifier"] = {"include_ski": True}
elif isinstance(value, x509.AuthorityInformationAccess): elif isinstance(value, x509.AuthorityInformationAccess):
return_extensions['certificate_info_access'] = {'include_aia': True} return_extensions["certificate_info_access"] = {"include_aia": True}
elif isinstance(value, x509.AuthorityKeyIdentifier): elif isinstance(value, x509.AuthorityKeyIdentifier):
aki = { aki = {"use_key_identifier": False, "use_authority_cert": False}
'use_key_identifier': False,
'use_authority_cert': False
}
if value.key_identifier: if value.key_identifier:
aki['use_key_identifier'] = True aki["use_key_identifier"] = True
if value.authority_cert_issuer: if value.authority_cert_issuer:
aki['use_authority_cert'] = True aki["use_authority_cert"] = True
return_extensions['authority_key_identifier'] = aki return_extensions["authority_key_identifier"] = aki
elif isinstance(value, x509.CRLDistributionPoints): elif isinstance(value, x509.CRLDistributionPoints):
return_extensions['crl_distribution_points'] = {'include_crl_dp': value} return_extensions["crl_distribution_points"] = {
"include_crl_dp": value
}
# TODO: Not supporting custom OIDs yet. https://github.com/Netflix/lemur/issues/665 # TODO: Not supporting custom OIDs yet. https://github.com/Netflix/lemur/issues/665
else: else:
current_app.logger.warning('Custom OIDs not yet supported for clone operation.') current_app.logger.warning(
"Custom OIDs not yet supported for clone operation."
)
except InvalidCodepoint as e: except InvalidCodepoint as e:
sentry.captureException() sentry.captureException()
current_app.logger.warning('Unable to parse extensions due to underscore in dns name') current_app.logger.warning(
"Unable to parse extensions due to underscore in dns name"
)
except ValueError as e: except ValueError as e:
sentry.captureException() sentry.captureException()
current_app.logger.warning('Unable to parse') current_app.logger.warning("Unable to parse")
current_app.logger.exception(e) current_app.logger.exception(e)
return return_extensions return return_extensions
@ -375,7 +425,7 @@ class Certificate(db.Model):
return "Certificate(name={name})".format(name=self.name) return "Certificate(name={name})".format(name=self.name)
@event.listens_for(Certificate.destinations, 'append') @event.listens_for(Certificate.destinations, "append")
def update_destinations(target, value, initiator): def update_destinations(target, value, initiator):
""" """
Attempt to upload certificate to the new destination Attempt to upload certificate to the new destination
@ -389,17 +439,31 @@ def update_destinations(target, value, initiator):
status = FAILURE_METRIC_STATUS status = FAILURE_METRIC_STATUS
try: try:
if target.private_key or not destination_plugin.requires_key: if target.private_key or not destination_plugin.requires_key:
destination_plugin.upload(target.name, target.body, target.private_key, target.chain, value.options) destination_plugin.upload(
target.name,
target.body,
target.private_key,
target.chain,
value.options,
)
status = SUCCESS_METRIC_STATUS status = SUCCESS_METRIC_STATUS
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
raise raise
metrics.send('destination_upload', 'counter', 1, metrics.send(
metric_tags={'status': status, 'certificate': target.name, 'destination': value.label}) "destination_upload",
"counter",
1,
metric_tags={
"status": status,
"certificate": target.name,
"destination": value.label,
},
)
@event.listens_for(Certificate.replaces, 'append') @event.listens_for(Certificate.replaces, "append")
def update_replacement(target, value, initiator): def update_replacement(target, value, initiator):
""" """
When a certificate is marked as 'replaced' we should not notify. When a certificate is marked as 'replaced' we should not notify.

View File

@ -39,22 +39,26 @@ from lemur.users.schemas import UserNestedOutputSchema
class CertificateSchema(LemurInputSchema): class CertificateSchema(LemurInputSchema):
owner = fields.Email(required=True) owner = fields.Email(required=True)
description = fields.String(missing='', allow_none=True) description = fields.String(missing="", allow_none=True)
class CertificateCreationSchema(CertificateSchema): class CertificateCreationSchema(CertificateSchema):
@post_load @post_load
def default_notification(self, data): def default_notification(self, data):
if not data['notifications']: if not data["notifications"]:
data['notifications'] += notification_service.create_default_expiration_notifications( data[
"DEFAULT_{0}".format(data['owner'].split('@')[0].upper()), "notifications"
[data['owner']], ] += notification_service.create_default_expiration_notifications(
"DEFAULT_{0}".format(data["owner"].split("@")[0].upper()),
[data["owner"]],
) )
data['notifications'] += notification_service.create_default_expiration_notifications( data[
'DEFAULT_SECURITY', "notifications"
current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL'), ] += notification_service.create_default_expiration_notifications(
current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL_INTERVALS', None) "DEFAULT_SECURITY",
current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL"),
current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL_INTERVALS", None),
) )
return data return data
@ -71,37 +75,53 @@ class CertificateInputSchema(CertificateCreationSchema):
destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True) destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True)
notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True) notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True)
replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True)
replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated replacements = fields.Nested(
AssociatedCertificateSchema, missing=[], many=True
) # deprecated
roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True) roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True)
dns_provider = fields.Nested(AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False) dns_provider = fields.Nested(
AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False
)
csr = fields.String(allow_none=True, validate=validators.csr) csr = fields.String(allow_none=True, validate=validators.csr)
key_type = fields.String( key_type = fields.String(
validate=validate.OneOf(CERTIFICATE_KEY_TYPES), validate=validate.OneOf(CERTIFICATE_KEY_TYPES), missing="RSA2048"
missing='RSA2048') )
notify = fields.Boolean(default=True) notify = fields.Boolean(default=True)
rotation = fields.Boolean() rotation = fields.Boolean()
rotation_policy = fields.Nested(AssociatedRotationPolicySchema, missing={'name': 'default'}, allow_none=True, rotation_policy = fields.Nested(
default={'name': 'default'}) AssociatedRotationPolicySchema,
missing={"name": "default"},
allow_none=True,
default={"name": "default"},
)
# certificate body fields # certificate body fields
organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT')) organizational_unit = fields.String(
organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION')) missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT")
location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION')) )
country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY')) organization = fields.String(
state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE')) missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION")
)
location = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION")
)
country = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY")
)
state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE"))
extensions = fields.Nested(ExtensionSchema) extensions = fields.Nested(ExtensionSchema)
@validates_schema @validates_schema
def validate_authority(self, data): def validate_authority(self, data):
if isinstance(data['authority'], str): if isinstance(data["authority"], str):
raise ValidationError("Authority not found.") raise ValidationError("Authority not found.")
if not data['authority'].active: if not data["authority"].active:
raise ValidationError("The authority is inactive.", ['authority']) raise ValidationError("The authority is inactive.", ["authority"])
@validates_schema @validates_schema
def validate_dates(self, data): def validate_dates(self, data):
@ -109,23 +129,19 @@ class CertificateInputSchema(CertificateCreationSchema):
@pre_load @pre_load
def load_data(self, data): def load_data(self, data):
if data.get('replacements'): if data.get("replacements"):
data['replaces'] = data['replacements'] # TODO remove when field is deprecated data["replaces"] = data[
if data.get('csr'): "replacements"
csr_sans = cert_utils.get_sans_from_csr(data['csr']) ] # TODO remove when field is deprecated
if not data.get('extensions'): if data.get("csr"):
data['extensions'] = { csr_sans = cert_utils.get_sans_from_csr(data["csr"])
'subAltNames': { if not data.get("extensions"):
'names': [] data["extensions"] = {"subAltNames": {"names": []}}
} elif not data["extensions"].get("subAltNames"):
} data["extensions"]["subAltNames"] = {"names": []}
elif not data['extensions'].get('subAltNames'): elif not data["extensions"]["subAltNames"].get("names"):
data['extensions']['subAltNames'] = { data["extensions"]["subAltNames"]["names"] = []
'names': [] data["extensions"]["subAltNames"]["names"] += csr_sans
}
elif not data['extensions']['subAltNames'].get('names'):
data['extensions']['subAltNames']['names'] = []
data['extensions']['subAltNames']['names'] += csr_sans
return missing.convert_validity_years(data) return missing.convert_validity_years(data)
@ -138,13 +154,17 @@ class CertificateEditInputSchema(CertificateSchema):
destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True) destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True)
notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True) notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True)
replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True)
replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated replacements = fields.Nested(
AssociatedCertificateSchema, missing=[], many=True
) # deprecated
roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True) roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True)
@pre_load @pre_load
def load_data(self, data): def load_data(self, data):
if data.get('replacements'): if data.get("replacements"):
data['replaces'] = data['replacements'] # TODO remove when field is deprecated data["replaces"] = data[
"replacements"
] # TODO remove when field is deprecated
return data return data
@post_load @post_load
@ -155,10 +175,15 @@ class CertificateEditInputSchema(CertificateSchema):
:param data: :param data:
:return: :return:
""" """
if data['owner']: if data["owner"]:
notification_name = "DEFAULT_{0}".format(data['owner'].split('@')[0].upper()) notification_name = "DEFAULT_{0}".format(
data['notifications'] += notification_service.create_default_expiration_notifications(notification_name, data["owner"].split("@")[0].upper()
[data['owner']]) )
data[
"notifications"
] += notification_service.create_default_expiration_notifications(
notification_name, [data["owner"]]
)
return data return data
@ -184,13 +209,13 @@ class CertificateNestedOutputSchema(LemurOutputSchema):
# Note aliasing is the first step in deprecating these fields. # Note aliasing is the first step in deprecating these fields.
cn = fields.String() # deprecated cn = fields.String() # deprecated
common_name = fields.String(attribute='cn') common_name = fields.String(attribute="cn")
not_after = fields.DateTime() # deprecated not_after = fields.DateTime() # deprecated
validity_end = ArrowDateTime(attribute='not_after') validity_end = ArrowDateTime(attribute="not_after")
not_before = fields.DateTime() # deprecated not_before = fields.DateTime() # deprecated
validity_start = ArrowDateTime(attribute='not_before') validity_start = ArrowDateTime(attribute="not_before")
issuer = fields.Nested(AuthorityNestedOutputSchema) issuer = fields.Nested(AuthorityNestedOutputSchema)
@ -221,22 +246,22 @@ class CertificateOutputSchema(LemurOutputSchema):
# Note aliasing is the first step in deprecating these fields. # Note aliasing is the first step in deprecating these fields.
notify = fields.Boolean() notify = fields.Boolean()
active = fields.Boolean(attribute='notify') active = fields.Boolean(attribute="notify")
cn = fields.String() cn = fields.String()
common_name = fields.String(attribute='cn') common_name = fields.String(attribute="cn")
distinguished_name = fields.String() distinguished_name = fields.String()
not_after = fields.DateTime() not_after = fields.DateTime()
validity_end = ArrowDateTime(attribute='not_after') validity_end = ArrowDateTime(attribute="not_after")
not_before = fields.DateTime() not_before = fields.DateTime()
validity_start = ArrowDateTime(attribute='not_before') validity_start = ArrowDateTime(attribute="not_before")
owner = fields.Email() owner = fields.Email()
san = fields.Boolean() san = fields.Boolean()
serial = fields.String() serial = fields.String()
serial_hex = Hex(attribute='serial') serial_hex = Hex(attribute="serial")
signing_algorithm = fields.String() signing_algorithm = fields.String()
status = fields.String() status = fields.String()
@ -253,7 +278,9 @@ class CertificateOutputSchema(LemurOutputSchema):
dns_provider = fields.Nested(DnsProvidersNestedOutputSchema) dns_provider = fields.Nested(DnsProvidersNestedOutputSchema)
roles = fields.Nested(RoleNestedOutputSchema, many=True) roles = fields.Nested(RoleNestedOutputSchema, many=True)
endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[])
replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') replaced_by = fields.Nested(
CertificateNestedOutputSchema, many=True, attribute="replaced"
)
rotation_policy = fields.Nested(RotationPolicyNestedOutputSchema) rotation_policy = fields.Nested(RotationPolicyNestedOutputSchema)
@ -274,35 +301,41 @@ class CertificateUploadInputSchema(CertificateCreationSchema):
@validates_schema @validates_schema
def keys(self, data): def keys(self, data):
if data.get('destinations'): if data.get("destinations"):
if not data.get('private_key'): if not data.get("private_key"):
raise ValidationError('Destinations require private key.') raise ValidationError("Destinations require private key.")
@validates_schema @validates_schema
def validate_cert_private_key_chain(self, data): def validate_cert_private_key_chain(self, data):
cert = None cert = None
key = None key = None
if data.get('body'): if data.get("body"):
try: try:
cert = utils.parse_certificate(data['body']) cert = utils.parse_certificate(data["body"])
except ValueError: except ValueError:
raise ValidationError("Public certificate presented is not valid.", field_names=['body']) raise ValidationError(
"Public certificate presented is not valid.", field_names=["body"]
)
if data.get('private_key'): if data.get("private_key"):
try: try:
key = utils.parse_private_key(data['private_key']) key = utils.parse_private_key(data["private_key"])
except ValueError: except ValueError:
raise ValidationError("Private key presented is not valid.", field_names=['private_key']) raise ValidationError(
"Private key presented is not valid.", field_names=["private_key"]
)
if cert and key: if cert and key:
# Throws ValidationError # Throws ValidationError
validators.verify_private_key_match(key, cert) validators.verify_private_key_match(key, cert)
if data.get('chain'): if data.get("chain"):
try: try:
chain = utils.parse_cert_chain(data['chain']) chain = utils.parse_cert_chain(data["chain"])
except ValueError: except ValueError:
raise ValidationError("Invalid certificate in certificate chain.", field_names=['chain']) raise ValidationError(
"Invalid certificate in certificate chain.", field_names=["chain"]
)
# Throws ValidationError # Throws ValidationError
validators.verify_cert_chain([cert] + chain) validators.verify_cert_chain([cert] + chain)
@ -318,8 +351,10 @@ class CertificateNotificationOutputSchema(LemurOutputSchema):
name = fields.String() name = fields.String()
owner = fields.Email() owner = fields.Email()
user = fields.Nested(UserNestedOutputSchema) user = fields.Nested(UserNestedOutputSchema)
validity_end = ArrowDateTime(attribute='not_after') validity_end = ArrowDateTime(attribute="not_after")
replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced') replaced_by = fields.Nested(
CertificateNestedOutputSchema, many=True, attribute="replaced"
)
endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[]) endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[])

View File

@ -26,10 +26,14 @@ from lemur.plugins.base import plugins
from lemur.roles import service as role_service from lemur.roles import service as role_service
from lemur.roles.models import Role from lemur.roles.models import Role
csr_created = signals.signal('csr_created', "CSR generated") csr_created = signals.signal("csr_created", "CSR generated")
csr_imported = signals.signal('csr_imported', "CSR imported from external source") csr_imported = signals.signal("csr_imported", "CSR imported from external source")
certificate_issued = signals.signal('certificate_issued', "Authority issued a certificate") certificate_issued = signals.signal(
certificate_imported = signals.signal('certificate_imported', "Certificate imported from external source") "certificate_issued", "Authority issued a certificate"
)
certificate_imported = signals.signal(
"certificate_imported", "Certificate imported from external source"
)
def get(cert_id): def get(cert_id):
@ -49,7 +53,7 @@ def get_by_name(name):
:param name: :param name:
:return: :return:
""" """
return database.get(Certificate, name, field='name') return database.get(Certificate, name, field="name")
def get_by_serial(serial): def get_by_serial(serial):
@ -105,8 +109,12 @@ def get_all_pending_cleaning(source):
:param source: :param source:
:return: :return:
""" """
return Certificate.query.filter(Certificate.sources.any(id=source.id)) \ return (
.filter(not_(Certificate.endpoints.any())).filter(Certificate.expired).all() Certificate.query.filter(Certificate.sources.any(id=source.id))
.filter(not_(Certificate.endpoints.any()))
.filter(Certificate.expired)
.all()
)
def get_all_pending_reissue(): def get_all_pending_reissue():
@ -119,9 +127,12 @@ def get_all_pending_reissue():
:return: :return:
""" """
return Certificate.query.filter(Certificate.rotation == True) \ return (
.filter(not_(Certificate.replaced.any())) \ Certificate.query.filter(Certificate.rotation == True)
.filter(Certificate.in_rotation_window == True).all() # noqa .filter(not_(Certificate.replaced.any()))
.filter(Certificate.in_rotation_window == True)
.all()
) # noqa
def find_duplicates(cert): def find_duplicates(cert):
@ -133,10 +144,12 @@ def find_duplicates(cert):
:param cert: :param cert:
:return: :return:
""" """
if cert['chain']: if cert["chain"]:
return Certificate.query.filter_by(body=cert['body'].strip(), chain=cert['chain'].strip()).all() return Certificate.query.filter_by(
body=cert["body"].strip(), chain=cert["chain"].strip()
).all()
else: else:
return Certificate.query.filter_by(body=cert['body'].strip(), chain=None).all() return Certificate.query.filter_by(body=cert["body"].strip(), chain=None).all()
def export(cert, export_plugin): def export(cert, export_plugin):
@ -148,8 +161,10 @@ def export(cert, export_plugin):
:param cert: :param cert:
:return: :return:
""" """
plugin = plugins.get(export_plugin['slug']) plugin = plugins.get(export_plugin["slug"])
return plugin.export(cert.body, cert.chain, cert.private_key, export_plugin['pluginOptions']) return plugin.export(
cert.body, cert.chain, cert.private_key, export_plugin["pluginOptions"]
)
def update(cert_id, **kwargs): def update(cert_id, **kwargs):
@ -168,17 +183,19 @@ def update(cert_id, **kwargs):
def create_certificate_roles(**kwargs): def create_certificate_roles(**kwargs):
# create an role for the owner and assign it # create an role for the owner and assign it
owner_role = role_service.get_by_name(kwargs['owner']) owner_role = role_service.get_by_name(kwargs["owner"])
if not owner_role: if not owner_role:
owner_role = role_service.create( owner_role = role_service.create(
kwargs['owner'], kwargs["owner"],
description="Auto generated role based on owner: {0}".format(kwargs['owner']) description="Auto generated role based on owner: {0}".format(
kwargs["owner"]
),
) )
# ensure that the authority's owner is also associated with the certificate # ensure that the authority's owner is also associated with the certificate
if kwargs.get('authority'): if kwargs.get("authority"):
authority_owner_role = role_service.get_by_name(kwargs['authority'].owner) authority_owner_role = role_service.get_by_name(kwargs["authority"].owner)
return [owner_role, authority_owner_role] return [owner_role, authority_owner_role]
return [owner_role] return [owner_role]
@ -190,16 +207,16 @@ def mint(**kwargs):
Support for multiple authorities is handled by individual plugins. Support for multiple authorities is handled by individual plugins.
""" """
authority = kwargs['authority'] authority = kwargs["authority"]
issuer = plugins.get(authority.plugin_name) issuer = plugins.get(authority.plugin_name)
# allow the CSR to be specified by the user # allow the CSR to be specified by the user
if not kwargs.get('csr'): if not kwargs.get("csr"):
csr, private_key = create_csr(**kwargs) csr, private_key = create_csr(**kwargs)
csr_created.send(authority=authority, csr=csr) csr_created.send(authority=authority, csr=csr)
else: else:
csr = str(kwargs.get('csr')) csr = str(kwargs.get("csr"))
private_key = None private_key = None
csr_imported.send(authority=authority, csr=csr) csr_imported.send(authority=authority, csr=csr)
@ -220,8 +237,8 @@ def import_certificate(**kwargs):
:param kwargs: :param kwargs:
""" """
if not kwargs.get('owner'): if not kwargs.get("owner"):
kwargs['owner'] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL')[0] kwargs["owner"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")[0]
return upload(**kwargs) return upload(**kwargs)
@ -232,16 +249,16 @@ def upload(**kwargs):
""" """
roles = create_certificate_roles(**kwargs) roles = create_certificate_roles(**kwargs)
if kwargs.get('roles'): if kwargs.get("roles"):
kwargs['roles'] += roles kwargs["roles"] += roles
else: else:
kwargs['roles'] = roles kwargs["roles"] = roles
cert = Certificate(**kwargs) cert = Certificate(**kwargs)
cert.authority = kwargs.get('authority') cert.authority = kwargs.get("authority")
cert = database.create(cert) cert = database.create(cert)
kwargs['creator'].certificates.append(cert) kwargs["creator"].certificates.append(cert)
cert = database.update(cert) cert = database.update(cert)
certificate_imported.send(certificate=cert, authority=cert.authority) certificate_imported.send(certificate=cert, authority=cert.authority)
@ -258,39 +275,45 @@ def create(**kwargs):
current_app.logger.error("Exception minting certificate", exc_info=True) current_app.logger.error("Exception minting certificate", exc_info=True)
sentry.captureException() sentry.captureException()
raise raise
kwargs['body'] = cert_body kwargs["body"] = cert_body
kwargs['private_key'] = private_key kwargs["private_key"] = private_key
kwargs['chain'] = cert_chain kwargs["chain"] = cert_chain
kwargs['external_id'] = external_id kwargs["external_id"] = external_id
kwargs['csr'] = csr kwargs["csr"] = csr
roles = create_certificate_roles(**kwargs) roles = create_certificate_roles(**kwargs)
if kwargs.get('roles'): if kwargs.get("roles"):
kwargs['roles'] += roles kwargs["roles"] += roles
else: else:
kwargs['roles'] = roles kwargs["roles"] = roles
if cert_body: if cert_body:
cert = Certificate(**kwargs) cert = Certificate(**kwargs)
kwargs['creator'].certificates.append(cert) kwargs["creator"].certificates.append(cert)
else: else:
cert = PendingCertificate(**kwargs) cert = PendingCertificate(**kwargs)
kwargs['creator'].pending_certificates.append(cert) kwargs["creator"].pending_certificates.append(cert)
cert.authority = kwargs['authority'] cert.authority = kwargs["authority"]
database.commit() database.commit()
if isinstance(cert, Certificate): if isinstance(cert, Certificate):
certificate_issued.send(certificate=cert, authority=cert.authority) certificate_issued.send(certificate=cert, authority=cert.authority)
metrics.send('certificate_issued', 'counter', 1, metric_tags=dict(owner=cert.owner, issuer=cert.issuer)) metrics.send(
"certificate_issued",
"counter",
1,
metric_tags=dict(owner=cert.owner, issuer=cert.issuer),
)
if isinstance(cert, PendingCertificate): if isinstance(cert, PendingCertificate):
# We need to refresh the pending certificate to avoid "Instance is not bound to a Session; " # We need to refresh the pending certificate to avoid "Instance is not bound to a Session; "
# "attribute refresh operation cannot proceed" # "attribute refresh operation cannot proceed"
pending_cert = database.session_query(PendingCertificate).get(cert.id) pending_cert = database.session_query(PendingCertificate).get(cert.id)
from lemur.common.celery import fetch_acme_cert from lemur.common.celery import fetch_acme_cert
if not current_app.config.get("ACME_DISABLE_AUTORESOLVE", False): if not current_app.config.get("ACME_DISABLE_AUTORESOLVE", False):
fetch_acme_cert.apply_async((pending_cert.id,), countdown=5) fetch_acme_cert.apply_async((pending_cert.id,), countdown=5)
@ -306,51 +329,55 @@ def render(args):
""" """
query = database.session_query(Certificate) query = database.session_query(Certificate)
time_range = args.pop('time_range') time_range = args.pop("time_range")
destination_id = args.pop('destination_id') destination_id = args.pop("destination_id")
notification_id = args.pop('notification_id', None) notification_id = args.pop("notification_id", None)
show = args.pop('show') show = args.pop("show")
# owner = args.pop('owner') # owner = args.pop('owner')
# creator = args.pop('creator') # TODO we should enabling filtering by owner # creator = args.pop('creator') # TODO we should enabling filtering by owner
filt = args.pop('filter') filt = args.pop("filter")
if filt: if filt:
terms = filt.split(';') terms = filt.split(";")
term = '%{0}%'.format(terms[1]) term = "%{0}%".format(terms[1])
# Exact matches for quotes. Only applies to name, issuer, and cn # Exact matches for quotes. Only applies to name, issuer, and cn
if terms[1].startswith('"') and terms[1].endswith('"'): if terms[1].startswith('"') and terms[1].endswith('"'):
term = terms[1][1:-1] term = terms[1][1:-1]
if 'issuer' in terms: if "issuer" in terms:
# we can't rely on issuer being correct in the cert directly so we combine queries # we can't rely on issuer being correct in the cert directly so we combine queries
sub_query = database.session_query(Authority.id) \ sub_query = (
.filter(Authority.name.ilike(term)) \ database.session_query(Authority.id)
.filter(Authority.name.ilike(term))
.subquery() .subquery()
)
query = query.filter( query = query.filter(
or_( or_(
Certificate.issuer.ilike(term), Certificate.issuer.ilike(term),
Certificate.authority_id.in_(sub_query) Certificate.authority_id.in_(sub_query),
) )
) )
elif 'destination' in terms: elif "destination" in terms:
query = query.filter(Certificate.destinations.any(Destination.id == terms[1])) query = query.filter(
elif 'notify' in filt: Certificate.destinations.any(Destination.id == terms[1])
)
elif "notify" in filt:
query = query.filter(Certificate.notify == truthiness(terms[1])) query = query.filter(Certificate.notify == truthiness(terms[1]))
elif 'active' in filt: elif "active" in filt:
query = query.filter(Certificate.active == truthiness(terms[1])) query = query.filter(Certificate.active == truthiness(terms[1]))
elif 'cn' in terms: elif "cn" in terms:
query = query.filter( query = query.filter(
or_( or_(
Certificate.cn.ilike(term), Certificate.cn.ilike(term),
Certificate.domains.any(Domain.name.ilike(term)) Certificate.domains.any(Domain.name.ilike(term)),
) )
) )
elif 'id' in terms: elif "id" in terms:
query = query.filter(Certificate.id == cast(terms[1], Integer)) query = query.filter(Certificate.id == cast(terms[1], Integer))
elif 'name' in terms: elif "name" in terms:
query = query.filter( query = query.filter(
or_( or_(
Certificate.name.ilike(term), Certificate.name.ilike(term),
@ -362,26 +389,35 @@ def render(args):
query = database.filter(query, Certificate, terms) query = database.filter(query, Certificate, terms)
if show: if show:
sub_query = database.session_query(Role.name).filter(Role.user_id == args['user'].id).subquery() sub_query = (
database.session_query(Role.name)
.filter(Role.user_id == args["user"].id)
.subquery()
)
query = query.filter( query = query.filter(
or_( or_(
Certificate.user_id == args['user'].id, Certificate.user_id == args["user"].id, Certificate.owner.in_(sub_query)
Certificate.owner.in_(sub_query)
) )
) )
if destination_id: if destination_id:
query = query.filter(Certificate.destinations.any(Destination.id == destination_id)) query = query.filter(
Certificate.destinations.any(Destination.id == destination_id)
)
if notification_id: if notification_id:
query = query.filter(Certificate.notifications.any(Notification.id == notification_id)) query = query.filter(
Certificate.notifications.any(Notification.id == notification_id)
)
if time_range: if time_range:
to = arrow.now().replace(weeks=+time_range).format('YYYY-MM-DD') to = arrow.now().replace(weeks=+time_range).format("YYYY-MM-DD")
now = arrow.now().format('YYYY-MM-DD') now = arrow.now().format("YYYY-MM-DD")
query = query.filter(Certificate.not_after <= to).filter(Certificate.not_after >= now) query = query.filter(Certificate.not_after <= to).filter(
Certificate.not_after >= now
)
if current_app.config.get('ALLOW_CERT_DELETION', False): if current_app.config.get("ALLOW_CERT_DELETION", False):
query = query.filter(Certificate.deleted == False) # noqa query = query.filter(Certificate.deleted == False) # noqa
result = database.sort_and_page(query, Certificate, args) result = database.sort_and_page(query, Certificate, args)
@ -409,18 +445,20 @@ def query_common_name(common_name, args):
:param args: :param args:
:return: :return:
""" """
owner = args.pop('owner') owner = args.pop("owner")
if not owner: if not owner:
owner = '%' owner = "%"
# only not expired certificates # only not expired certificates
current_time = arrow.utcnow() current_time = arrow.utcnow()
result = Certificate.query.filter(Certificate.cn.ilike(common_name)) \ result = (
.filter(Certificate.owner.ilike(owner))\ Certificate.query.filter(Certificate.cn.ilike(common_name))
.filter(Certificate.not_after >= current_time.format('YYYY-MM-DD')) \ .filter(Certificate.owner.ilike(owner))
.filter(Certificate.rotation.is_(True))\ .filter(Certificate.not_after >= current_time.format("YYYY-MM-DD"))
.filter(Certificate.rotation.is_(True))
.all() .all()
)
return result return result
@ -432,62 +470,77 @@ def create_csr(**csr_config):
:param csr_config: :param csr_config:
""" """
private_key = generate_private_key(csr_config.get('key_type')) private_key = generate_private_key(csr_config.get("key_type"))
builder = x509.CertificateSigningRequestBuilder() builder = x509.CertificateSigningRequestBuilder()
name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config['common_name'])] name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config["common_name"])]
if current_app.config.get('LEMUR_OWNER_EMAIL_IN_SUBJECT', True): if current_app.config.get("LEMUR_OWNER_EMAIL_IN_SUBJECT", True):
name_list.append(x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config['owner'])) name_list.append(
if 'organization' in csr_config and csr_config['organization'].strip(): x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config["owner"])
name_list.append(x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config['organization'])) )
if 'organizational_unit' in csr_config and csr_config['organizational_unit'].strip(): if "organization" in csr_config and csr_config["organization"].strip():
name_list.append(x509.NameAttribute(x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config['organizational_unit'])) name_list.append(
if 'country' in csr_config and csr_config['country'].strip(): x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config["organization"])
name_list.append(x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config['country'])) )
if 'state' in csr_config and csr_config['state'].strip(): if (
name_list.append(x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config['state'])) "organizational_unit" in csr_config
if 'location' in csr_config and csr_config['location'].strip(): and csr_config["organizational_unit"].strip()
name_list.append(x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config['location'])) ):
name_list.append(
x509.NameAttribute(
x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config["organizational_unit"]
)
)
if "country" in csr_config and csr_config["country"].strip():
name_list.append(
x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config["country"])
)
if "state" in csr_config and csr_config["state"].strip():
name_list.append(
x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config["state"])
)
if "location" in csr_config and csr_config["location"].strip():
name_list.append(
x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config["location"])
)
builder = builder.subject_name(x509.Name(name_list)) builder = builder.subject_name(x509.Name(name_list))
extensions = csr_config.get('extensions', {}) extensions = csr_config.get("extensions", {})
critical_extensions = ['basic_constraints', 'sub_alt_names', 'key_usage'] critical_extensions = ["basic_constraints", "sub_alt_names", "key_usage"]
noncritical_extensions = ['extended_key_usage'] noncritical_extensions = ["extended_key_usage"]
for k, v in extensions.items(): for k, v in extensions.items():
if v: if v:
if k in critical_extensions: if k in critical_extensions:
current_app.logger.debug('Adding Critical Extension: {0} {1}'.format(k, v)) current_app.logger.debug(
if k == 'sub_alt_names': "Adding Critical Extension: {0} {1}".format(k, v)
if v['names']: )
builder = builder.add_extension(v['names'], critical=True) if k == "sub_alt_names":
if v["names"]:
builder = builder.add_extension(v["names"], critical=True)
else: else:
builder = builder.add_extension(v, critical=True) builder = builder.add_extension(v, critical=True)
if k in noncritical_extensions: if k in noncritical_extensions:
current_app.logger.debug('Adding Extension: {0} {1}'.format(k, v)) current_app.logger.debug("Adding Extension: {0} {1}".format(k, v))
builder = builder.add_extension(v, critical=False) builder = builder.add_extension(v, critical=False)
ski = extensions.get('subject_key_identifier', {}) ski = extensions.get("subject_key_identifier", {})
if ski.get('include_ski', False): if ski.get("include_ski", False):
builder = builder.add_extension( builder = builder.add_extension(
x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()),
critical=False critical=False,
) )
request = builder.sign( request = builder.sign(private_key, hashes.SHA256(), default_backend())
private_key, hashes.SHA256(), default_backend()
)
# serialize our private key and CSR # serialize our private key and CSR
private_key = private_key.private_bytes( private_key = private_key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL, # would like to use PKCS8 but AWS ELBs don't like it format=serialization.PrivateFormat.TraditionalOpenSSL, # would like to use PKCS8 but AWS ELBs don't like it
encryption_algorithm=serialization.NoEncryption() encryption_algorithm=serialization.NoEncryption(),
).decode('utf-8') ).decode("utf-8")
csr = request.public_bytes( csr = request.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8")
encoding=serialization.Encoding.PEM
).decode('utf-8')
return csr, private_key return csr, private_key
@ -499,16 +552,19 @@ def stats(**kwargs):
:param kwargs: :param kwargs:
:return: :return:
""" """
if kwargs.get('metric') == 'not_after': if kwargs.get("metric") == "not_after":
start = arrow.utcnow() start = arrow.utcnow()
end = start.replace(weeks=+32) end = start.replace(weeks=+32)
items = database.db.session.query(Certificate.issuer, func.count(Certificate.id)) \ items = (
.group_by(Certificate.issuer) \ database.db.session.query(Certificate.issuer, func.count(Certificate.id))
.filter(Certificate.not_after <= end.format('YYYY-MM-DD')) \ .group_by(Certificate.issuer)
.filter(Certificate.not_after >= start.format('YYYY-MM-DD')).all() .filter(Certificate.not_after <= end.format("YYYY-MM-DD"))
.filter(Certificate.not_after >= start.format("YYYY-MM-DD"))
.all()
)
else: else:
attr = getattr(Certificate, kwargs.get('metric')) attr = getattr(Certificate, kwargs.get("metric"))
query = database.db.session.query(attr, func.count(attr)) query = database.db.session.query(attr, func.count(attr))
items = query.group_by(attr).all() items = query.group_by(attr).all()
@ -519,7 +575,7 @@ def stats(**kwargs):
keys.append(key) keys.append(key)
values.append(count) values.append(count)
return {'labels': keys, 'values': values} return {"labels": keys, "values": values}
def get_account_number(arn): def get_account_number(arn):
@ -566,22 +622,24 @@ def get_certificate_primitives(certificate):
certificate via `create`. certificate via `create`.
""" """
start, end = calculate_reissue_range(certificate.not_before, certificate.not_after) start, end = calculate_reissue_range(certificate.not_before, certificate.not_after)
ser = CertificateInputSchema().load(CertificateOutputSchema().dump(certificate).data) ser = CertificateInputSchema().load(
CertificateOutputSchema().dump(certificate).data
)
assert not ser.errors, "Error re-serializing certificate: %s" % ser.errors assert not ser.errors, "Error re-serializing certificate: %s" % ser.errors
data = ser.data data = ser.data
# we can't quite tell if we are using a custom name, as this is an automated process (typically) # we can't quite tell if we are using a custom name, as this is an automated process (typically)
# we will rely on the Lemur generated name # we will rely on the Lemur generated name
data.pop('name', None) data.pop("name", None)
# TODO this can be removed once we migrate away from cn # TODO this can be removed once we migrate away from cn
data['cn'] = data['common_name'] data["cn"] = data["common_name"]
# needed until we move off not_* # needed until we move off not_*
data['not_before'] = start data["not_before"] = start
data['not_after'] = end data["not_after"] = end
data['validity_start'] = start data["validity_start"] = start
data['validity_end'] = end data["validity_end"] = end
return data return data
@ -599,13 +657,13 @@ def reissue_certificate(certificate, replace=None, user=None):
# We do not want to re-use the CSR when creating a certificate because this defeats the purpose of rotation. # We do not want to re-use the CSR when creating a certificate because this defeats the purpose of rotation.
del primitives["csr"] del primitives["csr"]
if not user: if not user:
primitives['creator'] = certificate.user primitives["creator"] = certificate.user
else: else:
primitives['creator'] = user primitives["creator"] = user
if replace: if replace:
primitives['replaces'] = [certificate] primitives["replaces"] = [certificate]
new_cert = create(**primitives) new_cert = create(**primitives)

View File

@ -23,17 +23,18 @@ def get_sans_from_csr(data):
""" """
sub_alt_names = [] sub_alt_names = []
try: try:
request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend()) request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend())
except Exception: except Exception:
raise ValidationError('CSR presented is not valid.') raise ValidationError("CSR presented is not valid.")
try: try:
alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName) alt_names = request.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
for alt_name in alt_names.value: for alt_name in alt_names.value:
sub_alt_names.append({ sub_alt_names.append(
'nameType': type(alt_name).__name__, {"nameType": type(alt_name).__name__, "value": alt_name.value}
'value': alt_name.value )
})
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
pass pass

View File

@ -29,31 +29,45 @@ def ocsp_verify(cert, cert_path, issuer_chain_path):
:param issuer_chain_path: :param issuer_chain_path:
:return bool: True if certificate is valid, False otherwise :return bool: True if certificate is valid, False otherwise
""" """
command = ['openssl', 'x509', '-noout', '-ocsp_uri', '-in', cert_path] command = ["openssl", "x509", "-noout", "-ocsp_uri", "-in", cert_path]
p1 = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) p1 = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
url, err = p1.communicate() url, err = p1.communicate()
if not url: if not url:
current_app.logger.debug("No OCSP URL in certificate {}".format(cert.serial_number)) current_app.logger.debug(
"No OCSP URL in certificate {}".format(cert.serial_number)
)
return None return None
p2 = subprocess.Popen(['openssl', 'ocsp', '-issuer', issuer_chain_path, p2 = subprocess.Popen(
'-cert', cert_path, "-url", url.strip()], [
"openssl",
"ocsp",
"-issuer",
issuer_chain_path,
"-cert",
cert_path,
"-url",
url.strip(),
],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE,
)
message, err = p2.communicate() message, err = p2.communicate()
p_message = message.decode('utf-8') p_message = message.decode("utf-8")
if 'error' in p_message or 'Error' in p_message: if "error" in p_message or "Error" in p_message:
raise Exception("Got error when parsing OCSP url") raise Exception("Got error when parsing OCSP url")
elif 'revoked' in p_message: elif "revoked" in p_message:
current_app.logger.debug("OCSP reports certificate revoked: {}".format(cert.serial_number)) current_app.logger.debug(
"OCSP reports certificate revoked: {}".format(cert.serial_number)
)
return False return False
elif 'good' not in p_message: elif "good" not in p_message:
raise Exception("Did not receive a valid response") raise Exception("Did not receive a valid response")
return True return True
@ -73,7 +87,9 @@ def crl_verify(cert, cert_path):
x509.OID_CRL_DISTRIBUTION_POINTS x509.OID_CRL_DISTRIBUTION_POINTS
).value ).value
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
current_app.logger.debug("No CRLDP extension in certificate {}".format(cert.serial_number)) current_app.logger.debug(
"No CRLDP extension in certificate {}".format(cert.serial_number)
)
return None return None
for p in distribution_points: for p in distribution_points:
@ -92,8 +108,9 @@ def crl_verify(cert, cert_path):
except ConnectionError: except ConnectionError:
raise Exception("Unable to retrieve CRL: {0}".format(point)) raise Exception("Unable to retrieve CRL: {0}".format(point))
crl_cache[point] = x509.load_der_x509_crl(response.content, crl_cache[point] = x509.load_der_x509_crl(
backend=default_backend()) response.content, backend=default_backend()
)
else: else:
current_app.logger.debug("CRL point is cached {}".format(point)) current_app.logger.debug("CRL point is cached {}".format(point))
@ -110,8 +127,9 @@ def crl_verify(cert, cert_path):
except x509.ExtensionNotFound: except x509.ExtensionNotFound:
pass pass
current_app.logger.debug("CRL reports certificate " current_app.logger.debug(
"revoked: {}".format(cert.serial_number)) "CRL reports certificate " "revoked: {}".format(cert.serial_number)
)
return False return False
return True return True
@ -125,7 +143,7 @@ def verify(cert_path, issuer_chain_path):
:param issuer_chain_path: :param issuer_chain_path:
:return: True if valid, False otherwise :return: True if valid, False otherwise
""" """
with open(cert_path, 'rt') as c: with open(cert_path, "rt") as c:
try: try:
cert = parse_certificate(c.read()) cert = parse_certificate(c.read())
except ValueError as e: except ValueError as e:
@ -154,10 +172,10 @@ def verify_string(cert_string, issuer_string):
:return: True if valid, False otherwise :return: True if valid, False otherwise
""" """
with mktempfile() as cert_tmp: with mktempfile() as cert_tmp:
with open(cert_tmp, 'w') as f: with open(cert_tmp, "w") as f:
f.write(cert_string) f.write(cert_string)
with mktempfile() as issuer_tmp: with mktempfile() as issuer_tmp:
with open(issuer_tmp, 'w') as f: with open(issuer_tmp, "w") as f:
f.write(issuer_string) f.write(issuer_string)
status = verify(cert_tmp, issuer_tmp) status = verify(cert_tmp, issuer_tmp)
return status return status

View File

@ -26,14 +26,14 @@ from lemur.certificates.schemas import (
certificate_upload_input_schema, certificate_upload_input_schema,
certificates_output_schema, certificates_output_schema,
certificate_export_input_schema, certificate_export_input_schema,
certificate_edit_input_schema certificate_edit_input_schema,
) )
from lemur.roles import service as role_service from lemur.roles import service as role_service
from lemur.logs import service as log_service from lemur.logs import service as log_service
mod = Blueprint('certificates', __name__) mod = Blueprint("certificates", __name__)
api = Api(mod) api = Api(mod)
@ -128,8 +128,8 @@ class CertificatesListValid(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
args = parser.parse_args() args = parser.parse_args()
args['user'] = g.user args["user"] = g.user
common_name = args['filter'].split(';')[1] common_name = args["filter"].split(";")[1]
return service.query_common_name(common_name, args) return service.query_common_name(common_name, args)
@ -228,16 +228,18 @@ class CertificatesNameQuery(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
parser.add_argument('timeRange', type=int, dest='time_range', location='args') parser.add_argument("timeRange", type=int, dest="time_range", location="args")
parser.add_argument('owner', type=inputs.boolean, location='args') parser.add_argument("owner", type=inputs.boolean, location="args")
parser.add_argument('id', type=str, location='args') parser.add_argument("id", type=str, location="args")
parser.add_argument('active', type=inputs.boolean, location='args') parser.add_argument("active", type=inputs.boolean, location="args")
parser.add_argument('destinationId', type=int, dest="destination_id", location='args') parser.add_argument(
parser.add_argument('creator', type=str, location='args') "destinationId", type=int, dest="destination_id", location="args"
parser.add_argument('show', type=str, location='args') )
parser.add_argument("creator", type=str, location="args")
parser.add_argument("show", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
args['user'] = g.user args["user"] = g.user
return service.query_name(certificate_name, args) return service.query_name(certificate_name, args)
@ -336,16 +338,18 @@ class CertificatesList(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
parser.add_argument('timeRange', type=int, dest='time_range', location='args') parser.add_argument("timeRange", type=int, dest="time_range", location="args")
parser.add_argument('owner', type=inputs.boolean, location='args') parser.add_argument("owner", type=inputs.boolean, location="args")
parser.add_argument('id', type=str, location='args') parser.add_argument("id", type=str, location="args")
parser.add_argument('active', type=inputs.boolean, location='args') parser.add_argument("active", type=inputs.boolean, location="args")
parser.add_argument('destinationId', type=int, dest="destination_id", location='args') parser.add_argument(
parser.add_argument('creator', type=str, location='args') "destinationId", type=int, dest="destination_id", location="args"
parser.add_argument('show', type=str, location='args') )
parser.add_argument("creator", type=str, location="args")
parser.add_argument("show", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
args['user'] = g.user args["user"] = g.user
return service.render(args) return service.render(args)
@validate_schema(certificate_input_schema, certificate_output_schema) @validate_schema(certificate_input_schema, certificate_output_schema)
@ -463,24 +467,31 @@ class CertificatesList(AuthenticatedResource):
:statuscode 403: unauthenticated :statuscode 403: unauthenticated
""" """
role = role_service.get_by_name(data['authority'].owner) role = role_service.get_by_name(data["authority"].owner)
# all the authority role members should be allowed # all the authority role members should be allowed
roles = [x.name for x in data['authority'].roles] roles = [x.name for x in data["authority"].roles]
# allow "owner" roles by team DL # allow "owner" roles by team DL
roles.append(role) roles.append(role)
authority_permission = AuthorityPermission(data['authority'].id, roles) authority_permission = AuthorityPermission(data["authority"].id, roles)
if authority_permission.can(): if authority_permission.can():
data['creator'] = g.user data["creator"] = g.user
cert = service.create(**data) cert = service.create(**data)
if isinstance(cert, Certificate): if isinstance(cert, Certificate):
# only log if created, not pending # only log if created, not pending
log_service.create(g.user, 'create_cert', certificate=cert) log_service.create(g.user, "create_cert", certificate=cert)
return cert return cert
return dict(message="You are not authorized to use the authority: {0}".format(data['authority'].name)), 403 return (
dict(
message="You are not authorized to use the authority: {0}".format(
data["authority"].name
)
),
403,
)
class CertificatesUpload(AuthenticatedResource): class CertificatesUpload(AuthenticatedResource):
@ -583,12 +594,14 @@ class CertificatesUpload(AuthenticatedResource):
:statuscode 200: no error :statuscode 200: no error
""" """
data['creator'] = g.user data["creator"] = g.user
if data.get('destinations'): if data.get("destinations"):
if data.get('private_key'): if data.get("private_key"):
return service.upload(**data) return service.upload(**data)
else: else:
raise Exception("Private key must be provided in order to upload certificate to AWS") raise Exception(
"Private key must be provided in order to upload certificate to AWS"
)
return service.upload(**data) return service.upload(**data)
@ -600,10 +613,12 @@ class CertificatesStats(AuthenticatedResource):
super(CertificatesStats, self).__init__() super(CertificatesStats, self).__init__()
def get(self): def get(self):
self.reqparse.add_argument('metric', type=str, location='args') self.reqparse.add_argument("metric", type=str, location="args")
self.reqparse.add_argument('range', default=32, type=int, location='args') self.reqparse.add_argument("range", default=32, type=int, location="args")
self.reqparse.add_argument('destinationId', dest='destination_id', location='args') self.reqparse.add_argument(
self.reqparse.add_argument('active', type=str, default='true', location='args') "destinationId", dest="destination_id", location="args"
)
self.reqparse.add_argument("active", type=str, default="true", location="args")
args = self.reqparse.parse_args() args = self.reqparse.parse_args()
@ -655,12 +670,12 @@ class CertificatePrivateKey(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can(): if not permission.can():
return dict(message='You are not authorized to view this key'), 403 return dict(message="You are not authorized to view this key"), 403
log_service.create(g.current_user, 'key_view', certificate=cert) log_service.create(g.current_user, "key_view", certificate=cert)
response = make_response(jsonify(key=cert.private_key), 200) response = make_response(jsonify(key=cert.private_key), 200)
response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store' response.headers["cache-control"] = "private, max-age=0, no-cache, no-store"
response.headers['pragma'] = 'no-cache' response.headers["pragma"] = "no-cache"
return response return response
@ -850,19 +865,25 @@ class Certificates(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can(): if not permission.can():
return dict(message='You are not authorized to update this certificate'), 403 return (
dict(message="You are not authorized to update this certificate"),
403,
)
for destination in data['destinations']: for destination in data["destinations"]:
if destination.plugin.requires_key: if destination.plugin.requires_key:
if not cert.private_key: if not cert.private_key:
return dict( return (
message='Unable to add destination: {0}. Certificate does not have required private key.'.format( dict(
message="Unable to add destination: {0}. Certificate does not have required private key.".format(
destination.label destination.label
) )
), 400 ),
400,
)
cert = service.update(certificate_id, **data) cert = service.update(certificate_id, **data)
log_service.create(g.current_user, 'update_cert', certificate=cert) log_service.create(g.current_user, "update_cert", certificate=cert)
return cert return cert
def delete(self, certificate_id, data=None): def delete(self, certificate_id, data=None):
@ -891,7 +912,7 @@ class Certificates(AuthenticatedResource):
:statuscode 405: certificate deletion is disabled :statuscode 405: certificate deletion is disabled
""" """
if not current_app.config.get('ALLOW_CERT_DELETION', False): if not current_app.config.get("ALLOW_CERT_DELETION", False):
return dict(message="Certificate deletion is disabled"), 405 return dict(message="Certificate deletion is disabled"), 405
cert = service.get(certificate_id) cert = service.get(certificate_id)
@ -908,11 +929,14 @@ class Certificates(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can(): if not permission.can():
return dict(message='You are not authorized to delete this certificate'), 403 return (
dict(message="You are not authorized to delete this certificate"),
403,
)
service.update(certificate_id, deleted=True) service.update(certificate_id, deleted=True)
log_service.create(g.current_user, 'delete_cert', certificate=cert) log_service.create(g.current_user, "delete_cert", certificate=cert)
return 'Certificate deleted', 204 return "Certificate deleted", 204
class NotificationCertificatesList(AuthenticatedResource): class NotificationCertificatesList(AuthenticatedResource):
@ -1012,17 +1036,19 @@ class NotificationCertificatesList(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
parser.add_argument('timeRange', type=int, dest='time_range', location='args') parser.add_argument("timeRange", type=int, dest="time_range", location="args")
parser.add_argument('owner', type=inputs.boolean, location='args') parser.add_argument("owner", type=inputs.boolean, location="args")
parser.add_argument('id', type=str, location='args') parser.add_argument("id", type=str, location="args")
parser.add_argument('active', type=inputs.boolean, location='args') parser.add_argument("active", type=inputs.boolean, location="args")
parser.add_argument('destinationId', type=int, dest="destination_id", location='args') parser.add_argument(
parser.add_argument('creator', type=str, location='args') "destinationId", type=int, dest="destination_id", location="args"
parser.add_argument('show', type=str, location='args') )
parser.add_argument("creator", type=str, location="args")
parser.add_argument("show", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
args['notification_id'] = notification_id args["notification_id"] = notification_id
args['user'] = g.current_user args["user"] = g.current_user
return service.render(args) return service.render(args)
@ -1195,30 +1221,48 @@ class CertificateExport(AuthenticatedResource):
if not cert: if not cert:
return dict(message="Cannot find specified certificate"), 404 return dict(message="Cannot find specified certificate"), 404
plugin = data['plugin']['plugin_object'] plugin = data["plugin"]["plugin_object"]
if plugin.requires_key: if plugin.requires_key:
if not cert.private_key: if not cert.private_key:
return dict( return (
message='Unable to export certificate, plugin: {0} requires a private key but no key was found.'.format( dict(
plugin.slug)), 400 message="Unable to export certificate, plugin: {0} requires a private key but no key was found.".format(
plugin.slug
)
),
400,
)
else: else:
# allow creators # allow creators
if g.current_user != cert.user: if g.current_user != cert.user:
owner_role = role_service.get_by_name(cert.owner) owner_role = role_service.get_by_name(cert.owner)
permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) permission = CertificatePermission(
owner_role, [x.name for x in cert.roles]
)
if not permission.can(): if not permission.can():
return dict(message='You are not authorized to export this certificate.'), 403 return (
dict(
message="You are not authorized to export this certificate."
),
403,
)
options = data['plugin']['plugin_options'] options = data["plugin"]["plugin_options"]
log_service.create(g.current_user, 'key_view', certificate=cert) log_service.create(g.current_user, "key_view", certificate=cert)
extension, passphrase, data = plugin.export(cert.body, cert.chain, cert.private_key, options) extension, passphrase, data = plugin.export(
cert.body, cert.chain, cert.private_key, options
)
# we take a hit in message size when b64 encoding # we take a hit in message size when b64 encoding
return dict(extension=extension, passphrase=passphrase, data=base64.b64encode(data).decode('utf-8')) return dict(
extension=extension,
passphrase=passphrase,
data=base64.b64encode(data).decode("utf-8"),
)
class CertificateRevoke(AuthenticatedResource): class CertificateRevoke(AuthenticatedResource):
@ -1269,30 +1313,66 @@ class CertificateRevoke(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles]) permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can(): if not permission.can():
return dict(message='You are not authorized to revoke this certificate.'), 403 return (
dict(message="You are not authorized to revoke this certificate."),
403,
)
if not cert.external_id: if not cert.external_id:
return dict(message='Cannot revoke certificate. No external id found.'), 400 return dict(message="Cannot revoke certificate. No external id found."), 400
if cert.endpoints: if cert.endpoints:
return dict(message='Cannot revoke certificate. Endpoints are deployed with the given certificate.'), 403 return (
dict(
message="Cannot revoke certificate. Endpoints are deployed with the given certificate."
),
403,
)
plugin = plugins.get(cert.authority.plugin_name) plugin = plugins.get(cert.authority.plugin_name)
plugin.revoke_certificate(cert, data) plugin.revoke_certificate(cert, data)
log_service.create(g.current_user, 'revoke_cert', certificate=cert) log_service.create(g.current_user, "revoke_cert", certificate=cert)
return dict(id=cert.id) return dict(id=cert.id)
api.add_resource(CertificateRevoke, '/certificates/<int:certificate_id>/revoke', endpoint='revokeCertificate') api.add_resource(
api.add_resource(CertificatesNameQuery, '/certificates/name/<string:certificate_name>', endpoint='certificatesNameQuery') CertificateRevoke,
api.add_resource(CertificatesList, '/certificates', endpoint='certificates') "/certificates/<int:certificate_id>/revoke",
api.add_resource(CertificatesListValid, '/certificates/valid', endpoint='certificatesListValid') endpoint="revokeCertificate",
api.add_resource(Certificates, '/certificates/<int:certificate_id>', endpoint='certificate') )
api.add_resource(CertificatesStats, '/certificates/stats', endpoint='certificateStats') api.add_resource(
api.add_resource(CertificatesUpload, '/certificates/upload', endpoint='certificateUpload') CertificatesNameQuery,
api.add_resource(CertificatePrivateKey, '/certificates/<int:certificate_id>/key', endpoint='privateKeyCertificates') "/certificates/name/<string:certificate_name>",
api.add_resource(CertificateExport, '/certificates/<int:certificate_id>/export', endpoint='exportCertificate') endpoint="certificatesNameQuery",
api.add_resource(NotificationCertificatesList, '/notifications/<int:notification_id>/certificates', )
endpoint='notificationCertificates') api.add_resource(CertificatesList, "/certificates", endpoint="certificates")
api.add_resource(CertificatesReplacementsList, '/certificates/<int:certificate_id>/replacements', api.add_resource(
endpoint='replacements') CertificatesListValid, "/certificates/valid", endpoint="certificatesListValid"
)
api.add_resource(
Certificates, "/certificates/<int:certificate_id>", endpoint="certificate"
)
api.add_resource(CertificatesStats, "/certificates/stats", endpoint="certificateStats")
api.add_resource(
CertificatesUpload, "/certificates/upload", endpoint="certificateUpload"
)
api.add_resource(
CertificatePrivateKey,
"/certificates/<int:certificate_id>/key",
endpoint="privateKeyCertificates",
)
api.add_resource(
CertificateExport,
"/certificates/<int:certificate_id>/export",
endpoint="exportCertificate",
)
api.add_resource(
NotificationCertificatesList,
"/notifications/<int:notification_id>/certificates",
endpoint="notificationCertificates",
)
api.add_resource(
CertificatesReplacementsList,
"/certificates/<int:certificate_id>/replacements",
endpoint="replacements",
)

View File

@ -32,8 +32,11 @@ else:
def make_celery(app): def make_celery(app):
celery = Celery(app.import_name, backend=app.config.get('CELERY_RESULT_BACKEND'), celery = Celery(
broker=app.config.get('CELERY_BROKER_URL')) app.import_name,
backend=app.config.get("CELERY_RESULT_BACKEND"),
broker=app.config.get("CELERY_BROKER_URL"),
)
celery.conf.update(app.config) celery.conf.update(app.config)
TaskBase = celery.Task TaskBase = celery.Task
@ -53,6 +56,7 @@ celery = make_celery(flask_app)
def is_task_active(fun, task_id, args): def is_task_active(fun, task_id, args):
from celery.task.control import inspect from celery.task.control import inspect
i = inspect() i = inspect()
active_tasks = i.active() active_tasks = i.active()
for _, tasks in active_tasks.items(): for _, tasks in active_tasks.items():
@ -99,7 +103,7 @@ def fetch_acme_cert(id):
# We only care about certs using the acme-issuer plugin # We only care about certs using the acme-issuer plugin
for cert in pending_certs: for cert in pending_certs:
cert_authority = get_authority(cert.authority_id) cert_authority = get_authority(cert.authority_id)
if cert_authority.plugin_name == 'acme-issuer': if cert_authority.plugin_name == "acme-issuer":
acme_certs.append(cert) acme_certs.append(cert)
else: else:
wrong_issuer += 1 wrong_issuer += 1
@ -112,20 +116,22 @@ def fetch_acme_cert(id):
# It's necessary to reload the pending cert due to detached instance: http://sqlalche.me/e/bhk3 # It's necessary to reload the pending cert due to detached instance: http://sqlalche.me/e/bhk3
pending_cert = pending_certificate_service.get(cert.get("pending_cert").id) pending_cert = pending_certificate_service.get(cert.get("pending_cert").id)
if not pending_cert: if not pending_cert:
log_data["message"] = "Pending certificate doesn't exist anymore. Was it resolved by another process?" log_data[
"message"
] = "Pending certificate doesn't exist anymore. Was it resolved by another process?"
current_app.logger.error(log_data) current_app.logger.error(log_data)
continue continue
if real_cert: if real_cert:
# If a real certificate was returned from issuer, then create it in Lemur and mark # If a real certificate was returned from issuer, then create it in Lemur and mark
# the pending certificate as resolved # the pending certificate as resolved
final_cert = pending_certificate_service.create_certificate(pending_cert, real_cert, pending_cert.user) final_cert = pending_certificate_service.create_certificate(
pending_certificate_service.update( pending_cert, real_cert, pending_cert.user
cert.get("pending_cert").id,
resolved_cert_id=final_cert.id
) )
pending_certificate_service.update( pending_certificate_service.update(
cert.get("pending_cert").id, cert.get("pending_cert").id, resolved_cert_id=final_cert.id
resolved=True )
pending_certificate_service.update(
cert.get("pending_cert").id, resolved=True
) )
# add metrics to metrics extension # add metrics to metrics extension
new += 1 new += 1
@ -139,17 +145,17 @@ def fetch_acme_cert(id):
if pending_cert.number_attempts > 4: if pending_cert.number_attempts > 4:
error_log["message"] = "Deleting pending certificate" error_log["message"] = "Deleting pending certificate"
send_pending_failure_notification(pending_cert, notify_owner=pending_cert.notify) send_pending_failure_notification(
pending_cert, notify_owner=pending_cert.notify
)
# Mark the pending cert as resolved # Mark the pending cert as resolved
pending_certificate_service.update( pending_certificate_service.update(
cert.get("pending_cert").id, cert.get("pending_cert").id, resolved=True
resolved=True
) )
else: else:
pending_certificate_service.increment_attempt(pending_cert) pending_certificate_service.increment_attempt(pending_cert)
pending_certificate_service.update( pending_certificate_service.update(
cert.get("pending_cert").id, cert.get("pending_cert").id, status=str(cert.get("last_error"))
status=str(cert.get("last_error"))
) )
# Add failed pending cert task back to queue # Add failed pending cert task back to queue
fetch_acme_cert.delay(id) fetch_acme_cert.delay(id)
@ -161,9 +167,7 @@ def fetch_acme_cert(id):
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
print( print(
"[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format( "[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format(
new=new, new=new, failed=failed, wrong_issuer=wrong_issuer
failed=failed,
wrong_issuer=wrong_issuer
) )
) )
@ -175,7 +179,7 @@ def fetch_all_pending_acme_certs():
log_data = { log_data = {
"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name), "function": "{}.{}".format(__name__, sys._getframe().f_code.co_name),
"message": "Starting job." "message": "Starting job.",
} }
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
@ -183,7 +187,7 @@ def fetch_all_pending_acme_certs():
# We only care about certs using the acme-issuer plugin # We only care about certs using the acme-issuer plugin
for cert in pending_certs: for cert in pending_certs:
cert_authority = get_authority(cert.authority_id) cert_authority = get_authority(cert.authority_id)
if cert_authority.plugin_name == 'acme-issuer': if cert_authority.plugin_name == "acme-issuer":
if datetime.now(timezone.utc) - cert.last_updated > timedelta(minutes=5): if datetime.now(timezone.utc) - cert.last_updated > timedelta(minutes=5):
log_data["message"] = "Triggering job for cert {}".format(cert.name) log_data["message"] = "Triggering job for cert {}".format(cert.name)
log_data["cert_name"] = cert.name log_data["cert_name"] = cert.name
@ -195,17 +199,15 @@ def fetch_all_pending_acme_certs():
@celery.task() @celery.task()
def remove_old_acme_certs(): def remove_old_acme_certs():
"""Prune old pending acme certificates from the database""" """Prune old pending acme certificates from the database"""
log_data = { log_data = {"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)}
"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name) pending_certs = pending_certificate_service.get_pending_certs("all")
}
pending_certs = pending_certificate_service.get_pending_certs('all')
# Delete pending certs more than a week old # Delete pending certs more than a week old
for cert in pending_certs: for cert in pending_certs:
if datetime.now(timezone.utc) - cert.last_updated > timedelta(days=7): if datetime.now(timezone.utc) - cert.last_updated > timedelta(days=7):
log_data['pending_cert_id'] = cert.id log_data["pending_cert_id"] = cert.id
log_data['pending_cert_name'] = cert.name log_data["pending_cert_name"] = cert.name
log_data['message'] = "Deleting pending certificate" log_data["message"] = "Deleting pending certificate"
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
pending_certificate_service.delete(cert) pending_certificate_service.delete(cert)
@ -218,7 +220,9 @@ def clean_all_sources():
""" """
sources = validate_sources("all") sources = validate_sources("all")
for source in sources: for source in sources:
current_app.logger.debug("Creating celery task to clean source {}".format(source.label)) current_app.logger.debug(
"Creating celery task to clean source {}".format(source.label)
)
clean_source.delay(source.label) clean_source.delay(source.label)
@ -242,7 +246,9 @@ def sync_all_sources():
""" """
sources = validate_sources("all") sources = validate_sources("all")
for source in sources: for source in sources:
current_app.logger.debug("Creating celery task to sync source {}".format(source.label)) current_app.logger.debug(
"Creating celery task to sync source {}".format(source.label)
)
sync_source.delay(source.label) sync_source.delay(source.label)
@ -277,7 +283,9 @@ def sync_source(source):
log_data["message"] = "Error syncing source: Time limit exceeded." log_data["message"] = "Error syncing source: Time limit exceeded."
current_app.logger.error(log_data) current_app.logger.error(log_data)
sentry.captureException() sentry.captureException()
metrics.send('sync_source_timeout', 'counter', 1, metric_tags={'source': source}) metrics.send(
"sync_source_timeout", "counter", 1, metric_tags={"source": source}
)
return return
log_data["message"] = "Done syncing source" log_data["message"] = "Done syncing source"

View File

@ -9,18 +9,20 @@ from lemur.extensions import sentry
from lemur.constants import SAN_NAMING_TEMPLATE, DEFAULT_NAMING_TEMPLATE from lemur.constants import SAN_NAMING_TEMPLATE, DEFAULT_NAMING_TEMPLATE
def text_to_slug(value, joiner='-'): def text_to_slug(value, joiner="-"):
""" """
Normalize a string to a "slug" value, stripping character accents and removing non-alphanum characters. Normalize a string to a "slug" value, stripping character accents and removing non-alphanum characters.
A series of non-alphanumeric characters is replaced with the joiner character. A series of non-alphanumeric characters is replaced with the joiner character.
""" """
# Strip all character accents: decompose Unicode characters and then drop combining chars. # Strip all character accents: decompose Unicode characters and then drop combining chars.
value = ''.join(c for c in unicodedata.normalize('NFKD', value) if not unicodedata.combining(c)) value = "".join(
c for c in unicodedata.normalize("NFKD", value) if not unicodedata.combining(c)
)
# Replace all remaining non-alphanumeric characters with joiner string. Multiple characters get collapsed into a # Replace all remaining non-alphanumeric characters with joiner string. Multiple characters get collapsed into a
# single joiner. Except, keep 'xn--' used in IDNA domain names as is. # single joiner. Except, keep 'xn--' used in IDNA domain names as is.
value = re.sub(r'[^A-Za-z0-9.]+(?<!xn--)', joiner, value) value = re.sub(r"[^A-Za-z0-9.]+(?<!xn--)", joiner, value)
# '-' in the beginning or end of string looks ugly. # '-' in the beginning or end of string looks ugly.
return value.strip(joiner) return value.strip(joiner)
@ -48,12 +50,12 @@ def certificate_name(common_name, issuer, not_before, not_after, san):
temp = t.format( temp = t.format(
subject=common_name, subject=common_name,
issuer=issuer.replace(' ', ''), issuer=issuer.replace(" ", ""),
not_before=not_before.strftime('%Y%m%d'), not_before=not_before.strftime("%Y%m%d"),
not_after=not_after.strftime('%Y%m%d') not_after=not_after.strftime("%Y%m%d"),
) )
temp = temp.replace('*', "WILDCARD") temp = temp.replace("*", "WILDCARD")
return text_to_slug(temp) return text_to_slug(temp)
@ -69,9 +71,9 @@ def common_name(cert):
:return: Common name or None :return: Common name or None
""" """
try: try:
return cert.subject.get_attributes_for_oid( return cert.subject.get_attributes_for_oid(x509.OID_COMMON_NAME)[
x509.OID_COMMON_NAME 0
)[0].value.strip() ].value.strip()
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.error("Unable to get common name! {0}".format(e)) current_app.logger.error("Unable to get common name! {0}".format(e))
@ -84,9 +86,9 @@ def organization(cert):
:return: :return:
""" """
try: try:
return cert.subject.get_attributes_for_oid( return cert.subject.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)[
x509.OID_ORGANIZATION_NAME 0
)[0].value.strip() ].value.strip()
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.error("Unable to get organization! {0}".format(e)) current_app.logger.error("Unable to get organization! {0}".format(e))
@ -99,9 +101,9 @@ def organizational_unit(cert):
:return: :return:
""" """
try: try:
return cert.subject.get_attributes_for_oid( return cert.subject.get_attributes_for_oid(x509.OID_ORGANIZATIONAL_UNIT_NAME)[
x509.OID_ORGANIZATIONAL_UNIT_NAME 0
)[0].value.strip() ].value.strip()
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.error("Unable to get organizational unit! {0}".format(e)) current_app.logger.error("Unable to get organizational unit! {0}".format(e))
@ -114,9 +116,9 @@ def country(cert):
:return: :return:
""" """
try: try:
return cert.subject.get_attributes_for_oid( return cert.subject.get_attributes_for_oid(x509.OID_COUNTRY_NAME)[
x509.OID_COUNTRY_NAME 0
)[0].value.strip() ].value.strip()
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.error("Unable to get country! {0}".format(e)) current_app.logger.error("Unable to get country! {0}".format(e))
@ -129,9 +131,9 @@ def state(cert):
:return: :return:
""" """
try: try:
return cert.subject.get_attributes_for_oid( return cert.subject.get_attributes_for_oid(x509.OID_STATE_OR_PROVINCE_NAME)[
x509.OID_STATE_OR_PROVINCE_NAME 0
)[0].value.strip() ].value.strip()
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.error("Unable to get state! {0}".format(e)) current_app.logger.error("Unable to get state! {0}".format(e))
@ -144,9 +146,9 @@ def location(cert):
:return: :return:
""" """
try: try:
return cert.subject.get_attributes_for_oid( return cert.subject.get_attributes_for_oid(x509.OID_LOCALITY_NAME)[
x509.OID_LOCALITY_NAME 0
)[0].value.strip() ].value.strip()
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
current_app.logger.error("Unable to get location! {0}".format(e)) current_app.logger.error("Unable to get location! {0}".format(e))
@ -224,7 +226,7 @@ def bitstrength(cert):
return cert.public_key().key_size return cert.public_key().key_size
except AttributeError: except AttributeError:
sentry.captureException() sentry.captureException()
current_app.logger.debug('Unable to get bitstrength.') current_app.logger.debug("Unable to get bitstrength.")
def issuer(cert): def issuer(cert):
@ -239,16 +241,19 @@ def issuer(cert):
""" """
# If certificate is self-signed, we return a special value -- there really is no distinct "issuer" for it # If certificate is self-signed, we return a special value -- there really is no distinct "issuer" for it
if is_selfsigned(cert): if is_selfsigned(cert):
return '<selfsigned>' return "<selfsigned>"
# Try Common Name or fall back to Organization name # Try Common Name or fall back to Organization name
attrs = (cert.issuer.get_attributes_for_oid(x509.OID_COMMON_NAME) or attrs = cert.issuer.get_attributes_for_oid(
cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)) x509.OID_COMMON_NAME
) or cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)
if not attrs: if not attrs:
current_app.logger.error("Unable to get issuer! Cert serial {:x}".format(cert.serial_number)) current_app.logger.error(
return '<unknown>' "Unable to get issuer! Cert serial {:x}".format(cert.serial_number)
)
return "<unknown>"
return text_to_slug(attrs[0].value, '') return text_to_slug(attrs[0].value, "")
def not_before(cert): def not_before(cert):

View File

@ -25,6 +25,7 @@ class Hex(Field):
""" """
A hex formatted string. A hex formatted string.
""" """
def _serialize(self, value, attr, obj): def _serialize(self, value, attr, obj):
if value: if value:
value = hex(int(value))[2:].upper() value = hex(int(value))[2:].upper()
@ -48,25 +49,25 @@ class ArrowDateTime(Field):
""" """
DATEFORMAT_SERIALIZATION_FUNCS = { DATEFORMAT_SERIALIZATION_FUNCS = {
'iso': utils.isoformat, "iso": utils.isoformat,
'iso8601': utils.isoformat, "iso8601": utils.isoformat,
'rfc': utils.rfcformat, "rfc": utils.rfcformat,
'rfc822': utils.rfcformat, "rfc822": utils.rfcformat,
} }
DATEFORMAT_DESERIALIZATION_FUNCS = { DATEFORMAT_DESERIALIZATION_FUNCS = {
'iso': utils.from_iso, "iso": utils.from_iso,
'iso8601': utils.from_iso, "iso8601": utils.from_iso,
'rfc': utils.from_rfc, "rfc": utils.from_rfc,
'rfc822': utils.from_rfc, "rfc822": utils.from_rfc,
} }
DEFAULT_FORMAT = 'iso' DEFAULT_FORMAT = "iso"
localtime = False localtime = False
default_error_messages = { default_error_messages = {
'invalid': 'Not a valid datetime.', "invalid": "Not a valid datetime.",
'format': '"{input}" cannot be formatted as a datetime.', "format": '"{input}" cannot be formatted as a datetime.',
} }
def __init__(self, format=None, **kwargs): def __init__(self, format=None, **kwargs):
@ -89,34 +90,36 @@ class ArrowDateTime(Field):
try: try:
return format_func(value, localtime=self.localtime) return format_func(value, localtime=self.localtime)
except (AttributeError, ValueError) as err: except (AttributeError, ValueError) as err:
self.fail('format', input=value) self.fail("format", input=value)
else: else:
return value.strftime(self.dateformat) return value.strftime(self.dateformat)
def _deserialize(self, value, attr, data): def _deserialize(self, value, attr, data):
if not value: # Falsy values, e.g. '', None, [] are not valid if not value: # Falsy values, e.g. '', None, [] are not valid
raise self.fail('invalid') raise self.fail("invalid")
self.dateformat = self.dateformat or self.DEFAULT_FORMAT self.dateformat = self.dateformat or self.DEFAULT_FORMAT
func = self.DATEFORMAT_DESERIALIZATION_FUNCS.get(self.dateformat) func = self.DATEFORMAT_DESERIALIZATION_FUNCS.get(self.dateformat)
if func: if func:
try: try:
return arrow.get(func(value)) return arrow.get(func(value))
except (TypeError, AttributeError, ValueError): except (TypeError, AttributeError, ValueError):
raise self.fail('invalid') raise self.fail("invalid")
elif self.dateformat: elif self.dateformat:
try: try:
return dt.datetime.strptime(value, self.dateformat) return dt.datetime.strptime(value, self.dateformat)
except (TypeError, AttributeError, ValueError): except (TypeError, AttributeError, ValueError):
raise self.fail('invalid') raise self.fail("invalid")
elif utils.dateutil_available: elif utils.dateutil_available:
try: try:
return arrow.get(utils.from_datestring(value)) return arrow.get(utils.from_datestring(value))
except TypeError: except TypeError:
raise self.fail('invalid') raise self.fail("invalid")
else: else:
warnings.warn('It is recommended that you install python-dateutil ' warnings.warn(
'for improved datetime deserialization.') "It is recommended that you install python-dateutil "
raise self.fail('invalid') "for improved datetime deserialization."
)
raise self.fail("invalid")
class KeyUsageExtension(Field): class KeyUsageExtension(Field):
@ -131,73 +134,75 @@ class KeyUsageExtension(Field):
def _serialize(self, value, attr, obj): def _serialize(self, value, attr, obj):
return { return {
'useDigitalSignature': value.digital_signature, "useDigitalSignature": value.digital_signature,
'useNonRepudiation': value.content_commitment, "useNonRepudiation": value.content_commitment,
'useKeyEncipherment': value.key_encipherment, "useKeyEncipherment": value.key_encipherment,
'useDataEncipherment': value.data_encipherment, "useDataEncipherment": value.data_encipherment,
'useKeyAgreement': value.key_agreement, "useKeyAgreement": value.key_agreement,
'useKeyCertSign': value.key_cert_sign, "useKeyCertSign": value.key_cert_sign,
'useCRLSign': value.crl_sign, "useCRLSign": value.crl_sign,
'useEncipherOnly': value._encipher_only, "useEncipherOnly": value._encipher_only,
'useDecipherOnly': value._decipher_only "useDecipherOnly": value._decipher_only,
} }
def _deserialize(self, value, attr, data): def _deserialize(self, value, attr, data):
keyusages = { keyusages = {
'digital_signature': False, "digital_signature": False,
'content_commitment': False, "content_commitment": False,
'key_encipherment': False, "key_encipherment": False,
'data_encipherment': False, "data_encipherment": False,
'key_agreement': False, "key_agreement": False,
'key_cert_sign': False, "key_cert_sign": False,
'crl_sign': False, "crl_sign": False,
'encipher_only': False, "encipher_only": False,
'decipher_only': False "decipher_only": False,
} }
for k, v in value.items(): for k, v in value.items():
if k == 'useDigitalSignature': if k == "useDigitalSignature":
keyusages['digital_signature'] = v keyusages["digital_signature"] = v
elif k == 'useNonRepudiation': elif k == "useNonRepudiation":
keyusages['content_commitment'] = v keyusages["content_commitment"] = v
elif k == 'useKeyEncipherment': elif k == "useKeyEncipherment":
keyusages['key_encipherment'] = v keyusages["key_encipherment"] = v
elif k == 'useDataEncipherment': elif k == "useDataEncipherment":
keyusages['data_encipherment'] = v keyusages["data_encipherment"] = v
elif k == 'useKeyCertSign': elif k == "useKeyCertSign":
keyusages['key_cert_sign'] = v keyusages["key_cert_sign"] = v
elif k == 'useCRLSign': elif k == "useCRLSign":
keyusages['crl_sign'] = v keyusages["crl_sign"] = v
elif k == 'useKeyAgreement': elif k == "useKeyAgreement":
keyusages['key_agreement'] = v keyusages["key_agreement"] = v
elif k == 'useEncipherOnly' and v: elif k == "useEncipherOnly" and v:
keyusages['encipher_only'] = True keyusages["encipher_only"] = True
keyusages['key_agreement'] = True keyusages["key_agreement"] = True
elif k == 'useDecipherOnly' and v: elif k == "useDecipherOnly" and v:
keyusages['decipher_only'] = True keyusages["decipher_only"] = True
keyusages['key_agreement'] = True keyusages["key_agreement"] = True
if keyusages['encipher_only'] and keyusages['decipher_only']: if keyusages["encipher_only"] and keyusages["decipher_only"]:
raise ValidationError('A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages.') raise ValidationError(
"A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages."
)
return x509.KeyUsage( return x509.KeyUsage(
digital_signature=keyusages['digital_signature'], digital_signature=keyusages["digital_signature"],
content_commitment=keyusages['content_commitment'], content_commitment=keyusages["content_commitment"],
key_encipherment=keyusages['key_encipherment'], key_encipherment=keyusages["key_encipherment"],
data_encipherment=keyusages['data_encipherment'], data_encipherment=keyusages["data_encipherment"],
key_agreement=keyusages['key_agreement'], key_agreement=keyusages["key_agreement"],
key_cert_sign=keyusages['key_cert_sign'], key_cert_sign=keyusages["key_cert_sign"],
crl_sign=keyusages['crl_sign'], crl_sign=keyusages["crl_sign"],
encipher_only=keyusages['encipher_only'], encipher_only=keyusages["encipher_only"],
decipher_only=keyusages['decipher_only'] decipher_only=keyusages["decipher_only"],
) )
@ -216,69 +221,77 @@ class ExtendedKeyUsageExtension(Field):
usage_list = {} usage_list = {}
for usage in usages: for usage in usages:
if usage == x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH: if usage == x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH:
usage_list['useClientAuthentication'] = True usage_list["useClientAuthentication"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.SERVER_AUTH: elif usage == x509.oid.ExtendedKeyUsageOID.SERVER_AUTH:
usage_list['useServerAuthentication'] = True usage_list["useServerAuthentication"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.CODE_SIGNING: elif usage == x509.oid.ExtendedKeyUsageOID.CODE_SIGNING:
usage_list['useCodeSigning'] = True usage_list["useCodeSigning"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION: elif usage == x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION:
usage_list['useEmailProtection'] = True usage_list["useEmailProtection"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.TIME_STAMPING: elif usage == x509.oid.ExtendedKeyUsageOID.TIME_STAMPING:
usage_list['useTimestamping'] = True usage_list["useTimestamping"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING: elif usage == x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING:
usage_list['useOCSPSigning'] = True usage_list["useOCSPSigning"] = True
elif usage.dotted_string == '1.3.6.1.5.5.7.3.14': elif usage.dotted_string == "1.3.6.1.5.5.7.3.14":
usage_list['useEapOverLAN'] = True usage_list["useEapOverLAN"] = True
elif usage.dotted_string == '1.3.6.1.5.5.7.3.13': elif usage.dotted_string == "1.3.6.1.5.5.7.3.13":
usage_list['useEapOverPPP'] = True usage_list["useEapOverPPP"] = True
elif usage.dotted_string == '1.3.6.1.4.1.311.20.2.2': elif usage.dotted_string == "1.3.6.1.4.1.311.20.2.2":
usage_list['useSmartCardLogon'] = True usage_list["useSmartCardLogon"] = True
else: else:
current_app.logger.warning('Unable to serialize ExtendedKeyUsage with OID: {usage}'.format(usage=usage.dotted_string)) current_app.logger.warning(
"Unable to serialize ExtendedKeyUsage with OID: {usage}".format(
usage=usage.dotted_string
)
)
return usage_list return usage_list
def _deserialize(self, value, attr, data): def _deserialize(self, value, attr, data):
usage_oids = [] usage_oids = []
for k, v in value.items(): for k, v in value.items():
if k == 'useClientAuthentication' and v: if k == "useClientAuthentication" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH) usage_oids.append(x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH)
elif k == 'useServerAuthentication' and v: elif k == "useServerAuthentication" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.SERVER_AUTH) usage_oids.append(x509.oid.ExtendedKeyUsageOID.SERVER_AUTH)
elif k == 'useCodeSigning' and v: elif k == "useCodeSigning" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.CODE_SIGNING) usage_oids.append(x509.oid.ExtendedKeyUsageOID.CODE_SIGNING)
elif k == 'useEmailProtection' and v: elif k == "useEmailProtection" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION) usage_oids.append(x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION)
elif k == 'useTimestamping' and v: elif k == "useTimestamping" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.TIME_STAMPING) usage_oids.append(x509.oid.ExtendedKeyUsageOID.TIME_STAMPING)
elif k == 'useOCSPSigning' and v: elif k == "useOCSPSigning" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING) usage_oids.append(x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING)
elif k == 'useEapOverLAN' and v: elif k == "useEapOverLAN" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.14")) usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.14"))
elif k == 'useEapOverPPP' and v: elif k == "useEapOverPPP" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.13")) usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.13"))
elif k == 'useSmartCardLogon' and v: elif k == "useSmartCardLogon" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.4.1.311.20.2.2")) usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.4.1.311.20.2.2"))
else: else:
current_app.logger.warning('Unable to deserialize ExtendedKeyUsage with name: {key}'.format(key=k)) current_app.logger.warning(
"Unable to deserialize ExtendedKeyUsage with name: {key}".format(
key=k
)
)
return x509.ExtendedKeyUsage(usage_oids) return x509.ExtendedKeyUsage(usage_oids)
@ -294,15 +307,17 @@ class BasicConstraintsExtension(Field):
""" """
def _serialize(self, value, attr, obj): def _serialize(self, value, attr, obj):
return {'ca': value.ca, 'path_length': value.path_length} return {"ca": value.ca, "path_length": value.path_length}
def _deserialize(self, value, attr, data): def _deserialize(self, value, attr, data):
ca = value.get('ca', False) ca = value.get("ca", False)
path_length = value.get('path_length', None) path_length = value.get("path_length", None)
if ca: if ca:
if not isinstance(path_length, (type(None), int)): if not isinstance(path_length, (type(None), int)):
raise ValidationError('A CA certificate path_length (for BasicConstraints) must be None or an integer.') raise ValidationError(
"A CA certificate path_length (for BasicConstraints) must be None or an integer."
)
return x509.BasicConstraints(ca=True, path_length=path_length) return x509.BasicConstraints(ca=True, path_length=path_length)
else: else:
return x509.BasicConstraints(ca=False, path_length=None) return x509.BasicConstraints(ca=False, path_length=None)
@ -317,6 +332,7 @@ class SubjectAlternativeNameExtension(Field):
:param kwargs: The same keyword arguments that :class:`Field` receives. :param kwargs: The same keyword arguments that :class:`Field` receives.
""" """
def _serialize(self, value, attr, obj): def _serialize(self, value, attr, obj):
general_names = [] general_names = []
name_type = None name_type = None
@ -326,53 +342,59 @@ class SubjectAlternativeNameExtension(Field):
value = name.value value = name.value
if isinstance(name, x509.DNSName): if isinstance(name, x509.DNSName):
name_type = 'DNSName' name_type = "DNSName"
elif isinstance(name, x509.IPAddress): elif isinstance(name, x509.IPAddress):
if isinstance(value, ipaddress.IPv4Network): if isinstance(value, ipaddress.IPv4Network):
name_type = 'IPNetwork' name_type = "IPNetwork"
else: else:
name_type = 'IPAddress' name_type = "IPAddress"
value = str(value) value = str(value)
elif isinstance(name, x509.UniformResourceIdentifier): elif isinstance(name, x509.UniformResourceIdentifier):
name_type = 'uniformResourceIdentifier' name_type = "uniformResourceIdentifier"
elif isinstance(name, x509.DirectoryName): elif isinstance(name, x509.DirectoryName):
name_type = 'directoryName' name_type = "directoryName"
elif isinstance(name, x509.RFC822Name): elif isinstance(name, x509.RFC822Name):
name_type = 'rfc822Name' name_type = "rfc822Name"
elif isinstance(name, x509.RegisteredID): elif isinstance(name, x509.RegisteredID):
name_type = 'registeredID' name_type = "registeredID"
value = value.dotted_string value = value.dotted_string
else: else:
current_app.logger.warning('Unknown SubAltName type: {name}'.format(name=name)) current_app.logger.warning(
"Unknown SubAltName type: {name}".format(name=name)
)
continue continue
general_names.append({'nameType': name_type, 'value': value}) general_names.append({"nameType": name_type, "value": value})
return general_names return general_names
def _deserialize(self, value, attr, data): def _deserialize(self, value, attr, data):
general_names = [] general_names = []
for name in value: for name in value:
if name['nameType'] == 'DNSName': if name["nameType"] == "DNSName":
validators.sensitive_domain(name['value']) validators.sensitive_domain(name["value"])
general_names.append(x509.DNSName(name['value'])) general_names.append(x509.DNSName(name["value"]))
elif name['nameType'] == 'IPAddress': elif name["nameType"] == "IPAddress":
general_names.append(x509.IPAddress(ipaddress.ip_address(name['value']))) general_names.append(
x509.IPAddress(ipaddress.ip_address(name["value"]))
)
elif name['nameType'] == 'IPNetwork': elif name["nameType"] == "IPNetwork":
general_names.append(x509.IPAddress(ipaddress.ip_network(name['value']))) general_names.append(
x509.IPAddress(ipaddress.ip_network(name["value"]))
)
elif name['nameType'] == 'uniformResourceIdentifier': elif name["nameType"] == "uniformResourceIdentifier":
general_names.append(x509.UniformResourceIdentifier(name['value'])) general_names.append(x509.UniformResourceIdentifier(name["value"]))
elif name['nameType'] == 'directoryName': elif name["nameType"] == "directoryName":
# TODO: Need to parse a string in name['value'] like: # TODO: Need to parse a string in name['value'] like:
# 'CN=Common Name, O=Org Name, OU=OrgUnit Name, C=US, ST=ST, L=City/emailAddress=person@example.com' # 'CN=Common Name, O=Org Name, OU=OrgUnit Name, C=US, ST=ST, L=City/emailAddress=person@example.com'
# or # or
@ -390,26 +412,32 @@ class SubjectAlternativeNameExtension(Field):
# general_names.append(x509.DirectoryName(x509.Name(BLAH)))) # general_names.append(x509.DirectoryName(x509.Name(BLAH))))
pass pass
elif name['nameType'] == 'rfc822Name': elif name["nameType"] == "rfc822Name":
general_names.append(x509.RFC822Name(name['value'])) general_names.append(x509.RFC822Name(name["value"]))
elif name['nameType'] == 'registeredID': elif name["nameType"] == "registeredID":
general_names.append(x509.RegisteredID(x509.ObjectIdentifier(name['value']))) general_names.append(
x509.RegisteredID(x509.ObjectIdentifier(name["value"]))
)
elif name['nameType'] == 'otherName': elif name["nameType"] == "otherName":
# This has two inputs (type and value), so it doesn't fit the mold of the rest of these GeneralName entities. # This has two inputs (type and value), so it doesn't fit the mold of the rest of these GeneralName entities.
# general_names.append(x509.OtherName(name['type'], bytes(name['value']), 'utf-8')) # general_names.append(x509.OtherName(name['type'], bytes(name['value']), 'utf-8'))
pass pass
elif name['nameType'] == 'x400Address': elif name["nameType"] == "x400Address":
# The Python Cryptography library doesn't support x400Address types (yet?) # The Python Cryptography library doesn't support x400Address types (yet?)
pass pass
elif name['nameType'] == 'EDIPartyName': elif name["nameType"] == "EDIPartyName":
# The Python Cryptography library doesn't support EDIPartyName types (yet?) # The Python Cryptography library doesn't support EDIPartyName types (yet?)
pass pass
else: else:
current_app.logger.warning('Unable to deserialize SubAltName with type: {name_type}'.format(name_type=name['nameType'])) current_app.logger.warning(
"Unable to deserialize SubAltName with type: {name_type}".format(
name_type=name["nameType"]
)
)
return x509.SubjectAlternativeName(general_names) return x509.SubjectAlternativeName(general_names)

View File

@ -10,20 +10,20 @@ from flask import Blueprint
from lemur.database import db from lemur.database import db
from lemur.extensions import sentry from lemur.extensions import sentry
mod = Blueprint('healthCheck', __name__) mod = Blueprint("healthCheck", __name__)
@mod.route('/healthcheck') @mod.route("/healthcheck")
def health(): def health():
try: try:
if healthcheck(db): if healthcheck(db):
return 'ok' return "ok"
except Exception: except Exception:
sentry.captureException() sentry.captureException()
return 'db check failed' return "db check failed"
def healthcheck(db): def healthcheck(db):
with db.engine.connect() as connection: with db.engine.connect() as connection:
connection.execute('SELECT 1;') connection.execute("SELECT 1;")
return True return True

View File

@ -52,7 +52,7 @@ class InstanceManager(object):
results = [] results = []
for cls_path in class_list: for cls_path in class_list:
module_name, class_name = cls_path.rsplit('.', 1) module_name, class_name = cls_path.rsplit(".", 1)
try: try:
module = __import__(module_name, {}, {}, class_name) module = __import__(module_name, {}, {}, class_name)
cls = getattr(module, class_name) cls = getattr(module, class_name)
@ -62,10 +62,14 @@ class InstanceManager(object):
results.append(cls) results.append(cls)
except InvalidConfiguration as e: except InvalidConfiguration as e:
current_app.logger.warning("Plugin '{0}' may not work correctly. {1}".format(class_name, e)) current_app.logger.warning(
"Plugin '{0}' may not work correctly. {1}".format(class_name, e)
)
except Exception as e: except Exception as e:
current_app.logger.exception("Unable to import {0}. Reason: {1}".format(cls_path, e)) current_app.logger.exception(
"Unable to import {0}. Reason: {1}".format(cls_path, e)
)
continue continue
self.cache = results self.cache = results

View File

@ -11,15 +11,15 @@ def convert_validity_years(data):
:param data: :param data:
:return: :return:
""" """
if data.get('validity_years'): if data.get("validity_years"):
now = arrow.utcnow() now = arrow.utcnow()
data['validity_start'] = now.isoformat() data["validity_start"] = now.isoformat()
end = now.replace(years=+int(data['validity_years'])) end = now.replace(years=+int(data["validity_years"]))
if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True): if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True):
if is_weekend(end): if is_weekend(end):
end = end.replace(days=-2) end = end.replace(days=-2)
data['validity_end'] = end.isoformat() data["validity_end"] = end.isoformat()
return data return data

View File

@ -22,27 +22,26 @@ class LemurSchema(Schema):
""" """
Base schema from which all grouper schema's inherit Base schema from which all grouper schema's inherit
""" """
__envelope__ = True __envelope__ = True
def under(self, data, many=None): def under(self, data, many=None):
items = [] items = []
if many: if many:
for i in data: for i in data:
items.append( items.append({underscore(key): value for key, value in i.items()})
{underscore(key): value for key, value in i.items()}
)
return items return items
return { return {underscore(key): value for key, value in data.items()}
underscore(key): value
for key, value in data.items()
}
def camel(self, data, many=None): def camel(self, data, many=None):
items = [] items = []
if many: if many:
for i in data: for i in data:
items.append( items.append(
{camelize(key, uppercase_first_letter=False): value for key, value in i.items()} {
camelize(key, uppercase_first_letter=False): value
for key, value in i.items()
}
) )
return items return items
return { return {
@ -52,16 +51,16 @@ class LemurSchema(Schema):
def wrap_with_envelope(self, data, many): def wrap_with_envelope(self, data, many):
if many: if many:
if 'total' in self.context.keys(): if "total" in self.context.keys():
return dict(total=self.context['total'], items=data) return dict(total=self.context["total"], items=data)
return data return data
class LemurInputSchema(LemurSchema): class LemurInputSchema(LemurSchema):
@pre_load(pass_many=True) @pre_load(pass_many=True)
def preprocess(self, data, many): def preprocess(self, data, many):
if isinstance(data, dict) and data.get('owner'): if isinstance(data, dict) and data.get("owner"):
data['owner'] = data['owner'].lower() data["owner"] = data["owner"].lower()
return self.under(data, many=many) return self.under(data, many=many)
@ -74,17 +73,17 @@ class LemurOutputSchema(LemurSchema):
def unwrap_envelope(self, data, many): def unwrap_envelope(self, data, many):
if many: if many:
if data['items']: if data["items"]:
if isinstance(data, InstrumentedList) or isinstance(data, list): if isinstance(data, InstrumentedList) or isinstance(data, list):
self.context['total'] = len(data) self.context["total"] = len(data)
return data return data
else: else:
self.context['total'] = data['total'] self.context["total"] = data["total"]
else: else:
self.context['total'] = 0 self.context["total"] = 0
data = {'items': []} data = {"items": []}
return data['items'] return data["items"]
return data return data
@ -110,11 +109,11 @@ def format_errors(messages):
def wrap_errors(messages): def wrap_errors(messages):
errors = dict(message='Validation Error.') errors = dict(message="Validation Error.")
if messages.get('_schema'): if messages.get("_schema"):
errors['reasons'] = {'Schema': {'rule': messages['_schema']}} errors["reasons"] = {"Schema": {"rule": messages["_schema"]}}
else: else:
errors['reasons'] = format_errors(messages) errors["reasons"] = format_errors(messages)
return errors return errors
@ -123,19 +122,19 @@ def unwrap_pagination(data, output_schema):
return data return data
if isinstance(data, dict): if isinstance(data, dict):
if 'total' in data.keys(): if "total" in data.keys():
if data.get('total') == 0: if data.get("total") == 0:
return data return data
marshaled_data = {'total': data['total']} marshaled_data = {"total": data["total"]}
marshaled_data['items'] = output_schema.dump(data['items'], many=True).data marshaled_data["items"] = output_schema.dump(data["items"], many=True).data
return marshaled_data return marshaled_data
return output_schema.dump(data).data return output_schema.dump(data).data
elif isinstance(data, list): elif isinstance(data, list):
marshaled_data = {'total': len(data)} marshaled_data = {"total": len(data)}
marshaled_data['items'] = output_schema.dump(data, many=True).data marshaled_data["items"] = output_schema.dump(data, many=True).data
return marshaled_data return marshaled_data
return output_schema.dump(data).data return output_schema.dump(data).data
@ -155,7 +154,7 @@ def validate_schema(input_schema, output_schema):
if errors: if errors:
return wrap_errors(errors), 400 return wrap_errors(errors), 400
kwargs['data'] = data kwargs["data"] = data
try: try:
resp = f(*args, **kwargs) resp = f(*args, **kwargs)
@ -173,4 +172,5 @@ def validate_schema(input_schema, output_schema):
return unwrap_pagination(resp, output_schema), 200 return unwrap_pagination(resp, output_schema), 200
return decorated_function return decorated_function
return decorator return decorator

View File

@ -25,22 +25,22 @@ from lemur.exceptions import InvalidConfiguration
paginated_parser = RequestParser() paginated_parser = RequestParser()
paginated_parser.add_argument('count', type=int, default=10, location='args') paginated_parser.add_argument("count", type=int, default=10, location="args")
paginated_parser.add_argument('page', type=int, default=1, location='args') paginated_parser.add_argument("page", type=int, default=1, location="args")
paginated_parser.add_argument('sortDir', type=str, dest='sort_dir', location='args') paginated_parser.add_argument("sortDir", type=str, dest="sort_dir", location="args")
paginated_parser.add_argument('sortBy', type=str, dest='sort_by', location='args') paginated_parser.add_argument("sortBy", type=str, dest="sort_by", location="args")
paginated_parser.add_argument('filter', type=str, location='args') paginated_parser.add_argument("filter", type=str, location="args")
paginated_parser.add_argument('owner', type=str, location='args') paginated_parser.add_argument("owner", type=str, location="args")
def get_psuedo_random_string(): def get_psuedo_random_string():
""" """
Create a random and strongish challenge. Create a random and strongish challenge.
""" """
challenge = ''.join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa challenge = "".join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa
challenge += ''.join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa challenge += "".join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa
challenge += ''.join(random.choice(string.ascii_lowercase) for x in range(6)) challenge += "".join(random.choice(string.ascii_lowercase) for x in range(6))
challenge += ''.join(random.choice(string.digits) for x in range(6)) # noqa challenge += "".join(random.choice(string.digits) for x in range(6)) # noqa
return challenge return challenge
@ -53,7 +53,7 @@ def parse_certificate(body):
""" """
assert isinstance(body, str) assert isinstance(body, str)
return x509.load_pem_x509_certificate(body.encode('utf-8'), default_backend()) return x509.load_pem_x509_certificate(body.encode("utf-8"), default_backend())
def parse_private_key(private_key): def parse_private_key(private_key):
@ -66,7 +66,9 @@ def parse_private_key(private_key):
""" """
assert isinstance(private_key, str) assert isinstance(private_key, str)
return load_pem_private_key(private_key.encode('utf8'), password=None, backend=default_backend()) return load_pem_private_key(
private_key.encode("utf8"), password=None, backend=default_backend()
)
def split_pem(data): def split_pem(data):
@ -100,14 +102,15 @@ def parse_csr(csr):
""" """
assert isinstance(csr, str) assert isinstance(csr, str)
return x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend()) return x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend())
def get_authority_key(body): def get_authority_key(body):
"""Returns the authority key for a given certificate in hex format""" """Returns the authority key for a given certificate in hex format"""
parsed_cert = parse_certificate(body) parsed_cert = parse_certificate(body)
authority_key = parsed_cert.extensions.get_extension_for_class( authority_key = parsed_cert.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier).value.key_identifier x509.AuthorityKeyIdentifier
).value.key_identifier
return authority_key.hex() return authority_key.hex()
@ -127,20 +130,17 @@ def generate_private_key(key_type):
_CURVE_TYPES = { _CURVE_TYPES = {
"ECCPRIME192V1": ec.SECP192R1(), "ECCPRIME192V1": ec.SECP192R1(),
"ECCPRIME256V1": ec.SECP256R1(), "ECCPRIME256V1": ec.SECP256R1(),
"ECCSECP192R1": ec.SECP192R1(), "ECCSECP192R1": ec.SECP192R1(),
"ECCSECP224R1": ec.SECP224R1(), "ECCSECP224R1": ec.SECP224R1(),
"ECCSECP256R1": ec.SECP256R1(), "ECCSECP256R1": ec.SECP256R1(),
"ECCSECP384R1": ec.SECP384R1(), "ECCSECP384R1": ec.SECP384R1(),
"ECCSECP521R1": ec.SECP521R1(), "ECCSECP521R1": ec.SECP521R1(),
"ECCSECP256K1": ec.SECP256K1(), "ECCSECP256K1": ec.SECP256K1(),
"ECCSECT163K1": ec.SECT163K1(), "ECCSECT163K1": ec.SECT163K1(),
"ECCSECT233K1": ec.SECT233K1(), "ECCSECT233K1": ec.SECT233K1(),
"ECCSECT283K1": ec.SECT283K1(), "ECCSECT283K1": ec.SECT283K1(),
"ECCSECT409K1": ec.SECT409K1(), "ECCSECT409K1": ec.SECT409K1(),
"ECCSECT571K1": ec.SECT571K1(), "ECCSECT571K1": ec.SECT571K1(),
"ECCSECT163R2": ec.SECT163R2(), "ECCSECT163R2": ec.SECT163R2(),
"ECCSECT233R1": ec.SECT233R1(), "ECCSECT233R1": ec.SECT233R1(),
"ECCSECT283R1": ec.SECT283R1(), "ECCSECT283R1": ec.SECT283R1(),
@ -149,22 +149,20 @@ def generate_private_key(key_type):
} }
if key_type not in CERTIFICATE_KEY_TYPES: if key_type not in CERTIFICATE_KEY_TYPES:
raise Exception("Invalid key type: {key_type}. Supported key types: {choices}".format( raise Exception(
key_type=key_type, "Invalid key type: {key_type}. Supported key types: {choices}".format(
choices=",".join(CERTIFICATE_KEY_TYPES) key_type=key_type, choices=",".join(CERTIFICATE_KEY_TYPES)
)) )
)
if 'RSA' in key_type: if "RSA" in key_type:
key_size = int(key_type[3:]) key_size = int(key_type[3:])
return rsa.generate_private_key( return rsa.generate_private_key(
public_exponent=65537, public_exponent=65537, key_size=key_size, backend=default_backend()
key_size=key_size,
backend=default_backend()
) )
elif 'ECC' in key_type: elif "ECC" in key_type:
return ec.generate_private_key( return ec.generate_private_key(
curve=_CURVE_TYPES[key_type], curve=_CURVE_TYPES[key_type], backend=default_backend()
backend=default_backend()
) )
@ -184,11 +182,26 @@ def check_cert_signature(cert, issuer_public_key):
raise UnsupportedAlgorithm("RSASSA-PSS not supported") raise UnsupportedAlgorithm("RSASSA-PSS not supported")
else: else:
padder = padding.PKCS1v15() padder = padding.PKCS1v15()
issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, padder, cert.signature_hash_algorithm) issuer_public_key.verify(
elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance(ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA): cert.signature,
issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, ec.ECDSA(cert.signature_hash_algorithm)) cert.tbs_certificate_bytes,
padder,
cert.signature_hash_algorithm,
)
elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance(
ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA
):
issuer_public_key.verify(
cert.signature,
cert.tbs_certificate_bytes,
ec.ECDSA(cert.signature_hash_algorithm),
)
else: else:
raise UnsupportedAlgorithm("Unsupported Algorithm '{var}'.".format(var=cert.signature_algorithm_oid._name)) raise UnsupportedAlgorithm(
"Unsupported Algorithm '{var}'.".format(
var=cert.signature_algorithm_oid._name
)
)
def is_selfsigned(cert): def is_selfsigned(cert):
@ -224,7 +237,9 @@ def validate_conf(app, required_vars):
""" """
for var in required_vars: for var in required_vars:
if var not in app.config: if var not in app.config:
raise InvalidConfiguration("Required variable '{var}' is not set in Lemur's conf.".format(var=var)) raise InvalidConfiguration(
"Required variable '{var}' is not set in Lemur's conf.".format(var=var)
)
# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/WindowedRangeQuery # https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/WindowedRangeQuery
@ -243,18 +258,15 @@ def column_windows(session, column, windowsize):
be computed. be computed.
""" """
def int_for_range(start_id, end_id): def int_for_range(start_id, end_id):
if end_id: if end_id:
return and_( return and_(column >= start_id, column < end_id)
column >= start_id,
column < end_id
)
else: else:
return column >= start_id return column >= start_id
q = session.query( q = session.query(
column, column, func.row_number().over(order_by=column).label("rownum")
func.row_number().over(order_by=column).label('rownum')
).from_self(column) ).from_self(column)
if windowsize > 1: if windowsize > 1:
@ -274,9 +286,7 @@ def column_windows(session, column, windowsize):
def windowed_query(q, column, windowsize): def windowed_query(q, column, windowsize):
""""Break a Query into windows on a given column.""" """"Break a Query into windows on a given column."""
for whereclause in column_windows( for whereclause in column_windows(q.session, column, windowsize):
q.session,
column, windowsize):
for row in q.filter(whereclause).order_by(column): for row in q.filter(whereclause).order_by(column):
yield row yield row
@ -284,7 +294,7 @@ def windowed_query(q, column, windowsize):
def truthiness(s): def truthiness(s):
"""If input string resembles something truthy then return True, else False.""" """If input string resembles something truthy then return True, else False."""
return s.lower() in ('true', 'yes', 'on', 't', '1') return s.lower() in ("true", "yes", "on", "t", "1")
def find_matching_certificates_by_hash(cert, matching_certs): def find_matching_certificates_by_hash(cert, matching_certs):
@ -292,6 +302,8 @@ def find_matching_certificates_by_hash(cert, matching_certs):
determine if any of the certificate hashes match and return the matches.""" determine if any of the certificate hashes match and return the matches."""
matching = [] matching = []
for c in matching_certs: for c in matching_certs:
if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint(hashes.SHA256()): if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint(
hashes.SHA256()
):
matching.append(c) matching.append(c)
return matching return matching

View File

@ -16,7 +16,7 @@ def common_name(value):
# Common name could be a domain name, or a human-readable name of the subject (often used in CA names or client # Common name could be a domain name, or a human-readable name of the subject (often used in CA names or client
# certificates). As a simple heuristic, we assume that human-readable names always include a space. # certificates). As a simple heuristic, we assume that human-readable names always include a space.
# However, to avoid confusion for humans, we also don't count spaces at the beginning or end of the string. # However, to avoid confusion for humans, we also don't count spaces at the beginning or end of the string.
if ' ' not in value.strip(): if " " not in value.strip():
return sensitive_domain(value) return sensitive_domain(value)
@ -30,17 +30,21 @@ def sensitive_domain(domain):
# User has permission, no need to check anything # User has permission, no need to check anything
return return
whitelist = current_app.config.get('LEMUR_WHITELISTED_DOMAINS', []) whitelist = current_app.config.get("LEMUR_WHITELISTED_DOMAINS", [])
if whitelist and not any(re.match(pattern, domain) for pattern in whitelist): if whitelist and not any(re.match(pattern, domain) for pattern in whitelist):
raise ValidationError('Domain {0} does not match whitelisted domain patterns. ' raise ValidationError(
'Contact an administrator to issue the certificate.'.format(domain)) "Domain {0} does not match whitelisted domain patterns. "
"Contact an administrator to issue the certificate.".format(domain)
)
# Avoid circular import. # Avoid circular import.
from lemur.domains import service as domain_service from lemur.domains import service as domain_service
if any(d.sensitive for d in domain_service.get_by_name(domain)): if any(d.sensitive for d in domain_service.get_by_name(domain)):
raise ValidationError('Domain {0} has been marked as sensitive. ' raise ValidationError(
'Contact an administrator to issue the certificate.'.format(domain)) "Domain {0} has been marked as sensitive. "
"Contact an administrator to issue the certificate.".format(domain)
)
def encoding(oid_encoding): def encoding(oid_encoding):
@ -49,9 +53,13 @@ def encoding(oid_encoding):
:param oid_encoding: :param oid_encoding:
:return: :return:
""" """
valid_types = ['b64asn1', 'string', 'ia5string'] valid_types = ["b64asn1", "string", "ia5string"]
if oid_encoding.lower() not in [o_type.lower() for o_type in valid_types]: if oid_encoding.lower() not in [o_type.lower() for o_type in valid_types]:
raise ValidationError('Invalid Oid Encoding: {0} choose from {1}'.format(oid_encoding, ",".join(valid_types))) raise ValidationError(
"Invalid Oid Encoding: {0} choose from {1}".format(
oid_encoding, ",".join(valid_types)
)
)
def sub_alt_type(alt_type): def sub_alt_type(alt_type):
@ -60,10 +68,23 @@ def sub_alt_type(alt_type):
:param alt_type: :param alt_type:
:return: :return:
""" """
valid_types = ['DNSName', 'IPAddress', 'uniFormResourceIdentifier', 'directoryName', 'rfc822Name', 'registrationID', valid_types = [
'otherName', 'x400Address', 'EDIPartyName'] "DNSName",
"IPAddress",
"uniFormResourceIdentifier",
"directoryName",
"rfc822Name",
"registrationID",
"otherName",
"x400Address",
"EDIPartyName",
]
if alt_type.lower() not in [a_type.lower() for a_type in valid_types]: if alt_type.lower() not in [a_type.lower() for a_type in valid_types]:
raise ValidationError('Invalid SubAltName Type: {0} choose from {1}'.format(type, ",".join(valid_types))) raise ValidationError(
"Invalid SubAltName Type: {0} choose from {1}".format(
type, ",".join(valid_types)
)
)
def csr(data): def csr(data):
@ -73,16 +94,18 @@ def csr(data):
:return: :return:
""" """
try: try:
request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend()) request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend())
except Exception: except Exception:
raise ValidationError('CSR presented is not valid.') raise ValidationError("CSR presented is not valid.")
# Validate common name and SubjectAltNames # Validate common name and SubjectAltNames
for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME): for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME):
common_name(name.value) common_name(name.value)
try: try:
alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName) alt_names = request.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
for name in alt_names.value.get_values_for_type(x509.DNSName): for name in alt_names.value.get_values_for_type(x509.DNSName):
sensitive_domain(name) sensitive_domain(name)
@ -91,26 +114,40 @@ def csr(data):
def dates(data): def dates(data):
if not data.get('validity_start') and data.get('validity_end'): if not data.get("validity_start") and data.get("validity_end"):
raise ValidationError('If validity start is specified so must validity end.') raise ValidationError("If validity start is specified so must validity end.")
if not data.get('validity_end') and data.get('validity_start'): if not data.get("validity_end") and data.get("validity_start"):
raise ValidationError('If validity end is specified so must validity start.') raise ValidationError("If validity end is specified so must validity start.")
if data.get('validity_start') and data.get('validity_end'): if data.get("validity_start") and data.get("validity_end"):
if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True): if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True):
if is_weekend(data.get('validity_end')): if is_weekend(data.get("validity_end")):
raise ValidationError('Validity end must not land on a weekend.') raise ValidationError("Validity end must not land on a weekend.")
if not data['validity_start'] < data['validity_end']: if not data["validity_start"] < data["validity_end"]:
raise ValidationError('Validity start must be before validity end.') raise ValidationError("Validity start must be before validity end.")
if data.get('authority'): if data.get("authority"):
if data.get('validity_start').date() < data['authority'].authority_certificate.not_before.date(): if (
raise ValidationError('Validity start must not be before {0}'.format(data['authority'].authority_certificate.not_before)) data.get("validity_start").date()
< data["authority"].authority_certificate.not_before.date()
):
raise ValidationError(
"Validity start must not be before {0}".format(
data["authority"].authority_certificate.not_before
)
)
if data.get('validity_end').date() > data['authority'].authority_certificate.not_after.date(): if (
raise ValidationError('Validity end must not be after {0}'.format(data['authority'].authority_certificate.not_after)) data.get("validity_end").date()
> data["authority"].authority_certificate.not_after.date()
):
raise ValidationError(
"Validity end must not be after {0}".format(
data["authority"].authority_certificate.not_after
)
)
return data return data
@ -148,8 +185,13 @@ def verify_cert_chain(certs, error_class=ValidationError):
# Avoid circular import. # Avoid circular import.
from lemur.common import defaults from lemur.common import defaults
raise error_class("Incorrect chain certificate(s) provided: '%s' is not signed by '%s'" raise error_class(
% (defaults.common_name(cert) or 'Unknown', defaults.common_name(issuer))) "Incorrect chain certificate(s) provided: '%s' is not signed by '%s'"
% (
defaults.common_name(cert) or "Unknown",
defaults.common_name(issuer),
)
)
except UnsupportedAlgorithm as err: except UnsupportedAlgorithm as err:
current_app.logger.warning("Skipping chain validation: %s", err) current_app.logger.warning("Skipping chain validation: %s", err)

View File

@ -7,28 +7,28 @@ SAN_NAMING_TEMPLATE = "SAN-{subject}-{issuer}-{not_before}-{not_after}"
DEFAULT_NAMING_TEMPLATE = "{subject}-{issuer}-{not_before}-{not_after}" DEFAULT_NAMING_TEMPLATE = "{subject}-{issuer}-{not_before}-{not_after}"
NONSTANDARD_NAMING_TEMPLATE = "{issuer}-{not_before}-{not_after}" NONSTANDARD_NAMING_TEMPLATE = "{issuer}-{not_before}-{not_after}"
SUCCESS_METRIC_STATUS = 'success' SUCCESS_METRIC_STATUS = "success"
FAILURE_METRIC_STATUS = 'failure' FAILURE_METRIC_STATUS = "failure"
CERTIFICATE_KEY_TYPES = [ CERTIFICATE_KEY_TYPES = [
'RSA2048', "RSA2048",
'RSA4096', "RSA4096",
'ECCPRIME192V1', "ECCPRIME192V1",
'ECCPRIME256V1', "ECCPRIME256V1",
'ECCSECP192R1', "ECCSECP192R1",
'ECCSECP224R1', "ECCSECP224R1",
'ECCSECP256R1', "ECCSECP256R1",
'ECCSECP384R1', "ECCSECP384R1",
'ECCSECP521R1', "ECCSECP521R1",
'ECCSECP256K1', "ECCSECP256K1",
'ECCSECT163K1', "ECCSECT163K1",
'ECCSECT233K1', "ECCSECT233K1",
'ECCSECT283K1', "ECCSECT283K1",
'ECCSECT409K1', "ECCSECT409K1",
'ECCSECT571K1', "ECCSECT571K1",
'ECCSECT163R2', "ECCSECT163R2",
'ECCSECT233R1', "ECCSECT233R1",
'ECCSECT283R1', "ECCSECT283R1",
'ECCSECT409R1', "ECCSECT409R1",
'ECCSECT571R2' "ECCSECT571R2",
] ]

View File

@ -43,7 +43,7 @@ def session_query(model):
:param model: sqlalchemy model :param model: sqlalchemy model
:return: query object for model :return: query object for model
""" """
return model.query if hasattr(model, 'query') else db.session.query(model) return model.query if hasattr(model, "query") else db.session.query(model)
def create_query(model, kwargs): def create_query(model, kwargs):
@ -77,7 +77,7 @@ def add(model):
def get_model_column(model, field): def get_model_column(model, field):
if field in getattr(model, 'sensitive_fields', ()): if field in getattr(model, "sensitive_fields", ()):
raise AttrNotFound(field) raise AttrNotFound(field)
column = model.__table__.columns._data.get(field, None) column = model.__table__.columns._data.get(field, None)
if column is None: if column is None:
@ -100,7 +100,7 @@ def find_all(query, model, kwargs):
kwargs = filter_none(kwargs) kwargs = filter_none(kwargs)
for attr, value in kwargs.items(): for attr, value in kwargs.items():
if not isinstance(value, list): if not isinstance(value, list):
value = value.split(',') value = value.split(",")
conditions.append(get_model_column(model, attr).in_(value)) conditions.append(get_model_column(model, attr).in_(value))
@ -200,7 +200,7 @@ def filter(query, model, terms):
:return: :return:
""" """
column = get_model_column(model, underscore(terms[0])) column = get_model_column(model, underscore(terms[0]))
return query.filter(column.ilike('%{}%'.format(terms[1]))) return query.filter(column.ilike("%{}%".format(terms[1])))
def sort(query, model, field, direction): def sort(query, model, field, direction):
@ -214,7 +214,7 @@ def sort(query, model, field, direction):
:param direction: :param direction:
""" """
column = get_model_column(model, underscore(field)) column = get_model_column(model, underscore(field))
return query.order_by(column.desc() if direction == 'desc' else column.asc()) return query.order_by(column.desc() if direction == "desc" else column.asc())
def paginate(query, page, count): def paginate(query, page, count):
@ -247,10 +247,10 @@ def update_list(model, model_attr, item_model, items):
for i in items: for i in items:
for item in getattr(model, model_attr): for item in getattr(model, model_attr):
if item.id == i['id']: if item.id == i["id"]:
break break
else: else:
getattr(model, model_attr).append(get(item_model, i['id'])) getattr(model, model_attr).append(get(item_model, i["id"]))
return model return model
@ -276,9 +276,9 @@ def get_count(q):
disable_group_by = False disable_group_by = False
if len(q._entities) > 1: if len(q._entities) > 1:
# currently support only one entity # currently support only one entity
raise Exception('only one entity is supported for get_count, got: %s' % q) raise Exception("only one entity is supported for get_count, got: %s" % q)
entity = q._entities[0] entity = q._entities[0]
if hasattr(entity, 'column'): if hasattr(entity, "column"):
# _ColumnEntity has column attr - on case: query(Model.column)... # _ColumnEntity has column attr - on case: query(Model.column)...
col = entity.column col = entity.column
if q._group_by and q._distinct: if q._group_by and q._distinct:
@ -295,7 +295,11 @@ def get_count(q):
count_func = func.count() count_func = func.count()
if q._group_by and not disable_group_by: if q._group_by and not disable_group_by:
count_func = count_func.over(None) count_func = count_func.over(None)
count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None) count_q = (
q.options(lazyload("*"))
.statement.with_only_columns([count_func])
.order_by(None)
)
if disable_group_by: if disable_group_by:
count_q = count_q.group_by(None) count_q = count_q.group_by(None)
count = q.session.execute(count_q).scalar() count = q.session.execute(count_q).scalar()
@ -311,13 +315,13 @@ def sort_and_page(query, model, args):
:param args: :param args:
:return: :return:
""" """
sort_by = args.pop('sort_by') sort_by = args.pop("sort_by")
sort_dir = args.pop('sort_dir') sort_dir = args.pop("sort_dir")
page = args.pop('page') page = args.pop("page")
count = args.pop('count') count = args.pop("count")
if args.get('user'): if args.get("user"):
user = args.pop('user') user = args.pop("user")
query = find_all(query, model, args) query = find_all(query, model, args)

View File

@ -1,6 +1,7 @@
# This is just Python which means you can inherit and tweak settings # This is just Python which means you can inherit and tweak settings
import os import os
_basedir = os.path.abspath(os.path.dirname(__file__)) _basedir = os.path.abspath(os.path.dirname(__file__))
THREADS_PER_PAGE = 8 THREADS_PER_PAGE = 8

View File

@ -13,12 +13,13 @@ from lemur.auth.service import AuthenticatedResource
from lemur.defaults.schemas import default_output_schema from lemur.defaults.schemas import default_output_schema
mod = Blueprint('default', __name__) mod = Blueprint("default", __name__)
api = Api(mod) api = Api(mod)
class LemurDefaults(AuthenticatedResource): class LemurDefaults(AuthenticatedResource):
""" Defines the 'defaults' endpoint """ """ Defines the 'defaults' endpoint """
def __init__(self): def __init__(self):
super(LemurDefaults) super(LemurDefaults)
@ -59,17 +60,21 @@ class LemurDefaults(AuthenticatedResource):
:statuscode 403: unauthenticated :statuscode 403: unauthenticated
""" """
default_authority = get_by_name(current_app.config.get('LEMUR_DEFAULT_AUTHORITY')) default_authority = get_by_name(
current_app.config.get("LEMUR_DEFAULT_AUTHORITY")
)
return dict( return dict(
country=current_app.config.get('LEMUR_DEFAULT_COUNTRY'), country=current_app.config.get("LEMUR_DEFAULT_COUNTRY"),
state=current_app.config.get('LEMUR_DEFAULT_STATE'), state=current_app.config.get("LEMUR_DEFAULT_STATE"),
location=current_app.config.get('LEMUR_DEFAULT_LOCATION'), location=current_app.config.get("LEMUR_DEFAULT_LOCATION"),
organization=current_app.config.get('LEMUR_DEFAULT_ORGANIZATION'), organization=current_app.config.get("LEMUR_DEFAULT_ORGANIZATION"),
organizational_unit=current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT'), organizational_unit=current_app.config.get(
issuer_plugin=current_app.config.get('LEMUR_DEFAULT_ISSUER_PLUGIN'), "LEMUR_DEFAULT_ORGANIZATIONAL_UNIT"
),
issuer_plugin=current_app.config.get("LEMUR_DEFAULT_ISSUER_PLUGIN"),
authority=default_authority, authority=default_authority,
) )
api.add_resource(LemurDefaults, '/defaults', endpoint='default') api.add_resource(LemurDefaults, "/defaults", endpoint="default")

View File

@ -13,7 +13,7 @@ from lemur.plugins.base import plugins
class Destination(db.Model): class Destination(db.Model):
__tablename__ = 'destinations' __tablename__ = "destinations"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
label = Column(String(32)) label = Column(String(32))
options = Column(JSONType) options = Column(JSONType)

View File

@ -30,7 +30,7 @@ class DestinationOutputSchema(LemurOutputSchema):
@post_dump @post_dump
def fill_object(self, data): def fill_object(self, data):
if data: if data:
data['plugin']['pluginOptions'] = data['options'] data["plugin"]["pluginOptions"] = data["options"]
return data return data

View File

@ -26,10 +26,12 @@ def create(label, plugin_name, options, description=None):
""" """
# remove any sub-plugin objects before try to save the json options # remove any sub-plugin objects before try to save the json options
for option in options: for option in options:
if 'plugin' in option['type']: if "plugin" in option["type"]:
del option['value']['plugin_object'] del option["value"]["plugin_object"]
destination = Destination(label=label, options=options, plugin_name=plugin_name, description=description) destination = Destination(
label=label, options=options, plugin_name=plugin_name, description=description
)
current_app.logger.info("Destination: %s created", label) current_app.logger.info("Destination: %s created", label)
# add the destination as source, to avoid new destinations that are not in source, as long as an AWS destination # add the destination as source, to avoid new destinations that are not in source, as long as an AWS destination
@ -85,7 +87,7 @@ def get_by_label(label):
:param label: :param label:
:return: :return:
""" """
return database.get(Destination, label, field='label') return database.get(Destination, label, field="label")
def get_all(): def get_all():
@ -99,17 +101,19 @@ def get_all():
def render(args): def render(args):
filt = args.pop('filter') filt = args.pop("filter")
certificate_id = args.pop('certificate_id', None) certificate_id = args.pop("certificate_id", None)
if certificate_id: if certificate_id:
query = database.session_query(Destination).join(Certificate, Destination.certificate) query = database.session_query(Destination).join(
Certificate, Destination.certificate
)
query = query.filter(Certificate.id == certificate_id) query = query.filter(Certificate.id == certificate_id)
else: else:
query = database.session_query(Destination) query = database.session_query(Destination)
if filt: if filt:
terms = filt.split(';') terms = filt.split(";")
query = database.filter(query, Destination, terms) query = database.filter(query, Destination, terms)
return database.sort_and_page(query, Destination, args) return database.sort_and_page(query, Destination, args)
@ -122,9 +126,15 @@ def stats(**kwargs):
:param kwargs: :param kwargs:
:return: :return:
""" """
items = database.db.session.query(Destination.label, func.count(certificate_destination_associations.c.certificate_id))\ items = (
.join(certificate_destination_associations)\ database.db.session.query(
.group_by(Destination.label).all() Destination.label,
func.count(certificate_destination_associations.c.certificate_id),
)
.join(certificate_destination_associations)
.group_by(Destination.label)
.all()
)
keys = [] keys = []
values = [] values = []
@ -132,4 +142,4 @@ def stats(**kwargs):
keys.append(key) keys.append(key)
values.append(count) values.append(count)
return {'labels': keys, 'values': values} return {"labels": keys, "values": values}

View File

@ -15,15 +15,20 @@ from lemur.auth.permissions import admin_permission
from lemur.common.utils import paginated_parser from lemur.common.utils import paginated_parser
from lemur.common.schema import validate_schema from lemur.common.schema import validate_schema
from lemur.destinations.schemas import destinations_output_schema, destination_input_schema, destination_output_schema from lemur.destinations.schemas import (
destinations_output_schema,
destination_input_schema,
destination_output_schema,
)
mod = Blueprint('destinations', __name__) mod = Blueprint("destinations", __name__)
api = Api(mod) api = Api(mod)
class DestinationsList(AuthenticatedResource): class DestinationsList(AuthenticatedResource):
""" Defines the 'destinations' endpoint """ """ Defines the 'destinations' endpoint """
def __init__(self): def __init__(self):
self.reqparse = reqparse.RequestParser() self.reqparse = reqparse.RequestParser()
super(DestinationsList, self).__init__() super(DestinationsList, self).__init__()
@ -176,7 +181,12 @@ class DestinationsList(AuthenticatedResource):
:reqheader Authorization: OAuth token to authenticate :reqheader Authorization: OAuth token to authenticate
:statuscode 200: no error :statuscode 200: no error
""" """
return service.create(data['label'], data['plugin']['slug'], data['plugin']['plugin_options'], data['description']) return service.create(
data["label"],
data["plugin"]["slug"],
data["plugin"]["plugin_options"],
data["description"],
)
class Destinations(AuthenticatedResource): class Destinations(AuthenticatedResource):
@ -325,16 +335,22 @@ class Destinations(AuthenticatedResource):
:reqheader Authorization: OAuth token to authenticate :reqheader Authorization: OAuth token to authenticate
:statuscode 200: no error :statuscode 200: no error
""" """
return service.update(destination_id, data['label'], data['plugin']['plugin_options'], data['description']) return service.update(
destination_id,
data["label"],
data["plugin"]["plugin_options"],
data["description"],
)
@admin_permission.require(http_exception=403) @admin_permission.require(http_exception=403)
def delete(self, destination_id): def delete(self, destination_id):
service.delete(destination_id) service.delete(destination_id)
return {'result': True} return {"result": True}
class CertificateDestinations(AuthenticatedResource): class CertificateDestinations(AuthenticatedResource):
""" Defines the 'certificate/<int:certificate_id/destinations'' endpoint """ """ Defines the 'certificate/<int:certificate_id/destinations'' endpoint """
def __init__(self): def __init__(self):
super(CertificateDestinations, self).__init__() super(CertificateDestinations, self).__init__()
@ -401,25 +417,31 @@ class CertificateDestinations(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
args = parser.parse_args() args = parser.parse_args()
args['certificate_id'] = certificate_id args["certificate_id"] = certificate_id
return service.render(args) return service.render(args)
class DestinationsStats(AuthenticatedResource): class DestinationsStats(AuthenticatedResource):
""" Defines the 'certificates' stats endpoint """ """ Defines the 'certificates' stats endpoint """
def __init__(self): def __init__(self):
self.reqparse = reqparse.RequestParser() self.reqparse = reqparse.RequestParser()
super(DestinationsStats, self).__init__() super(DestinationsStats, self).__init__()
def get(self): def get(self):
self.reqparse.add_argument('metric', type=str, location='args') self.reqparse.add_argument("metric", type=str, location="args")
args = self.reqparse.parse_args() args = self.reqparse.parse_args()
items = service.stats(**args) items = service.stats(**args)
return dict(items=items, total=len(items)) return dict(items=items, total=len(items))
api.add_resource(DestinationsList, '/destinations', endpoint='destinations') api.add_resource(DestinationsList, "/destinations", endpoint="destinations")
api.add_resource(Destinations, '/destinations/<int:destination_id>', endpoint='destination') api.add_resource(
api.add_resource(CertificateDestinations, '/certificates/<int:certificate_id>/destinations', Destinations, "/destinations/<int:destination_id>", endpoint="destination"
endpoint='certificateDestinations') )
api.add_resource(DestinationsStats, '/destinations/stats', endpoint='destinationStats') api.add_resource(
CertificateDestinations,
"/certificates/<int:certificate_id>/destinations",
endpoint="certificateDestinations",
)
api.add_resource(DestinationsStats, "/destinations/stats", endpoint="destinationStats")

View File

@ -5,7 +5,9 @@ from lemur.dns_providers.service import get_all_dns_providers, set_domains
from lemur.extensions import metrics from lemur.extensions import metrics
from lemur.plugins.base import plugins from lemur.plugins.base import plugins
manager = Manager(usage="Iterates through all DNS providers and sets DNS zones in the database.") manager = Manager(
usage="Iterates through all DNS providers and sets DNS zones in the database."
)
@manager.command @manager.command
@ -27,5 +29,5 @@ def get_all_zones():
status = SUCCESS_METRIC_STATUS status = SUCCESS_METRIC_STATUS
metrics.send('get_all_zones', 'counter', 1, metric_tags={'status': status}) metrics.send("get_all_zones", "counter", 1, metric_tags={"status": status})
print("[+] Done with dns provider zone lookup and configuration.") print("[+] Done with dns provider zone lookup and configuration.")

View File

@ -9,22 +9,23 @@ from lemur.utils import Vault
class DnsProvider(db.Model): class DnsProvider(db.Model):
__tablename__ = 'dns_providers' __tablename__ = "dns_providers"
id = Column( id = Column(Integer(), primary_key=True)
Integer(),
primary_key=True,
)
name = Column(String(length=256), unique=True, nullable=True) name = Column(String(length=256), unique=True, nullable=True)
description = Column(Text(), nullable=True) description = Column(Text(), nullable=True)
provider_type = Column(String(length=256), nullable=True) provider_type = Column(String(length=256), nullable=True)
credentials = Column(Vault, nullable=True) credentials = Column(Vault, nullable=True)
api_endpoint = Column(String(length=256), nullable=True) api_endpoint = Column(String(length=256), nullable=True)
date_created = Column(ArrowType(), server_default=text('now()'), nullable=False) date_created = Column(ArrowType(), server_default=text("now()"), nullable=False)
status = Column(String(length=128), nullable=True) status = Column(String(length=128), nullable=True)
options = Column(JSON, nullable=True) options = Column(JSON, nullable=True)
domains = Column(JSON, nullable=True) domains = Column(JSON, nullable=True)
certificates = relationship("Certificate", backref='dns_provider', foreign_keys='Certificate.dns_provider_id', certificates = relationship(
lazy='dynamic') "Certificate",
backref="dns_provider",
foreign_keys="Certificate.dns_provider_id",
lazy="dynamic",
)
def __init__(self, name, description, provider_type, credentials): def __init__(self, name, description, provider_type, credentials):
self.name = name self.name = name

View File

@ -49,7 +49,9 @@ def get_friendly(dns_provider_id):
} }
if dns_provider.provider_type == "route53": if dns_provider.provider_type == "route53":
dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get("account_id") dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get(
"account_id"
)
return dns_provider_friendly return dns_provider_friendly
@ -64,40 +66,40 @@ def delete(dns_provider_id):
def get_types(): def get_types():
provider_config = current_app.config.get( provider_config = current_app.config.get(
'ACME_DNS_PROVIDER_TYPES', "ACME_DNS_PROVIDER_TYPES",
{"items": [
{ {
'name': 'route53', "items": [
'requirements': [
{ {
'name': 'account_id', "name": "route53",
'type': 'int', "requirements": [
'required': True, {
'helpMessage': 'AWS Account number' "name": "account_id",
"type": "int",
"required": True,
"helpMessage": "AWS Account number",
}
],
}, },
{
"name": "cloudflare",
"requirements": [
{
"name": "email",
"type": "str",
"required": True,
"helpMessage": "Cloudflare Email",
},
{
"name": "key",
"type": "str",
"required": True,
"helpMessage": "Cloudflare Key",
},
],
},
{"name": "dyn"},
] ]
}, },
{
'name': 'cloudflare',
'requirements': [
{
'name': 'email',
'type': 'str',
'required': True,
'helpMessage': 'Cloudflare Email'
},
{
'name': 'key',
'type': 'str',
'required': True,
'helpMessage': 'Cloudflare Key'
},
]
},
{
'name': 'dyn',
},
]}
) )
if not provider_config: if not provider_config:
raise Exception("No DNS Provider configuration specified.") raise Exception("No DNS Provider configuration specified.")

View File

@ -13,9 +13,12 @@ from lemur.auth.service import AuthenticatedResource
from lemur.common.schema import validate_schema from lemur.common.schema import validate_schema
from lemur.common.utils import paginated_parser from lemur.common.utils import paginated_parser
from lemur.dns_providers import service from lemur.dns_providers import service
from lemur.dns_providers.schemas import dns_provider_output_schema, dns_provider_input_schema from lemur.dns_providers.schemas import (
dns_provider_output_schema,
dns_provider_input_schema,
)
mod = Blueprint('dns_providers', __name__) mod = Blueprint("dns_providers", __name__)
api = Api(mod) api = Api(mod)
@ -71,12 +74,12 @@ class DnsProvidersList(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
parser.add_argument('dns_provider_id', type=int, location='args') parser.add_argument("dns_provider_id", type=int, location="args")
parser.add_argument('name', type=str, location='args') parser.add_argument("name", type=str, location="args")
parser.add_argument('type', type=str, location='args') parser.add_argument("type", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
args['user'] = g.user args["user"] = g.user
return service.render(args) return service.render(args)
@validate_schema(dns_provider_input_schema, None) @validate_schema(dns_provider_input_schema, None)
@ -152,7 +155,7 @@ class DnsProviders(AuthenticatedResource):
@admin_permission.require(http_exception=403) @admin_permission.require(http_exception=403)
def delete(self, dns_provider_id): def delete(self, dns_provider_id):
service.delete(dns_provider_id) service.delete(dns_provider_id)
return {'result': True} return {"result": True}
class DnsProviderOptions(AuthenticatedResource): class DnsProviderOptions(AuthenticatedResource):
@ -166,6 +169,10 @@ class DnsProviderOptions(AuthenticatedResource):
return service.get_types() return service.get_types()
api.add_resource(DnsProvidersList, '/dns_providers', endpoint='dns_providers') api.add_resource(DnsProvidersList, "/dns_providers", endpoint="dns_providers")
api.add_resource(DnsProviders, '/dns_providers/<int:dns_provider_id>', endpoint='dns_provider') api.add_resource(
api.add_resource(DnsProviderOptions, '/dns_provider_options', endpoint='dns_provider_options') DnsProviders, "/dns_providers/<int:dns_provider_id>", endpoint="dns_provider"
)
api.add_resource(
DnsProviderOptions, "/dns_provider_options", endpoint="dns_provider_options"
)

View File

@ -13,11 +13,14 @@ from lemur.database import db
class Domain(db.Model): class Domain(db.Model):
__tablename__ = 'domains' __tablename__ = "domains"
__table_args__ = ( __table_args__ = (
Index('ix_domains_name_gin', "name", Index(
"ix_domains_name_gin",
"name",
postgresql_ops={"name": "gin_trgm_ops"}, postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using='gin'), postgresql_using="gin",
),
) )
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String(256), index=True) name = Column(String(256), index=True)

View File

@ -77,11 +77,11 @@ def render(args):
:return: :return:
""" """
query = database.session_query(Domain) query = database.session_query(Domain)
filt = args.pop('filter') filt = args.pop("filter")
certificate_id = args.pop('certificate_id', None) certificate_id = args.pop("certificate_id", None)
if filt: if filt:
terms = filt.split(';') terms = filt.split(";")
query = database.filter(query, Domain, terms) query = database.filter(query, Domain, terms)
if certificate_id: if certificate_id:

View File

@ -17,14 +17,19 @@ from lemur.auth.permissions import SensitiveDomainPermission
from lemur.common.schema import validate_schema from lemur.common.schema import validate_schema
from lemur.common.utils import paginated_parser from lemur.common.utils import paginated_parser
from lemur.domains.schemas import domain_input_schema, domain_output_schema, domains_output_schema from lemur.domains.schemas import (
domain_input_schema,
domain_output_schema,
domains_output_schema,
)
mod = Blueprint('domains', __name__) mod = Blueprint("domains", __name__)
api = Api(mod) api = Api(mod)
class DomainsList(AuthenticatedResource): class DomainsList(AuthenticatedResource):
""" Defines the 'domains' endpoint """ """ Defines the 'domains' endpoint """
def __init__(self): def __init__(self):
super(DomainsList, self).__init__() super(DomainsList, self).__init__()
@ -123,7 +128,7 @@ class DomainsList(AuthenticatedResource):
:statuscode 200: no error :statuscode 200: no error
:statuscode 403: unauthenticated :statuscode 403: unauthenticated
""" """
return service.create(data['name'], data['sensitive']) return service.create(data["name"], data["sensitive"])
class Domains(AuthenticatedResource): class Domains(AuthenticatedResource):
@ -205,13 +210,14 @@ class Domains(AuthenticatedResource):
:statuscode 403: unauthenticated :statuscode 403: unauthenticated
""" """
if SensitiveDomainPermission().can(): if SensitiveDomainPermission().can():
return service.update(domain_id, data['name'], data['sensitive']) return service.update(domain_id, data["name"], data["sensitive"])
return dict(message='You are not authorized to modify this domain'), 403 return dict(message="You are not authorized to modify this domain"), 403
class CertificateDomains(AuthenticatedResource): class CertificateDomains(AuthenticatedResource):
""" Defines the 'domains' endpoint """ """ Defines the 'domains' endpoint """
def __init__(self): def __init__(self):
super(CertificateDomains, self).__init__() super(CertificateDomains, self).__init__()
@ -265,10 +271,14 @@ class CertificateDomains(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
args = parser.parse_args() args = parser.parse_args()
args['certificate_id'] = certificate_id args["certificate_id"] = certificate_id
return service.render(args) return service.render(args)
api.add_resource(DomainsList, '/domains', endpoint='domains') api.add_resource(DomainsList, "/domains", endpoint="domains")
api.add_resource(Domains, '/domains/<int:domain_id>', endpoint='domain') api.add_resource(Domains, "/domains/<int:domain_id>", endpoint="domain")
api.add_resource(CertificateDomains, '/certificates/<int:certificate_id>/domains', endpoint='certificateDomains') api.add_resource(
CertificateDomains,
"/certificates/<int:certificate_id>/domains",
endpoint="certificateDomains",
)

View File

@ -21,7 +21,14 @@ from lemur.endpoints.models import Endpoint
manager = Manager(usage="Handles all endpoint related tasks.") manager = Manager(usage="Handles all endpoint related tasks.")
@manager.option('-ttl', '--time-to-live', type=int, dest='ttl', default=2, help='Time in hours, which endpoint has not been refreshed to remove the endpoint.') @manager.option(
"-ttl",
"--time-to-live",
type=int,
dest="ttl",
default=2,
help="Time in hours, which endpoint has not been refreshed to remove the endpoint.",
)
def expire(ttl): def expire(ttl):
""" """
Removed all endpoints that have not been recently updated. Removed all endpoints that have not been recently updated.
@ -31,12 +38,18 @@ def expire(ttl):
try: try:
now = arrow.utcnow() now = arrow.utcnow()
expiration = now - timedelta(hours=ttl) expiration = now - timedelta(hours=ttl)
endpoints = database.session_query(Endpoint).filter(cast(Endpoint.last_updated, ArrowType) <= expiration) endpoints = database.session_query(Endpoint).filter(
cast(Endpoint.last_updated, ArrowType) <= expiration
)
for endpoint in endpoints: for endpoint in endpoints:
print("[!] Expiring endpoint: {name} Last Updated: {last_updated}".format(name=endpoint.name, last_updated=endpoint.last_updated)) print(
"[!] Expiring endpoint: {name} Last Updated: {last_updated}".format(
name=endpoint.name, last_updated=endpoint.last_updated
)
)
database.delete(endpoint) database.delete(endpoint)
metrics.send('endpoint_expired', 'counter', 1) metrics.send("endpoint_expired", "counter", 1)
print("[+] Finished expiration.") print("[+] Finished expiration.")
except Exception as e: except Exception as e:

View File

@ -20,15 +20,11 @@ from lemur.database import db
from lemur.models import policies_ciphers from lemur.models import policies_ciphers
BAD_CIPHERS = [ BAD_CIPHERS = ["Protocol-SSLv3", "Protocol-SSLv2", "Protocol-TLSv1"]
'Protocol-SSLv3',
'Protocol-SSLv2',
'Protocol-TLSv1'
]
class Cipher(db.Model): class Cipher(db.Model):
__tablename__ = 'ciphers' __tablename__ = "ciphers"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String(128), nullable=False) name = Column(String(128), nullable=False)
@ -38,23 +34,18 @@ class Cipher(db.Model):
@deprecated.expression @deprecated.expression
def deprecated(cls): def deprecated(cls):
return case( return case([(cls.name in BAD_CIPHERS, True)], else_=False)
[
(cls.name in BAD_CIPHERS, True)
],
else_=False
)
class Policy(db.Model): class Policy(db.Model):
___tablename__ = 'policies' ___tablename__ = "policies"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String(128), nullable=True) name = Column(String(128), nullable=True)
ciphers = relationship('Cipher', secondary=policies_ciphers, backref='policy') ciphers = relationship("Cipher", secondary=policies_ciphers, backref="policy")
class Endpoint(db.Model): class Endpoint(db.Model):
__tablename__ = 'endpoints' __tablename__ = "endpoints"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
owner = Column(String(128)) owner = Column(String(128))
name = Column(String(128)) name = Column(String(128))
@ -62,16 +53,18 @@ class Endpoint(db.Model):
type = Column(String(128)) type = Column(String(128))
active = Column(Boolean, default=True) active = Column(Boolean, default=True)
port = Column(Integer) port = Column(Integer)
policy_id = Column(Integer, ForeignKey('policy.id')) policy_id = Column(Integer, ForeignKey("policy.id"))
policy = relationship('Policy', backref='endpoint') policy = relationship("Policy", backref="endpoint")
certificate_id = Column(Integer, ForeignKey('certificates.id')) certificate_id = Column(Integer, ForeignKey("certificates.id"))
source_id = Column(Integer, ForeignKey('sources.id')) source_id = Column(Integer, ForeignKey("sources.id"))
sensitive = Column(Boolean, default=False) sensitive = Column(Boolean, default=False)
source = relationship('Source', back_populates='endpoints') source = relationship("Source", back_populates="endpoints")
last_updated = Column(ArrowType, default=arrow.utcnow, nullable=False) last_updated = Column(ArrowType, default=arrow.utcnow, nullable=False)
date_created = Column(ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False) date_created = Column(
ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False
)
replaced = association_proxy('certificate', 'replaced') replaced = association_proxy("certificate", "replaced")
@property @property
def issues(self): def issues(self):
@ -79,13 +72,30 @@ class Endpoint(db.Model):
for cipher in self.policy.ciphers: for cipher in self.policy.ciphers:
if cipher.deprecated: if cipher.deprecated:
issues.append({'name': 'deprecated cipher', 'value': '{0} has been deprecated consider removing it.'.format(cipher.name)}) issues.append(
{
"name": "deprecated cipher",
"value": "{0} has been deprecated consider removing it.".format(
cipher.name
),
}
)
if self.certificate.expired: if self.certificate.expired:
issues.append({'name': 'expired certificate', 'value': 'There is an expired certificate attached to this endpoint consider replacing it.'}) issues.append(
{
"name": "expired certificate",
"value": "There is an expired certificate attached to this endpoint consider replacing it.",
}
)
if self.certificate.revoked: if self.certificate.revoked:
issues.append({'name': 'revoked', 'value': 'There is a revoked certificate attached to this endpoint consider replacing it.'}) issues.append(
{
"name": "revoked",
"value": "There is a revoked certificate attached to this endpoint consider replacing it.",
}
)
return issues return issues

View File

@ -46,7 +46,7 @@ def get_by_name(name):
:param name: :param name:
:return: :return:
""" """
return database.get(Endpoint, name, field='name') return database.get(Endpoint, name, field="name")
def get_by_dnsname(dnsname): def get_by_dnsname(dnsname):
@ -56,7 +56,7 @@ def get_by_dnsname(dnsname):
:param dnsname: :param dnsname:
:return: :return:
""" """
return database.get(Endpoint, dnsname, field='dnsname') return database.get(Endpoint, dnsname, field="dnsname")
def get_by_dnsname_and_port(dnsname, port): def get_by_dnsname_and_port(dnsname, port):
@ -66,7 +66,11 @@ def get_by_dnsname_and_port(dnsname, port):
:param port: :param port:
:return: :return:
""" """
return Endpoint.query.filter(Endpoint.dnsname == dnsname).filter(Endpoint.port == port).scalar() return (
Endpoint.query.filter(Endpoint.dnsname == dnsname)
.filter(Endpoint.port == port)
.scalar()
)
def get_by_source(source_label): def get_by_source(source_label):
@ -95,12 +99,14 @@ def create(**kwargs):
""" """
endpoint = Endpoint(**kwargs) endpoint = Endpoint(**kwargs)
database.create(endpoint) database.create(endpoint)
metrics.send('endpoint_added', 'counter', 1, metric_tags={'source': endpoint.source.label}) metrics.send(
"endpoint_added", "counter", 1, metric_tags={"source": endpoint.source.label}
)
return endpoint return endpoint
def get_or_create_policy(**kwargs): def get_or_create_policy(**kwargs):
policy = database.get(Policy, kwargs['name'], field='name') policy = database.get(Policy, kwargs["name"], field="name")
if not policy: if not policy:
policy = Policy(**kwargs) policy = Policy(**kwargs)
@ -110,7 +116,7 @@ def get_or_create_policy(**kwargs):
def get_or_create_cipher(**kwargs): def get_or_create_cipher(**kwargs):
cipher = database.get(Cipher, kwargs['name'], field='name') cipher = database.get(Cipher, kwargs["name"], field="name")
if not cipher: if not cipher:
cipher = Cipher(**kwargs) cipher = Cipher(**kwargs)
@ -122,11 +128,13 @@ def get_or_create_cipher(**kwargs):
def update(endpoint_id, **kwargs): def update(endpoint_id, **kwargs):
endpoint = database.get(Endpoint, endpoint_id) endpoint = database.get(Endpoint, endpoint_id)
endpoint.policy = kwargs['policy'] endpoint.policy = kwargs["policy"]
endpoint.certificate = kwargs['certificate'] endpoint.certificate = kwargs["certificate"]
endpoint.source = kwargs['source'] endpoint.source = kwargs["source"]
endpoint.last_updated = arrow.utcnow() endpoint.last_updated = arrow.utcnow()
metrics.send('endpoint_updated', 'counter', 1, metric_tags={'source': endpoint.source.label}) metrics.send(
"endpoint_updated", "counter", 1, metric_tags={"source": endpoint.source.label}
)
database.update(endpoint) database.update(endpoint)
return endpoint return endpoint
@ -138,19 +146,17 @@ def render(args):
:return: :return:
""" """
query = database.session_query(Endpoint) query = database.session_query(Endpoint)
filt = args.pop('filter') filt = args.pop("filter")
if filt: if filt:
terms = filt.split(';') terms = filt.split(";")
if 'active' in filt: # this is really weird but strcmp seems to not work here?? if "active" in filt: # this is really weird but strcmp seems to not work here??
query = query.filter(Endpoint.active == truthiness(terms[1])) query = query.filter(Endpoint.active == truthiness(terms[1]))
elif 'port' in filt: elif "port" in filt:
if terms[1] != 'null': # ng-table adds 'null' if a number is removed if terms[1] != "null": # ng-table adds 'null' if a number is removed
query = query.filter(Endpoint.port == terms[1]) query = query.filter(Endpoint.port == terms[1])
elif 'ciphers' in filt: elif "ciphers" in filt:
query = query.filter( query = query.filter(Cipher.name == terms[1])
Cipher.name == terms[1]
)
else: else:
query = database.filter(query, Endpoint, terms) query = database.filter(query, Endpoint, terms)
@ -164,7 +170,7 @@ def stats(**kwargs):
:param kwargs: :param kwargs:
:return: :return:
""" """
attr = getattr(Endpoint, kwargs.get('metric')) attr = getattr(Endpoint, kwargs.get("metric"))
query = database.db.session.query(attr, func.count(attr)) query = database.db.session.query(attr, func.count(attr))
items = query.group_by(attr).all() items = query.group_by(attr).all()
@ -175,4 +181,4 @@ def stats(**kwargs):
keys.append(key) keys.append(key)
values.append(count) values.append(count)
return {'labels': keys, 'values': values} return {"labels": keys, "values": values}

View File

@ -16,12 +16,13 @@ from lemur.endpoints import service
from lemur.endpoints.schemas import endpoint_output_schema, endpoints_output_schema from lemur.endpoints.schemas import endpoint_output_schema, endpoints_output_schema
mod = Blueprint('endpoints', __name__) mod = Blueprint("endpoints", __name__)
api = Api(mod) api = Api(mod)
class EndpointsList(AuthenticatedResource): class EndpointsList(AuthenticatedResource):
""" Defines the 'endpoints' endpoint """ """ Defines the 'endpoints' endpoint """
def __init__(self): def __init__(self):
self.reqparse = reqparse.RequestParser() self.reqparse = reqparse.RequestParser()
super(EndpointsList, self).__init__() super(EndpointsList, self).__init__()
@ -63,7 +64,7 @@ class EndpointsList(AuthenticatedResource):
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
args = parser.parse_args() args = parser.parse_args()
args['user'] = g.current_user args["user"] = g.current_user
return service.render(args) return service.render(args)
@ -103,5 +104,5 @@ class Endpoints(AuthenticatedResource):
return service.get(endpoint_id) return service.get(endpoint_id)
api.add_resource(EndpointsList, '/endpoints', endpoint='endpoints') api.add_resource(EndpointsList, "/endpoints", endpoint="endpoints")
api.add_resource(Endpoints, '/endpoints/<int:endpoint_id>', endpoint='endpoint') api.add_resource(Endpoints, "/endpoints/<int:endpoint_id>", endpoint="endpoint")

View File

@ -21,7 +21,9 @@ class DuplicateError(LemurException):
class InvalidListener(LemurException): class InvalidListener(LemurException):
def __str__(self): def __str__(self):
return repr("Invalid listener, ensure you select a certificate if you are using a secure protocol") return repr(
"Invalid listener, ensure you select a certificate if you are using a secure protocol"
)
class AttrNotFound(LemurException): class AttrNotFound(LemurException):

View File

@ -15,25 +15,33 @@ class SQLAlchemy(SA):
db = SQLAlchemy() db = SQLAlchemy()
from flask_migrate import Migrate from flask_migrate import Migrate
migrate = Migrate() migrate = Migrate()
from flask_bcrypt import Bcrypt from flask_bcrypt import Bcrypt
bcrypt = Bcrypt() bcrypt = Bcrypt()
from flask_principal import Principal from flask_principal import Principal
principal = Principal(use_sessions=False) principal = Principal(use_sessions=False)
from flask_mail import Mail from flask_mail import Mail
smtp_mail = Mail() smtp_mail = Mail()
from lemur.metrics import Metrics from lemur.metrics import Metrics
metrics = Metrics() metrics = Metrics()
from raven.contrib.flask import Sentry from raven.contrib.flask import Sentry
sentry = Sentry() sentry = Sentry()
from blinker import Namespace from blinker import Namespace
signals = Namespace() signals = Namespace()
from flask_cors import CORS from flask_cors import CORS
cors = CORS() cors = CORS()

View File

@ -13,20 +13,21 @@ import os
import imp import imp
import errno import errno
import pkg_resources import pkg_resources
import socket
from logging import Formatter, StreamHandler from logging import Formatter, StreamHandler
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
from flask import Flask from flask import Flask
from flask_replicated import FlaskReplicated
import logmatic
from lemur.certificates.hooks import activate_debug_dump from lemur.certificates.hooks import activate_debug_dump
from lemur.common.health import mod as health from lemur.common.health import mod as health
from lemur.extensions import db, migrate, principal, smtp_mail, metrics, sentry, cors from lemur.extensions import db, migrate, principal, smtp_mail, metrics, sentry, cors
DEFAULT_BLUEPRINTS = ( DEFAULT_BLUEPRINTS = (health,)
health,
)
API_VERSION = 1 API_VERSION = 1
@ -53,6 +54,7 @@ def create_app(app_name=None, blueprints=None, config=None):
configure_blueprints(app, blueprints) configure_blueprints(app, blueprints)
configure_extensions(app) configure_extensions(app)
configure_logging(app) configure_logging(app)
configure_database(app)
install_plugins(app) install_plugins(app)
@app.teardown_appcontext @app.teardown_appcontext
@ -71,16 +73,17 @@ def from_file(file_path, silent=False):
:param file_path: :param file_path:
:param silent: :param silent:
""" """
d = imp.new_module('config') d = imp.new_module("config")
d.__file__ = file_path d.__file__ = file_path
try: try:
with open(file_path) as config_file: with open(file_path) as config_file:
exec(compile(config_file.read(), # nosec: config file safe exec( # nosec: config file safe
file_path, 'exec'), d.__dict__) compile(config_file.read(), file_path, "exec"), d.__dict__
)
except IOError as e: except IOError as e:
if silent and e.errno in (errno.ENOENT, errno.EISDIR): if silent and e.errno in (errno.ENOENT, errno.EISDIR):
return False return False
e.strerror = 'Unable to load configuration file (%s)' % e.strerror e.strerror = "Unable to load configuration file (%s)" % e.strerror
raise raise
return d return d
@ -94,8 +97,8 @@ def configure_app(app, config=None):
:return: :return:
""" """
# respect the config first # respect the config first
if config and config != 'None': if config and config != "None":
app.config['CONFIG_PATH'] = config app.config["CONFIG_PATH"] = config
app.config.from_object(from_file(config)) app.config.from_object(from_file(config))
else: else:
try: try:
@ -103,12 +106,21 @@ def configure_app(app, config=None):
except RuntimeError: except RuntimeError:
# look in default paths # look in default paths
if os.path.isfile(os.path.expanduser("~/.lemur/lemur.conf.py")): if os.path.isfile(os.path.expanduser("~/.lemur/lemur.conf.py")):
app.config.from_object(from_file(os.path.expanduser("~/.lemur/lemur.conf.py"))) app.config.from_object(
from_file(os.path.expanduser("~/.lemur/lemur.conf.py"))
)
else: else:
app.config.from_object(from_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'default.conf.py'))) app.config.from_object(
from_file(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"default.conf.py",
)
)
)
# we don't use this # we don't use this
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
def configure_extensions(app): def configure_extensions(app):
@ -125,9 +137,15 @@ def configure_extensions(app):
metrics.init_app(app) metrics.init_app(app)
sentry.init_app(app) sentry.init_app(app)
if app.config['CORS']: if app.config["CORS"]:
app.config['CORS_HEADERS'] = 'Content-Type' app.config["CORS_HEADERS"] = "Content-Type"
cors.init_app(app, resources=r'/api/*', headers='Content-Type', origin='*', supports_credentials=True) cors.init_app(
app,
resources=r"/api/*",
headers="Content-Type",
origin="*",
supports_credentials=True,
)
def configure_blueprints(app, blueprints): def configure_blueprints(app, blueprints):
@ -142,28 +160,41 @@ def configure_blueprints(app, blueprints):
app.register_blueprint(blueprint, url_prefix="/api/{0}".format(API_VERSION)) app.register_blueprint(blueprint, url_prefix="/api/{0}".format(API_VERSION))
def configure_database(app):
if app.config.get("SQLALCHEMY_ENABLE_FLASK_REPLICATED"):
FlaskReplicated(app)
def configure_logging(app): def configure_logging(app):
""" """
Sets up application wide logging. Sets up application wide logging.
:param app: :param app:
""" """
handler = RotatingFileHandler(app.config.get('LOG_FILE', 'lemur.log'), maxBytes=10000000, backupCount=100) handler = RotatingFileHandler(
app.config.get("LOG_FILE", "lemur.log"), maxBytes=10000000, backupCount=100
)
handler.setFormatter(Formatter( handler.setFormatter(
'%(asctime)s %(levelname)s: %(message)s ' Formatter(
'[in %(pathname)s:%(lineno)d]' "%(asctime)s %(levelname)s: %(message)s " "[in %(pathname)s:%(lineno)d]"
)) )
)
handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) if app.config.get("LOG_JSON", False):
app.logger.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) handler.setFormatter(
logmatic.JsonFormatter(extra={"hostname": socket.gethostname()})
)
handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG"))
app.logger.setLevel(app.config.get("LOG_LEVEL", "DEBUG"))
app.logger.addHandler(handler) app.logger.addHandler(handler)
stream_handler = StreamHandler() stream_handler = StreamHandler()
stream_handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG')) stream_handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG"))
app.logger.addHandler(stream_handler) app.logger.addHandler(stream_handler)
if app.config.get('DEBUG_DUMP', False): if app.config.get("DEBUG_DUMP", False):
activate_debug_dump() activate_debug_dump()
@ -176,17 +207,21 @@ def install_plugins(app):
""" """
from lemur.plugins import plugins from lemur.plugins import plugins
from lemur.plugins.base import register from lemur.plugins.base import register
# entry_points={ # entry_points={
# 'lemur.plugins': [ # 'lemur.plugins': [
# 'verisign = lemur_verisign.plugin:VerisignPlugin' # 'verisign = lemur_verisign.plugin:VerisignPlugin'
# ], # ],
# }, # },
for ep in pkg_resources.iter_entry_points('lemur.plugins'): for ep in pkg_resources.iter_entry_points("lemur.plugins"):
try: try:
plugin = ep.load() plugin = ep.load()
except Exception: except Exception:
import traceback import traceback
app.logger.error("Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc()))
app.logger.error(
"Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc())
)
else: else:
register(plugin) register(plugin)
@ -196,6 +231,9 @@ def install_plugins(app):
try: try:
plugins.get(slug) plugins.get(slug)
except KeyError: except KeyError:
raise Exception("Unable to location notification plugin: {slug}. Ensure that " raise Exception(
"LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin." "Unable to location notification plugin: {slug}. Ensure that "
.format(slug=slug)) "LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin.".format(
slug=slug
)
)

View File

@ -15,9 +15,19 @@ from lemur.database import db
class Log(db.Model): class Log(db.Model):
__tablename__ = 'logs' __tablename__ = "logs"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
certificate_id = Column(Integer, ForeignKey('certificates.id')) certificate_id = Column(Integer, ForeignKey("certificates.id"))
log_type = Column(Enum('key_view', 'create_cert', 'update_cert', 'revoke_cert', 'delete_cert', name='log_type'), nullable=False) log_type = Column(
Enum(
"key_view",
"create_cert",
"update_cert",
"revoke_cert",
"delete_cert",
name="log_type",
),
nullable=False,
)
logged_at = Column(ArrowType(), PassiveDefault(func.now()), nullable=False) logged_at = Column(ArrowType(), PassiveDefault(func.now()), nullable=False)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=False)

View File

@ -24,7 +24,11 @@ def create(user, type, certificate=None):
:param certificate: :param certificate:
:return: :return:
""" """
current_app.logger.info("[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format(type, user.email, certificate.name)) current_app.logger.info(
"[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format(
type, user.email, certificate.name
)
)
view = Log(user_id=user.id, log_type=type, certificate_id=certificate.id) view = Log(user_id=user.id, log_type=type, certificate_id=certificate.id)
database.add(view) database.add(view)
database.commit() database.commit()
@ -50,20 +54,22 @@ def render(args):
""" """
query = database.session_query(Log) query = database.session_query(Log)
filt = args.pop('filter') filt = args.pop("filter")
if filt: if filt:
terms = filt.split(';') terms = filt.split(";")
if 'certificate.name' in terms: if "certificate.name" in terms:
sub_query = database.session_query(Certificate.id)\ sub_query = database.session_query(Certificate.id).filter(
.filter(Certificate.name.ilike('%{0}%'.format(terms[1]))) Certificate.name.ilike("%{0}%".format(terms[1]))
)
query = query.filter(Log.certificate_id.in_(sub_query)) query = query.filter(Log.certificate_id.in_(sub_query))
elif 'user.email' in terms: elif "user.email" in terms:
sub_query = database.session_query(User.id)\ sub_query = database.session_query(User.id).filter(
.filter(User.email.ilike('%{0}%'.format(terms[1]))) User.email.ilike("%{0}%".format(terms[1]))
)
query = query.filter(Log.user_id.in_(sub_query)) query = query.filter(Log.user_id.in_(sub_query))

View File

@ -17,12 +17,13 @@ from lemur.logs.schemas import logs_output_schema
from lemur.logs import service from lemur.logs import service
mod = Blueprint('logs', __name__) mod = Blueprint("logs", __name__)
api = Api(mod) api = Api(mod)
class LogsList(AuthenticatedResource): class LogsList(AuthenticatedResource):
""" Defines the 'logs' endpoint """ """ Defines the 'logs' endpoint """
def __init__(self): def __init__(self):
self.reqparse = reqparse.RequestParser() self.reqparse = reqparse.RequestParser()
super(LogsList, self).__init__() super(LogsList, self).__init__()
@ -65,10 +66,10 @@ class LogsList(AuthenticatedResource):
:statuscode 200: no error :statuscode 200: no error
""" """
parser = paginated_parser.copy() parser = paginated_parser.copy()
parser.add_argument('owner', type=str, location='args') parser.add_argument("owner", type=str, location="args")
parser.add_argument('id', type=str, location='args') parser.add_argument("id", type=str, location="args")
args = parser.parse_args() args = parser.parse_args()
return service.render(args) return service.render(args)
api.add_resource(LogsList, '/logs', endpoint='logs') api.add_resource(LogsList, "/logs", endpoint="logs")

View File

@ -52,24 +52,24 @@ from lemur.dns_providers.models import DnsProvider # noqa
from sqlalchemy.sql import text from sqlalchemy.sql import text
manager = Manager(create_app) manager = Manager(create_app)
manager.add_option('-c', '--config', dest='config_path', required=False) manager.add_option("-c", "--config", dest="config_path", required=False)
migrate = Migrate(create_app) migrate = Migrate(create_app)
REQUIRED_VARIABLES = [ REQUIRED_VARIABLES = [
'LEMUR_SECURITY_TEAM_EMAIL', "LEMUR_SECURITY_TEAM_EMAIL",
'LEMUR_DEFAULT_ORGANIZATIONAL_UNIT', "LEMUR_DEFAULT_ORGANIZATIONAL_UNIT",
'LEMUR_DEFAULT_ORGANIZATION', "LEMUR_DEFAULT_ORGANIZATION",
'LEMUR_DEFAULT_LOCATION', "LEMUR_DEFAULT_LOCATION",
'LEMUR_DEFAULT_COUNTRY', "LEMUR_DEFAULT_COUNTRY",
'LEMUR_DEFAULT_STATE', "LEMUR_DEFAULT_STATE",
'SQLALCHEMY_DATABASE_URI' "SQLALCHEMY_DATABASE_URI",
] ]
KEY_LENGTH = 40 KEY_LENGTH = 40
DEFAULT_CONFIG_PATH = '~/.lemur/lemur.conf.py' DEFAULT_CONFIG_PATH = "~/.lemur/lemur.conf.py"
DEFAULT_SETTINGS = 'lemur.conf.server' DEFAULT_SETTINGS = "lemur.conf.server"
SETTINGS_ENVVAR = 'LEMUR_CONF' SETTINGS_ENVVAR = "LEMUR_CONF"
CONFIG_TEMPLATE = """ CONFIG_TEMPLATE = """
# This is just Python which means you can inherit and tweak settings # This is just Python which means you can inherit and tweak settings
@ -144,9 +144,9 @@ SQLALCHEMY_DATABASE_URI = 'postgresql://lemur:lemur@localhost:5432/lemur'
@MigrateCommand.command @MigrateCommand.command
def create(): def create():
database.db.engine.execute(text('CREATE EXTENSION IF NOT EXISTS pg_trgm')) database.db.engine.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
database.db.create_all() database.db.create_all()
stamp(revision='head') stamp(revision="head")
@MigrateCommand.command @MigrateCommand.command
@ -174,9 +174,9 @@ def generate_settings():
output = CONFIG_TEMPLATE.format( output = CONFIG_TEMPLATE.format(
# we use Fernet.generate_key to make sure that the key length is # we use Fernet.generate_key to make sure that the key length is
# compatible with Fernet # compatible with Fernet
encryption_key=Fernet.generate_key().decode('utf-8'), encryption_key=Fernet.generate_key().decode("utf-8"),
secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'), secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"),
flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'), flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"),
) )
return output return output
@ -190,39 +190,44 @@ class InitializeApp(Command):
Additionally a Lemur user will be created as a default user Additionally a Lemur user will be created as a default user
and be used when certificates are discovered by Lemur. and be used when certificates are discovered by Lemur.
""" """
option_list = (
Option('-p', '--password', dest='password'), option_list = (Option("-p", "--password", dest="password"),)
)
def run(self, password): def run(self, password):
create() create()
user = user_service.get_by_username("lemur") user = user_service.get_by_username("lemur")
admin_role = role_service.get_by_name('admin') admin_role = role_service.get_by_name("admin")
if admin_role: if admin_role:
sys.stdout.write("[-] Admin role already created, skipping...!\n") sys.stdout.write("[-] Admin role already created, skipping...!\n")
else: else:
# we create an admin role # we create an admin role
admin_role = role_service.create('admin', description='This is the Lemur administrator role.') admin_role = role_service.create(
"admin", description="This is the Lemur administrator role."
)
sys.stdout.write("[+] Created 'admin' role\n") sys.stdout.write("[+] Created 'admin' role\n")
operator_role = role_service.get_by_name('operator') operator_role = role_service.get_by_name("operator")
if operator_role: if operator_role:
sys.stdout.write("[-] Operator role already created, skipping...!\n") sys.stdout.write("[-] Operator role already created, skipping...!\n")
else: else:
# we create an operator role # we create an operator role
operator_role = role_service.create('operator', description='This is the Lemur operator role.') operator_role = role_service.create(
"operator", description="This is the Lemur operator role."
)
sys.stdout.write("[+] Created 'operator' role\n") sys.stdout.write("[+] Created 'operator' role\n")
read_only_role = role_service.get_by_name('read-only') read_only_role = role_service.get_by_name("read-only")
if read_only_role: if read_only_role:
sys.stdout.write("[-] Read only role already created, skipping...!\n") sys.stdout.write("[-] Read only role already created, skipping...!\n")
else: else:
# we create an read only role # we create an read only role
read_only_role = role_service.create('read-only', description='This is the Lemur read only role.') read_only_role = role_service.create(
"read-only", description="This is the Lemur read only role."
)
sys.stdout.write("[+] Created 'read-only' role\n") sys.stdout.write("[+] Created 'read-only' role\n")
if not user: if not user:
@ -235,34 +240,54 @@ class InitializeApp(Command):
sys.stderr.write("[!] Passwords do not match!\n") sys.stderr.write("[!] Passwords do not match!\n")
sys.exit(1) sys.exit(1)
user_service.create("lemur", password, 'lemur@nobody.com', True, None, [admin_role]) user_service.create(
sys.stdout.write("[+] Created the user 'lemur' and granted it the 'admin' role!\n") "lemur", password, "lemur@nobody.com", True, None, [admin_role]
)
sys.stdout.write(
"[+] Created the user 'lemur' and granted it the 'admin' role!\n"
)
else: else:
sys.stdout.write("[-] Default user has already been created, skipping...!\n") sys.stdout.write(
"[-] Default user has already been created, skipping...!\n"
)
intervals = current_app.config.get("LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", []) intervals = current_app.config.get(
"LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", []
)
sys.stdout.write( sys.stdout.write(
"[!] Creating {num} notifications for {intervals} days as specified by LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS\n".format( "[!] Creating {num} notifications for {intervals} days as specified by LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS\n".format(
num=len(intervals), num=len(intervals), intervals=",".join([str(x) for x in intervals])
intervals=",".join([str(x) for x in intervals])
) )
) )
recipients = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') recipients = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")
sys.stdout.write("[+] Creating expiration email notifications!\n") sys.stdout.write("[+] Creating expiration email notifications!\n")
sys.stdout.write("[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format(recipients)) sys.stdout.write(
notification_service.create_default_expiration_notifications("DEFAULT_SECURITY", recipients=recipients) "[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format(
recipients
)
)
notification_service.create_default_expiration_notifications(
"DEFAULT_SECURITY", recipients=recipients
)
_DEFAULT_ROTATION_INTERVAL = 'default' _DEFAULT_ROTATION_INTERVAL = "default"
default_rotation_interval = policy_service.get_by_name(_DEFAULT_ROTATION_INTERVAL) default_rotation_interval = policy_service.get_by_name(
_DEFAULT_ROTATION_INTERVAL
)
if default_rotation_interval: if default_rotation_interval:
sys.stdout.write("[-] Default rotation interval policy already created, skipping...!\n") sys.stdout.write(
"[-] Default rotation interval policy already created, skipping...!\n"
)
else: else:
days = current_app.config.get("LEMUR_DEFAULT_ROTATION_INTERVAL", 30) days = current_app.config.get("LEMUR_DEFAULT_ROTATION_INTERVAL", 30)
sys.stdout.write("[+] Creating default certificate rotation policy of {days} days before issuance.\n".format( sys.stdout.write(
days=days)) "[+] Creating default certificate rotation policy of {days} days before issuance.\n".format(
days=days
)
)
policy_service.create(days=days, name=_DEFAULT_ROTATION_INTERVAL) policy_service.create(days=days, name=_DEFAULT_ROTATION_INTERVAL)
sys.stdout.write("[/] Done!\n") sys.stdout.write("[/] Done!\n")
@ -272,12 +297,13 @@ class CreateUser(Command):
""" """
This command allows for the creation of a new user within Lemur. This command allows for the creation of a new user within Lemur.
""" """
option_list = ( option_list = (
Option('-u', '--username', dest='username', required=True), Option("-u", "--username", dest="username", required=True),
Option('-e', '--email', dest='email', required=True), Option("-e", "--email", dest="email", required=True),
Option('-a', '--active', dest='active', default=True), Option("-a", "--active", dest="active", default=True),
Option('-r', '--roles', dest='roles', action='append', default=[]), Option("-r", "--roles", dest="roles", action="append", default=[]),
Option('-p', '--password', dest='password', default=None) Option("-p", "--password", dest="password", default=None),
) )
def run(self, username, email, active, roles, password): def run(self, username, email, active, roles, password):
@ -307,9 +333,8 @@ class ResetPassword(Command):
""" """
This command allows you to reset a user's password. This command allows you to reset a user's password.
""" """
option_list = (
Option('-u', '--username', dest='username', required=True), option_list = (Option("-u", "--username", dest="username", required=True),)
)
def run(self, username): def run(self, username):
user = user_service.get_by_username(username) user = user_service.get_by_username(username)
@ -335,10 +360,11 @@ class CreateRole(Command):
""" """
This command allows for the creation of a new role within Lemur This command allows for the creation of a new role within Lemur
""" """
option_list = ( option_list = (
Option('-n', '--name', dest='name', required=True), Option("-n", "--name", dest="name", required=True),
Option('-u', '--users', dest='users', default=[]), Option("-u", "--users", dest="users", default=[]),
Option('-d', '--description', dest='description', required=True) Option("-d", "--description", dest="description", required=True),
) )
def run(self, name, users, description): def run(self, name, users, description):
@ -369,7 +395,8 @@ class LemurServer(Command):
Will start gunicorn with 4 workers bound to 127.0.0.0:8002 Will start gunicorn with 4 workers bound to 127.0.0.0:8002
""" """
description = 'Run the app within Gunicorn'
description = "Run the app within Gunicorn"
def get_options(self): def get_options(self):
settings = make_settings() settings = make_settings()
@ -377,8 +404,10 @@ class LemurServer(Command):
for setting, klass in settings.items(): for setting, klass in settings.items():
if klass.cli: if klass.cli:
if klass.action: if klass.action:
if klass.action == 'store_const': if klass.action == "store_const":
options.append(Option(*klass.cli, const=klass.const, action=klass.action)) options.append(
Option(*klass.cli, const=klass.const, action=klass.action)
)
else: else:
options.append(Option(*klass.cli, action=klass.action)) options.append(Option(*klass.cli, action=klass.action))
else: else:
@ -394,7 +423,9 @@ class LemurServer(Command):
# run startup tasks on an app like object # run startup tasks on an app like object
validate_conf(current_app, REQUIRED_VARIABLES) validate_conf(current_app, REQUIRED_VARIABLES)
app.app_uri = 'lemur:create_app(config_path="{0}")'.format(current_app.config.get('CONFIG_PATH')) app.app_uri = 'lemur:create_app(config_path="{0}")'.format(
current_app.config.get("CONFIG_PATH")
)
return app.run() return app.run()
@ -414,7 +445,7 @@ def create_config(config_path=None):
os.makedirs(dir) os.makedirs(dir)
config = generate_settings() config = generate_settings()
with open(config_path, 'w') as f: with open(config_path, "w") as f:
f.write(config) f.write(config)
sys.stdout.write("[+] Created a new configuration file {0}\n".format(config_path)) sys.stdout.write("[+] Created a new configuration file {0}\n".format(config_path))
@ -436,7 +467,7 @@ def lock(path=None):
:param: path :param: path
""" """
if not path: if not path:
path = os.path.expanduser('~/.lemur/keys') path = os.path.expanduser("~/.lemur/keys")
dest_dir = os.path.join(path, "encrypted") dest_dir = os.path.join(path, "encrypted")
sys.stdout.write("[!] Generating a new key...\n") sys.stdout.write("[!] Generating a new key...\n")
@ -447,15 +478,17 @@ def lock(path=None):
sys.stdout.write("[+] Creating encryption directory: {0}\n".format(dest_dir)) sys.stdout.write("[+] Creating encryption directory: {0}\n".format(dest_dir))
os.makedirs(dest_dir) os.makedirs(dest_dir)
for root, dirs, files in os.walk(os.path.join(path, 'decrypted')): for root, dirs, files in os.walk(os.path.join(path, "decrypted")):
for f in files: for f in files:
source = os.path.join(root, f) source = os.path.join(root, f)
dest = os.path.join(dest_dir, f + ".enc") dest = os.path.join(dest_dir, f + ".enc")
with open(source, 'rb') as in_file, open(dest, 'wb') as out_file: with open(source, "rb") as in_file, open(dest, "wb") as out_file:
f = Fernet(key) f = Fernet(key)
data = f.encrypt(in_file.read()) data = f.encrypt(in_file.read())
out_file.write(data) out_file.write(data)
sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source)) sys.stdout.write(
"[+] Writing file: {0} Source: {1}\n".format(dest, source)
)
sys.stdout.write("[+] Keys have been encrypted with key {0}\n".format(key)) sys.stdout.write("[+] Keys have been encrypted with key {0}\n".format(key))
@ -475,7 +508,7 @@ def unlock(path=None):
key = prompt_pass("[!] Please enter the encryption password") key = prompt_pass("[!] Please enter the encryption password")
if not path: if not path:
path = os.path.expanduser('~/.lemur/keys') path = os.path.expanduser("~/.lemur/keys")
dest_dir = os.path.join(path, "decrypted") dest_dir = os.path.join(path, "decrypted")
source_dir = os.path.join(path, "encrypted") source_dir = os.path.join(path, "encrypted")
@ -488,11 +521,13 @@ def unlock(path=None):
for f in files: for f in files:
source = os.path.join(source_dir, f) source = os.path.join(source_dir, f)
dest = os.path.join(dest_dir, ".".join(f.split(".")[:-1])) dest = os.path.join(dest_dir, ".".join(f.split(".")[:-1]))
with open(source, 'rb') as in_file, open(dest, 'wb') as out_file: with open(source, "rb") as in_file, open(dest, "wb") as out_file:
f = Fernet(key) f = Fernet(key)
data = f.decrypt(in_file.read()) data = f.decrypt(in_file.read())
out_file.write(data) out_file.write(data)
sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source)) sys.stdout.write(
"[+] Writing file: {0} Source: {1}\n".format(dest, source)
)
sys.stdout.write("[+] Keys have been unencrypted!\n") sys.stdout.write("[+] Keys have been unencrypted!\n")
@ -505,15 +540,16 @@ def publish_verisign_units():
:return: :return:
""" """
from lemur.plugins import plugins from lemur.plugins import plugins
v = plugins.get('verisign-issuer')
v = plugins.get("verisign-issuer")
units = v.get_available_units() units = v.get_available_units()
metrics = {} metrics = {}
for item in units: for item in units:
if item['@type'] in metrics.keys(): if item["@type"] in metrics.keys():
metrics[item['@type']] += int(item['@remaining']) metrics[item["@type"]] += int(item["@remaining"])
else: else:
metrics.update({item['@type']: int(item['@remaining'])}) metrics.update({item["@type"]: int(item["@remaining"])})
for name, value in metrics.items(): for name, value in metrics.items():
metric = [ metric = [
@ -522,16 +558,16 @@ def publish_verisign_units():
"type": "GAUGE", "type": "GAUGE",
"name": "Symantec {0} Unit Count".format(name), "name": "Symantec {0} Unit Count".format(name),
"tags": {}, "tags": {},
"value": value "value": value,
} }
] ]
requests.post('http://localhost:8078/metrics', data=json.dumps(metric)) requests.post("http://localhost:8078/metrics", data=json.dumps(metric))
def main(): def main():
manager.add_command("start", LemurServer()) manager.add_command("start", LemurServer())
manager.add_command("runserver", Server(host='127.0.0.1', threaded=True)) manager.add_command("runserver", Server(host="127.0.0.1", threaded=True))
manager.add_command("clean", Clean()) manager.add_command("clean", Clean())
manager.add_command("show_urls", ShowUrls()) manager.add_command("show_urls", ShowUrls())
manager.add_command("db", MigrateCommand) manager.add_command("db", MigrateCommand)

View File

@ -11,6 +11,7 @@ class Metrics(object):
""" """
:param app: The Flask application object. Defaults to None. :param app: The Flask application object. Defaults to None.
""" """
_providers = [] _providers = []
def __init__(self, app=None): def __init__(self, app=None):
@ -22,11 +23,14 @@ class Metrics(object):
:param app: The Flask application object. :param app: The Flask application object.
""" """
self._providers = app.config.get('METRIC_PROVIDERS', []) self._providers = app.config.get("METRIC_PROVIDERS", [])
def send(self, metric_name, metric_type, metric_value, *args, **kwargs): def send(self, metric_name, metric_type, metric_value, *args, **kwargs):
for provider in self._providers: for provider in self._providers:
current_app.logger.debug( current_app.logger.debug(
"Sending metric '{metric}' to the {provider} provider.".format(metric=metric_name, provider=provider)) "Sending metric '{metric}' to the {provider} provider.".format(
metric=metric_name, provider=provider
)
)
p = plugins.get(provider) p = plugins.get(provider)
p.submit(metric_name, metric_type, metric_value, *args, **kwargs) p.submit(metric_name, metric_type, metric_value, *args, **kwargs)

View File

@ -19,8 +19,11 @@ fileConfig(config.config_file_name)
# from myapp import mymodel # from myapp import mymodel
# target_metadata = mymodel.Base.metadata # target_metadata = mymodel.Base.metadata
from flask import current_app from flask import current_app
config.set_main_option('sqlalchemy.url', current_app.config.get('SQLALCHEMY_DATABASE_URI'))
target_metadata = current_app.extensions['migrate'].db.metadata config.set_main_option(
"sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI")
)
target_metadata = current_app.extensions["migrate"].db.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@ -54,14 +57,18 @@ def run_migrations_online():
and associate a connection with the context. and associate a connection with the context.
""" """
engine = engine_from_config(config.get_section(config.config_ini_section), engine = engine_from_config(
prefix='sqlalchemy.', config.get_section(config.config_ini_section),
poolclass=pool.NullPool) prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
connection = engine.connect() connection = engine.connect()
context.configure(connection=connection, context.configure(
connection=connection,
target_metadata=target_metadata, target_metadata=target_metadata,
**current_app.extensions['migrate'].configure_args) **current_app.extensions["migrate"].configure_args
)
try: try:
with context.begin_transaction(): with context.begin_transaction():
@ -69,8 +76,8 @@ def run_migrations_online():
finally: finally:
connection.close() connection.close()
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:
run_migrations_online() run_migrations_online()

View File

@ -7,8 +7,8 @@ Create Date: 2016-12-07 17:29:42.049986
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '131ec6accff5' revision = "131ec6accff5"
down_revision = 'e3691fc396e9' down_revision = "e3691fc396e9"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -16,13 +16,24 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column('certificates', sa.Column('rotation', sa.Boolean(), nullable=False, server_default=sa.false())) op.add_column(
op.add_column('endpoints', sa.Column('last_updated', sa.DateTime(), server_default=sa.text('now()'), nullable=False)) "certificates",
sa.Column("rotation", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"endpoints",
sa.Column(
"last_updated",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_column('endpoints', 'last_updated') op.drop_column("endpoints", "last_updated")
op.drop_column('certificates', 'rotation') op.drop_column("certificates", "rotation")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -7,15 +7,19 @@ Create Date: 2017-07-13 12:32:09.162800
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1ae8e3104db8' revision = "1ae8e3104db8"
down_revision = 'a02a678ddc25' down_revision = "a02a678ddc25"
from alembic import op from alembic import op
def upgrade(): def upgrade():
op.sync_enum_values('public', 'log_type', ['key_view'], ['create_cert', 'key_view', 'update_cert']) op.sync_enum_values(
"public", "log_type", ["key_view"], ["create_cert", "key_view", "update_cert"]
)
def downgrade(): def downgrade():
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['key_view']) op.sync_enum_values(
"public", "log_type", ["create_cert", "key_view", "update_cert"], ["key_view"]
)

View File

@ -7,8 +7,8 @@ Create Date: 2018-08-03 12:56:44.565230
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '1db4f82bc780' revision = "1db4f82bc780"
down_revision = '3adfdd6598df' down_revision = "3adfdd6598df"
import logging import logging
@ -20,12 +20,14 @@ log = logging.getLogger(__name__)
def upgrade(): def upgrade():
connection = op.get_bind() connection = op.get_bind()
result = connection.execute("""\ result = connection.execute(
"""\
UPDATE certificates UPDATE certificates
SET rotation_policy_id=(SELECT id FROM rotation_policies WHERE name='default') SET rotation_policy_id=(SELECT id FROM rotation_policies WHERE name='default')
WHERE rotation_policy_id IS NULL WHERE rotation_policy_id IS NULL
RETURNING id RETURNING id
""") """
)
log.info("Filled rotation_policy for %d certificates" % result.rowcount) log.info("Filled rotation_policy for %d certificates" % result.rowcount)

View File

@ -7,8 +7,8 @@ Create Date: 2016-06-28 16:05:25.720213
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '29d8c8455c86' revision = "29d8c8455c86"
down_revision = '3307381f3b88' down_revision = "3307381f3b88"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -17,46 +17,60 @@ from sqlalchemy.dialects import postgresql
def upgrade(): def upgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.create_table('ciphers', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "ciphers",
sa.Column('name', sa.String(length=128), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id') sa.Column("name", sa.String(length=128), nullable=False),
sa.PrimaryKeyConstraint("id"),
) )
op.create_table('policy', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "policy",
sa.Column('name', sa.String(length=128), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id') sa.Column("name", sa.String(length=128), nullable=True),
sa.PrimaryKeyConstraint("id"),
) )
op.create_table('policies_ciphers', op.create_table(
sa.Column('cipher_id', sa.Integer(), nullable=True), "policies_ciphers",
sa.Column('policy_id', sa.Integer(), nullable=True), sa.Column("cipher_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['cipher_id'], ['ciphers.id'], ), sa.Column("policy_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], ) sa.ForeignKeyConstraint(["cipher_id"], ["ciphers.id"]),
sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]),
) )
op.create_index('policies_ciphers_ix', 'policies_ciphers', ['cipher_id', 'policy_id'], unique=False) op.create_index(
op.create_table('endpoints', "policies_ciphers_ix",
sa.Column('id', sa.Integer(), nullable=False), "policies_ciphers",
sa.Column('owner', sa.String(length=128), nullable=True), ["cipher_id", "policy_id"],
sa.Column('name', sa.String(length=128), nullable=True), unique=False,
sa.Column('dnsname', sa.String(length=256), nullable=True), )
sa.Column('type', sa.String(length=128), nullable=True), op.create_table(
sa.Column('active', sa.Boolean(), nullable=True), "endpoints",
sa.Column('port', sa.Integer(), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('date_created', sa.DateTime(), server_default=sa.text(u'now()'), nullable=False), sa.Column("owner", sa.String(length=128), nullable=True),
sa.Column('policy_id', sa.Integer(), nullable=True), sa.Column("name", sa.String(length=128), nullable=True),
sa.Column('certificate_id', sa.Integer(), nullable=True), sa.Column("dnsname", sa.String(length=256), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), sa.Column("type", sa.String(length=128), nullable=True),
sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], ), sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("port", sa.Integer(), nullable=True),
sa.Column(
"date_created",
sa.DateTime(),
server_default=sa.text(u"now()"),
nullable=False,
),
sa.Column("policy_id", sa.Integer(), nullable=True),
sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]),
sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]),
sa.PrimaryKeyConstraint("id"),
) )
### end Alembic commands ### ### end Alembic commands ###
def downgrade(): def downgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.drop_table('endpoints') op.drop_table("endpoints")
op.drop_index('policies_ciphers_ix', table_name='policies_ciphers') op.drop_index("policies_ciphers_ix", table_name="policies_ciphers")
op.drop_table('policies_ciphers') op.drop_table("policies_ciphers")
op.drop_table('policy') op.drop_table("policy")
op.drop_table('ciphers') op.drop_table("ciphers")
### end Alembic commands ### ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2019-02-05 15:42:25.477587
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '318b66568358' revision = "318b66568358"
down_revision = '9f79024fe67b' down_revision = "9f79024fe67b"
from alembic import op from alembic import op
@ -16,7 +16,7 @@ from alembic import op
def upgrade(): def upgrade():
connection = op.get_bind() connection = op.get_bind()
# Delete duplicate entries # Delete duplicate entries
connection.execute('UPDATE certificates SET deleted = false WHERE deleted IS NULL') connection.execute("UPDATE certificates SET deleted = false WHERE deleted IS NULL")
def downgrade(): def downgrade():

View File

@ -12,8 +12,8 @@ Create Date: 2016-05-20 17:33:04.360687
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '3307381f3b88' revision = "3307381f3b88"
down_revision = '412b22cb656a' down_revision = "412b22cb656a"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -23,109 +23,165 @@ from sqlalchemy.dialects import postgresql
def upgrade(): def upgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.alter_column('authorities', 'owner', op.alter_column(
existing_type=sa.VARCHAR(length=128), "authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
nullable=True) )
op.drop_column('authorities', 'not_after') op.drop_column("authorities", "not_after")
op.drop_column('authorities', 'bits') op.drop_column("authorities", "bits")
op.drop_column('authorities', 'cn') op.drop_column("authorities", "cn")
op.drop_column('authorities', 'not_before') op.drop_column("authorities", "not_before")
op.add_column('certificates', sa.Column('root_authority_id', sa.Integer(), nullable=True)) op.add_column(
op.alter_column('certificates', 'body', "certificates", sa.Column("root_authority_id", sa.Integer(), nullable=True)
existing_type=sa.TEXT(), )
nullable=False) op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=False)
op.alter_column('certificates', 'owner', op.alter_column(
existing_type=sa.VARCHAR(length=128), "certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
nullable=True) )
op.drop_constraint(u'certificates_authority_id_fkey', 'certificates', type_='foreignkey') op.drop_constraint(
op.create_foreign_key(None, 'certificates', 'authorities', ['authority_id'], ['id'], ondelete='CASCADE') u"certificates_authority_id_fkey", "certificates", type_="foreignkey"
op.create_foreign_key(None, 'certificates', 'authorities', ['root_authority_id'], ['id'], ondelete='CASCADE') )
op.create_foreign_key(
None,
"certificates",
"authorities",
["authority_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
None,
"certificates",
"authorities",
["root_authority_id"],
["id"],
ondelete="CASCADE",
)
### end Alembic commands ### ### end Alembic commands ###
# link existing certificate to their authority certificates # link existing certificate to their authority certificates
conn = op.get_bind() conn = op.get_bind()
for id, body, owner in conn.execute(text('select id, body, owner from authorities')): for id, body, owner in conn.execute(
text("select id, body, owner from authorities")
):
if not owner: if not owner:
owner = "lemur@nobody" owner = "lemur@nobody"
# look up certificate by body, if duplications are found, pick one # look up certificate by body, if duplications are found, pick one
stmt = text('select id from certificates where body=:body') stmt = text("select id from certificates where body=:body")
stmt = stmt.bindparams(body=body) stmt = stmt.bindparams(body=body)
root_certificate = conn.execute(stmt).fetchone() root_certificate = conn.execute(stmt).fetchone()
if root_certificate: if root_certificate:
stmt = text('update certificates set root_authority_id=:root_authority_id where id=:id') stmt = text(
"update certificates set root_authority_id=:root_authority_id where id=:id"
)
stmt = stmt.bindparams(root_authority_id=id, id=root_certificate[0]) stmt = stmt.bindparams(root_authority_id=id, id=root_certificate[0])
op.execute(stmt) op.execute(stmt)
# link owner roles to their authorities # link owner roles to their authorities
stmt = text('select id from roles where name=:name') stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner) stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone() owner_role = conn.execute(stmt).fetchone()
if not owner_role: if not owner_role:
stmt = text('insert into roles (name, description) values (:name, :description)') stmt = text(
stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.') "insert into roles (name, description) values (:name, :description)"
)
stmt = stmt.bindparams(
name=owner, description="Lemur generated role or existing owner."
)
op.execute(stmt) op.execute(stmt)
stmt = text('select id from roles where name=:name') stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner) stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone() owner_role = conn.execute(stmt).fetchone()
stmt = text('select * from roles_authorities where role_id=:role_id and authority_id=:authority_id') stmt = text(
"select * from roles_authorities where role_id=:role_id and authority_id=:authority_id"
)
stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id) stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id)
exists = conn.execute(stmt).fetchone() exists = conn.execute(stmt).fetchone()
if not exists: if not exists:
stmt = text('insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)') stmt = text(
"insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)"
)
stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id) stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id)
op.execute(stmt) op.execute(stmt)
# link owner roles to their certificates # link owner roles to their certificates
for id, owner in conn.execute(text('select id, owner from certificates')): for id, owner in conn.execute(text("select id, owner from certificates")):
if not owner: if not owner:
owner = "lemur@nobody" owner = "lemur@nobody"
stmt = text('select id from roles where name=:name') stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner) stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone() owner_role = conn.execute(stmt).fetchone()
if not owner_role: if not owner_role:
stmt = text('insert into roles (name, description) values (:name, :description)') stmt = text(
stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.') "insert into roles (name, description) values (:name, :description)"
)
stmt = stmt.bindparams(
name=owner, description="Lemur generated role or existing owner."
)
op.execute(stmt) op.execute(stmt)
# link owner roles to their authorities # link owner roles to their authorities
stmt = text('select id from roles where name=:name') stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner) stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone() owner_role = conn.execute(stmt).fetchone()
stmt = text('select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id') stmt = text(
"select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id"
)
stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id) stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id)
exists = conn.execute(stmt).fetchone() exists = conn.execute(stmt).fetchone()
if not exists: if not exists:
stmt = text('insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)') stmt = text(
"insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)"
)
stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id) stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id)
op.execute(stmt) op.execute(stmt)
def downgrade(): def downgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'certificates', type_='foreignkey') op.drop_constraint(None, "certificates", type_="foreignkey")
op.drop_constraint(None, 'certificates', type_='foreignkey') op.drop_constraint(None, "certificates", type_="foreignkey")
op.create_foreign_key(u'certificates_authority_id_fkey', 'certificates', 'authorities', ['authority_id'], ['id']) op.create_foreign_key(
op.alter_column('certificates', 'owner', u"certificates_authority_id_fkey",
existing_type=sa.VARCHAR(length=128), "certificates",
nullable=True) "authorities",
op.alter_column('certificates', 'body', ["authority_id"],
existing_type=sa.TEXT(), ["id"],
nullable=True) )
op.drop_column('certificates', 'root_authority_id') op.alter_column(
op.add_column('authorities', sa.Column('not_before', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) "certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
op.add_column('authorities', sa.Column('cn', sa.VARCHAR(length=128), autoincrement=False, nullable=True)) )
op.add_column('authorities', sa.Column('bits', sa.INTEGER(), autoincrement=False, nullable=True)) op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=True)
op.add_column('authorities', sa.Column('not_after', postgresql.TIMESTAMP(), autoincrement=False, nullable=True)) op.drop_column("certificates", "root_authority_id")
op.alter_column('authorities', 'owner', op.add_column(
existing_type=sa.VARCHAR(length=128), "authorities",
nullable=True) sa.Column(
"not_before", postgresql.TIMESTAMP(), autoincrement=False, nullable=True
),
)
op.add_column(
"authorities",
sa.Column("cn", sa.VARCHAR(length=128), autoincrement=False, nullable=True),
)
op.add_column(
"authorities",
sa.Column("bits", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.add_column(
"authorities",
sa.Column(
"not_after", postgresql.TIMESTAMP(), autoincrement=False, nullable=True
),
)
op.alter_column(
"authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
)
### end Alembic commands ### ### end Alembic commands ###

View File

@ -7,25 +7,31 @@ Create Date: 2015-11-30 15:40:19.827272
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '33de094da890' revision = "33de094da890"
down_revision = None down_revision = None
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
def upgrade(): def upgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.create_table('certificate_replacement_associations', op.create_table(
sa.Column('replaced_certificate_id', sa.Integer(), nullable=True), "certificate_replacement_associations",
sa.Column('certificate_id', sa.Integer(), nullable=True), sa.Column("replaced_certificate_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ondelete='cascade'), sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade') sa.ForeignKeyConstraint(
["certificate_id"], ["certificates.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["replaced_certificate_id"], ["certificates.id"], ondelete="cascade"
),
) )
### end Alembic commands ### ### end Alembic commands ###
def downgrade(): def downgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.drop_table('certificate_replacement_associations') op.drop_table("certificate_replacement_associations")
### end Alembic commands ### ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-04-10 13:25:47.007556
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '3adfdd6598df' revision = "3adfdd6598df"
down_revision = '556ceb3e3c3e' down_revision = "556ceb3e3c3e"
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
@ -22,84 +22,90 @@ def upgrade():
# create provider table # create provider table
print("Creating dns_providers table") print("Creating dns_providers table")
op.create_table( op.create_table(
'dns_providers', "dns_providers",
sa.Column('id', sa.Integer(), nullable=False), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=256), nullable=True), sa.Column("name", sa.String(length=256), nullable=True),
sa.Column('description', sa.String(length=1024), nullable=True), sa.Column("description", sa.String(length=1024), nullable=True),
sa.Column('provider_type', sa.String(length=256), nullable=True), sa.Column("provider_type", sa.String(length=256), nullable=True),
sa.Column('credentials', Vault(), nullable=True), sa.Column("credentials", Vault(), nullable=True),
sa.Column('api_endpoint', sa.String(length=256), nullable=True), sa.Column("api_endpoint", sa.String(length=256), nullable=True),
sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False), sa.Column(
sa.Column('status', sa.String(length=128), nullable=True), "date_created", ArrowType(), server_default=sa.text("now()"), nullable=False
sa.Column('options', JSON), ),
sa.Column('domains', sa.JSON(), nullable=True), sa.Column("status", sa.String(length=128), nullable=True),
sa.PrimaryKeyConstraint('id'), sa.Column("options", JSON),
sa.UniqueConstraint('name') sa.Column("domains", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
) )
print("Adding dns_provider_id column to certificates") print("Adding dns_provider_id column to certificates")
op.add_column('certificates', sa.Column('dns_provider_id', sa.Integer(), nullable=True)) op.add_column(
"certificates", sa.Column("dns_provider_id", sa.Integer(), nullable=True)
)
print("Adding dns_provider_id column to pending_certs") print("Adding dns_provider_id column to pending_certs")
op.add_column('pending_certs', sa.Column('dns_provider_id', sa.Integer(), nullable=True)) op.add_column(
"pending_certs", sa.Column("dns_provider_id", sa.Integer(), nullable=True)
)
print("Adding options column to pending_certs") print("Adding options column to pending_certs")
op.add_column('pending_certs', sa.Column('options', JSON)) op.add_column("pending_certs", sa.Column("options", JSON))
print("Creating pending_dns_authorizations table") print("Creating pending_dns_authorizations table")
op.create_table( op.create_table(
'pending_dns_authorizations', "pending_dns_authorizations",
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True), sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column('account_number', sa.String(length=128), nullable=True), sa.Column("account_number", sa.String(length=128), nullable=True),
sa.Column('domains', JSON, nullable=True), sa.Column("domains", JSON, nullable=True),
sa.Column('dns_provider_type', sa.String(length=128), nullable=True), sa.Column("dns_provider_type", sa.String(length=128), nullable=True),
sa.Column('options', JSON, nullable=True), sa.Column("options", JSON, nullable=True),
) )
print("Creating certificates_dns_providers_fk foreign key") print("Creating certificates_dns_providers_fk foreign key")
op.create_foreign_key('certificates_dns_providers_fk', 'certificates', 'dns_providers', ['dns_provider_id'], ['id'], op.create_foreign_key(
ondelete='cascade') "certificates_dns_providers_fk",
"certificates",
"dns_providers",
["dns_provider_id"],
["id"],
ondelete="cascade",
)
print("Altering column types in the api_keys table") print("Altering column types in the api_keys table")
op.alter_column('api_keys', 'issued_at', op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=True)
existing_type=sa.BIGINT(), op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=True)
nullable=True) op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=True)
op.alter_column('api_keys', 'revoked', op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=True)
existing_type=sa.BOOLEAN(),
nullable=True)
op.alter_column('api_keys', 'ttl',
existing_type=sa.BIGINT(),
nullable=True)
op.alter_column('api_keys', 'user_id',
existing_type=sa.INTEGER(),
nullable=True)
print("Creating dns_providers_id foreign key on pending_certs table") print("Creating dns_providers_id foreign key on pending_certs table")
op.create_foreign_key(None, 'pending_certs', 'dns_providers', ['dns_provider_id'], ['id'], ondelete='CASCADE') op.create_foreign_key(
None,
"pending_certs",
"dns_providers",
["dns_provider_id"],
["id"],
ondelete="CASCADE",
)
def downgrade(): def downgrade():
print("Removing dns_providers_id foreign key on pending_certs table") print("Removing dns_providers_id foreign key on pending_certs table")
op.drop_constraint(None, 'pending_certs', type_='foreignkey') op.drop_constraint(None, "pending_certs", type_="foreignkey")
print("Reverting column types in the api_keys table") print("Reverting column types in the api_keys table")
op.alter_column('api_keys', 'user_id', op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=False)
existing_type=sa.INTEGER(), op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=False)
nullable=False) op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=False)
op.alter_column('api_keys', 'ttl', op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=False)
existing_type=sa.BIGINT(),
nullable=False)
op.alter_column('api_keys', 'revoked',
existing_type=sa.BOOLEAN(),
nullable=False)
op.alter_column('api_keys', 'issued_at',
existing_type=sa.BIGINT(),
nullable=False)
print("Reverting certificates_dns_providers_fk foreign key") print("Reverting certificates_dns_providers_fk foreign key")
op.drop_constraint('certificates_dns_providers_fk', 'certificates', type_='foreignkey') op.drop_constraint(
"certificates_dns_providers_fk", "certificates", type_="foreignkey"
)
print("Dropping pending_dns_authorizations table") print("Dropping pending_dns_authorizations table")
op.drop_table('pending_dns_authorizations') op.drop_table("pending_dns_authorizations")
print("Undoing modifications to pending_certs table") print("Undoing modifications to pending_certs table")
op.drop_column('pending_certs', 'options') op.drop_column("pending_certs", "options")
op.drop_column('pending_certs', 'dns_provider_id') op.drop_column("pending_certs", "dns_provider_id")
print("Undoing modifications to certificates table") print("Undoing modifications to certificates table")
op.drop_column('certificates', 'dns_provider_id') op.drop_column("certificates", "dns_provider_id")
print("Deleting dns_providers table") print("Deleting dns_providers table")
op.drop_table('dns_providers') op.drop_table("dns_providers")

View File

@ -7,8 +7,8 @@ Create Date: 2016-05-17 17:37:41.210232
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '412b22cb656a' revision = "412b22cb656a"
down_revision = '4c50b903d1ae' down_revision = "4c50b903d1ae"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -17,47 +17,102 @@ from sqlalchemy.sql import text
def upgrade(): def upgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.create_table('roles_authorities', op.create_table(
sa.Column('authority_id', sa.Integer(), nullable=True), "roles_authorities",
sa.Column('role_id', sa.Integer(), nullable=True), sa.Column("authority_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ), sa.Column("role_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) sa.ForeignKeyConstraint(["authority_id"], ["authorities.id"]),
sa.ForeignKeyConstraint(["role_id"], ["roles.id"]),
) )
op.create_index('roles_authorities_ix', 'roles_authorities', ['authority_id', 'role_id'], unique=True) op.create_index(
op.create_table('roles_certificates', "roles_authorities_ix",
sa.Column('certificate_id', sa.Integer(), nullable=True), "roles_authorities",
sa.Column('role_id', sa.Integer(), nullable=True), ["authority_id", "role_id"],
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), unique=True,
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], ) )
op.create_table(
"roles_certificates",
sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.Column("role_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]),
sa.ForeignKeyConstraint(["role_id"], ["roles.id"]),
)
op.create_index(
"roles_certificates_ix",
"roles_certificates",
["certificate_id", "role_id"],
unique=True,
)
op.create_index(
"certificate_associations_ix",
"certificate_associations",
["domain_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_destination_associations_ix",
"certificate_destination_associations",
["destination_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_notification_associations_ix",
"certificate_notification_associations",
["notification_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["certificate_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_source_associations_ix",
"certificate_source_associations",
["source_id", "certificate_id"],
unique=True,
)
op.create_index(
"roles_users_ix", "roles_users", ["user_id", "role_id"], unique=True
) )
op.create_index('roles_certificates_ix', 'roles_certificates', ['certificate_id', 'role_id'], unique=True)
op.create_index('certificate_associations_ix', 'certificate_associations', ['domain_id', 'certificate_id'], unique=True)
op.create_index('certificate_destination_associations_ix', 'certificate_destination_associations', ['destination_id', 'certificate_id'], unique=True)
op.create_index('certificate_notification_associations_ix', 'certificate_notification_associations', ['notification_id', 'certificate_id'], unique=True)
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True)
op.create_index('certificate_source_associations_ix', 'certificate_source_associations', ['source_id', 'certificate_id'], unique=True)
op.create_index('roles_users_ix', 'roles_users', ['user_id', 'role_id'], unique=True)
### end Alembic commands ### ### end Alembic commands ###
# migrate existing authority_id relationship to many_to_many # migrate existing authority_id relationship to many_to_many
conn = op.get_bind() conn = op.get_bind()
for id, authority_id in conn.execute(text('select id, authority_id from roles where authority_id is not null')): for id, authority_id in conn.execute(
stmt = text('insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)') text("select id, authority_id from roles where authority_id is not null")
):
stmt = text(
"insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)"
)
stmt = stmt.bindparams(role_id=id, authority_id=authority_id) stmt = stmt.bindparams(role_id=id, authority_id=authority_id)
op.execute(stmt) op.execute(stmt)
def downgrade(): def downgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.drop_index('roles_users_ix', table_name='roles_users') op.drop_index("roles_users_ix", table_name="roles_users")
op.drop_index('certificate_source_associations_ix', table_name='certificate_source_associations') op.drop_index(
op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations') "certificate_source_associations_ix",
op.drop_index('certificate_notification_associations_ix', table_name='certificate_notification_associations') table_name="certificate_source_associations",
op.drop_index('certificate_destination_associations_ix', table_name='certificate_destination_associations') )
op.drop_index('certificate_associations_ix', table_name='certificate_associations') op.drop_index(
op.drop_index('roles_certificates_ix', table_name='roles_certificates') "certificate_replacement_associations_ix",
op.drop_table('roles_certificates') table_name="certificate_replacement_associations",
op.drop_index('roles_authorities_ix', table_name='roles_authorities') )
op.drop_table('roles_authorities') op.drop_index(
"certificate_notification_associations_ix",
table_name="certificate_notification_associations",
)
op.drop_index(
"certificate_destination_associations_ix",
table_name="certificate_destination_associations",
)
op.drop_index("certificate_associations_ix", table_name="certificate_associations")
op.drop_index("roles_certificates_ix", table_name="roles_certificates")
op.drop_table("roles_certificates")
op.drop_index("roles_authorities_ix", table_name="roles_authorities")
op.drop_table("roles_authorities")
### end Alembic commands ### ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-02-24 22:51:35.369229
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '449c3d5c7299' revision = "449c3d5c7299"
down_revision = '5770674184de' down_revision = "5770674184de"
from alembic import op from alembic import op
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
@ -23,12 +23,14 @@ COLUMNS = ["notification_id", "certificate_id"]
def upgrade(): def upgrade():
connection = op.get_bind() connection = op.get_bind()
# Delete duplicate entries # Delete duplicate entries
connection.execute("""\ connection.execute(
"""\
DELETE FROM certificate_notification_associations WHERE ctid NOT IN ( DELETE FROM certificate_notification_associations WHERE ctid NOT IN (
-- Select the first tuple ID for each (notification_id, certificate_id) combination and keep that -- Select the first tuple ID for each (notification_id, certificate_id) combination and keep that
SELECT min(ctid) FROM certificate_notification_associations GROUP BY notification_id, certificate_id SELECT min(ctid) FROM certificate_notification_associations GROUP BY notification_id, certificate_id
) )
""") """
)
op.create_unique_constraint(CONSTRAINT_NAME, TABLE, COLUMNS) op.create_unique_constraint(CONSTRAINT_NAME, TABLE, COLUMNS)

View File

@ -7,20 +7,21 @@ Create Date: 2015-12-30 10:19:30.057791
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '4c50b903d1ae' revision = "4c50b903d1ae"
down_revision = '33de094da890' down_revision = "33de094da890"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
def upgrade(): def upgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.add_column('domains', sa.Column('sensitive', sa.Boolean(), nullable=True)) op.add_column("domains", sa.Column("sensitive", sa.Boolean(), nullable=True))
### end Alembic commands ### ### end Alembic commands ###
def downgrade(): def downgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.drop_column('domains', 'sensitive') op.drop_column("domains", "sensitive")
### end Alembic commands ### ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-01-05 01:18:45.571595
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '556ceb3e3c3e' revision = "556ceb3e3c3e"
down_revision = '449c3d5c7299' down_revision = "449c3d5c7299"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -16,84 +16,150 @@ from lemur.utils import Vault
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import ArrowType from sqlalchemy_utils import ArrowType
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('pending_certs', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "pending_certs",
sa.Column('external_id', sa.String(length=128), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('owner', sa.String(length=128), nullable=False), sa.Column("external_id", sa.String(length=128), nullable=True),
sa.Column('name', sa.String(length=256), nullable=True), sa.Column("owner", sa.String(length=128), nullable=False),
sa.Column('description', sa.String(length=1024), nullable=True), sa.Column("name", sa.String(length=256), nullable=True),
sa.Column('notify', sa.Boolean(), nullable=True), sa.Column("description", sa.String(length=1024), nullable=True),
sa.Column('number_attempts', sa.Integer(), nullable=True), sa.Column("notify", sa.Boolean(), nullable=True),
sa.Column('rename', sa.Boolean(), nullable=True), sa.Column("number_attempts", sa.Integer(), nullable=True),
sa.Column('cn', sa.String(length=128), nullable=True), sa.Column("rename", sa.Boolean(), nullable=True),
sa.Column('csr', sa.Text(), nullable=False), sa.Column("cn", sa.String(length=128), nullable=True),
sa.Column('chain', sa.Text(), nullable=True), sa.Column("csr", sa.Text(), nullable=False),
sa.Column('private_key', Vault(), nullable=True), sa.Column("chain", sa.Text(), nullable=True),
sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False), sa.Column("private_key", Vault(), nullable=True),
sa.Column('status', sa.String(length=128), nullable=True), sa.Column(
sa.Column('rotation', sa.Boolean(), nullable=True), "date_created", ArrowType(), server_default=sa.text("now()"), nullable=False
sa.Column('user_id', sa.Integer(), nullable=True), ),
sa.Column('authority_id', sa.Integer(), nullable=True), sa.Column("status", sa.String(length=128), nullable=True),
sa.Column('root_authority_id', sa.Integer(), nullable=True), sa.Column("rotation", sa.Boolean(), nullable=True),
sa.Column('rotation_policy_id', sa.Integer(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ondelete='CASCADE'), sa.Column("authority_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['root_authority_id'], ['authorities.id'], ondelete='CASCADE'), sa.Column("root_authority_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['rotation_policy_id'], ['rotation_policies.id'], ), sa.Column("rotation_policy_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.ForeignKeyConstraint(
sa.PrimaryKeyConstraint('id'), ["authority_id"], ["authorities.id"], ondelete="CASCADE"
sa.UniqueConstraint('name') ),
sa.ForeignKeyConstraint(
["root_authority_id"], ["authorities.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["rotation_policy_id"], ["rotation_policies.id"]),
sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
) )
op.create_table('pending_cert_destination_associations', op.create_table(
sa.Column('destination_id', sa.Integer(), nullable=True), "pending_cert_destination_associations",
sa.Column('pending_cert_id', sa.Integer(), nullable=True), sa.Column("destination_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['destination_id'], ['destinations.id'], ondelete='cascade'), sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade') sa.ForeignKeyConstraint(
["destination_id"], ["destinations.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
) )
op.create_index('pending_cert_destination_associations_ix', 'pending_cert_destination_associations', ['destination_id', 'pending_cert_id'], unique=False) op.create_index(
op.create_table('pending_cert_notification_associations', "pending_cert_destination_associations_ix",
sa.Column('notification_id', sa.Integer(), nullable=True), "pending_cert_destination_associations",
sa.Column('pending_cert_id', sa.Integer(), nullable=True), ["destination_id", "pending_cert_id"],
sa.ForeignKeyConstraint(['notification_id'], ['notifications.id'], ondelete='cascade'), unique=False,
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade')
) )
op.create_index('pending_cert_notification_associations_ix', 'pending_cert_notification_associations', ['notification_id', 'pending_cert_id'], unique=False) op.create_table(
op.create_table('pending_cert_replacement_associations', "pending_cert_notification_associations",
sa.Column('replaced_certificate_id', sa.Integer(), nullable=True), sa.Column("notification_id", sa.Integer(), nullable=True),
sa.Column('pending_cert_id', sa.Integer(), nullable=True), sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade') ["notification_id"], ["notifications.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
) )
op.create_index('pending_cert_replacement_associations_ix', 'pending_cert_replacement_associations', ['replaced_certificate_id', 'pending_cert_id'], unique=False) op.create_index(
op.create_table('pending_cert_role_associations', "pending_cert_notification_associations_ix",
sa.Column('pending_cert_id', sa.Integer(), nullable=True), "pending_cert_notification_associations",
sa.Column('role_id', sa.Integer(), nullable=True), ["notification_id", "pending_cert_id"],
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ), unique=False,
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], )
) )
op.create_index('pending_cert_role_associations_ix', 'pending_cert_role_associations', ['pending_cert_id', 'role_id'], unique=False) op.create_table(
op.create_table('pending_cert_source_associations', "pending_cert_replacement_associations",
sa.Column('source_id', sa.Integer(), nullable=True), sa.Column("replaced_certificate_id", sa.Integer(), nullable=True),
sa.Column('pending_cert_id', sa.Integer(), nullable=True), sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='cascade') ["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["replaced_certificate_id"], ["certificates.id"], ondelete="cascade"
),
)
op.create_index(
"pending_cert_replacement_associations_ix",
"pending_cert_replacement_associations",
["replaced_certificate_id", "pending_cert_id"],
unique=False,
)
op.create_table(
"pending_cert_role_associations",
sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.Column("role_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["pending_cert_id"], ["pending_certs.id"]),
sa.ForeignKeyConstraint(["role_id"], ["roles.id"]),
)
op.create_index(
"pending_cert_role_associations_ix",
"pending_cert_role_associations",
["pending_cert_id", "role_id"],
unique=False,
)
op.create_table(
"pending_cert_source_associations",
sa.Column("source_id", sa.Integer(), nullable=True),
sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="cascade"),
)
op.create_index(
"pending_cert_source_associations_ix",
"pending_cert_source_associations",
["source_id", "pending_cert_id"],
unique=False,
) )
op.create_index('pending_cert_source_associations_ix', 'pending_cert_source_associations', ['source_id', 'pending_cert_id'], unique=False)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index('pending_cert_source_associations_ix', table_name='pending_cert_source_associations') op.drop_index(
op.drop_table('pending_cert_source_associations') "pending_cert_source_associations_ix",
op.drop_index('pending_cert_role_associations_ix', table_name='pending_cert_role_associations') table_name="pending_cert_source_associations",
op.drop_table('pending_cert_role_associations') )
op.drop_index('pending_cert_replacement_associations_ix', table_name='pending_cert_replacement_associations') op.drop_table("pending_cert_source_associations")
op.drop_table('pending_cert_replacement_associations') op.drop_index(
op.drop_index('pending_cert_notification_associations_ix', table_name='pending_cert_notification_associations') "pending_cert_role_associations_ix", table_name="pending_cert_role_associations"
op.drop_table('pending_cert_notification_associations') )
op.drop_index('pending_cert_destination_associations_ix', table_name='pending_cert_destination_associations') op.drop_table("pending_cert_role_associations")
op.drop_table('pending_cert_destination_associations') op.drop_index(
op.drop_table('pending_certs') "pending_cert_replacement_associations_ix",
table_name="pending_cert_replacement_associations",
)
op.drop_table("pending_cert_replacement_associations")
op.drop_index(
"pending_cert_notification_associations_ix",
table_name="pending_cert_notification_associations",
)
op.drop_table("pending_cert_notification_associations")
op.drop_index(
"pending_cert_destination_associations_ix",
table_name="pending_cert_destination_associations",
)
op.drop_table("pending_cert_destination_associations")
op.drop_table("pending_certs")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-02-23 15:27:30.335435
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '5770674184de' revision = "5770674184de"
down_revision = 'ce547319f7be' down_revision = "ce547319f7be"
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from lemur.models import certificate_notification_associations from lemur.models import certificate_notification_associations
@ -32,7 +32,9 @@ def upgrade():
# If we've seen a pair already, delete the duplicates # If we've seen a pair already, delete the duplicates
if seen.get("{}-{}".format(x.certificate_id, x.notification_id)): if seen.get("{}-{}".format(x.certificate_id, x.notification_id)):
print("Deleting duplicate: {}".format(x)) print("Deleting duplicate: {}".format(x))
d = session.query(certificate_notification_associations).filter(certificate_notification_associations.c.id==x.id) d = session.query(certificate_notification_associations).filter(
certificate_notification_associations.c.id == x.id
)
d.delete(synchronize_session=False) d.delete(synchronize_session=False)
seen["{}-{}".format(x.certificate_id, x.notification_id)] = True seen["{}-{}".format(x.certificate_id, x.notification_id)] = True
db.session.commit() db.session.commit()

View File

@ -7,8 +7,8 @@ Create Date: 2018-08-14 08:16:43.329316
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '5ae0ecefb01f' revision = "5ae0ecefb01f"
down_revision = '1db4f82bc780' down_revision = "1db4f82bc780"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -16,17 +16,14 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
op.alter_column( op.alter_column(
table_name='pending_certs', table_name="pending_certs", column_name="status", nullable=True, type_=sa.TEXT()
column_name='status',
nullable=True,
type_=sa.TEXT()
) )
def downgrade(): def downgrade():
op.alter_column( op.alter_column(
table_name='pending_certs', table_name="pending_certs",
column_name='status', column_name="status",
nullable=True, nullable=True,
type_=sa.VARCHAR(128) type_=sa.VARCHAR(128),
) )

View File

@ -7,16 +7,18 @@ Create Date: 2017-12-08 14:19:11.903864
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '5bc47fa7cac4' revision = "5bc47fa7cac4"
down_revision = 'c05a8998b371' down_revision = "c05a8998b371"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('roles', sa.Column('third_party', sa.Boolean(), nullable=True, default=False)) op.add_column(
"roles", sa.Column("third_party", sa.Boolean(), nullable=True, default=False)
)
def downgrade(): def downgrade():
op.drop_column('roles', 'third_party') op.drop_column("roles", "third_party")

View File

@ -7,20 +7,20 @@ Create Date: 2017-01-26 05:05:25.168125
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '5e680529b666' revision = "5e680529b666"
down_revision = '131ec6accff5' down_revision = "131ec6accff5"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('endpoints', sa.Column('sensitive', sa.Boolean(), nullable=True)) op.add_column("endpoints", sa.Column("sensitive", sa.Boolean(), nullable=True))
op.add_column('endpoints', sa.Column('source_id', sa.Integer(), nullable=True)) op.add_column("endpoints", sa.Column("source_id", sa.Integer(), nullable=True))
op.create_foreign_key(None, 'endpoints', 'sources', ['source_id'], ['id']) op.create_foreign_key(None, "endpoints", "sources", ["source_id"], ["id"])
def downgrade(): def downgrade():
op.drop_constraint(None, 'endpoints', type_='foreignkey') op.drop_constraint(None, "endpoints", type_="foreignkey")
op.drop_column('endpoints', 'source_id') op.drop_column("endpoints", "source_id")
op.drop_column('endpoints', 'sensitive') op.drop_column("endpoints", "sensitive")

View File

@ -7,15 +7,15 @@ Create Date: 2018-10-19 15:23:06.750510
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '6006c79b6011' revision = "6006c79b6011"
down_revision = '984178255c83' down_revision = "984178255c83"
from alembic import op from alembic import op
def upgrade(): def upgrade():
op.create_unique_constraint("uq_label", 'sources', ['label']) op.create_unique_constraint("uq_label", "sources", ["label"])
def downgrade(): def downgrade():
op.drop_constraint("uq_label", 'sources', type_='unique') op.drop_constraint("uq_label", "sources", type_="unique")

View File

@ -7,15 +7,16 @@ Create Date: 2018-10-21 22:06:23.056906
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '7ead443ba911' revision = "7ead443ba911"
down_revision = '6006c79b6011' down_revision = "6006c79b6011"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('certificates', sa.Column('csr', sa.TEXT(), nullable=True)) op.add_column("certificates", sa.Column("csr", sa.TEXT(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('certificates', 'csr') op.drop_column("certificates", "csr")

View File

@ -9,8 +9,8 @@ Create Date: 2016-07-28 09:39:12.736506
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '7f71c0cea31a' revision = "7f71c0cea31a"
down_revision = '29d8c8455c86' down_revision = "29d8c8455c86"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -19,17 +19,25 @@ from sqlalchemy.sql import text
def upgrade(): def upgrade():
conn = op.get_bind() conn = op.get_bind()
for name in conn.execute(text('select name from certificates group by name having count(*) > 1')): for name in conn.execute(
for idx, id in enumerate(conn.execute(text("select id from certificates where certificates.name like :name order by id ASC").bindparams(name=name[0]))): text("select name from certificates group by name having count(*) > 1")
):
for idx, id in enumerate(
conn.execute(
text(
"select id from certificates where certificates.name like :name order by id ASC"
).bindparams(name=name[0])
)
):
if not idx: if not idx:
continue continue
new_name = name[0] + '-' + str(idx) new_name = name[0] + "-" + str(idx)
stmt = text('update certificates set name=:name where id=:id') stmt = text("update certificates set name=:name where id=:id")
stmt = stmt.bindparams(name=new_name, id=id[0]) stmt = stmt.bindparams(name=new_name, id=id[0])
op.execute(stmt) op.execute(stmt)
op.create_unique_constraint(None, 'certificates', ['name']) op.create_unique_constraint(None, "certificates", ["name"])
def downgrade(): def downgrade():
op.drop_constraint(None, 'certificates', type_='unique') op.drop_constraint(None, "certificates", type_="unique")

View File

@ -7,18 +7,28 @@ Create Date: 2017-05-10 11:56:13.999332
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '8ae67285ff14' revision = "8ae67285ff14"
down_revision = '5e680529b666' down_revision = "5e680529b666"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.drop_index('certificate_replacement_associations_ix') op.drop_index("certificate_replacement_associations_ix")
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True) op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["replaced_certificate_id", "certificate_id"],
unique=True,
)
def downgrade(): def downgrade():
op.drop_index('certificate_replacement_associations_ix') op.drop_index("certificate_replacement_associations_ix")
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True) op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["certificate_id", "certificate_id"],
unique=True,
)

View File

@ -7,15 +7,15 @@ Create Date: 2016-10-13 20:14:33.928029
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '932525b82f1a' revision = "932525b82f1a"
down_revision = '7f71c0cea31a' down_revision = "7f71c0cea31a"
from alembic import op from alembic import op
def upgrade(): def upgrade():
op.alter_column('certificates', 'active', new_column_name='notify') op.alter_column("certificates", "active", new_column_name="notify")
def downgrade(): def downgrade():
op.alter_column('certificates', 'notify', new_column_name='active') op.alter_column("certificates", "notify", new_column_name="active")

View File

@ -6,8 +6,8 @@ Create Date: 2018-09-17 08:33:37.087488
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '9392b9f9a805' revision = "9392b9f9a805"
down_revision = '5ae0ecefb01f' down_revision = "5ae0ecefb01f"
from alembic import op from alembic import op
from sqlalchemy_utils import ArrowType from sqlalchemy_utils import ArrowType
@ -15,10 +15,17 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('pending_certs', sa.Column('last_updated', ArrowType, server_default=sa.text('now()'), onupdate=sa.text('now()'), op.add_column(
nullable=False)) "pending_certs",
sa.Column(
"last_updated",
ArrowType,
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
)
def downgrade(): def downgrade():
op.drop_column('pending_certs', 'last_updated') op.drop_column("pending_certs", "last_updated")

View File

@ -7,18 +7,20 @@ Create Date: 2018-10-11 20:49:12.704563
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '984178255c83' revision = "984178255c83"
down_revision = 'f2383bf08fbc' down_revision = "f2383bf08fbc"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.add_column('pending_certs', sa.Column('resolved', sa.Boolean(), nullable=True)) op.add_column("pending_certs", sa.Column("resolved", sa.Boolean(), nullable=True))
op.add_column('pending_certs', sa.Column('resolved_cert_id', sa.Integer(), nullable=True)) op.add_column(
"pending_certs", sa.Column("resolved_cert_id", sa.Integer(), nullable=True)
)
def downgrade(): def downgrade():
op.drop_column('pending_certs', 'resolved_cert_id') op.drop_column("pending_certs", "resolved_cert_id")
op.drop_column('pending_certs', 'resolved') op.drop_column("pending_certs", "resolved")

View File

@ -7,16 +7,26 @@ Create Date: 2019-01-03 15:36:59.181911
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '9f79024fe67b' revision = "9f79024fe67b"
down_revision = 'ee827d1e1974' down_revision = "ee827d1e1974"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
def upgrade(): def upgrade():
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'delete_cert', 'key_view', 'revoke_cert', 'update_cert']) op.sync_enum_values(
"public",
"log_type",
["create_cert", "key_view", "revoke_cert", "update_cert"],
["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"],
)
def downgrade(): def downgrade():
op.sync_enum_values('public', 'log_type', ['create_cert', 'delete_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'key_view', 'revoke_cert', 'update_cert']) op.sync_enum_values(
"public",
"log_type",
["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"],
["create_cert", "key_view", "revoke_cert", "update_cert"],
)

View File

@ -10,8 +10,8 @@ Create Date: 2017-07-12 11:45:49.257927
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'a02a678ddc25' revision = "a02a678ddc25"
down_revision = '8ae67285ff14' down_revision = "8ae67285ff14"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -20,25 +20,30 @@ from sqlalchemy.sql import text
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('rotation_policies', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "rotation_policies",
sa.Column('name', sa.String(), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('days', sa.Integer(), nullable=True), sa.Column("name", sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id') sa.Column("days", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column(
"certificates", sa.Column("rotation_policy_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
None, "certificates", "rotation_policies", ["rotation_policy_id"], ["id"]
) )
op.add_column('certificates', sa.Column('rotation_policy_id', sa.Integer(), nullable=True))
op.create_foreign_key(None, 'certificates', 'rotation_policies', ['rotation_policy_id'], ['id'])
conn = op.get_bind() conn = op.get_bind()
stmt = text('insert into rotation_policies (days, name) values (:days, :name)') stmt = text("insert into rotation_policies (days, name) values (:days, :name)")
stmt = stmt.bindparams(days=30, name='default') stmt = stmt.bindparams(days=30, name="default")
conn.execute(stmt) conn.execute(stmt)
stmt = text('select id from rotation_policies where name=:name') stmt = text("select id from rotation_policies where name=:name")
stmt = stmt.bindparams(name='default') stmt = stmt.bindparams(name="default")
rotation_policy_id = conn.execute(stmt).fetchone()[0] rotation_policy_id = conn.execute(stmt).fetchone()[0]
stmt = text('update certificates set rotation_policy_id=:rotation_policy_id') stmt = text("update certificates set rotation_policy_id=:rotation_policy_id")
stmt = stmt.bindparams(rotation_policy_id=rotation_policy_id) stmt = stmt.bindparams(rotation_policy_id=rotation_policy_id)
conn.execute(stmt) conn.execute(stmt)
# ### end Alembic commands ### # ### end Alembic commands ###
@ -46,9 +51,17 @@ def upgrade():
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'certificates', type_='foreignkey') op.drop_constraint(None, "certificates", type_="foreignkey")
op.drop_column('certificates', 'rotation_policy_id') op.drop_column("certificates", "rotation_policy_id")
op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations') op.drop_index(
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True) "certificate_replacement_associations_ix",
op.drop_table('rotation_policies') table_name="certificate_replacement_associations",
)
op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["replaced_certificate_id", "certificate_id"],
unique=True,
)
op.drop_table("rotation_policies")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2017-10-11 10:16:39.682591
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'ac483cfeb230' revision = "ac483cfeb230"
down_revision = 'b29e2c4bf8c9' down_revision = "b29e2c4bf8c9"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -16,12 +16,18 @@ from sqlalchemy.dialects import postgresql
def upgrade(): def upgrade():
op.alter_column('certificates', 'name', op.alter_column(
"certificates",
"name",
existing_type=sa.VARCHAR(length=128), existing_type=sa.VARCHAR(length=128),
type_=sa.String(length=256)) type_=sa.String(length=256),
)
def downgrade(): def downgrade():
op.alter_column('certificates', 'name', op.alter_column(
"certificates",
"name",
existing_type=sa.VARCHAR(length=256), existing_type=sa.VARCHAR(length=256),
type_=sa.String(length=128)) type_=sa.String(length=128),
)

View File

@ -7,8 +7,8 @@ Create Date: 2017-09-26 10:50:35.740367
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'b29e2c4bf8c9' revision = "b29e2c4bf8c9"
down_revision = '1ae8e3104db8' down_revision = "1ae8e3104db8"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
@ -16,13 +16,25 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column('certificates', sa.Column('external_id', sa.String(128), nullable=True)) op.add_column(
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['create_cert', 'key_view', 'revoke_cert', 'update_cert']) "certificates", sa.Column("external_id", sa.String(128), nullable=True)
)
op.sync_enum_values(
"public",
"log_type",
["create_cert", "key_view", "update_cert"],
["create_cert", "key_view", "revoke_cert", "update_cert"],
)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'key_view', 'update_cert']) op.sync_enum_values(
op.drop_column('certificates', 'external_id') "public",
"log_type",
["create_cert", "key_view", "revoke_cert", "update_cert"],
["create_cert", "key_view", "update_cert"],
)
op.drop_column("certificates", "external_id")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -7,25 +7,27 @@ Create Date: 2017-11-10 14:51:28.975927
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'c05a8998b371' revision = "c05a8998b371"
down_revision = 'ac483cfeb230' down_revision = "ac483cfeb230"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy_utils import sqlalchemy_utils
def upgrade(): def upgrade():
op.create_table('api_keys', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "api_keys",
sa.Column('name', sa.String(length=128), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column("name", sa.String(length=128), nullable=True),
sa.Column('ttl', sa.BigInteger(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column('issued_at', sa.BigInteger(), nullable=False), sa.Column("ttl", sa.BigInteger(), nullable=False),
sa.Column('revoked', sa.Boolean(), nullable=False), sa.Column("issued_at", sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sa.Column("revoked", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('id') sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.PrimaryKeyConstraint("id"),
) )
def downgrade(): def downgrade():
op.drop_table('api_keys') op.drop_table("api_keys")

View File

@ -5,15 +5,15 @@ Create Date: 2018-10-11 09:44:57.099854
""" """
revision = 'c87cb989af04' revision = "c87cb989af04"
down_revision = '9392b9f9a805' down_revision = "9392b9f9a805"
from alembic import op from alembic import op
def upgrade(): def upgrade():
op.create_index(op.f('ix_domains_name'), 'domains', ['name'], unique=False) op.create_index(op.f("ix_domains_name"), "domains", ["name"], unique=False)
def downgrade(): def downgrade():
op.drop_index(op.f('ix_domains_name'), table_name='domains') op.drop_index(op.f("ix_domains_name"), table_name="domains")

View File

@ -7,8 +7,8 @@ Create Date: 2018-02-23 11:00:02.150561
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'ce547319f7be' revision = "ce547319f7be"
down_revision = '5bc47fa7cac4' down_revision = "5bc47fa7cac4"
import sqlalchemy as sa import sqlalchemy as sa
@ -24,12 +24,12 @@ TABLE = "certificate_notification_associations"
def upgrade(): def upgrade():
print("Adding id column") print("Adding id column")
op.add_column( op.add_column(
TABLE, TABLE, sa.Column("id", sa.Integer, primary_key=True, autoincrement=True)
sa.Column('id', sa.Integer, primary_key=True, autoincrement=True)
) )
db.session.commit() db.session.commit()
db.session.flush() db.session.flush()
def downgrade(): def downgrade():
op.drop_column(TABLE, "id") op.drop_column(TABLE, "id")
db.session.commit() db.session.commit()

View File

@ -7,29 +7,36 @@ Create Date: 2016-11-28 13:15:46.995219
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'e3691fc396e9' revision = "e3691fc396e9"
down_revision = '932525b82f1a' down_revision = "932525b82f1a"
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy_utils import sqlalchemy_utils
def upgrade(): def upgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.create_table('logs', op.create_table(
sa.Column('id', sa.Integer(), nullable=False), "logs",
sa.Column('certificate_id', sa.Integer(), nullable=True), sa.Column("id", sa.Integer(), nullable=False),
sa.Column('log_type', sa.Enum('key_view', name='log_type'), nullable=False), sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.Column('logged_at', sqlalchemy_utils.types.arrow.ArrowType(), server_default=sa.text('now()'), nullable=False), sa.Column("log_type", sa.Enum("key_view", name="log_type"), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False), sa.Column(
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ), "logged_at",
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), sqlalchemy_utils.types.arrow.ArrowType(),
sa.PrimaryKeyConstraint('id') server_default=sa.text("now()"),
nullable=False,
),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]),
sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.PrimaryKeyConstraint("id"),
) )
### end Alembic commands ### ### end Alembic commands ###
def downgrade(): def downgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.drop_table('logs') op.drop_table("logs")
### end Alembic commands ### ### end Alembic commands ###

View File

@ -7,25 +7,44 @@ Create Date: 2018-11-05 09:49:40.226368
""" """
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'ee827d1e1974' revision = "ee827d1e1974"
down_revision = '7ead443ba911' down_revision = "7ead443ba911"
from alembic import op from alembic import op
from sqlalchemy.exc import ProgrammingError from sqlalchemy.exc import ProgrammingError
def upgrade(): def upgrade():
connection = op.get_bind() connection = op.get_bind()
connection.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") connection.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
op.create_index('ix_certificates_cn', 'certificates', ['cn'], unique=False, postgresql_ops={'cn': 'gin_trgm_ops'}, op.create_index(
postgresql_using='gin') "ix_certificates_cn",
op.create_index('ix_certificates_name', 'certificates', ['name'], unique=False, "certificates",
postgresql_ops={'name': 'gin_trgm_ops'}, postgresql_using='gin') ["cn"],
op.create_index('ix_domains_name_gin', 'domains', ['name'], unique=False, postgresql_ops={'name': 'gin_trgm_ops'}, unique=False,
postgresql_using='gin') postgresql_ops={"cn": "gin_trgm_ops"},
postgresql_using="gin",
)
op.create_index(
"ix_certificates_name",
"certificates",
["name"],
unique=False,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
)
op.create_index(
"ix_domains_name_gin",
"domains",
["name"],
unique=False,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
)
def downgrade(): def downgrade():
op.drop_index('ix_domains_name', table_name='domains') op.drop_index("ix_domains_name", table_name="domains")
op.drop_index('ix_certificates_name', table_name='certificates') op.drop_index("ix_certificates_name", table_name="certificates")
op.drop_index('ix_certificates_cn', table_name='certificates') op.drop_index("ix_certificates_cn", table_name="certificates")

View File

@ -7,17 +7,22 @@ Create Date: 2018-10-11 11:23:31.195471
""" """
revision = 'f2383bf08fbc' revision = "f2383bf08fbc"
down_revision = 'c87cb989af04' down_revision = "c87cb989af04"
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op from alembic import op
def upgrade(): def upgrade():
op.create_index('ix_certificates_id_desc', 'certificates', [sa.text('id DESC')], unique=True, op.create_index(
postgresql_using='btree') "ix_certificates_id_desc",
"certificates",
[sa.text("id DESC")],
unique=True,
postgresql_using="btree",
)
def downgrade(): def downgrade():
op.drop_index('ix_certificates_id_desc', table_name='certificates') op.drop_index("ix_certificates_id_desc", table_name="certificates")

View File

@ -12,121 +12,201 @@ from sqlalchemy import Column, Integer, ForeignKey, Index, UniqueConstraint
from lemur.database import db from lemur.database import db
certificate_associations = db.Table('certificate_associations', certificate_associations = db.Table(
Column('domain_id', Integer, ForeignKey('domains.id')), "certificate_associations",
Column('certificate_id', Integer, ForeignKey('certificates.id')) Column("domain_id", Integer, ForeignKey("domains.id")),
Column("certificate_id", Integer, ForeignKey("certificates.id")),
) )
Index('certificate_associations_ix', certificate_associations.c.domain_id, certificate_associations.c.certificate_id) Index(
"certificate_associations_ix",
certificate_destination_associations = db.Table('certificate_destination_associations', certificate_associations.c.domain_id,
Column('destination_id', Integer, certificate_associations.c.certificate_id,
ForeignKey('destinations.id', ondelete='cascade')),
Column('certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade'))
) )
Index('certificate_destination_associations_ix', certificate_destination_associations.c.destination_id, certificate_destination_associations.c.certificate_id) certificate_destination_associations = db.Table(
"certificate_destination_associations",
certificate_source_associations = db.Table('certificate_source_associations', Column(
Column('source_id', Integer, "destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade")
ForeignKey('sources.id', ondelete='cascade')), ),
Column('certificate_id', Integer, Column(
ForeignKey('certificates.id', ondelete='cascade')) "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
),
) )
Index('certificate_source_associations_ix', certificate_source_associations.c.source_id, certificate_source_associations.c.certificate_id) Index(
"certificate_destination_associations_ix",
certificate_notification_associations = db.Table('certificate_notification_associations', certificate_destination_associations.c.destination_id,
Column('notification_id', Integer, certificate_destination_associations.c.certificate_id,
ForeignKey('notifications.id', ondelete='cascade')),
Column('certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade')),
Column('id', Integer, primary_key=True, autoincrement=True),
UniqueConstraint('notification_id', 'certificate_id', name='uq_dest_not_ids')
) )
Index('certificate_notification_associations_ix', certificate_notification_associations.c.notification_id, certificate_notification_associations.c.certificate_id) certificate_source_associations = db.Table(
"certificate_source_associations",
certificate_replacement_associations = db.Table('certificate_replacement_associations', Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")),
Column('replaced_certificate_id', Integer, Column(
ForeignKey('certificates.id', ondelete='cascade')), "certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
Column('certificate_id', Integer, ),
ForeignKey('certificates.id', ondelete='cascade'))
) )
Index('certificate_replacement_associations_ix', certificate_replacement_associations.c.replaced_certificate_id, certificate_replacement_associations.c.certificate_id, unique=True) Index(
"certificate_source_associations_ix",
roles_authorities = db.Table('roles_authorities', certificate_source_associations.c.source_id,
Column('authority_id', Integer, ForeignKey('authorities.id')), certificate_source_associations.c.certificate_id,
Column('role_id', Integer, ForeignKey('roles.id'))
) )
Index('roles_authorities_ix', roles_authorities.c.authority_id, roles_authorities.c.role_id) certificate_notification_associations = db.Table(
"certificate_notification_associations",
roles_certificates = db.Table('roles_certificates', Column(
Column('certificate_id', Integer, ForeignKey('certificates.id')), "notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade")
Column('role_id', Integer, ForeignKey('roles.id')) ),
Column(
"certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
),
Column("id", Integer, primary_key=True, autoincrement=True),
UniqueConstraint("notification_id", "certificate_id", name="uq_dest_not_ids"),
) )
Index('roles_certificates_ix', roles_certificates.c.certificate_id, roles_certificates.c.role_id) Index(
"certificate_notification_associations_ix",
certificate_notification_associations.c.notification_id,
roles_users = db.Table('roles_users', certificate_notification_associations.c.certificate_id,
Column('user_id', Integer, ForeignKey('users.id')),
Column('role_id', Integer, ForeignKey('roles.id'))
) )
Index('roles_users_ix', roles_users.c.user_id, roles_users.c.role_id) certificate_replacement_associations = db.Table(
"certificate_replacement_associations",
Column(
policies_ciphers = db.Table('policies_ciphers', "replaced_certificate_id",
Column('cipher_id', Integer, ForeignKey('ciphers.id')), Integer,
Column('policy_id', Integer, ForeignKey('policy.id'))) ForeignKey("certificates.id", ondelete="cascade"),
),
Index('policies_ciphers_ix', policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id) Column(
"certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
),
pending_cert_destination_associations = db.Table('pending_cert_destination_associations',
Column('destination_id', Integer,
ForeignKey('destinations.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
) )
Index('pending_cert_destination_associations_ix', pending_cert_destination_associations.c.destination_id, pending_cert_destination_associations.c.pending_cert_id) Index(
"certificate_replacement_associations_ix",
certificate_replacement_associations.c.replaced_certificate_id,
pending_cert_notification_associations = db.Table('pending_cert_notification_associations', certificate_replacement_associations.c.certificate_id,
Column('notification_id', Integer, unique=True,
ForeignKey('notifications.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
) )
Index('pending_cert_notification_associations_ix', pending_cert_notification_associations.c.notification_id, pending_cert_notification_associations.c.pending_cert_id) roles_authorities = db.Table(
"roles_authorities",
pending_cert_source_associations = db.Table('pending_cert_source_associations', Column("authority_id", Integer, ForeignKey("authorities.id")),
Column('source_id', Integer, Column("role_id", Integer, ForeignKey("roles.id")),
ForeignKey('sources.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
) )
Index('pending_cert_source_associations_ix', pending_cert_source_associations.c.source_id, pending_cert_source_associations.c.pending_cert_id) Index(
"roles_authorities_ix",
pending_cert_replacement_associations = db.Table('pending_cert_replacement_associations', roles_authorities.c.authority_id,
Column('replaced_certificate_id', Integer, roles_authorities.c.role_id,
ForeignKey('certificates.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
) )
Index('pending_cert_replacement_associations_ix', pending_cert_replacement_associations.c.replaced_certificate_id, pending_cert_replacement_associations.c.pending_cert_id) roles_certificates = db.Table(
"roles_certificates",
pending_cert_role_associations = db.Table('pending_cert_role_associations', Column("certificate_id", Integer, ForeignKey("certificates.id")),
Column('pending_cert_id', Integer, ForeignKey('pending_certs.id')), Column("role_id", Integer, ForeignKey("roles.id")),
Column('role_id', Integer, ForeignKey('roles.id'))
) )
Index('pending_cert_role_associations_ix', pending_cert_role_associations.c.pending_cert_id, pending_cert_role_associations.c.role_id) Index(
"roles_certificates_ix",
roles_certificates.c.certificate_id,
roles_certificates.c.role_id,
)
roles_users = db.Table(
"roles_users",
Column("user_id", Integer, ForeignKey("users.id")),
Column("role_id", Integer, ForeignKey("roles.id")),
)
Index("roles_users_ix", roles_users.c.user_id, roles_users.c.role_id)
policies_ciphers = db.Table(
"policies_ciphers",
Column("cipher_id", Integer, ForeignKey("ciphers.id")),
Column("policy_id", Integer, ForeignKey("policy.id")),
)
Index("policies_ciphers_ix", policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id)
pending_cert_destination_associations = db.Table(
"pending_cert_destination_associations",
Column(
"destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade")
),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index(
"pending_cert_destination_associations_ix",
pending_cert_destination_associations.c.destination_id,
pending_cert_destination_associations.c.pending_cert_id,
)
pending_cert_notification_associations = db.Table(
"pending_cert_notification_associations",
Column(
"notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade")
),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index(
"pending_cert_notification_associations_ix",
pending_cert_notification_associations.c.notification_id,
pending_cert_notification_associations.c.pending_cert_id,
)
pending_cert_source_associations = db.Table(
"pending_cert_source_associations",
Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index(
"pending_cert_source_associations_ix",
pending_cert_source_associations.c.source_id,
pending_cert_source_associations.c.pending_cert_id,
)
pending_cert_replacement_associations = db.Table(
"pending_cert_replacement_associations",
Column(
"replaced_certificate_id",
Integer,
ForeignKey("certificates.id", ondelete="cascade"),
),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index(
"pending_cert_replacement_associations_ix",
pending_cert_replacement_associations.c.replaced_certificate_id,
pending_cert_replacement_associations.c.pending_cert_id,
)
pending_cert_role_associations = db.Table(
"pending_cert_role_associations",
Column("pending_cert_id", Integer, ForeignKey("pending_certs.id")),
Column("role_id", Integer, ForeignKey("roles.id")),
)
Index(
"pending_cert_role_associations_ix",
pending_cert_role_associations.c.pending_cert_id,
pending_cert_role_associations.c.role_id,
)

View File

@ -14,7 +14,14 @@ from lemur.notifications.messaging import send_expiration_notifications
manager = Manager(usage="Handles notification related tasks.") manager = Manager(usage="Handles notification related tasks.")
@manager.option('-e', '--exclude', dest='exclude', action='append', default=[], help='Common name matching of certificates that should be excluded from notification') @manager.option(
"-e",
"--exclude",
dest="exclude",
action="append",
default=[],
help="Common name matching of certificates that should be excluded from notification",
)
def expirations(exclude): def expirations(exclude):
""" """
Runs Lemur's notification engine, that looks for expired certificates and sends Runs Lemur's notification engine, that looks for expired certificates and sends
@ -33,12 +40,13 @@ def expirations(exclude):
success, failed = send_expiration_notifications(exclude) success, failed = send_expiration_notifications(exclude)
print( print(
"Finished notifying subscribers about expiring certificates! Sent: {success} Failed: {failed}".format( "Finished notifying subscribers about expiring certificates! Sent: {success} Failed: {failed}".format(
success=success, success=success, failed=failed
failed=failed
) )
) )
status = SUCCESS_METRIC_STATUS status = SUCCESS_METRIC_STATUS
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
metrics.send('expiration_notification_job', 'counter', 1, metric_tags={'status': status}) metrics.send(
"expiration_notification_job", "counter", 1, metric_tags={"status": status}
)

View File

@ -36,15 +36,17 @@ def get_certificates(exclude=None):
now = arrow.utcnow() now = arrow.utcnow()
max = now + timedelta(days=90) max = now + timedelta(days=90)
q = database.db.session.query(Certificate) \ q = (
.filter(Certificate.not_after <= max) \ database.db.session.query(Certificate)
.filter(Certificate.notify == True) \ .filter(Certificate.not_after <= max)
.filter(Certificate.expired == False) # noqa .filter(Certificate.notify == True)
.filter(Certificate.expired == False)
) # noqa
exclude_conditions = [] exclude_conditions = []
if exclude: if exclude:
for e in exclude: for e in exclude:
exclude_conditions.append(~Certificate.name.ilike('%{}%'.format(e))) exclude_conditions.append(~Certificate.name.ilike("%{}%".format(e)))
q = q.filter(and_(*exclude_conditions)) q = q.filter(and_(*exclude_conditions))
@ -101,7 +103,12 @@ def send_notification(event_type, data, targets, notification):
except Exception as e: except Exception as e:
sentry.captureException() sentry.captureException()
metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': event_type}) metrics.send(
"notification",
"counter",
1,
metric_tags={"status": status, "event_type": event_type},
)
if status == SUCCESS_METRIC_STATUS: if status == SUCCESS_METRIC_STATUS:
return True return True
@ -115,7 +122,7 @@ def send_expiration_notifications(exclude):
success = failure = 0 success = failure = 0
# security team gets all # security team gets all
security_email = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') security_email = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")
security_data = [] security_data = []
for owner, notification_group in get_eligible_certificates(exclude=exclude).items(): for owner, notification_group in get_eligible_certificates(exclude=exclude).items():
@ -127,26 +134,43 @@ def send_expiration_notifications(exclude):
for data in certificates: for data in certificates:
n, certificate = data n, certificate = data
cert_data = certificate_notification_output_schema.dump(certificate).data cert_data = certificate_notification_output_schema.dump(
certificate
).data
notification_data.append(cert_data) notification_data.append(cert_data)
security_data.append(cert_data) security_data.append(cert_data)
notification_recipient = get_plugin_option('recipients', notification.options) notification_recipient = get_plugin_option(
"recipients", notification.options
)
if notification_recipient: if notification_recipient:
notification_recipient = notification_recipient.split(",") notification_recipient = notification_recipient.split(",")
if send_notification('expiration', notification_data, [owner], notification): if send_notification(
"expiration", notification_data, [owner], notification
):
success += 1 success += 1
else: else:
failure += 1 failure += 1
if notification_recipient and owner != notification_recipient and security_email != notification_recipient: if (
if send_notification('expiration', notification_data, notification_recipient, notification): notification_recipient
and owner != notification_recipient
and security_email != notification_recipient
):
if send_notification(
"expiration",
notification_data,
notification_recipient,
notification,
):
success += 1 success += 1
else: else:
failure += 1 failure += 1
if send_notification('expiration', security_data, security_email, notification): if send_notification(
"expiration", security_data, security_email, notification
):
success += 1 success += 1
else: else:
failure += 1 failure += 1
@ -165,24 +189,35 @@ def send_rotation_notification(certificate, notification_plugin=None):
""" """
status = FAILURE_METRIC_STATUS status = FAILURE_METRIC_STATUS
if not notification_plugin: if not notification_plugin:
notification_plugin = plugins.get(current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN')) notification_plugin = plugins.get(
current_app.config.get("LEMUR_DEFAULT_NOTIFICATION_PLUGIN")
)
data = certificate_notification_output_schema.dump(certificate).data data = certificate_notification_output_schema.dump(certificate).data
try: try:
notification_plugin.send('rotation', data, [data['owner']]) notification_plugin.send("rotation", data, [data["owner"]])
status = SUCCESS_METRIC_STATUS status = SUCCESS_METRIC_STATUS
except Exception as e: except Exception as e:
current_app.logger.error('Unable to send notification to {}.'.format(data['owner']), exc_info=True) current_app.logger.error(
"Unable to send notification to {}.".format(data["owner"]), exc_info=True
)
sentry.captureException() sentry.captureException()
metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'}) metrics.send(
"notification",
"counter",
1,
metric_tags={"status": status, "event_type": "rotation"},
)
if status == SUCCESS_METRIC_STATUS: if status == SUCCESS_METRIC_STATUS:
return True return True
def send_pending_failure_notification(pending_cert, notify_owner=True, notify_security=True, notification_plugin=None): def send_pending_failure_notification(
pending_cert, notify_owner=True, notify_security=True, notification_plugin=None
):
""" """
Sends a report to certificate owners when their pending certificate failed to be created. Sends a report to certificate owners when their pending certificate failed to be created.
@ -194,32 +229,47 @@ def send_pending_failure_notification(pending_cert, notify_owner=True, notify_se
if not notification_plugin: if not notification_plugin:
notification_plugin = plugins.get( notification_plugin = plugins.get(
current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN', 'email-notification') current_app.config.get(
"LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification"
)
) )
data = pending_certificate_output_schema.dump(pending_cert).data data = pending_certificate_output_schema.dump(pending_cert).data
data["security_email"] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL') data["security_email"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")
if notify_owner: if notify_owner:
try: try:
notification_plugin.send('failed', data, [data['owner']], pending_cert) notification_plugin.send("failed", data, [data["owner"]], pending_cert)
status = SUCCESS_METRIC_STATUS status = SUCCESS_METRIC_STATUS
except Exception as e: except Exception as e:
current_app.logger.error('Unable to send pending failure notification to {}.'.format(data['owner']), current_app.logger.error(
exc_info=True) "Unable to send pending failure notification to {}.".format(
data["owner"]
),
exc_info=True,
)
sentry.captureException() sentry.captureException()
if notify_security: if notify_security:
try: try:
notification_plugin.send('failed', data, data["security_email"], pending_cert) notification_plugin.send(
"failed", data, data["security_email"], pending_cert
)
status = SUCCESS_METRIC_STATUS status = SUCCESS_METRIC_STATUS
except Exception as e: except Exception as e:
current_app.logger.error('Unable to send pending failure notification to ' current_app.logger.error(
'{}.'.format(data['security_email']), "Unable to send pending failure notification to "
exc_info=True) "{}.".format(data["security_email"]),
exc_info=True,
)
sentry.captureException() sentry.captureException()
metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'}) metrics.send(
"notification",
"counter",
1,
metric_tags={"status": status, "event_type": "rotation"},
)
if status == SUCCESS_METRIC_STATUS: if status == SUCCESS_METRIC_STATUS:
return True return True
@ -242,20 +292,22 @@ def needs_notification(certificate):
if not notification.active or not notification.options: if not notification.active or not notification.options:
return return
interval = get_plugin_option('interval', notification.options) interval = get_plugin_option("interval", notification.options)
unit = get_plugin_option('unit', notification.options) unit = get_plugin_option("unit", notification.options)
if unit == 'weeks': if unit == "weeks":
interval *= 7 interval *= 7
elif unit == 'months': elif unit == "months":
interval *= 30 interval *= 30
elif unit == 'days': # it's nice to be explicit about the base unit elif unit == "days": # it's nice to be explicit about the base unit
pass pass
else: else:
raise Exception("Invalid base unit for expiration interval: {0}".format(unit)) raise Exception(
"Invalid base unit for expiration interval: {0}".format(unit)
)
if days == interval: if days == interval:
notifications.append(notification) notifications.append(notification)

View File

@ -11,12 +11,14 @@ from sqlalchemy_utils import JSONType
from lemur.database import db from lemur.database import db
from lemur.plugins.base import plugins from lemur.plugins.base import plugins
from lemur.models import certificate_notification_associations, \ from lemur.models import (
pending_cert_notification_associations certificate_notification_associations,
pending_cert_notification_associations,
)
class Notification(db.Model): class Notification(db.Model):
__tablename__ = 'notifications' __tablename__ = "notifications"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
label = Column(String(128), unique=True) label = Column(String(128), unique=True)
description = Column(Text()) description = Column(Text())
@ -28,14 +30,14 @@ class Notification(db.Model):
secondary=certificate_notification_associations, secondary=certificate_notification_associations,
passive_deletes=True, passive_deletes=True,
backref="notification", backref="notification",
cascade='all,delete' cascade="all,delete",
) )
pending_certificates = relationship( pending_certificates = relationship(
"PendingCertificate", "PendingCertificate",
secondary=pending_cert_notification_associations, secondary=pending_cert_notification_associations,
passive_deletes=True, passive_deletes=True,
backref="notification", backref="notification",
cascade='all,delete' cascade="all,delete",
) )
@property @property

View File

@ -7,7 +7,11 @@
""" """
from marshmallow import fields, post_dump from marshmallow import fields, post_dump
from lemur.common.schema import LemurInputSchema, LemurOutputSchema from lemur.common.schema import LemurInputSchema, LemurOutputSchema
from lemur.schemas import PluginInputSchema, PluginOutputSchema, AssociatedCertificateSchema from lemur.schemas import (
PluginInputSchema,
PluginOutputSchema,
AssociatedCertificateSchema,
)
class NotificationInputSchema(LemurInputSchema): class NotificationInputSchema(LemurInputSchema):
@ -30,7 +34,7 @@ class NotificationOutputSchema(LemurOutputSchema):
@post_dump @post_dump
def fill_object(self, data): def fill_object(self, data):
if data: if data:
data['plugin']['pluginOptions'] = data['options'] data["plugin"]["pluginOptions"] = data["options"]
return data return data

Some files were not shown because too many files have changed in this diff Show More