Merge pull request #2792 from castrapel/black-052019

Black lint all the things
This commit is contained in:
Curtis 2019-05-16 08:25:12 -07:00 committed by GitHub
commit 37e5857406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
226 changed files with 9356 additions and 5943 deletions

View File

@ -8,3 +8,17 @@
sha: v2.9.5
hooks:
- id: jshint
- repo: https://github.com/ambv/black
rev: stable
hooks:
- id: black
language_version: python3.7
- repo: local
hooks:
- id: python-bandit-vulnerability-check
name: bandit
entry: bandit
args: ['--ini', 'tox.ini', '-r', 'consoleme']
language: system
pass_filenames: false

View File

@ -1,12 +1,18 @@
from __future__ import absolute_import, division, print_function
__all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__",
"__email__", "__license__", "__copyright__",
"__title__",
"__summary__",
"__uri__",
"__version__",
"__author__",
"__email__",
"__license__",
"__copyright__",
]
__title__ = "lemur"
__summary__ = ("Certificate management and orchestration service")
__summary__ = "Certificate management and orchestration service"
__uri__ = "https://github.com/Netflix/lemur"
__version__ = "0.7.0"

View File

@ -32,14 +32,26 @@ from lemur.pending_certificates.views import mod as pending_certificates_bp
from lemur.dns_providers.views import mod as dns_providers_bp
from lemur.__about__ import (
__author__, __copyright__, __email__, __license__, __summary__, __title__,
__uri__, __version__
__author__,
__copyright__,
__email__,
__license__,
__summary__,
__title__,
__uri__,
__version__,
)
__all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__",
"__email__", "__license__", "__copyright__",
"__title__",
"__summary__",
"__uri__",
"__version__",
"__author__",
"__email__",
"__license__",
"__copyright__",
]
LEMUR_BLUEPRINTS = (
@ -63,7 +75,9 @@ LEMUR_BLUEPRINTS = (
def create_app(config_path=None):
app = factory.create_app(app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config_path)
app = factory.create_app(
app_name=__name__, blueprints=LEMUR_BLUEPRINTS, config=config_path
)
configure_hook(app)
return app
@ -93,7 +107,7 @@ def configure_hook(app):
@app.after_request
def after_request(response):
# Return early if we don't have the start time
if not hasattr(g, 'request_start_time'):
if not hasattr(g, "request_start_time"):
return response
# Get elapsed time in milliseconds
@ -102,12 +116,12 @@ def configure_hook(app):
# Collect request/response tags
tags = {
'endpoint': request.endpoint,
'request_method': request.method.lower(),
'status_code': response.status_code
"endpoint": request.endpoint,
"request_method": request.method.lower(),
"status_code": response.status_code,
}
# Record our response time metric
metrics.send('response_time', 'TIMER', elapsed, metric_tags=tags)
metrics.send('status_code_{}'.format(response.status_code), 'counter', 1)
metrics.send("response_time", "TIMER", elapsed, metric_tags=tags)
metrics.send("status_code_{}".format(response.status_code), "counter", 1)
return response

View File

@ -14,23 +14,32 @@ from datetime import datetime
manager = Manager(usage="Handles all api key related tasks.")
@manager.option('-u', '--user-id', dest='uid', help='The User ID this access key belongs too.')
@manager.option('-n', '--name', dest='name', help='The name of this API Key.')
@manager.option('-t', '--ttl', dest='ttl', help='The TTL of this API Key. -1 for forever.')
@manager.option(
"-u", "--user-id", dest="uid", help="The User ID this access key belongs too."
)
@manager.option("-n", "--name", dest="name", help="The name of this API Key.")
@manager.option(
"-t", "--ttl", dest="ttl", help="The TTL of this API Key. -1 for forever."
)
def create(uid, name, ttl):
"""
Create a new api key for a user.
:return:
"""
print("[+] Creating a new api key.")
key = api_key_service.create(user_id=uid, name=name,
ttl=ttl, issued_at=int(datetime.utcnow().timestamp()), revoked=False)
key = api_key_service.create(
user_id=uid,
name=name,
ttl=ttl,
issued_at=int(datetime.utcnow().timestamp()),
revoked=False,
)
print("[+] Successfully created a new api key. Generating a JWT...")
jwt = create_token(uid, key.id, key.ttl)
print("[+] Your JWT is: {jwt}".format(jwt=jwt))
@manager.option('-a', '--api-key-id', dest='aid', help='The API Key ID to revoke.')
@manager.option("-a", "--api-key-id", dest="aid", help="The API Key ID to revoke.")
def revoke(aid):
"""
Revokes an api key for a user.

View File

@ -12,14 +12,19 @@ from lemur.database import db
class ApiKey(db.Model):
__tablename__ = 'api_keys'
__tablename__ = "api_keys"
id = Column(Integer, primary_key=True)
name = Column(String)
user_id = Column(Integer, ForeignKey('users.id'))
user_id = Column(Integer, ForeignKey("users.id"))
ttl = Column(BigInteger)
issued_at = Column(BigInteger)
revoked = Column(Boolean)
def __repr__(self):
return "ApiKey(name={name}, user_id={user_id}, ttl={ttl}, issued_at={iat}, revoked={revoked})".format(
user_id=self.user_id, name=self.name, ttl=self.ttl, iat=self.issued_at, revoked=self.revoked)
user_id=self.user_id,
name=self.name,
ttl=self.ttl,
iat=self.issued_at,
revoked=self.revoked,
)

View File

@ -13,12 +13,18 @@ from lemur.users.schemas import UserNestedOutputSchema, UserInputSchema
def current_user_id():
return {'id': g.current_user.id, 'email': g.current_user.email, 'username': g.current_user.username}
return {
"id": g.current_user.id,
"email": g.current_user.email,
"username": g.current_user.username,
}
class ApiKeyInputSchema(LemurInputSchema):
name = fields.String(required=False)
user = fields.Nested(UserInputSchema, missing=current_user_id, default=current_user_id)
user = fields.Nested(
UserInputSchema, missing=current_user_id, default=current_user_id
)
ttl = fields.Integer()

View File

@ -34,7 +34,7 @@ def revoke(aid):
:return:
"""
api_key = get(aid)
setattr(api_key, 'revoked', False)
setattr(api_key, "revoked", False)
return database.update(api_key)
@ -80,10 +80,10 @@ def render(args):
:return:
"""
query = database.session_query(ApiKey)
user_id = args.pop('user_id', None)
aid = args.pop('id', None)
has_permission = args.pop('has_permission', False)
requesting_user_id = args.pop('requesting_user_id')
user_id = args.pop("user_id", None)
aid = args.pop("id", None)
has_permission = args.pop("has_permission", False)
requesting_user_id = args.pop("requesting_user_id")
if user_id:
query = query.filter(ApiKey.user_id == user_id)

View File

@ -19,10 +19,16 @@ from lemur.auth.permissions import ApiKeyCreatorPermission
from lemur.common.schema import validate_schema
from lemur.common.utils import paginated_parser
from lemur.api_keys.schemas import api_key_input_schema, api_key_revoke_schema, api_key_output_schema, \
api_keys_output_schema, api_key_described_output_schema, user_api_key_input_schema
from lemur.api_keys.schemas import (
api_key_input_schema,
api_key_revoke_schema,
api_key_output_schema,
api_keys_output_schema,
api_key_described_output_schema,
user_api_key_input_schema,
)
mod = Blueprint('api_keys', __name__)
mod = Blueprint("api_keys", __name__)
api = Api(mod)
@ -81,8 +87,8 @@ class ApiKeyList(AuthenticatedResource):
"""
parser = paginated_parser.copy()
args = parser.parse_args()
args['has_permission'] = ApiKeyCreatorPermission().can()
args['requesting_user_id'] = g.current_user.id
args["has_permission"] = ApiKeyCreatorPermission().can()
args["requesting_user_id"] = g.current_user.id
return service.render(args)
@validate_schema(api_key_input_schema, api_key_output_schema)
@ -124,12 +130,26 @@ class ApiKeyList(AuthenticatedResource):
:statuscode 403: unauthenticated
"""
if not ApiKeyCreatorPermission().can():
if data['user']['id'] != g.current_user.id:
return dict(message="You are not authorized to create tokens for: {0}".format(data['user']['username'])), 403
if data["user"]["id"] != g.current_user.id:
return (
dict(
message="You are not authorized to create tokens for: {0}".format(
data["user"]["username"]
)
),
403,
)
access_token = service.create(name=data['name'], user_id=data['user']['id'], ttl=data['ttl'],
revoked=False, issued_at=int(datetime.utcnow().timestamp()))
return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl))
access_token = service.create(
name=data["name"],
user_id=data["user"]["id"],
ttl=data["ttl"],
revoked=False,
issued_at=int(datetime.utcnow().timestamp()),
)
return dict(
jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)
)
class ApiKeyUserList(AuthenticatedResource):
@ -186,9 +206,9 @@ class ApiKeyUserList(AuthenticatedResource):
"""
parser = paginated_parser.copy()
args = parser.parse_args()
args['has_permission'] = ApiKeyCreatorPermission().can()
args['requesting_user_id'] = g.current_user.id
args['user_id'] = user_id
args["has_permission"] = ApiKeyCreatorPermission().can()
args["requesting_user_id"] = g.current_user.id
args["user_id"] = user_id
return service.render(args)
@validate_schema(user_api_key_input_schema, api_key_output_schema)
@ -230,11 +250,25 @@ class ApiKeyUserList(AuthenticatedResource):
"""
if not ApiKeyCreatorPermission().can():
if user_id != g.current_user.id:
return dict(message="You are not authorized to create tokens for: {0}".format(user_id)), 403
return (
dict(
message="You are not authorized to create tokens for: {0}".format(
user_id
)
),
403,
)
access_token = service.create(name=data['name'], user_id=user_id, ttl=data['ttl'],
revoked=False, issued_at=int(datetime.utcnow().timestamp()))
return dict(jwt=create_token(access_token.user_id, access_token.id, access_token.ttl))
access_token = service.create(
name=data["name"],
user_id=user_id,
ttl=data["ttl"],
revoked=False,
issued_at=int(datetime.utcnow().timestamp()),
)
return dict(
jwt=create_token(access_token.user_id, access_token.id, access_token.ttl)
)
class ApiKeys(AuthenticatedResource):
@ -329,7 +363,9 @@ class ApiKeys(AuthenticatedResource):
if not ApiKeyCreatorPermission().can():
return dict(message="You are not authorized to update this token!"), 403
service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl'])
service.update(
access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"]
)
return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl))
def delete(self, aid):
@ -371,7 +407,7 @@ class ApiKeys(AuthenticatedResource):
return dict(message="You are not authorized to delete this token!"), 403
service.delete(access_key)
return {'result': True}
return {"result": True}
class UserApiKeys(AuthenticatedResource):
@ -472,7 +508,9 @@ class UserApiKeys(AuthenticatedResource):
if access_key.user_id != uid:
return dict(message="You are not authorized to update this token!"), 403
service.update(access_key, name=data['name'], revoked=data['revoked'], ttl=data['ttl'])
service.update(
access_key, name=data["name"], revoked=data["revoked"], ttl=data["ttl"]
)
return dict(jwt=create_token(access_key.user_id, access_key.id, access_key.ttl))
def delete(self, uid, aid):
@ -517,7 +555,7 @@ class UserApiKeys(AuthenticatedResource):
return dict(message="You are not authorized to delete this token!"), 403
service.delete(access_key)
return {'result': True}
return {"result": True}
class ApiKeysDescribed(AuthenticatedResource):
@ -572,8 +610,12 @@ class ApiKeysDescribed(AuthenticatedResource):
return access_key
api.add_resource(ApiKeyList, '/keys', endpoint='api_keys')
api.add_resource(ApiKeys, '/keys/<int:aid>', endpoint='api_key')
api.add_resource(ApiKeysDescribed, '/keys/<int:aid>/described', endpoint='api_key_described')
api.add_resource(ApiKeyUserList, '/users/<int:user_id>/keys', endpoint='user_api_keys')
api.add_resource(UserApiKeys, '/users/<int:uid>/keys/<int:aid>', endpoint='user_api_key')
api.add_resource(ApiKeyList, "/keys", endpoint="api_keys")
api.add_resource(ApiKeys, "/keys/<int:aid>", endpoint="api_key")
api.add_resource(
ApiKeysDescribed, "/keys/<int:aid>/described", endpoint="api_key_described"
)
api.add_resource(ApiKeyUserList, "/users/<int:user_id>/keys", endpoint="user_api_keys")
api.add_resource(
UserApiKeys, "/users/<int:uid>/keys/<int:aid>", endpoint="user_api_key"
)

View File

@ -14,35 +14,41 @@ from lemur.roles import service as role_service
from lemur.common.utils import validate_conf, get_psuedo_random_string
class LdapPrincipal():
class LdapPrincipal:
"""
Provides methods for authenticating against an LDAP server.
"""
def __init__(self, args):
self._ldap_validate_conf()
# setup ldap config
if not args['username']:
if not args["username"]:
raise Exception("missing ldap username")
if not args['password']:
if not args["password"]:
self.error_message = "missing ldap password"
raise Exception("missing ldap password")
self.ldap_principal = args['username']
self.ldap_principal = args["username"]
self.ldap_email_domain = current_app.config.get("LDAP_EMAIL_DOMAIN", None)
if '@' not in self.ldap_principal:
self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain)
self.ldap_username = args['username']
if '@' in self.ldap_username:
self.ldap_username = args['username'].split("@")[0]
self.ldap_password = args['password']
self.ldap_server = current_app.config.get('LDAP_BIND_URI', None)
if "@" not in self.ldap_principal:
self.ldap_principal = "%s@%s" % (
self.ldap_principal,
self.ldap_email_domain,
)
self.ldap_username = args["username"]
if "@" in self.ldap_username:
self.ldap_username = args["username"].split("@")[0]
self.ldap_password = args["password"]
self.ldap_server = current_app.config.get("LDAP_BIND_URI", None)
self.ldap_base_dn = current_app.config.get("LDAP_BASE_DN", None)
self.ldap_use_tls = current_app.config.get("LDAP_USE_TLS", False)
self.ldap_cacert_file = current_app.config.get("LDAP_CACERT_FILE", None)
self.ldap_default_role = current_app.config.get("LEMUR_DEFAULT_ROLE", None)
self.ldap_required_group = current_app.config.get("LDAP_REQUIRED_GROUP", None)
self.ldap_groups_to_roles = current_app.config.get("LDAP_GROUPS_TO_ROLES", None)
self.ldap_is_active_directory = current_app.config.get("LDAP_IS_ACTIVE_DIRECTORY", False)
self.ldap_attrs = ['memberOf']
self.ldap_is_active_directory = current_app.config.get(
"LDAP_IS_ACTIVE_DIRECTORY", False
)
self.ldap_attrs = ["memberOf"]
self.ldap_client = None
self.ldap_groups = None
@ -60,8 +66,8 @@ class LdapPrincipal():
get_psuedo_random_string(),
self.ldap_principal,
True,
'', # thumbnailPhotoUrl
list(roles)
"", # thumbnailPhotoUrl
list(roles),
)
else:
# we add 'lemur' specific roles, so they do not get marked as removed
@ -76,7 +82,7 @@ class LdapPrincipal():
self.ldap_principal,
user.active,
user.profile_picture,
list(roles)
list(roles),
)
return user
@ -105,9 +111,12 @@ class LdapPrincipal():
# update their 'roles'
role = role_service.get_by_name(self.ldap_principal)
if not role:
description = "auto generated role based on owner: {0}".format(self.ldap_principal)
role = role_service.create(self.ldap_principal, description=description,
third_party=True)
description = "auto generated role based on owner: {0}".format(
self.ldap_principal
)
role = role_service.create(
self.ldap_principal, description=description, third_party=True
)
if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True)
roles.add(role)
@ -118,9 +127,15 @@ class LdapPrincipal():
role = role_service.get_by_name(role_name)
if role:
if ldap_group_name in self.ldap_groups:
current_app.logger.debug("assigning role {0} to ldap user {1}".format(self.ldap_principal, role))
current_app.logger.debug(
"assigning role {0} to ldap user {1}".format(
self.ldap_principal, role
)
)
if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True)
role = role_service.set_third_party(
role.id, third_party_status=True
)
roles.add(role)
return roles
@ -132,7 +147,7 @@ class LdapPrincipal():
self._bind()
roles = self._authorize()
if not roles:
raise Exception('ldap authorization failed')
raise Exception("ldap authorization failed")
return self._update_user(roles)
def _bind(self):
@ -141,9 +156,12 @@ class LdapPrincipal():
list groups for a user.
raise an exception on error.
"""
if '@' not in self.ldap_principal:
self.ldap_principal = '%s@%s' % (self.ldap_principal, self.ldap_email_domain)
ldap_filter = 'userPrincipalName=%s' % self.ldap_principal
if "@" not in self.ldap_principal:
self.ldap_principal = "%s@%s" % (
self.ldap_principal,
self.ldap_email_domain,
)
ldap_filter = "userPrincipalName=%s" % self.ldap_principal
# query ldap for auth
try:
@ -159,37 +177,47 @@ class LdapPrincipal():
self.ldap_client.set_option(ldap.OPT_X_TLS_DEMAND, True)
self.ldap_client.set_option(ldap.OPT_DEBUG_LEVEL, 255)
if self.ldap_cacert_file:
self.ldap_client.set_option(ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file)
self.ldap_client.set_option(
ldap.OPT_X_TLS_CACERTFILE, self.ldap_cacert_file
)
self.ldap_client.simple_bind_s(self.ldap_principal, self.ldap_password)
except ldap.INVALID_CREDENTIALS:
self.ldap_client.unbind()
raise Exception('The supplied ldap credentials are invalid')
raise Exception("The supplied ldap credentials are invalid")
except ldap.SERVER_DOWN:
raise Exception('ldap server unavailable')
raise Exception("ldap server unavailable")
except ldap.LDAPError as e:
raise Exception("ldap error: {0}".format(e))
if self.ldap_is_active_directory:
# Lookup user DN, needed to search for group membership
userdn = self.ldap_client.search_s(self.ldap_base_dn,
ldap.SCOPE_SUBTREE, ldap_filter,
['distinguishedName'])[0][1]['distinguishedName'][0]
userdn = userdn.decode('utf-8')
userdn = self.ldap_client.search_s(
self.ldap_base_dn,
ldap.SCOPE_SUBTREE,
ldap_filter,
["distinguishedName"],
)[0][1]["distinguishedName"][0]
userdn = userdn.decode("utf-8")
# Search all groups that have the userDN as a member
groupfilter = '(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))'.format(userdn)
lgroups = self.ldap_client.search_s(self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ['cn'])
groupfilter = "(&(objectclass=group)(member:1.2.840.113556.1.4.1941:={0}))".format(
userdn
)
lgroups = self.ldap_client.search_s(
self.ldap_base_dn, ldap.SCOPE_SUBTREE, groupfilter, ["cn"]
)
# Create a list of group CN's from the result
self.ldap_groups = []
for group in lgroups:
(dn, values) = group
self.ldap_groups.append(values['cn'][0].decode('ascii'))
self.ldap_groups.append(values["cn"][0].decode("ascii"))
else:
lgroups = self.ldap_client.search_s(self.ldap_base_dn,
ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs)[0][1]['memberOf']
lgroups = self.ldap_client.search_s(
self.ldap_base_dn, ldap.SCOPE_SUBTREE, ldap_filter, self.ldap_attrs
)[0][1]["memberOf"]
# lgroups is a list of utf-8 encoded strings
# convert to a single string of groups to allow matching
self.ldap_groups = b''.join(lgroups).decode('ascii')
self.ldap_groups = b"".join(lgroups).decode("ascii")
self.ldap_client.unbind()
@ -197,9 +225,5 @@ class LdapPrincipal():
"""
Confirms required ldap config settings exist.
"""
required_vars = [
'LDAP_BIND_URI',
'LDAP_BASE_DN',
'LDAP_EMAIL_DOMAIN',
]
required_vars = ["LDAP_BIND_URI", "LDAP_BASE_DN", "LDAP_EMAIL_DOMAIN"]
validate_conf(current_app, required_vars)

View File

@ -12,21 +12,21 @@ from collections import namedtuple
from flask_principal import Permission, RoleNeed
# Permissions
operator_permission = Permission(RoleNeed('operator'))
admin_permission = Permission(RoleNeed('admin'))
operator_permission = Permission(RoleNeed("operator"))
admin_permission = Permission(RoleNeed("admin"))
CertificateOwner = namedtuple('certificate', ['method', 'value'])
CertificateOwnerNeed = partial(CertificateOwner, 'role')
CertificateOwner = namedtuple("certificate", ["method", "value"])
CertificateOwnerNeed = partial(CertificateOwner, "role")
class SensitiveDomainPermission(Permission):
def __init__(self):
super(SensitiveDomainPermission, self).__init__(RoleNeed('admin'))
super(SensitiveDomainPermission, self).__init__(RoleNeed("admin"))
class CertificatePermission(Permission):
def __init__(self, owner, roles):
needs = [RoleNeed('admin'), RoleNeed(owner), RoleNeed('creator')]
needs = [RoleNeed("admin"), RoleNeed(owner), RoleNeed("creator")]
for r in roles:
needs.append(CertificateOwnerNeed(str(r)))
# Backwards compatibility with mixed-case role names
@ -38,29 +38,29 @@ class CertificatePermission(Permission):
class ApiKeyCreatorPermission(Permission):
def __init__(self):
super(ApiKeyCreatorPermission, self).__init__(RoleNeed('admin'))
super(ApiKeyCreatorPermission, self).__init__(RoleNeed("admin"))
RoleMember = namedtuple('role', ['method', 'value'])
RoleMemberNeed = partial(RoleMember, 'member')
RoleMember = namedtuple("role", ["method", "value"])
RoleMemberNeed = partial(RoleMember, "member")
class RoleMemberPermission(Permission):
def __init__(self, role_id):
needs = [RoleNeed('admin'), RoleMemberNeed(role_id)]
needs = [RoleNeed("admin"), RoleMemberNeed(role_id)]
super(RoleMemberPermission, self).__init__(*needs)
AuthorityCreator = namedtuple('authority', ['method', 'value'])
AuthorityCreatorNeed = partial(AuthorityCreator, 'authorityUse')
AuthorityCreator = namedtuple("authority", ["method", "value"])
AuthorityCreatorNeed = partial(AuthorityCreator, "authorityUse")
AuthorityOwner = namedtuple('authority', ['method', 'value'])
AuthorityOwnerNeed = partial(AuthorityOwner, 'role')
AuthorityOwner = namedtuple("authority", ["method", "value"])
AuthorityOwnerNeed = partial(AuthorityOwner, "role")
class AuthorityPermission(Permission):
def __init__(self, authority_id, roles):
needs = [RoleNeed('admin'), AuthorityCreatorNeed(str(authority_id))]
needs = [RoleNeed("admin"), AuthorityCreatorNeed(str(authority_id))]
for r in roles:
needs.append(AuthorityOwnerNeed(str(r)))

View File

@ -39,13 +39,13 @@ def get_rsa_public_key(n, e):
:param e:
:return: a RSA Public Key in PEM format
"""
n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, 'utf-8'))), 16)
e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, 'utf-8'))), 16)
n = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(n, "utf-8"))), 16)
e = int(binascii.hexlify(jwt.utils.base64url_decode(bytes(e, "utf-8"))), 16)
pub = RSAPublicNumbers(e, n).public_key(default_backend())
return pub.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
@ -57,28 +57,27 @@ def create_token(user, aid=None, ttl=None):
:param user:
:return:
"""
expiration_delta = timedelta(days=int(current_app.config.get('LEMUR_TOKEN_EXPIRATION', 1)))
payload = {
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + expiration_delta
}
expiration_delta = timedelta(
days=int(current_app.config.get("LEMUR_TOKEN_EXPIRATION", 1))
)
payload = {"iat": datetime.utcnow(), "exp": datetime.utcnow() + expiration_delta}
# Handle Just a User ID & User Object.
if isinstance(user, int):
payload['sub'] = user
payload["sub"] = user
else:
payload['sub'] = user.id
payload["sub"] = user.id
if aid is not None:
payload['aid'] = aid
payload["aid"] = aid
# Custom TTLs are only supported on Access Keys.
if ttl is not None and aid is not None:
# Tokens that are forever until revoked.
if ttl == -1:
del payload['exp']
del payload["exp"]
else:
payload['exp'] = ttl
token = jwt.encode(payload, current_app.config['LEMUR_TOKEN_SECRET'])
return token.decode('unicode_escape')
payload["exp"] = ttl
token = jwt.encode(payload, current_app.config["LEMUR_TOKEN_SECRET"])
return token.decode("unicode_escape")
def login_required(f):
@ -88,49 +87,54 @@ def login_required(f):
:param f:
:return:
"""
@wraps(f)
def decorated_function(*args, **kwargs):
if not request.headers.get('Authorization'):
response = jsonify(message='Missing authorization header')
if not request.headers.get("Authorization"):
response = jsonify(message="Missing authorization header")
response.status_code = 401
return response
try:
token = request.headers.get('Authorization').split()[1]
token = request.headers.get("Authorization").split()[1]
except Exception as e:
return dict(message='Token is invalid'), 403
return dict(message="Token is invalid"), 403
try:
payload = jwt.decode(token, current_app.config['LEMUR_TOKEN_SECRET'])
payload = jwt.decode(token, current_app.config["LEMUR_TOKEN_SECRET"])
except jwt.DecodeError:
return dict(message='Token is invalid'), 403
return dict(message="Token is invalid"), 403
except jwt.ExpiredSignatureError:
return dict(message='Token has expired'), 403
return dict(message="Token has expired"), 403
except jwt.InvalidTokenError:
return dict(message='Token is invalid'), 403
return dict(message="Token is invalid"), 403
if 'aid' in payload:
access_key = api_key_service.get(payload['aid'])
if "aid" in payload:
access_key = api_key_service.get(payload["aid"])
if access_key.revoked:
return dict(message='Token has been revoked'), 403
return dict(message="Token has been revoked"), 403
if access_key.ttl != -1:
current_time = datetime.utcnow()
expired_time = datetime.fromtimestamp(access_key.issued_at + access_key.ttl)
expired_time = datetime.fromtimestamp(
access_key.issued_at + access_key.ttl
)
if current_time >= expired_time:
return dict(message='Token has expired'), 403
return dict(message="Token has expired"), 403
user = user_service.get(payload['sub'])
user = user_service.get(payload["sub"])
if not user.active:
return dict(message='User is not currently active'), 403
return dict(message="User is not currently active"), 403
g.current_user = user
if not g.current_user:
return dict(message='You are not logged in'), 403
return dict(message="You are not logged in"), 403
# Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity=Identity(g.current_user.id))
identity_changed.send(
current_app._get_current_object(), identity=Identity(g.current_user.id)
)
return f(*args, **kwargs)
@ -144,18 +148,18 @@ def fetch_token_header(token):
:param token:
:return: :raise jwt.DecodeError:
"""
token = token.encode('utf-8')
token = token.encode("utf-8")
try:
signing_input, crypto_segment = token.rsplit(b'.', 1)
header_segment, payload_segment = signing_input.split(b'.', 1)
signing_input, crypto_segment = token.rsplit(b".", 1)
header_segment, payload_segment = signing_input.split(b".", 1)
except ValueError:
raise jwt.DecodeError('Not enough segments')
raise jwt.DecodeError("Not enough segments")
try:
return json.loads(jwt.utils.base64url_decode(header_segment).decode('utf-8'))
return json.loads(jwt.utils.base64url_decode(header_segment).decode("utf-8"))
except TypeError as e:
current_app.logger.exception(e)
raise jwt.DecodeError('Invalid header padding')
raise jwt.DecodeError("Invalid header padding")
@identity_loaded.connect
@ -174,13 +178,13 @@ def on_identity_loaded(sender, identity):
identity.provides.add(UserNeed(identity.id))
# identity with the roles that the user provides
if hasattr(user, 'roles'):
if hasattr(user, "roles"):
for role in user.roles:
identity.provides.add(RoleNeed(role.name))
identity.provides.add(RoleMemberNeed(role.id))
# apply ownership for authorities
if hasattr(user, 'authorities'):
if hasattr(user, "authorities"):
for authority in user.authorities:
identity.provides.add(AuthorityCreatorNeed(authority.id))
@ -191,6 +195,7 @@ class AuthenticatedResource(Resource):
"""
Inherited by all resources that need to be protected by authentication.
"""
method_decorators = [login_required]
def __init__(self):

View File

@ -24,11 +24,13 @@ from lemur.auth.service import create_token, fetch_token_header, get_rsa_public_
from lemur.auth import ldap
mod = Blueprint('auth', __name__)
mod = Blueprint("auth", __name__)
api = Api(mod)
def exchange_for_access_token(code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True):
def exchange_for_access_token(
code, redirect_uri, client_id, secret, access_token_url=None, verify_cert=True
):
"""
Exchanges authorization code for access token.
@ -43,28 +45,32 @@ def exchange_for_access_token(code, redirect_uri, client_id, secret, access_toke
"""
# take the information we have received from the provider to create a new request
params = {
'grant_type': 'authorization_code',
'scope': 'openid email profile address',
'code': code,
'redirect_uri': redirect_uri,
'client_id': client_id
"grant_type": "authorization_code",
"scope": "openid email profile address",
"code": code,
"redirect_uri": redirect_uri,
"client_id": client_id,
}
# the secret and cliendId will be given to you when you signup for the provider
token = '{0}:{1}'.format(client_id, secret)
token = "{0}:{1}".format(client_id, secret)
basic = base64.b64encode(bytes(token, 'utf-8'))
basic = base64.b64encode(bytes(token, "utf-8"))
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'authorization': 'basic {0}'.format(basic.decode('utf-8'))
"Content-Type": "application/x-www-form-urlencoded",
"authorization": "basic {0}".format(basic.decode("utf-8")),
}
# exchange authorization code for access token.
r = requests.post(access_token_url, headers=headers, params=params, verify=verify_cert)
r = requests.post(
access_token_url, headers=headers, params=params, verify=verify_cert
)
if r.status_code == 400:
r = requests.post(access_token_url, headers=headers, data=params, verify=verify_cert)
id_token = r.json()['id_token']
access_token = r.json()['access_token']
r = requests.post(
access_token_url, headers=headers, data=params, verify=verify_cert
)
id_token = r.json()["id_token"]
access_token = r.json()["access_token"]
return id_token, access_token
@ -83,23 +89,25 @@ def validate_id_token(id_token, client_id, jwks_url):
# retrieve the key material as specified by the token header
r = requests.get(jwks_url)
for key in r.json()['keys']:
if key['kid'] == header_data['kid']:
secret = get_rsa_public_key(key['n'], key['e'])
algo = header_data['alg']
for key in r.json()["keys"]:
if key["kid"] == header_data["kid"]:
secret = get_rsa_public_key(key["n"], key["e"])
algo = header_data["alg"]
break
else:
return dict(message='Key not found'), 401
return dict(message="Key not found"), 401
# validate your token based on the key it was signed with
try:
jwt.decode(id_token, secret.decode('utf-8'), algorithms=[algo], audience=client_id)
jwt.decode(
id_token, secret.decode("utf-8"), algorithms=[algo], audience=client_id
)
except jwt.DecodeError:
return dict(message='Token is invalid'), 401
return dict(message="Token is invalid"), 401
except jwt.ExpiredSignatureError:
return dict(message='Token has expired'), 401
return dict(message="Token has expired"), 401
except jwt.InvalidTokenError:
return dict(message='Token is invalid'), 401
return dict(message="Token is invalid"), 401
def retrieve_user(user_api_url, access_token):
@ -110,22 +118,18 @@ def retrieve_user(user_api_url, access_token):
:param access_token:
:return:
"""
user_params = dict(access_token=access_token, schema='profile')
user_params = dict(access_token=access_token, schema="profile")
headers = {}
if current_app.config.get('PING_INCLUDE_BEARER_TOKEN'):
headers = {'Authorization': f'Bearer {access_token}'}
if current_app.config.get("PING_INCLUDE_BEARER_TOKEN"):
headers = {"Authorization": f"Bearer {access_token}"}
# retrieve information about the current user.
r = requests.get(
user_api_url,
params=user_params,
headers=headers,
)
r = requests.get(user_api_url, params=user_params, headers=headers)
profile = r.json()
user = user_service.get_by_email(profile['email'])
user = user_service.get_by_email(profile["email"])
return user, profile
@ -138,31 +142,44 @@ def create_user_roles(profile):
roles = []
# update their google 'roles'
if 'googleGroups' in profile:
for group in profile['googleGroups']:
if "googleGroups" in profile:
for group in profile["googleGroups"]:
role = role_service.get_by_name(group)
if not role:
role = role_service.create(group, description='This is a google group based role created by Lemur', third_party=True)
role = role_service.create(
group,
description="This is a google group based role created by Lemur",
third_party=True,
)
if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True)
roles.append(role)
else:
current_app.logger.warning("'googleGroups' not sent by identity provider, no specific roles will assigned to the user.")
current_app.logger.warning(
"'googleGroups' not sent by identity provider, no specific roles will assigned to the user."
)
role = role_service.get_by_name(profile['email'])
role = role_service.get_by_name(profile["email"])
if not role:
role = role_service.create(profile['email'], description='This is a user specific role', third_party=True)
role = role_service.create(
profile["email"],
description="This is a user specific role",
third_party=True,
)
if not role.third_party:
role = role_service.set_third_party(role.id, third_party_status=True)
roles.append(role)
# every user is an operator (tied to a default role)
if current_app.config.get('LEMUR_DEFAULT_ROLE'):
default = role_service.get_by_name(current_app.config['LEMUR_DEFAULT_ROLE'])
if current_app.config.get("LEMUR_DEFAULT_ROLE"):
default = role_service.get_by_name(current_app.config["LEMUR_DEFAULT_ROLE"])
if not default:
default = role_service.create(current_app.config['LEMUR_DEFAULT_ROLE'], description='This is the default Lemur role.')
default = role_service.create(
current_app.config["LEMUR_DEFAULT_ROLE"],
description="This is the default Lemur role.",
)
if not default.third_party:
role_service.set_third_party(default.id, third_party_status=True)
roles.append(default)
@ -181,12 +198,12 @@ def update_user(user, profile, roles):
# if we get an sso user create them an account
if not user:
user = user_service.create(
profile['email'],
profile["email"],
get_psuedo_random_string(),
profile['email'],
profile["email"],
True,
profile.get('thumbnailPhotoUrl'),
roles
profile.get("thumbnailPhotoUrl"),
roles,
)
else:
@ -198,11 +215,11 @@ def update_user(user, profile, roles):
# update any changes to the user
user_service.update(
user.id,
profile['email'],
profile['email'],
profile["email"],
profile["email"],
True,
profile.get('thumbnailPhotoUrl'), # profile isn't google+ enabled
roles
profile.get("thumbnailPhotoUrl"), # profile isn't google+ enabled
roles,
)
@ -223,6 +240,7 @@ class Login(Resource):
on your uses cases but. It is important to not that there is currently no build in method to revoke a users token \
and force re-authentication.
"""
def __init__(self):
self.reqparse = reqparse.RequestParser()
super(Login, self).__init__()
@ -263,23 +281,26 @@ class Login(Resource):
:statuscode 401: invalid credentials
:statuscode 200: no error
"""
self.reqparse.add_argument('username', type=str, required=True, location='json')
self.reqparse.add_argument('password', type=str, required=True, location='json')
self.reqparse.add_argument("username", type=str, required=True, location="json")
self.reqparse.add_argument("password", type=str, required=True, location="json")
args = self.reqparse.parse_args()
if '@' in args['username']:
user = user_service.get_by_email(args['username'])
if "@" in args["username"]:
user = user_service.get_by_email(args["username"])
else:
user = user_service.get_by_username(args['username'])
user = user_service.get_by_username(args["username"])
# default to local authentication
if user and user.check_password(args['password']) and user.active:
if user and user.check_password(args["password"]) and user.active:
# Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(),
identity=Identity(user.id))
identity_changed.send(
current_app._get_current_object(), identity=Identity(user.id)
)
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS})
metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user))
# try ldap login
@ -289,19 +310,29 @@ class Login(Resource):
user = ldap_principal.authenticate()
if user and user.active:
# Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(),
identity=Identity(user.id))
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS})
identity_changed.send(
current_app._get_current_object(), identity=Identity(user.id)
)
metrics.send(
"login",
"counter",
1,
metric_tags={"status": SUCCESS_METRIC_STATUS},
)
return dict(token=create_token(user))
except Exception as e:
current_app.logger.error("ldap error: {0}".format(e))
ldap_message = 'ldap error: %s' % e
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS})
ldap_message = "ldap error: %s" % e
metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message=ldap_message), 403
# if not valid user - no certificates for you
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS})
return dict(message='The supplied credentials are invalid'), 403
metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid"), 403
class Ping(Resource):
@ -314,36 +345,39 @@ class Ping(Resource):
provider uses for its callbacks.
2. Add or change the Lemur AngularJS Configuration to point to your new provider
"""
def __init__(self):
self.reqparse = reqparse.RequestParser()
super(Ping, self).__init__()
def get(self):
return 'Redirecting...'
return "Redirecting..."
def post(self):
self.reqparse.add_argument('clientId', type=str, required=True, location='json')
self.reqparse.add_argument('redirectUri', type=str, required=True, location='json')
self.reqparse.add_argument('code', type=str, required=True, location='json')
self.reqparse.add_argument("clientId", type=str, required=True, location="json")
self.reqparse.add_argument(
"redirectUri", type=str, required=True, location="json"
)
self.reqparse.add_argument("code", type=str, required=True, location="json")
args = self.reqparse.parse_args()
# you can either discover these dynamically or simply configure them
access_token_url = current_app.config.get('PING_ACCESS_TOKEN_URL')
user_api_url = current_app.config.get('PING_USER_API_URL')
access_token_url = current_app.config.get("PING_ACCESS_TOKEN_URL")
user_api_url = current_app.config.get("PING_USER_API_URL")
secret = current_app.config.get('PING_SECRET')
secret = current_app.config.get("PING_SECRET")
id_token, access_token = exchange_for_access_token(
args['code'],
args['redirectUri'],
args['clientId'],
args["code"],
args["redirectUri"],
args["clientId"],
secret,
access_token_url=access_token_url
access_token_url=access_token_url,
)
jwks_url = current_app.config.get('PING_JWKS_URL')
error_code = validate_id_token(id_token, args['clientId'], jwks_url)
jwks_url = current_app.config.get("PING_JWKS_URL")
error_code = validate_id_token(id_token, args["clientId"], jwks_url)
if error_code:
return error_code
user, profile = retrieve_user(user_api_url, access_token)
@ -351,13 +385,19 @@ class Ping(Resource):
update_user(user, profile, roles)
if not user or not user.active:
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS})
return dict(message='The supplied credentials are invalid'), 403
metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid"), 403
# Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity=Identity(user.id))
identity_changed.send(
current_app._get_current_object(), identity=Identity(user.id)
)
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS})
metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user))
@ -367,33 +407,35 @@ class OAuth2(Resource):
super(OAuth2, self).__init__()
def get(self):
return 'Redirecting...'
return "Redirecting..."
def post(self):
self.reqparse.add_argument('clientId', type=str, required=True, location='json')
self.reqparse.add_argument('redirectUri', type=str, required=True, location='json')
self.reqparse.add_argument('code', type=str, required=True, location='json')
self.reqparse.add_argument("clientId", type=str, required=True, location="json")
self.reqparse.add_argument(
"redirectUri", type=str, required=True, location="json"
)
self.reqparse.add_argument("code", type=str, required=True, location="json")
args = self.reqparse.parse_args()
# you can either discover these dynamically or simply configure them
access_token_url = current_app.config.get('OAUTH2_ACCESS_TOKEN_URL')
user_api_url = current_app.config.get('OAUTH2_USER_API_URL')
verify_cert = current_app.config.get('OAUTH2_VERIFY_CERT')
access_token_url = current_app.config.get("OAUTH2_ACCESS_TOKEN_URL")
user_api_url = current_app.config.get("OAUTH2_USER_API_URL")
verify_cert = current_app.config.get("OAUTH2_VERIFY_CERT")
secret = current_app.config.get('OAUTH2_SECRET')
secret = current_app.config.get("OAUTH2_SECRET")
id_token, access_token = exchange_for_access_token(
args['code'],
args['redirectUri'],
args['clientId'],
args["code"],
args["redirectUri"],
args["clientId"],
secret,
access_token_url=access_token_url,
verify_cert=verify_cert
verify_cert=verify_cert,
)
jwks_url = current_app.config.get('PING_JWKS_URL')
error_code = validate_id_token(id_token, args['clientId'], jwks_url)
jwks_url = current_app.config.get("PING_JWKS_URL")
error_code = validate_id_token(id_token, args["clientId"], jwks_url)
if error_code:
return error_code
@ -402,13 +444,19 @@ class OAuth2(Resource):
update_user(user, profile, roles)
if not user.active:
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS})
return dict(message='The supplied credentials are invalid'), 403
metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid"), 403
# Tell Flask-Principal the identity changed
identity_changed.send(current_app._get_current_object(), identity=Identity(user.id))
identity_changed.send(
current_app._get_current_object(), identity=Identity(user.id)
)
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS})
metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user))
@ -419,44 +467,52 @@ class Google(Resource):
super(Google, self).__init__()
def post(self):
access_token_url = 'https://accounts.google.com/o/oauth2/token'
people_api_url = 'https://www.googleapis.com/plus/v1/people/me/openIdConnect'
access_token_url = "https://accounts.google.com/o/oauth2/token"
people_api_url = "https://www.googleapis.com/plus/v1/people/me/openIdConnect"
self.reqparse.add_argument('clientId', type=str, required=True, location='json')
self.reqparse.add_argument('redirectUri', type=str, required=True, location='json')
self.reqparse.add_argument('code', type=str, required=True, location='json')
self.reqparse.add_argument("clientId", type=str, required=True, location="json")
self.reqparse.add_argument(
"redirectUri", type=str, required=True, location="json"
)
self.reqparse.add_argument("code", type=str, required=True, location="json")
args = self.reqparse.parse_args()
# Step 1. Exchange authorization code for access token
payload = {
'client_id': args['clientId'],
'grant_type': 'authorization_code',
'redirect_uri': args['redirectUri'],
'code': args['code'],
'client_secret': current_app.config.get('GOOGLE_SECRET')
"client_id": args["clientId"],
"grant_type": "authorization_code",
"redirect_uri": args["redirectUri"],
"code": args["code"],
"client_secret": current_app.config.get("GOOGLE_SECRET"),
}
r = requests.post(access_token_url, data=payload)
token = r.json()
# Step 2. Retrieve information about the current user
headers = {'Authorization': 'Bearer {0}'.format(token['access_token'])}
headers = {"Authorization": "Bearer {0}".format(token["access_token"])}
r = requests.get(people_api_url, headers=headers)
profile = r.json()
user = user_service.get_by_email(profile['email'])
user = user_service.get_by_email(profile["email"])
if not (user and user.active):
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS})
return dict(message='The supplied credentials are invalid.'), 403
metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
return dict(message="The supplied credentials are invalid."), 403
if user:
metrics.send('login', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS})
metrics.send(
"login", "counter", 1, metric_tags={"status": SUCCESS_METRIC_STATUS}
)
return dict(token=create_token(user))
metrics.send('login', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS})
metrics.send(
"login", "counter", 1, metric_tags={"status": FAILURE_METRIC_STATUS}
)
class Providers(Resource):
@ -467,47 +523,57 @@ class Providers(Resource):
provider = provider.lower()
if provider == "google":
active_providers.append({
'name': 'google',
'clientId': current_app.config.get("GOOGLE_CLIENT_ID"),
'url': api.url_for(Google)
})
active_providers.append(
{
"name": "google",
"clientId": current_app.config.get("GOOGLE_CLIENT_ID"),
"url": api.url_for(Google),
}
)
elif provider == "ping":
active_providers.append({
'name': current_app.config.get("PING_NAME"),
'url': current_app.config.get('PING_REDIRECT_URI'),
'redirectUri': current_app.config.get("PING_REDIRECT_URI"),
'clientId': current_app.config.get("PING_CLIENT_ID"),
'responseType': 'code',
'scope': ['openid', 'email', 'profile', 'address'],
'scopeDelimiter': ' ',
'authorizationEndpoint': current_app.config.get("PING_AUTH_ENDPOINT"),
'requiredUrlParams': ['scope'],
'type': '2.0'
})
active_providers.append(
{
"name": current_app.config.get("PING_NAME"),
"url": current_app.config.get("PING_REDIRECT_URI"),
"redirectUri": current_app.config.get("PING_REDIRECT_URI"),
"clientId": current_app.config.get("PING_CLIENT_ID"),
"responseType": "code",
"scope": ["openid", "email", "profile", "address"],
"scopeDelimiter": " ",
"authorizationEndpoint": current_app.config.get(
"PING_AUTH_ENDPOINT"
),
"requiredUrlParams": ["scope"],
"type": "2.0",
}
)
elif provider == "oauth2":
active_providers.append({
'name': current_app.config.get("OAUTH2_NAME"),
'url': current_app.config.get('OAUTH2_REDIRECT_URI'),
'redirectUri': current_app.config.get("OAUTH2_REDIRECT_URI"),
'clientId': current_app.config.get("OAUTH2_CLIENT_ID"),
'responseType': 'code',
'scope': ['openid', 'email', 'profile', 'groups'],
'scopeDelimiter': ' ',
'authorizationEndpoint': current_app.config.get("OAUTH2_AUTH_ENDPOINT"),
'requiredUrlParams': ['scope', 'state', 'nonce'],
'state': 'STATE',
'nonce': get_psuedo_random_string(),
'type': '2.0'
})
active_providers.append(
{
"name": current_app.config.get("OAUTH2_NAME"),
"url": current_app.config.get("OAUTH2_REDIRECT_URI"),
"redirectUri": current_app.config.get("OAUTH2_REDIRECT_URI"),
"clientId": current_app.config.get("OAUTH2_CLIENT_ID"),
"responseType": "code",
"scope": ["openid", "email", "profile", "groups"],
"scopeDelimiter": " ",
"authorizationEndpoint": current_app.config.get(
"OAUTH2_AUTH_ENDPOINT"
),
"requiredUrlParams": ["scope", "state", "nonce"],
"state": "STATE",
"nonce": get_psuedo_random_string(),
"type": "2.0",
}
)
return active_providers
api.add_resource(Login, '/auth/login', endpoint='login')
api.add_resource(Ping, '/auth/ping', endpoint='ping')
api.add_resource(Google, '/auth/google', endpoint='google')
api.add_resource(OAuth2, '/auth/oauth2', endpoint='oauth2')
api.add_resource(Providers, '/auth/providers', endpoint='providers')
api.add_resource(Login, "/auth/login", endpoint="login")
api.add_resource(Ping, "/auth/ping", endpoint="ping")
api.add_resource(Google, "/auth/google", endpoint="google")
api.add_resource(OAuth2, "/auth/oauth2", endpoint="oauth2")
api.add_resource(Providers, "/auth/providers", endpoint="providers")

View File

@ -7,7 +7,17 @@
.. moduleauthor:: Kevin Glisson <kglisson@netflix.com>
"""
from sqlalchemy.orm import relationship
from sqlalchemy import Column, Integer, String, Text, func, ForeignKey, DateTime, PassiveDefault, Boolean
from sqlalchemy import (
Column,
Integer,
String,
Text,
func,
ForeignKey,
DateTime,
PassiveDefault,
Boolean,
)
from sqlalchemy.dialects.postgresql import JSON
from lemur.database import db
@ -16,7 +26,7 @@ from lemur.models import roles_authorities
class Authority(db.Model):
__tablename__ = 'authorities'
__tablename__ = "authorities"
id = Column(Integer, primary_key=True)
owner = Column(String(128), nullable=False)
name = Column(String(128), unique=True)
@ -27,22 +37,44 @@ class Authority(db.Model):
description = Column(Text)
options = Column(JSON)
date_created = Column(DateTime, PassiveDefault(func.now()), nullable=False)
roles = relationship('Role', secondary=roles_authorities, passive_deletes=True, backref=db.backref('authority'), lazy='dynamic')
user_id = Column(Integer, ForeignKey('users.id'))
authority_certificate = relationship("Certificate", backref='root_authority', uselist=False, foreign_keys='Certificate.root_authority_id')
certificates = relationship("Certificate", backref='authority', foreign_keys='Certificate.authority_id')
roles = relationship(
"Role",
secondary=roles_authorities,
passive_deletes=True,
backref=db.backref("authority"),
lazy="dynamic",
)
user_id = Column(Integer, ForeignKey("users.id"))
authority_certificate = relationship(
"Certificate",
backref="root_authority",
uselist=False,
foreign_keys="Certificate.root_authority_id",
)
certificates = relationship(
"Certificate", backref="authority", foreign_keys="Certificate.authority_id"
)
authority_pending_certificate = relationship("PendingCertificate", backref='root_authority', uselist=False, foreign_keys='PendingCertificate.root_authority_id')
pending_certificates = relationship('PendingCertificate', backref='authority', foreign_keys='PendingCertificate.authority_id')
authority_pending_certificate = relationship(
"PendingCertificate",
backref="root_authority",
uselist=False,
foreign_keys="PendingCertificate.root_authority_id",
)
pending_certificates = relationship(
"PendingCertificate",
backref="authority",
foreign_keys="PendingCertificate.authority_id",
)
def __init__(self, **kwargs):
self.owner = kwargs['owner']
self.roles = kwargs.get('roles', [])
self.name = kwargs.get('name')
self.description = kwargs.get('description')
self.authority_certificate = kwargs['authority_certificate']
self.plugin_name = kwargs['plugin']['slug']
self.options = kwargs.get('options')
self.owner = kwargs["owner"]
self.roles = kwargs.get("roles", [])
self.name = kwargs.get("name")
self.description = kwargs.get("description")
self.authority_certificate = kwargs["authority_certificate"]
self.plugin_name = kwargs["plugin"]["slug"]
self.options = kwargs.get("options")
@property
def plugin(self):

View File

@ -11,7 +11,13 @@ from marshmallow import fields, validates_schema, pre_load
from marshmallow import validate
from marshmallow.exceptions import ValidationError
from lemur.schemas import PluginInputSchema, PluginOutputSchema, ExtensionSchema, AssociatedAuthoritySchema, AssociatedRoleSchema
from lemur.schemas import (
PluginInputSchema,
PluginOutputSchema,
ExtensionSchema,
AssociatedAuthoritySchema,
AssociatedRoleSchema,
)
from lemur.users.schemas import UserNestedOutputSchema
from lemur.common.schema import LemurInputSchema, LemurOutputSchema
from lemur.common import validators, missing
@ -30,21 +36,36 @@ class AuthorityInputSchema(LemurInputSchema):
validity_years = fields.Integer()
# certificate body fields
organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT'))
organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION'))
location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION'))
country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY'))
state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE'))
organizational_unit = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT")
)
organization = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION")
)
location = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION")
)
country = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY")
)
state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE"))
plugin = fields.Nested(PluginInputSchema)
# signing related options
type = fields.String(validate=validate.OneOf(['root', 'subca']), missing='root')
type = fields.String(validate=validate.OneOf(["root", "subca"]), missing="root")
parent = fields.Nested(AssociatedAuthoritySchema)
signing_algorithm = fields.String(validate=validate.OneOf(['sha256WithRSA', 'sha1WithRSA']), missing='sha256WithRSA')
key_type = fields.String(validate=validate.OneOf(['RSA2048', 'RSA4096']), missing='RSA2048')
signing_algorithm = fields.String(
validate=validate.OneOf(["sha256WithRSA", "sha1WithRSA"]),
missing="sha256WithRSA",
)
key_type = fields.String(
validate=validate.OneOf(["RSA2048", "RSA4096"]), missing="RSA2048"
)
key_name = fields.String()
sensitivity = fields.String(validate=validate.OneOf(['medium', 'high']), missing='medium')
sensitivity = fields.String(
validate=validate.OneOf(["medium", "high"]), missing="medium"
)
serial_number = fields.Integer()
first_serial = fields.Integer(missing=1)
@ -58,9 +79,11 @@ class AuthorityInputSchema(LemurInputSchema):
@validates_schema
def validate_subca(self, data):
if data['type'] == 'subca':
if not data.get('parent'):
raise ValidationError("If generating a subca, parent 'authority' must be specified.")
if data["type"] == "subca":
if not data.get("parent"):
raise ValidationError(
"If generating a subca, parent 'authority' must be specified."
)
@pre_load
def ensure_dates(self, data):

View File

@ -43,7 +43,7 @@ def mint(**kwargs):
"""
Creates the authority based on the plugin provided.
"""
issuer = kwargs['plugin']['plugin_object']
issuer = kwargs["plugin"]["plugin_object"]
values = issuer.create_authority(kwargs)
# support older plugins
@ -53,7 +53,12 @@ def mint(**kwargs):
elif len(values) == 4:
body, private_key, chain, roles = values
roles = create_authority_roles(roles, kwargs['owner'], kwargs['plugin']['plugin_object'].title, kwargs['creator'])
roles = create_authority_roles(
roles,
kwargs["owner"],
kwargs["plugin"]["plugin_object"].title,
kwargs["creator"],
)
return body, private_key, chain, roles
@ -66,16 +71,17 @@ def create_authority_roles(roles, owner, plugin_title, creator):
"""
role_objs = []
for r in roles:
role = role_service.get_by_name(r['name'])
role = role_service.get_by_name(r["name"])
if not role:
role = role_service.create(
r['name'],
password=r['password'],
r["name"],
password=r["password"],
description="Auto generated role for {0}".format(plugin_title),
username=r['username'])
username=r["username"],
)
# the user creating the authority should be able to administer it
if role.username == 'admin':
if role.username == "admin":
creator.roles.append(role)
role_objs.append(role)
@ -84,8 +90,7 @@ def create_authority_roles(roles, owner, plugin_title, creator):
owner_role = role_service.get_by_name(owner)
if not owner_role:
owner_role = role_service.create(
owner,
description="Auto generated role based on owner: {0}".format(owner)
owner, description="Auto generated role based on owner: {0}".format(owner)
)
role_objs.append(owner_role)
@ -98,27 +103,29 @@ def create(**kwargs):
"""
body, private_key, chain, roles = mint(**kwargs)
kwargs['creator'].roles = list(set(list(kwargs['creator'].roles) + roles))
kwargs["creator"].roles = list(set(list(kwargs["creator"].roles) + roles))
kwargs['body'] = body
kwargs['private_key'] = private_key
kwargs['chain'] = chain
kwargs["body"] = body
kwargs["private_key"] = private_key
kwargs["chain"] = chain
if kwargs.get('roles'):
kwargs['roles'] += roles
if kwargs.get("roles"):
kwargs["roles"] += roles
else:
kwargs['roles'] = roles
kwargs["roles"] = roles
cert = upload(**kwargs)
kwargs['authority_certificate'] = cert
if kwargs.get('plugin', {}).get('plugin_options', []):
kwargs['options'] = json.dumps(kwargs['plugin']['plugin_options'])
kwargs["authority_certificate"] = cert
if kwargs.get("plugin", {}).get("plugin_options", []):
kwargs["options"] = json.dumps(kwargs["plugin"]["plugin_options"])
authority = Authority(**kwargs)
authority = database.create(authority)
kwargs['creator'].authorities.append(authority)
kwargs["creator"].authorities.append(authority)
metrics.send('authority_created', 'counter', 1, metric_tags=dict(owner=authority.owner))
metrics.send(
"authority_created", "counter", 1, metric_tags=dict(owner=authority.owner)
)
return authority
@ -150,7 +157,7 @@ def get_by_name(authority_name):
:param authority_name:
:return:
"""
return database.get(Authority, authority_name, field='name')
return database.get(Authority, authority_name, field="name")
def get_authority_role(ca_name, creator=None):
@ -173,29 +180,31 @@ def render(args):
:return:
"""
query = database.session_query(Authority)
filt = args.pop('filter')
filt = args.pop("filter")
if filt:
terms = filt.split(';')
if 'active' in filt:
terms = filt.split(";")
if "active" in filt:
query = query.filter(Authority.active == truthiness(terms[1]))
elif 'cn' in filt:
term = '%{0}%'.format(terms[1])
sub_query = database.session_query(Certificate.root_authority_id) \
.filter(Certificate.cn.ilike(term)) \
elif "cn" in filt:
term = "%{0}%".format(terms[1])
sub_query = (
database.session_query(Certificate.root_authority_id)
.filter(Certificate.cn.ilike(term))
.subquery()
)
query = query.filter(Authority.id.in_(sub_query))
else:
query = database.filter(query, Authority, terms)
# we make sure that a user can only use an authority they either own are a member of - admins can see all
if not args['user'].is_admin:
if not args["user"].is_admin:
authority_ids = []
for authority in args['user'].authorities:
for authority in args["user"].authorities:
authority_ids.append(authority.id)
for role in args['user'].roles:
for role in args["user"].roles:
for authority in role.authorities:
authority_ids.append(authority.id)
query = query.filter(Authority.id.in_(authority_ids))

View File

@ -16,15 +16,21 @@ from lemur.auth.permissions import AuthorityPermission
from lemur.certificates import service as certificate_service
from lemur.authorities import service
from lemur.authorities.schemas import authority_input_schema, authority_output_schema, authorities_output_schema, authority_update_schema
from lemur.authorities.schemas import (
authority_input_schema,
authority_output_schema,
authorities_output_schema,
authority_update_schema,
)
mod = Blueprint('authorities', __name__)
mod = Blueprint("authorities", __name__)
api = Api(mod)
class AuthoritiesList(AuthenticatedResource):
""" Defines the 'authorities' endpoint """
def __init__(self):
self.reqparse = reqparse.RequestParser()
super(AuthoritiesList, self).__init__()
@ -107,7 +113,7 @@ class AuthoritiesList(AuthenticatedResource):
"""
parser = paginated_parser.copy()
args = parser.parse_args()
args['user'] = g.current_user
args["user"] = g.current_user
return service.render(args)
@validate_schema(authority_input_schema, authority_output_schema)
@ -220,7 +226,7 @@ class AuthoritiesList(AuthenticatedResource):
:statuscode 403: unauthenticated
:statuscode 200: no error
"""
data['creator'] = g.current_user
data["creator"] = g.current_user
return service.create(**data)
@ -388,7 +394,7 @@ class Authorities(AuthenticatedResource):
authority = service.get(authority_id)
if not authority:
return dict(message='Not Found'), 404
return dict(message="Not Found"), 404
# all the authority role members should be allowed
roles = [x.name for x in authority.roles]
@ -397,10 +403,10 @@ class Authorities(AuthenticatedResource):
if permission.can():
return service.update(
authority_id,
owner=data['owner'],
description=data['description'],
active=data['active'],
roles=data['roles']
owner=data["owner"],
description=data["description"],
active=data["active"],
roles=data["roles"],
)
return dict(message="You are not authorized to update this authority."), 403
@ -505,10 +511,21 @@ class AuthorityVisualizations(AuthenticatedResource):
]}
"""
authority = service.get(authority_id)
return dict(name=authority.name, children=[{"name": c.name} for c in authority.certificates])
return dict(
name=authority.name,
children=[{"name": c.name} for c in authority.certificates],
)
api.add_resource(AuthoritiesList, '/authorities', endpoint='authorities')
api.add_resource(Authorities, '/authorities/<int:authority_id>', endpoint='authority')
api.add_resource(AuthorityVisualizations, '/authorities/<int:authority_id>/visualize', endpoint='authority_visualizations')
api.add_resource(CertificateAuthority, '/certificates/<int:certificate_id>/authority', endpoint='certificateAuthority')
api.add_resource(AuthoritiesList, "/authorities", endpoint="authorities")
api.add_resource(Authorities, "/authorities/<int:authority_id>", endpoint="authority")
api.add_resource(
AuthorityVisualizations,
"/authorities/<int:authority_id>/visualize",
endpoint="authority_visualizations",
)
api.add_resource(
CertificateAuthority,
"/certificates/<int:certificate_id>/authority",
endpoint="certificateAuthority",
)

View File

@ -13,7 +13,7 @@ from lemur.plugins.base import plugins
class Authorization(db.Model):
__tablename__ = 'pending_dns_authorizations'
__tablename__ = "pending_dns_authorizations"
id = Column(Integer, primary_key=True, autoincrement=True)
account_number = Column(String(128))
domains = Column(JSONType)

View File

@ -34,7 +34,7 @@ from lemur.certificates.service import (
get_all_pending_reissue,
get_by_name,
get_all_certs,
get
get,
)
from lemur.certificates.verify import verify_string
@ -56,11 +56,14 @@ def print_certificate_details(details):
"\t[+] Authority: {authority_name}\n"
"\t[+] Validity Start: {validity_start}\n"
"\t[+] Validity End: {validity_end}\n".format(
common_name=details['commonName'],
sans=",".join(x['value'] for x in details['extensions']['subAltNames']['names']) or None,
authority_name=details['authority']['name'],
validity_start=details['validityStart'],
validity_end=details['validityEnd']
common_name=details["commonName"],
sans=",".join(
x["value"] for x in details["extensions"]["subAltNames"]["names"]
)
or None,
authority_name=details["authority"]["name"],
validity_start=details["validityStart"],
validity_end=details["validityEnd"],
)
)
@ -120,13 +123,11 @@ def request_rotation(endpoint, certificate, message, commit):
except Exception as e:
print(
"[!] Failed to rotate endpoint {0} to certificate {1} reason: {2}".format(
endpoint.name,
certificate.name,
e
endpoint.name, certificate.name, e
)
)
metrics.send('endpoint_rotation', 'counter', 1, metric_tags={'status': status})
metrics.send("endpoint_rotation", "counter", 1, metric_tags={"status": status})
def request_reissue(certificate, commit):
@ -154,17 +155,52 @@ def request_reissue(certificate, commit):
except Exception as e:
sentry.captureException(extra={"certificate_name": str(certificate.name)})
current_app.logger.exception(f"Error reissuing certificate: {certificate.name}", exc_info=True)
current_app.logger.exception(
f"Error reissuing certificate: {certificate.name}", exc_info=True
)
print(f"[!] Failed to reissue certificate: {certificate.name}. Reason: {e}")
metrics.send('certificate_reissue', 'counter', 1, metric_tags={'status': status, 'certificate': certificate.name})
metrics.send(
"certificate_reissue",
"counter",
1,
metric_tags={"status": status, "certificate": certificate.name},
)
@manager.option('-e', '--endpoint', dest='endpoint_name', help='Name of the endpoint you wish to rotate.')
@manager.option('-n', '--new-certificate', dest='new_certificate_name', help='Name of the certificate you wish to rotate to.')
@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to rotate.')
@manager.option('-a', '--notify', dest='message', action='store_true', help='Send a rotation notification to the certificates owner.')
@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.')
@manager.option(
"-e",
"--endpoint",
dest="endpoint_name",
help="Name of the endpoint you wish to rotate.",
)
@manager.option(
"-n",
"--new-certificate",
dest="new_certificate_name",
help="Name of the certificate you wish to rotate to.",
)
@manager.option(
"-o",
"--old-certificate",
dest="old_certificate_name",
help="Name of the certificate you wish to rotate.",
)
@manager.option(
"-a",
"--notify",
dest="message",
action="store_true",
help="Send a rotation notification to the certificates owner.",
)
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, commit):
"""
Rotates an endpoint and reissues it if it has not already been replaced. If it has
@ -183,7 +219,9 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c
endpoint = validate_endpoint(endpoint_name)
if endpoint and new_cert:
print(f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}")
print(
f"[+] Rotating endpoint: {endpoint.name} to certificate {new_cert.name}"
)
request_rotation(endpoint, new_cert, message, commit)
elif old_cert and new_cert:
@ -197,16 +235,27 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c
print("[+] Rotating all endpoints that have new certificates available")
for endpoint in endpoint_service.get_all_pending_rotation():
if len(endpoint.certificate.replaced) == 1:
print(f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}")
request_rotation(endpoint, endpoint.certificate.replaced[0], message, commit)
print(
f"[+] Rotating {endpoint.name} to {endpoint.certificate.replaced[0].name}"
)
request_rotation(
endpoint, endpoint.certificate.replaced[0], message, commit
)
else:
metrics.send('endpoint_rotation', 'counter', 1, metric_tags={
'status': FAILURE_METRIC_STATUS,
metrics.send(
"endpoint_rotation",
"counter",
1,
metric_tags={
"status": FAILURE_METRIC_STATUS,
"old_certificate_name": str(old_cert),
"new_certificate_name": str(endpoint.certificate.replaced[0].name),
"new_certificate_name": str(
endpoint.certificate.replaced[0].name
),
"endpoint_name": str(endpoint.name),
"message": str(message),
})
},
)
print(
f"[!] Failed to rotate endpoint {endpoint.name} reason: "
"Multiple replacement certificates found."
@ -222,20 +271,38 @@ def rotate(endpoint_name, new_certificate_name, old_certificate_name, message, c
"new_certificate_name": str(new_certificate_name),
"endpoint_name": str(endpoint_name),
"message": str(message),
})
}
)
metrics.send('endpoint_rotation_job', 'counter', 1, metric_tags={
metrics.send(
"endpoint_rotation_job",
"counter",
1,
metric_tags={
"status": status,
"old_certificate_name": str(old_certificate_name),
"new_certificate_name": str(new_certificate_name),
"endpoint_name": str(endpoint_name),
"message": str(message),
"endpoint": str(globals().get("endpoint"))
})
"endpoint": str(globals().get("endpoint")),
},
)
@manager.option('-o', '--old-certificate', dest='old_certificate_name', help='Name of the certificate you wish to reissue.')
@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.')
@manager.option(
"-o",
"--old-certificate",
dest="old_certificate_name",
help="Name of the certificate you wish to reissue.",
)
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def reissue(old_certificate_name, commit):
"""
Reissues certificate with the same parameters as it was originally issued with.
@ -263,76 +330,94 @@ def reissue(old_certificate_name, commit):
except Exception as e:
sentry.captureException()
current_app.logger.exception("Error reissuing certificate.", exc_info=True)
print(
"[!] Failed to reissue certificates. Reason: {}".format(
e
)
print("[!] Failed to reissue certificates. Reason: {}".format(e))
metrics.send(
"certificate_reissue_job", "counter", 1, metric_tags={"status": status}
)
metrics.send('certificate_reissue_job', 'counter', 1, metric_tags={'status': status})
@manager.option('-f', '--fqdns', dest='fqdns', help='FQDNs to query. Multiple fqdns specified via comma.')
@manager.option('-i', '--issuer', dest='issuer', help='Issuer to query for.')
@manager.option('-o', '--owner', dest='owner', help='Owner to query for.')
@manager.option('-e', '--expired', dest='expired', type=bool, default=False, help='Include expired certificates.')
@manager.option(
"-f",
"--fqdns",
dest="fqdns",
help="FQDNs to query. Multiple fqdns specified via comma.",
)
@manager.option("-i", "--issuer", dest="issuer", help="Issuer to query for.")
@manager.option("-o", "--owner", dest="owner", help="Owner to query for.")
@manager.option(
"-e",
"--expired",
dest="expired",
type=bool,
default=False,
help="Include expired certificates.",
)
def query(fqdns, issuer, owner, expired):
"""Prints certificates that match the query params."""
table = []
q = database.session_query(Certificate)
if issuer:
sub_query = database.session_query(Authority.id) \
.filter(Authority.name.ilike('%{0}%'.format(issuer))) \
sub_query = (
database.session_query(Authority.id)
.filter(Authority.name.ilike("%{0}%".format(issuer)))
.subquery()
)
q = q.filter(
or_(
Certificate.issuer.ilike('%{0}%'.format(issuer)),
Certificate.authority_id.in_(sub_query)
Certificate.issuer.ilike("%{0}%".format(issuer)),
Certificate.authority_id.in_(sub_query),
)
)
if owner:
q = q.filter(Certificate.owner.ilike('%{0}%'.format(owner)))
q = q.filter(Certificate.owner.ilike("%{0}%".format(owner)))
if not expired:
q = q.filter(Certificate.expired == False) # noqa
if fqdns:
for f in fqdns.split(','):
for f in fqdns.split(","):
q = q.filter(
or_(
Certificate.cn.ilike('%{0}%'.format(f)),
Certificate.domains.any(Domain.name.ilike('%{0}%'.format(f)))
Certificate.cn.ilike("%{0}%".format(f)),
Certificate.domains.any(Domain.name.ilike("%{0}%".format(f))),
)
)
for c in q.all():
table.append([c.id, c.name, c.owner, c.issuer])
print(tabulate(table, headers=['Id', 'Name', 'Owner', 'Issuer'], tablefmt='csv'))
print(tabulate(table, headers=["Id", "Name", "Owner", "Issuer"], tablefmt="csv"))
def worker(data, commit, reason):
parts = [x for x in data.split(' ') if x]
parts = [x for x in data.split(" ") if x]
try:
cert = get(int(parts[0].strip()))
plugin = plugins.get(cert.authority.plugin_name)
print('[+] Revoking certificate. Id: {0} Name: {1}'.format(cert.id, cert.name))
print("[+] Revoking certificate. Id: {0} Name: {1}".format(cert.id, cert.name))
if commit:
plugin.revoke_certificate(cert, reason)
metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': SUCCESS_METRIC_STATUS})
metrics.send(
"certificate_revoke",
"counter",
1,
metric_tags={"status": SUCCESS_METRIC_STATUS},
)
except Exception as e:
sentry.captureException()
metrics.send('certificate_revoke', 'counter', 1, metric_tags={'status': FAILURE_METRIC_STATUS})
print(
"[!] Failed to revoke certificates. Reason: {}".format(
e
)
metrics.send(
"certificate_revoke",
"counter",
1,
metric_tags={"status": FAILURE_METRIC_STATUS},
)
print("[!] Failed to revoke certificates. Reason: {}".format(e))
@manager.command
@ -341,13 +426,22 @@ def clear_pending():
Function clears all pending certificates.
:return:
"""
v = plugins.get('verisign-issuer')
v = plugins.get("verisign-issuer")
v.clear_pending_certificates()
@manager.option('-p', '--path', dest='path', help='Absolute file path to a Lemur query csv.')
@manager.option('-r', '--reason', dest='reason', help='Reason to revoke certificate.')
@manager.option('-c', '--commit', dest='commit', action='store_true', default=False, help='Persist changes.')
@manager.option(
"-p", "--path", dest="path", help="Absolute file path to a Lemur query csv."
)
@manager.option("-r", "--reason", dest="reason", help="Reason to revoke certificate.")
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def revoke(path, reason, commit):
"""
Revokes given certificate.
@ -357,7 +451,7 @@ def revoke(path, reason, commit):
print("[+] Starting certificate revocation.")
with open(path, 'r') as f:
with open(path, "r") as f:
args = [[x, commit, reason] for x in f.readlines()[2:]]
with multiprocessing.Pool(processes=3) as pool:
@ -380,11 +474,11 @@ def check_revoked():
else:
status = verify_string(cert.body, "")
cert.status = 'valid' if status else 'revoked'
cert.status = "valid" if status else "revoked"
except Exception as e:
sentry.captureException()
current_app.logger.exception(e)
cert.status = 'unknown'
cert.status = "unknown"
database.update(cert)

View File

@ -12,21 +12,30 @@ import subprocess
from flask import current_app
from lemur.certificates.service import csr_created, csr_imported, certificate_issued, certificate_imported
from lemur.certificates.service import (
csr_created,
csr_imported,
certificate_issued,
certificate_imported,
)
def csr_dump_handler(sender, csr, **kwargs):
try:
subprocess.run(['openssl', 'req', '-text', '-noout', '-reqopt', 'no_sigdump,no_pubkey'],
input=csr.encode('utf8'))
subprocess.run(
["openssl", "req", "-text", "-noout", "-reqopt", "no_sigdump,no_pubkey"],
input=csr.encode("utf8"),
)
except Exception as err:
current_app.logger.warning("Error inspecting CSR: %s", err)
def cert_dump_handler(sender, certificate, **kwargs):
try:
subprocess.run(['openssl', 'x509', '-text', '-noout', '-certopt', 'no_sigdump,no_pubkey'],
input=certificate.body.encode('utf8'))
subprocess.run(
["openssl", "x509", "-text", "-noout", "-certopt", "no_sigdump,no_pubkey"],
input=certificate.body.encode("utf8"),
)
except Exception as err:
current_app.logger.warning("Error inspecting certificate: %s", err)

View File

@ -12,7 +12,18 @@ from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import rsa
from flask import current_app
from idna.core import InvalidCodepoint
from sqlalchemy import event, Integer, ForeignKey, String, PassiveDefault, func, Column, Text, Boolean, Index
from sqlalchemy import (
event,
Integer,
ForeignKey,
String,
PassiveDefault,
func,
Column,
Text,
Boolean,
Index,
)
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlalchemy.sql.expression import case, extract
@ -25,19 +36,25 @@ from lemur.database import db
from lemur.domains.models import Domain
from lemur.extensions import metrics
from lemur.extensions import sentry
from lemur.models import certificate_associations, certificate_source_associations, \
certificate_destination_associations, certificate_notification_associations, \
certificate_replacement_associations, roles_certificates, pending_cert_replacement_associations
from lemur.models import (
certificate_associations,
certificate_source_associations,
certificate_destination_associations,
certificate_notification_associations,
certificate_replacement_associations,
roles_certificates,
pending_cert_replacement_associations,
)
from lemur.plugins.base import plugins
from lemur.policies.models import RotationPolicy
from lemur.utils import Vault
def get_sequence(name):
if '-' not in name:
if "-" not in name:
return name, None
parts = name.split('-')
parts = name.split("-")
# see if we have an int at the end of our name
try:
@ -49,18 +66,22 @@ def get_sequence(name):
if len(parts[-1]) == 8:
return name, None
root = '-'.join(parts[:-1])
root = "-".join(parts[:-1])
return root, seq
def get_or_increase_name(name, serial):
certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(name))).all()
certificates = Certificate.query.filter(
Certificate.name.ilike("{0}%".format(name))
).all()
if not certificates:
return name
serial_name = '{0}-{1}'.format(name, hex(int(serial))[2:].upper())
certificates = Certificate.query.filter(Certificate.name.ilike('{0}%'.format(serial_name))).all()
serial_name = "{0}-{1}".format(name, hex(int(serial))[2:].upper())
certificates = Certificate.query.filter(
Certificate.name.ilike("{0}%".format(serial_name))
).all()
if not certificates:
return serial_name
@ -72,21 +93,29 @@ def get_or_increase_name(name, serial):
if end:
ends.append(end)
return '{0}-{1}'.format(root, max(ends) + 1)
return "{0}-{1}".format(root, max(ends) + 1)
class Certificate(db.Model):
__tablename__ = 'certificates'
__tablename__ = "certificates"
__table_args__ = (
Index('ix_certificates_cn', "cn",
Index(
"ix_certificates_cn",
"cn",
postgresql_ops={"cn": "gin_trgm_ops"},
postgresql_using='gin'),
Index('ix_certificates_name', "name",
postgresql_using="gin",
),
Index(
"ix_certificates_name",
"name",
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using='gin'),
postgresql_using="gin",
),
)
id = Column(Integer, primary_key=True)
ix = Index('ix_certificates_id_desc', id.desc(), postgresql_using='btree', unique=True)
ix = Index(
"ix_certificates_id_desc", id.desc(), postgresql_using="btree", unique=True
)
external_id = Column(String(128))
owner = Column(String(128), nullable=False)
name = Column(String(256), unique=True)
@ -102,7 +131,9 @@ class Certificate(db.Model):
serial = Column(String(128))
cn = Column(String(128))
deleted = Column(Boolean, index=True, default=False)
dns_provider_id = Column(Integer(), ForeignKey('dns_providers.id', ondelete='CASCADE'), nullable=True)
dns_provider_id = Column(
Integer(), ForeignKey("dns_providers.id", ondelete="CASCADE"), nullable=True
)
not_before = Column(ArrowType)
not_after = Column(ArrowType)
@ -114,34 +145,53 @@ class Certificate(db.Model):
san = Column(String(1024)) # TODO this should be migrated to boolean
rotation = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('users.id'))
authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE"))
root_authority_id = Column(Integer, ForeignKey('authorities.id', ondelete="CASCADE"))
rotation_policy_id = Column(Integer, ForeignKey('rotation_policies.id'))
user_id = Column(Integer, ForeignKey("users.id"))
authority_id = Column(Integer, ForeignKey("authorities.id", ondelete="CASCADE"))
root_authority_id = Column(
Integer, ForeignKey("authorities.id", ondelete="CASCADE")
)
rotation_policy_id = Column(Integer, ForeignKey("rotation_policies.id"))
notifications = relationship('Notification', secondary=certificate_notification_associations, backref='certificate')
destinations = relationship('Destination', secondary=certificate_destination_associations, backref='certificate')
sources = relationship('Source', secondary=certificate_source_associations, backref='certificate')
domains = relationship('Domain', secondary=certificate_associations, backref='certificate')
roles = relationship('Role', secondary=roles_certificates, backref='certificate')
replaces = relationship('Certificate',
notifications = relationship(
"Notification",
secondary=certificate_notification_associations,
backref="certificate",
)
destinations = relationship(
"Destination",
secondary=certificate_destination_associations,
backref="certificate",
)
sources = relationship(
"Source", secondary=certificate_source_associations, backref="certificate"
)
domains = relationship(
"Domain", secondary=certificate_associations, backref="certificate"
)
roles = relationship("Role", secondary=roles_certificates, backref="certificate")
replaces = relationship(
"Certificate",
secondary=certificate_replacement_associations,
primaryjoin=id == certificate_replacement_associations.c.certificate_id, # noqa
secondaryjoin=id == certificate_replacement_associations.c.replaced_certificate_id, # noqa
backref='replaced')
secondaryjoin=id
== certificate_replacement_associations.c.replaced_certificate_id, # noqa
backref="replaced",
)
replaced_by_pending = relationship('PendingCertificate',
replaced_by_pending = relationship(
"PendingCertificate",
secondary=pending_cert_replacement_associations,
backref='pending_replace',
viewonly=True)
backref="pending_replace",
viewonly=True,
)
logs = relationship('Log', backref='certificate')
endpoints = relationship('Endpoint', backref='certificate')
logs = relationship("Log", backref="certificate")
endpoints = relationship("Endpoint", backref="certificate")
rotation_policy = relationship("RotationPolicy")
sensitive_fields = ('private_key',)
sensitive_fields = ("private_key",)
def __init__(self, **kwargs):
self.body = kwargs['body'].strip()
self.body = kwargs["body"].strip()
cert = self.parsed_cert
self.issuer = defaults.issuer(cert)
@ -152,36 +202,42 @@ class Certificate(db.Model):
self.serial = defaults.serial(cert)
# when destinations are appended they require a valid name.
if kwargs.get('name'):
self.name = get_or_increase_name(defaults.text_to_slug(kwargs['name']), self.serial)
if kwargs.get("name"):
self.name = get_or_increase_name(
defaults.text_to_slug(kwargs["name"]), self.serial
)
else:
self.name = get_or_increase_name(
defaults.certificate_name(self.cn, self.issuer, self.not_before, self.not_after, self.san), self.serial)
defaults.certificate_name(
self.cn, self.issuer, self.not_before, self.not_after, self.san
),
self.serial,
)
self.owner = kwargs['owner']
self.owner = kwargs["owner"]
if kwargs.get('private_key'):
self.private_key = kwargs['private_key'].strip()
if kwargs.get("private_key"):
self.private_key = kwargs["private_key"].strip()
if kwargs.get('chain'):
self.chain = kwargs['chain'].strip()
if kwargs.get("chain"):
self.chain = kwargs["chain"].strip()
if kwargs.get('csr'):
self.csr = kwargs['csr'].strip()
if kwargs.get("csr"):
self.csr = kwargs["csr"].strip()
self.notify = kwargs.get('notify', True)
self.destinations = kwargs.get('destinations', [])
self.notifications = kwargs.get('notifications', [])
self.description = kwargs.get('description')
self.roles = list(set(kwargs.get('roles', [])))
self.replaces = kwargs.get('replaces', [])
self.rotation = kwargs.get('rotation')
self.rotation_policy = kwargs.get('rotation_policy')
self.notify = kwargs.get("notify", True)
self.destinations = kwargs.get("destinations", [])
self.notifications = kwargs.get("notifications", [])
self.description = kwargs.get("description")
self.roles = list(set(kwargs.get("roles", [])))
self.replaces = kwargs.get("replaces", [])
self.rotation = kwargs.get("rotation")
self.rotation_policy = kwargs.get("rotation_policy")
self.signing_algorithm = defaults.signing_algorithm(cert)
self.bits = defaults.bitstrength(cert)
self.external_id = kwargs.get('external_id')
self.authority_id = kwargs.get('authority_id')
self.dns_provider_id = kwargs.get('dns_provider_id')
self.external_id = kwargs.get("external_id")
self.authority_id = kwargs.get("authority_id")
self.dns_provider_id = kwargs.get("dns_provider_id")
for domain in defaults.domains(cert):
self.domains.append(Domain(name=domain))
@ -195,8 +251,11 @@ class Certificate(db.Model):
Integrity checks: Does the cert have a valid chain and matching private key?
"""
if self.private_key:
validators.verify_private_key_match(utils.parse_private_key(self.private_key), self.parsed_cert,
error_class=AssertionError)
validators.verify_private_key_match(
utils.parse_private_key(self.private_key),
self.parsed_cert,
error_class=AssertionError,
)
if self.chain:
chain = [self.parsed_cert] + utils.parse_cert_chain(self.chain)
@ -238,7 +297,9 @@ class Certificate(db.Model):
@property
def key_type(self):
if isinstance(self.parsed_cert.public_key(), rsa.RSAPublicKey):
return 'RSA{key_size}'.format(key_size=self.parsed_cert.public_key().key_size)
return "RSA{key_size}".format(
key_size=self.parsed_cert.public_key().key_size
)
@property
def validity_remaining(self):
@ -263,26 +324,16 @@ class Certificate(db.Model):
@expired.expression
def expired(cls):
return case(
[
(cls.not_after <= arrow.utcnow(), True)
],
else_=False
)
return case([(cls.not_after <= arrow.utcnow(), True)], else_=False)
@hybrid_property
def revoked(self):
if 'revoked' == self.status:
if "revoked" == self.status:
return True
@revoked.expression
def revoked(cls):
return case(
[
(cls.status == 'revoked', True)
],
else_=False
)
return case([(cls.status == "revoked", True)], else_=False)
@hybrid_property
def in_rotation_window(self):
@ -305,66 +356,65 @@ class Certificate(db.Model):
:return:
"""
return case(
[
(extract('day', cls.not_after - func.now()) <= RotationPolicy.days, True)
],
else_=False
[(extract("day", cls.not_after - func.now()) <= RotationPolicy.days, True)],
else_=False,
)
@property
def extensions(self):
# setup default values
return_extensions = {
'sub_alt_names': {'names': []}
}
return_extensions = {"sub_alt_names": {"names": []}}
try:
for extension in self.parsed_cert.extensions:
value = extension.value
if isinstance(value, x509.BasicConstraints):
return_extensions['basic_constraints'] = value
return_extensions["basic_constraints"] = value
elif isinstance(value, x509.SubjectAlternativeName):
return_extensions['sub_alt_names']['names'] = value
return_extensions["sub_alt_names"]["names"] = value
elif isinstance(value, x509.ExtendedKeyUsage):
return_extensions['extended_key_usage'] = value
return_extensions["extended_key_usage"] = value
elif isinstance(value, x509.KeyUsage):
return_extensions['key_usage'] = value
return_extensions["key_usage"] = value
elif isinstance(value, x509.SubjectKeyIdentifier):
return_extensions['subject_key_identifier'] = {'include_ski': True}
return_extensions["subject_key_identifier"] = {"include_ski": True}
elif isinstance(value, x509.AuthorityInformationAccess):
return_extensions['certificate_info_access'] = {'include_aia': True}
return_extensions["certificate_info_access"] = {"include_aia": True}
elif isinstance(value, x509.AuthorityKeyIdentifier):
aki = {
'use_key_identifier': False,
'use_authority_cert': False
}
aki = {"use_key_identifier": False, "use_authority_cert": False}
if value.key_identifier:
aki['use_key_identifier'] = True
aki["use_key_identifier"] = True
if value.authority_cert_issuer:
aki['use_authority_cert'] = True
aki["use_authority_cert"] = True
return_extensions['authority_key_identifier'] = aki
return_extensions["authority_key_identifier"] = aki
elif isinstance(value, x509.CRLDistributionPoints):
return_extensions['crl_distribution_points'] = {'include_crl_dp': value}
return_extensions["crl_distribution_points"] = {
"include_crl_dp": value
}
# TODO: Not supporting custom OIDs yet. https://github.com/Netflix/lemur/issues/665
else:
current_app.logger.warning('Custom OIDs not yet supported for clone operation.')
current_app.logger.warning(
"Custom OIDs not yet supported for clone operation."
)
except InvalidCodepoint as e:
sentry.captureException()
current_app.logger.warning('Unable to parse extensions due to underscore in dns name')
current_app.logger.warning(
"Unable to parse extensions due to underscore in dns name"
)
except ValueError as e:
sentry.captureException()
current_app.logger.warning('Unable to parse')
current_app.logger.warning("Unable to parse")
current_app.logger.exception(e)
return return_extensions
@ -373,7 +423,7 @@ class Certificate(db.Model):
return "Certificate(name={name})".format(name=self.name)
@event.listens_for(Certificate.destinations, 'append')
@event.listens_for(Certificate.destinations, "append")
def update_destinations(target, value, initiator):
"""
Attempt to upload certificate to the new destination
@ -387,17 +437,31 @@ def update_destinations(target, value, initiator):
status = FAILURE_METRIC_STATUS
try:
if target.private_key or not destination_plugin.requires_key:
destination_plugin.upload(target.name, target.body, target.private_key, target.chain, value.options)
destination_plugin.upload(
target.name,
target.body,
target.private_key,
target.chain,
value.options,
)
status = SUCCESS_METRIC_STATUS
except Exception as e:
sentry.captureException()
raise
metrics.send('destination_upload', 'counter', 1,
metric_tags={'status': status, 'certificate': target.name, 'destination': value.label})
metrics.send(
"destination_upload",
"counter",
1,
metric_tags={
"status": status,
"certificate": target.name,
"destination": value.label,
},
)
@event.listens_for(Certificate.replaces, 'append')
@event.listens_for(Certificate.replaces, "append")
def update_replacement(target, value, initiator):
"""
When a certificate is marked as 'replaced' we should not notify.

View File

@ -39,22 +39,26 @@ from lemur.users.schemas import UserNestedOutputSchema
class CertificateSchema(LemurInputSchema):
owner = fields.Email(required=True)
description = fields.String(missing='', allow_none=True)
description = fields.String(missing="", allow_none=True)
class CertificateCreationSchema(CertificateSchema):
@post_load
def default_notification(self, data):
if not data['notifications']:
data['notifications'] += notification_service.create_default_expiration_notifications(
"DEFAULT_{0}".format(data['owner'].split('@')[0].upper()),
[data['owner']],
if not data["notifications"]:
data[
"notifications"
] += notification_service.create_default_expiration_notifications(
"DEFAULT_{0}".format(data["owner"].split("@")[0].upper()),
[data["owner"]],
)
data['notifications'] += notification_service.create_default_expiration_notifications(
'DEFAULT_SECURITY',
current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL'),
current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL_INTERVALS', None)
data[
"notifications"
] += notification_service.create_default_expiration_notifications(
"DEFAULT_SECURITY",
current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL"),
current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL_INTERVALS", None),
)
return data
@ -71,37 +75,53 @@ class CertificateInputSchema(CertificateCreationSchema):
destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True)
notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True)
replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True)
replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated
replacements = fields.Nested(
AssociatedCertificateSchema, missing=[], many=True
) # deprecated
roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True)
dns_provider = fields.Nested(AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False)
dns_provider = fields.Nested(
AssociatedDnsProviderSchema, missing=None, allow_none=True, required=False
)
csr = fields.String(allow_none=True, validate=validators.csr)
key_type = fields.String(
validate=validate.OneOf(CERTIFICATE_KEY_TYPES),
missing='RSA2048')
validate=validate.OneOf(CERTIFICATE_KEY_TYPES), missing="RSA2048"
)
notify = fields.Boolean(default=True)
rotation = fields.Boolean()
rotation_policy = fields.Nested(AssociatedRotationPolicySchema, missing={'name': 'default'}, allow_none=True,
default={'name': 'default'})
rotation_policy = fields.Nested(
AssociatedRotationPolicySchema,
missing={"name": "default"},
allow_none=True,
default={"name": "default"},
)
# certificate body fields
organizational_unit = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT'))
organization = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_ORGANIZATION'))
location = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_LOCATION'))
country = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_COUNTRY'))
state = fields.String(missing=lambda: current_app.config.get('LEMUR_DEFAULT_STATE'))
organizational_unit = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATIONAL_UNIT")
)
organization = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_ORGANIZATION")
)
location = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_LOCATION")
)
country = fields.String(
missing=lambda: current_app.config.get("LEMUR_DEFAULT_COUNTRY")
)
state = fields.String(missing=lambda: current_app.config.get("LEMUR_DEFAULT_STATE"))
extensions = fields.Nested(ExtensionSchema)
@validates_schema
def validate_authority(self, data):
if isinstance(data['authority'], str):
if isinstance(data["authority"], str):
raise ValidationError("Authority not found.")
if not data['authority'].active:
raise ValidationError("The authority is inactive.", ['authority'])
if not data["authority"].active:
raise ValidationError("The authority is inactive.", ["authority"])
@validates_schema
def validate_dates(self, data):
@ -109,23 +129,19 @@ class CertificateInputSchema(CertificateCreationSchema):
@pre_load
def load_data(self, data):
if data.get('replacements'):
data['replaces'] = data['replacements'] # TODO remove when field is deprecated
if data.get('csr'):
csr_sans = cert_utils.get_sans_from_csr(data['csr'])
if not data.get('extensions'):
data['extensions'] = {
'subAltNames': {
'names': []
}
}
elif not data['extensions'].get('subAltNames'):
data['extensions']['subAltNames'] = {
'names': []
}
elif not data['extensions']['subAltNames'].get('names'):
data['extensions']['subAltNames']['names'] = []
data['extensions']['subAltNames']['names'] += csr_sans
if data.get("replacements"):
data["replaces"] = data[
"replacements"
] # TODO remove when field is deprecated
if data.get("csr"):
csr_sans = cert_utils.get_sans_from_csr(data["csr"])
if not data.get("extensions"):
data["extensions"] = {"subAltNames": {"names": []}}
elif not data["extensions"].get("subAltNames"):
data["extensions"]["subAltNames"] = {"names": []}
elif not data["extensions"]["subAltNames"].get("names"):
data["extensions"]["subAltNames"]["names"] = []
data["extensions"]["subAltNames"]["names"] += csr_sans
return missing.convert_validity_years(data)
@ -138,13 +154,17 @@ class CertificateEditInputSchema(CertificateSchema):
destinations = fields.Nested(AssociatedDestinationSchema, missing=[], many=True)
notifications = fields.Nested(AssociatedNotificationSchema, missing=[], many=True)
replaces = fields.Nested(AssociatedCertificateSchema, missing=[], many=True)
replacements = fields.Nested(AssociatedCertificateSchema, missing=[], many=True) # deprecated
replacements = fields.Nested(
AssociatedCertificateSchema, missing=[], many=True
) # deprecated
roles = fields.Nested(AssociatedRoleSchema, missing=[], many=True)
@pre_load
def load_data(self, data):
if data.get('replacements'):
data['replaces'] = data['replacements'] # TODO remove when field is deprecated
if data.get("replacements"):
data["replaces"] = data[
"replacements"
] # TODO remove when field is deprecated
return data
@post_load
@ -155,10 +175,15 @@ class CertificateEditInputSchema(CertificateSchema):
:param data:
:return:
"""
if data['owner']:
notification_name = "DEFAULT_{0}".format(data['owner'].split('@')[0].upper())
data['notifications'] += notification_service.create_default_expiration_notifications(notification_name,
[data['owner']])
if data["owner"]:
notification_name = "DEFAULT_{0}".format(
data["owner"].split("@")[0].upper()
)
data[
"notifications"
] += notification_service.create_default_expiration_notifications(
notification_name, [data["owner"]]
)
return data
@ -184,13 +209,13 @@ class CertificateNestedOutputSchema(LemurOutputSchema):
# Note aliasing is the first step in deprecating these fields.
cn = fields.String() # deprecated
common_name = fields.String(attribute='cn')
common_name = fields.String(attribute="cn")
not_after = fields.DateTime() # deprecated
validity_end = ArrowDateTime(attribute='not_after')
validity_end = ArrowDateTime(attribute="not_after")
not_before = fields.DateTime() # deprecated
validity_start = ArrowDateTime(attribute='not_before')
validity_start = ArrowDateTime(attribute="not_before")
issuer = fields.Nested(AuthorityNestedOutputSchema)
@ -221,22 +246,22 @@ class CertificateOutputSchema(LemurOutputSchema):
# Note aliasing is the first step in deprecating these fields.
notify = fields.Boolean()
active = fields.Boolean(attribute='notify')
active = fields.Boolean(attribute="notify")
cn = fields.String()
common_name = fields.String(attribute='cn')
common_name = fields.String(attribute="cn")
distinguished_name = fields.String()
not_after = fields.DateTime()
validity_end = ArrowDateTime(attribute='not_after')
validity_end = ArrowDateTime(attribute="not_after")
not_before = fields.DateTime()
validity_start = ArrowDateTime(attribute='not_before')
validity_start = ArrowDateTime(attribute="not_before")
owner = fields.Email()
san = fields.Boolean()
serial = fields.String()
serial_hex = Hex(attribute='serial')
serial_hex = Hex(attribute="serial")
signing_algorithm = fields.String()
status = fields.String()
@ -253,7 +278,9 @@ class CertificateOutputSchema(LemurOutputSchema):
dns_provider = fields.Nested(DnsProvidersNestedOutputSchema)
roles = fields.Nested(RoleNestedOutputSchema, many=True)
endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[])
replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced')
replaced_by = fields.Nested(
CertificateNestedOutputSchema, many=True, attribute="replaced"
)
rotation_policy = fields.Nested(RotationPolicyNestedOutputSchema)
@ -274,35 +301,41 @@ class CertificateUploadInputSchema(CertificateCreationSchema):
@validates_schema
def keys(self, data):
if data.get('destinations'):
if not data.get('private_key'):
raise ValidationError('Destinations require private key.')
if data.get("destinations"):
if not data.get("private_key"):
raise ValidationError("Destinations require private key.")
@validates_schema
def validate_cert_private_key_chain(self, data):
cert = None
key = None
if data.get('body'):
if data.get("body"):
try:
cert = utils.parse_certificate(data['body'])
cert = utils.parse_certificate(data["body"])
except ValueError:
raise ValidationError("Public certificate presented is not valid.", field_names=['body'])
raise ValidationError(
"Public certificate presented is not valid.", field_names=["body"]
)
if data.get('private_key'):
if data.get("private_key"):
try:
key = utils.parse_private_key(data['private_key'])
key = utils.parse_private_key(data["private_key"])
except ValueError:
raise ValidationError("Private key presented is not valid.", field_names=['private_key'])
raise ValidationError(
"Private key presented is not valid.", field_names=["private_key"]
)
if cert and key:
# Throws ValidationError
validators.verify_private_key_match(key, cert)
if data.get('chain'):
if data.get("chain"):
try:
chain = utils.parse_cert_chain(data['chain'])
chain = utils.parse_cert_chain(data["chain"])
except ValueError:
raise ValidationError("Invalid certificate in certificate chain.", field_names=['chain'])
raise ValidationError(
"Invalid certificate in certificate chain.", field_names=["chain"]
)
# Throws ValidationError
validators.verify_cert_chain([cert] + chain)
@ -318,8 +351,10 @@ class CertificateNotificationOutputSchema(LemurOutputSchema):
name = fields.String()
owner = fields.Email()
user = fields.Nested(UserNestedOutputSchema)
validity_end = ArrowDateTime(attribute='not_after')
replaced_by = fields.Nested(CertificateNestedOutputSchema, many=True, attribute='replaced')
validity_end = ArrowDateTime(attribute="not_after")
replaced_by = fields.Nested(
CertificateNestedOutputSchema, many=True, attribute="replaced"
)
endpoints = fields.Nested(EndpointNestedOutputSchema, many=True, missing=[])

View File

@ -26,10 +26,14 @@ from lemur.plugins.base import plugins
from lemur.roles import service as role_service
from lemur.roles.models import Role
csr_created = signals.signal('csr_created', "CSR generated")
csr_imported = signals.signal('csr_imported', "CSR imported from external source")
certificate_issued = signals.signal('certificate_issued', "Authority issued a certificate")
certificate_imported = signals.signal('certificate_imported', "Certificate imported from external source")
csr_created = signals.signal("csr_created", "CSR generated")
csr_imported = signals.signal("csr_imported", "CSR imported from external source")
certificate_issued = signals.signal(
"certificate_issued", "Authority issued a certificate"
)
certificate_imported = signals.signal(
"certificate_imported", "Certificate imported from external source"
)
def get(cert_id):
@ -49,7 +53,7 @@ def get_by_name(name):
:param name:
:return:
"""
return database.get(Certificate, name, field='name')
return database.get(Certificate, name, field="name")
def get_by_serial(serial):
@ -105,8 +109,12 @@ def get_all_pending_cleaning(source):
:param source:
:return:
"""
return Certificate.query.filter(Certificate.sources.any(id=source.id)) \
.filter(not_(Certificate.endpoints.any())).filter(Certificate.expired).all()
return (
Certificate.query.filter(Certificate.sources.any(id=source.id))
.filter(not_(Certificate.endpoints.any()))
.filter(Certificate.expired)
.all()
)
def get_all_pending_reissue():
@ -119,9 +127,12 @@ def get_all_pending_reissue():
:return:
"""
return Certificate.query.filter(Certificate.rotation == True) \
.filter(not_(Certificate.replaced.any())) \
.filter(Certificate.in_rotation_window == True).all() # noqa
return (
Certificate.query.filter(Certificate.rotation == True)
.filter(not_(Certificate.replaced.any()))
.filter(Certificate.in_rotation_window == True)
.all()
) # noqa
def find_duplicates(cert):
@ -133,10 +144,12 @@ def find_duplicates(cert):
:param cert:
:return:
"""
if cert['chain']:
return Certificate.query.filter_by(body=cert['body'].strip(), chain=cert['chain'].strip()).all()
if cert["chain"]:
return Certificate.query.filter_by(
body=cert["body"].strip(), chain=cert["chain"].strip()
).all()
else:
return Certificate.query.filter_by(body=cert['body'].strip(), chain=None).all()
return Certificate.query.filter_by(body=cert["body"].strip(), chain=None).all()
def export(cert, export_plugin):
@ -148,8 +161,10 @@ def export(cert, export_plugin):
:param cert:
:return:
"""
plugin = plugins.get(export_plugin['slug'])
return plugin.export(cert.body, cert.chain, cert.private_key, export_plugin['pluginOptions'])
plugin = plugins.get(export_plugin["slug"])
return plugin.export(
cert.body, cert.chain, cert.private_key, export_plugin["pluginOptions"]
)
def update(cert_id, **kwargs):
@ -168,17 +183,19 @@ def update(cert_id, **kwargs):
def create_certificate_roles(**kwargs):
# create an role for the owner and assign it
owner_role = role_service.get_by_name(kwargs['owner'])
owner_role = role_service.get_by_name(kwargs["owner"])
if not owner_role:
owner_role = role_service.create(
kwargs['owner'],
description="Auto generated role based on owner: {0}".format(kwargs['owner'])
kwargs["owner"],
description="Auto generated role based on owner: {0}".format(
kwargs["owner"]
),
)
# ensure that the authority's owner is also associated with the certificate
if kwargs.get('authority'):
authority_owner_role = role_service.get_by_name(kwargs['authority'].owner)
if kwargs.get("authority"):
authority_owner_role = role_service.get_by_name(kwargs["authority"].owner)
return [owner_role, authority_owner_role]
return [owner_role]
@ -190,16 +207,16 @@ def mint(**kwargs):
Support for multiple authorities is handled by individual plugins.
"""
authority = kwargs['authority']
authority = kwargs["authority"]
issuer = plugins.get(authority.plugin_name)
# allow the CSR to be specified by the user
if not kwargs.get('csr'):
if not kwargs.get("csr"):
csr, private_key = create_csr(**kwargs)
csr_created.send(authority=authority, csr=csr)
else:
csr = str(kwargs.get('csr'))
csr = str(kwargs.get("csr"))
private_key = None
csr_imported.send(authority=authority, csr=csr)
@ -220,8 +237,8 @@ def import_certificate(**kwargs):
:param kwargs:
"""
if not kwargs.get('owner'):
kwargs['owner'] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL')[0]
if not kwargs.get("owner"):
kwargs["owner"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")[0]
return upload(**kwargs)
@ -232,16 +249,16 @@ def upload(**kwargs):
"""
roles = create_certificate_roles(**kwargs)
if kwargs.get('roles'):
kwargs['roles'] += roles
if kwargs.get("roles"):
kwargs["roles"] += roles
else:
kwargs['roles'] = roles
kwargs["roles"] = roles
cert = Certificate(**kwargs)
cert.authority = kwargs.get('authority')
cert.authority = kwargs.get("authority")
cert = database.create(cert)
kwargs['creator'].certificates.append(cert)
kwargs["creator"].certificates.append(cert)
cert = database.update(cert)
certificate_imported.send(certificate=cert, authority=cert.authority)
@ -258,39 +275,45 @@ def create(**kwargs):
current_app.logger.error("Exception minting certificate", exc_info=True)
sentry.captureException()
raise
kwargs['body'] = cert_body
kwargs['private_key'] = private_key
kwargs['chain'] = cert_chain
kwargs['external_id'] = external_id
kwargs['csr'] = csr
kwargs["body"] = cert_body
kwargs["private_key"] = private_key
kwargs["chain"] = cert_chain
kwargs["external_id"] = external_id
kwargs["csr"] = csr
roles = create_certificate_roles(**kwargs)
if kwargs.get('roles'):
kwargs['roles'] += roles
if kwargs.get("roles"):
kwargs["roles"] += roles
else:
kwargs['roles'] = roles
kwargs["roles"] = roles
if cert_body:
cert = Certificate(**kwargs)
kwargs['creator'].certificates.append(cert)
kwargs["creator"].certificates.append(cert)
else:
cert = PendingCertificate(**kwargs)
kwargs['creator'].pending_certificates.append(cert)
kwargs["creator"].pending_certificates.append(cert)
cert.authority = kwargs['authority']
cert.authority = kwargs["authority"]
database.commit()
if isinstance(cert, Certificate):
certificate_issued.send(certificate=cert, authority=cert.authority)
metrics.send('certificate_issued', 'counter', 1, metric_tags=dict(owner=cert.owner, issuer=cert.issuer))
metrics.send(
"certificate_issued",
"counter",
1,
metric_tags=dict(owner=cert.owner, issuer=cert.issuer),
)
if isinstance(cert, PendingCertificate):
# We need to refresh the pending certificate to avoid "Instance is not bound to a Session; "
# "attribute refresh operation cannot proceed"
pending_cert = database.session_query(PendingCertificate).get(cert.id)
from lemur.common.celery import fetch_acme_cert
if not current_app.config.get("ACME_DISABLE_AUTORESOLVE", False):
fetch_acme_cert.apply_async((pending_cert.id,), countdown=5)
@ -306,51 +329,55 @@ def render(args):
"""
query = database.session_query(Certificate)
time_range = args.pop('time_range')
destination_id = args.pop('destination_id')
notification_id = args.pop('notification_id', None)
show = args.pop('show')
time_range = args.pop("time_range")
destination_id = args.pop("destination_id")
notification_id = args.pop("notification_id", None)
show = args.pop("show")
# owner = args.pop('owner')
# creator = args.pop('creator') # TODO we should enabling filtering by owner
filt = args.pop('filter')
filt = args.pop("filter")
if filt:
terms = filt.split(';')
term = '%{0}%'.format(terms[1])
terms = filt.split(";")
term = "%{0}%".format(terms[1])
# Exact matches for quotes. Only applies to name, issuer, and cn
if terms[1].startswith('"') and terms[1].endswith('"'):
term = terms[1][1:-1]
if 'issuer' in terms:
if "issuer" in terms:
# we can't rely on issuer being correct in the cert directly so we combine queries
sub_query = database.session_query(Authority.id) \
.filter(Authority.name.ilike(term)) \
sub_query = (
database.session_query(Authority.id)
.filter(Authority.name.ilike(term))
.subquery()
)
query = query.filter(
or_(
Certificate.issuer.ilike(term),
Certificate.authority_id.in_(sub_query)
Certificate.authority_id.in_(sub_query),
)
)
elif 'destination' in terms:
query = query.filter(Certificate.destinations.any(Destination.id == terms[1]))
elif 'notify' in filt:
elif "destination" in terms:
query = query.filter(
Certificate.destinations.any(Destination.id == terms[1])
)
elif "notify" in filt:
query = query.filter(Certificate.notify == truthiness(terms[1]))
elif 'active' in filt:
elif "active" in filt:
query = query.filter(Certificate.active == truthiness(terms[1]))
elif 'cn' in terms:
elif "cn" in terms:
query = query.filter(
or_(
Certificate.cn.ilike(term),
Certificate.domains.any(Domain.name.ilike(term))
Certificate.domains.any(Domain.name.ilike(term)),
)
)
elif 'id' in terms:
elif "id" in terms:
query = query.filter(Certificate.id == cast(terms[1], Integer))
elif 'name' in terms:
elif "name" in terms:
query = query.filter(
or_(
Certificate.name.ilike(term),
@ -362,26 +389,35 @@ def render(args):
query = database.filter(query, Certificate, terms)
if show:
sub_query = database.session_query(Role.name).filter(Role.user_id == args['user'].id).subquery()
sub_query = (
database.session_query(Role.name)
.filter(Role.user_id == args["user"].id)
.subquery()
)
query = query.filter(
or_(
Certificate.user_id == args['user'].id,
Certificate.owner.in_(sub_query)
Certificate.user_id == args["user"].id, Certificate.owner.in_(sub_query)
)
)
if destination_id:
query = query.filter(Certificate.destinations.any(Destination.id == destination_id))
query = query.filter(
Certificate.destinations.any(Destination.id == destination_id)
)
if notification_id:
query = query.filter(Certificate.notifications.any(Notification.id == notification_id))
query = query.filter(
Certificate.notifications.any(Notification.id == notification_id)
)
if time_range:
to = arrow.now().replace(weeks=+time_range).format('YYYY-MM-DD')
now = arrow.now().format('YYYY-MM-DD')
query = query.filter(Certificate.not_after <= to).filter(Certificate.not_after >= now)
to = arrow.now().replace(weeks=+time_range).format("YYYY-MM-DD")
now = arrow.now().format("YYYY-MM-DD")
query = query.filter(Certificate.not_after <= to).filter(
Certificate.not_after >= now
)
if current_app.config.get('ALLOW_CERT_DELETION', False):
if current_app.config.get("ALLOW_CERT_DELETION", False):
query = query.filter(Certificate.deleted == False) # noqa
result = database.sort_and_page(query, Certificate, args)
@ -409,18 +445,20 @@ def query_common_name(common_name, args):
:param args:
:return:
"""
owner = args.pop('owner')
owner = args.pop("owner")
if not owner:
owner = '%'
owner = "%"
# only not expired certificates
current_time = arrow.utcnow()
result = Certificate.query.filter(Certificate.cn.ilike(common_name)) \
.filter(Certificate.owner.ilike(owner))\
.filter(Certificate.not_after >= current_time.format('YYYY-MM-DD')) \
.filter(Certificate.rotation.is_(True))\
result = (
Certificate.query.filter(Certificate.cn.ilike(common_name))
.filter(Certificate.owner.ilike(owner))
.filter(Certificate.not_after >= current_time.format("YYYY-MM-DD"))
.filter(Certificate.rotation.is_(True))
.all()
)
return result
@ -432,62 +470,77 @@ def create_csr(**csr_config):
:param csr_config:
"""
private_key = generate_private_key(csr_config.get('key_type'))
private_key = generate_private_key(csr_config.get("key_type"))
builder = x509.CertificateSigningRequestBuilder()
name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config['common_name'])]
if current_app.config.get('LEMUR_OWNER_EMAIL_IN_SUBJECT', True):
name_list.append(x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config['owner']))
if 'organization' in csr_config and csr_config['organization'].strip():
name_list.append(x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config['organization']))
if 'organizational_unit' in csr_config and csr_config['organizational_unit'].strip():
name_list.append(x509.NameAttribute(x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config['organizational_unit']))
if 'country' in csr_config and csr_config['country'].strip():
name_list.append(x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config['country']))
if 'state' in csr_config and csr_config['state'].strip():
name_list.append(x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config['state']))
if 'location' in csr_config and csr_config['location'].strip():
name_list.append(x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config['location']))
name_list = [x509.NameAttribute(x509.OID_COMMON_NAME, csr_config["common_name"])]
if current_app.config.get("LEMUR_OWNER_EMAIL_IN_SUBJECT", True):
name_list.append(
x509.NameAttribute(x509.OID_EMAIL_ADDRESS, csr_config["owner"])
)
if "organization" in csr_config and csr_config["organization"].strip():
name_list.append(
x509.NameAttribute(x509.OID_ORGANIZATION_NAME, csr_config["organization"])
)
if (
"organizational_unit" in csr_config
and csr_config["organizational_unit"].strip()
):
name_list.append(
x509.NameAttribute(
x509.OID_ORGANIZATIONAL_UNIT_NAME, csr_config["organizational_unit"]
)
)
if "country" in csr_config and csr_config["country"].strip():
name_list.append(
x509.NameAttribute(x509.OID_COUNTRY_NAME, csr_config["country"])
)
if "state" in csr_config and csr_config["state"].strip():
name_list.append(
x509.NameAttribute(x509.OID_STATE_OR_PROVINCE_NAME, csr_config["state"])
)
if "location" in csr_config and csr_config["location"].strip():
name_list.append(
x509.NameAttribute(x509.OID_LOCALITY_NAME, csr_config["location"])
)
builder = builder.subject_name(x509.Name(name_list))
extensions = csr_config.get('extensions', {})
critical_extensions = ['basic_constraints', 'sub_alt_names', 'key_usage']
noncritical_extensions = ['extended_key_usage']
extensions = csr_config.get("extensions", {})
critical_extensions = ["basic_constraints", "sub_alt_names", "key_usage"]
noncritical_extensions = ["extended_key_usage"]
for k, v in extensions.items():
if v:
if k in critical_extensions:
current_app.logger.debug('Adding Critical Extension: {0} {1}'.format(k, v))
if k == 'sub_alt_names':
if v['names']:
builder = builder.add_extension(v['names'], critical=True)
current_app.logger.debug(
"Adding Critical Extension: {0} {1}".format(k, v)
)
if k == "sub_alt_names":
if v["names"]:
builder = builder.add_extension(v["names"], critical=True)
else:
builder = builder.add_extension(v, critical=True)
if k in noncritical_extensions:
current_app.logger.debug('Adding Extension: {0} {1}'.format(k, v))
current_app.logger.debug("Adding Extension: {0} {1}".format(k, v))
builder = builder.add_extension(v, critical=False)
ski = extensions.get('subject_key_identifier', {})
if ski.get('include_ski', False):
ski = extensions.get("subject_key_identifier", {})
if ski.get("include_ski", False):
builder = builder.add_extension(
x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()),
critical=False
critical=False,
)
request = builder.sign(
private_key, hashes.SHA256(), default_backend()
)
request = builder.sign(private_key, hashes.SHA256(), default_backend())
# serialize our private key and CSR
private_key = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL, # would like to use PKCS8 but AWS ELBs don't like it
encryption_algorithm=serialization.NoEncryption()
).decode('utf-8')
encryption_algorithm=serialization.NoEncryption(),
).decode("utf-8")
csr = request.public_bytes(
encoding=serialization.Encoding.PEM
).decode('utf-8')
csr = request.public_bytes(encoding=serialization.Encoding.PEM).decode("utf-8")
return csr, private_key
@ -499,16 +552,19 @@ def stats(**kwargs):
:param kwargs:
:return:
"""
if kwargs.get('metric') == 'not_after':
if kwargs.get("metric") == "not_after":
start = arrow.utcnow()
end = start.replace(weeks=+32)
items = database.db.session.query(Certificate.issuer, func.count(Certificate.id)) \
.group_by(Certificate.issuer) \
.filter(Certificate.not_after <= end.format('YYYY-MM-DD')) \
.filter(Certificate.not_after >= start.format('YYYY-MM-DD')).all()
items = (
database.db.session.query(Certificate.issuer, func.count(Certificate.id))
.group_by(Certificate.issuer)
.filter(Certificate.not_after <= end.format("YYYY-MM-DD"))
.filter(Certificate.not_after >= start.format("YYYY-MM-DD"))
.all()
)
else:
attr = getattr(Certificate, kwargs.get('metric'))
attr = getattr(Certificate, kwargs.get("metric"))
query = database.db.session.query(attr, func.count(attr))
items = query.group_by(attr).all()
@ -519,7 +575,7 @@ def stats(**kwargs):
keys.append(key)
values.append(count)
return {'labels': keys, 'values': values}
return {"labels": keys, "values": values}
def get_account_number(arn):
@ -566,22 +622,24 @@ def get_certificate_primitives(certificate):
certificate via `create`.
"""
start, end = calculate_reissue_range(certificate.not_before, certificate.not_after)
ser = CertificateInputSchema().load(CertificateOutputSchema().dump(certificate).data)
ser = CertificateInputSchema().load(
CertificateOutputSchema().dump(certificate).data
)
assert not ser.errors, "Error re-serializing certificate: %s" % ser.errors
data = ser.data
# we can't quite tell if we are using a custom name, as this is an automated process (typically)
# we will rely on the Lemur generated name
data.pop('name', None)
data.pop("name", None)
# TODO this can be removed once we migrate away from cn
data['cn'] = data['common_name']
data["cn"] = data["common_name"]
# needed until we move off not_*
data['not_before'] = start
data['not_after'] = end
data['validity_start'] = start
data['validity_end'] = end
data["not_before"] = start
data["not_after"] = end
data["validity_start"] = start
data["validity_end"] = end
return data
@ -599,13 +657,13 @@ def reissue_certificate(certificate, replace=None, user=None):
# We do not want to re-use the CSR when creating a certificate because this defeats the purpose of rotation.
del primitives["csr"]
if not user:
primitives['creator'] = certificate.user
primitives["creator"] = certificate.user
else:
primitives['creator'] = user
primitives["creator"] = user
if replace:
primitives['replaces'] = [certificate]
primitives["replaces"] = [certificate]
new_cert = create(**primitives)

View File

@ -23,17 +23,18 @@ def get_sans_from_csr(data):
"""
sub_alt_names = []
try:
request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend())
request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend())
except Exception:
raise ValidationError('CSR presented is not valid.')
raise ValidationError("CSR presented is not valid.")
try:
alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName)
alt_names = request.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
for alt_name in alt_names.value:
sub_alt_names.append({
'nameType': type(alt_name).__name__,
'value': alt_name.value
})
sub_alt_names.append(
{"nameType": type(alt_name).__name__, "value": alt_name.value}
)
except x509.ExtensionNotFound:
pass

View File

@ -29,31 +29,45 @@ def ocsp_verify(cert, cert_path, issuer_chain_path):
:param issuer_chain_path:
:return bool: True if certificate is valid, False otherwise
"""
command = ['openssl', 'x509', '-noout', '-ocsp_uri', '-in', cert_path]
command = ["openssl", "x509", "-noout", "-ocsp_uri", "-in", cert_path]
p1 = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
url, err = p1.communicate()
if not url:
current_app.logger.debug("No OCSP URL in certificate {}".format(cert.serial_number))
current_app.logger.debug(
"No OCSP URL in certificate {}".format(cert.serial_number)
)
return None
p2 = subprocess.Popen(['openssl', 'ocsp', '-issuer', issuer_chain_path,
'-cert', cert_path, "-url", url.strip()],
p2 = subprocess.Popen(
[
"openssl",
"ocsp",
"-issuer",
issuer_chain_path,
"-cert",
cert_path,
"-url",
url.strip(),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stderr=subprocess.PIPE,
)
message, err = p2.communicate()
p_message = message.decode('utf-8')
p_message = message.decode("utf-8")
if 'error' in p_message or 'Error' in p_message:
if "error" in p_message or "Error" in p_message:
raise Exception("Got error when parsing OCSP url")
elif 'revoked' in p_message:
current_app.logger.debug("OCSP reports certificate revoked: {}".format(cert.serial_number))
elif "revoked" in p_message:
current_app.logger.debug(
"OCSP reports certificate revoked: {}".format(cert.serial_number)
)
return False
elif 'good' not in p_message:
elif "good" not in p_message:
raise Exception("Did not receive a valid response")
return True
@ -73,7 +87,9 @@ def crl_verify(cert, cert_path):
x509.OID_CRL_DISTRIBUTION_POINTS
).value
except x509.ExtensionNotFound:
current_app.logger.debug("No CRLDP extension in certificate {}".format(cert.serial_number))
current_app.logger.debug(
"No CRLDP extension in certificate {}".format(cert.serial_number)
)
return None
for p in distribution_points:
@ -92,8 +108,9 @@ def crl_verify(cert, cert_path):
except ConnectionError:
raise Exception("Unable to retrieve CRL: {0}".format(point))
crl_cache[point] = x509.load_der_x509_crl(response.content,
backend=default_backend())
crl_cache[point] = x509.load_der_x509_crl(
response.content, backend=default_backend()
)
else:
current_app.logger.debug("CRL point is cached {}".format(point))
@ -110,8 +127,9 @@ def crl_verify(cert, cert_path):
except x509.ExtensionNotFound:
pass
current_app.logger.debug("CRL reports certificate "
"revoked: {}".format(cert.serial_number))
current_app.logger.debug(
"CRL reports certificate " "revoked: {}".format(cert.serial_number)
)
return False
return True
@ -125,7 +143,7 @@ def verify(cert_path, issuer_chain_path):
:param issuer_chain_path:
:return: True if valid, False otherwise
"""
with open(cert_path, 'rt') as c:
with open(cert_path, "rt") as c:
try:
cert = parse_certificate(c.read())
except ValueError as e:
@ -154,10 +172,10 @@ def verify_string(cert_string, issuer_string):
:return: True if valid, False otherwise
"""
with mktempfile() as cert_tmp:
with open(cert_tmp, 'w') as f:
with open(cert_tmp, "w") as f:
f.write(cert_string)
with mktempfile() as issuer_tmp:
with open(issuer_tmp, 'w') as f:
with open(issuer_tmp, "w") as f:
f.write(issuer_string)
status = verify(cert_tmp, issuer_tmp)
return status

View File

@ -26,14 +26,14 @@ from lemur.certificates.schemas import (
certificate_upload_input_schema,
certificates_output_schema,
certificate_export_input_schema,
certificate_edit_input_schema
certificate_edit_input_schema,
)
from lemur.roles import service as role_service
from lemur.logs import service as log_service
mod = Blueprint('certificates', __name__)
mod = Blueprint("certificates", __name__)
api = Api(mod)
@ -128,8 +128,8 @@ class CertificatesListValid(AuthenticatedResource):
"""
parser = paginated_parser.copy()
args = parser.parse_args()
args['user'] = g.user
common_name = args['filter'].split(';')[1]
args["user"] = g.user
common_name = args["filter"].split(";")[1]
return service.query_common_name(common_name, args)
@ -228,16 +228,18 @@ class CertificatesNameQuery(AuthenticatedResource):
"""
parser = paginated_parser.copy()
parser.add_argument('timeRange', type=int, dest='time_range', location='args')
parser.add_argument('owner', type=inputs.boolean, location='args')
parser.add_argument('id', type=str, location='args')
parser.add_argument('active', type=inputs.boolean, location='args')
parser.add_argument('destinationId', type=int, dest="destination_id", location='args')
parser.add_argument('creator', type=str, location='args')
parser.add_argument('show', type=str, location='args')
parser.add_argument("timeRange", type=int, dest="time_range", location="args")
parser.add_argument("owner", type=inputs.boolean, location="args")
parser.add_argument("id", type=str, location="args")
parser.add_argument("active", type=inputs.boolean, location="args")
parser.add_argument(
"destinationId", type=int, dest="destination_id", location="args"
)
parser.add_argument("creator", type=str, location="args")
parser.add_argument("show", type=str, location="args")
args = parser.parse_args()
args['user'] = g.user
args["user"] = g.user
return service.query_name(certificate_name, args)
@ -336,16 +338,18 @@ class CertificatesList(AuthenticatedResource):
"""
parser = paginated_parser.copy()
parser.add_argument('timeRange', type=int, dest='time_range', location='args')
parser.add_argument('owner', type=inputs.boolean, location='args')
parser.add_argument('id', type=str, location='args')
parser.add_argument('active', type=inputs.boolean, location='args')
parser.add_argument('destinationId', type=int, dest="destination_id", location='args')
parser.add_argument('creator', type=str, location='args')
parser.add_argument('show', type=str, location='args')
parser.add_argument("timeRange", type=int, dest="time_range", location="args")
parser.add_argument("owner", type=inputs.boolean, location="args")
parser.add_argument("id", type=str, location="args")
parser.add_argument("active", type=inputs.boolean, location="args")
parser.add_argument(
"destinationId", type=int, dest="destination_id", location="args"
)
parser.add_argument("creator", type=str, location="args")
parser.add_argument("show", type=str, location="args")
args = parser.parse_args()
args['user'] = g.user
args["user"] = g.user
return service.render(args)
@validate_schema(certificate_input_schema, certificate_output_schema)
@ -463,24 +467,31 @@ class CertificatesList(AuthenticatedResource):
:statuscode 403: unauthenticated
"""
role = role_service.get_by_name(data['authority'].owner)
role = role_service.get_by_name(data["authority"].owner)
# all the authority role members should be allowed
roles = [x.name for x in data['authority'].roles]
roles = [x.name for x in data["authority"].roles]
# allow "owner" roles by team DL
roles.append(role)
authority_permission = AuthorityPermission(data['authority'].id, roles)
authority_permission = AuthorityPermission(data["authority"].id, roles)
if authority_permission.can():
data['creator'] = g.user
data["creator"] = g.user
cert = service.create(**data)
if isinstance(cert, Certificate):
# only log if created, not pending
log_service.create(g.user, 'create_cert', certificate=cert)
log_service.create(g.user, "create_cert", certificate=cert)
return cert
return dict(message="You are not authorized to use the authority: {0}".format(data['authority'].name)), 403
return (
dict(
message="You are not authorized to use the authority: {0}".format(
data["authority"].name
)
),
403,
)
class CertificatesUpload(AuthenticatedResource):
@ -583,12 +594,14 @@ class CertificatesUpload(AuthenticatedResource):
:statuscode 200: no error
"""
data['creator'] = g.user
if data.get('destinations'):
if data.get('private_key'):
data["creator"] = g.user
if data.get("destinations"):
if data.get("private_key"):
return service.upload(**data)
else:
raise Exception("Private key must be provided in order to upload certificate to AWS")
raise Exception(
"Private key must be provided in order to upload certificate to AWS"
)
return service.upload(**data)
@ -600,10 +613,12 @@ class CertificatesStats(AuthenticatedResource):
super(CertificatesStats, self).__init__()
def get(self):
self.reqparse.add_argument('metric', type=str, location='args')
self.reqparse.add_argument('range', default=32, type=int, location='args')
self.reqparse.add_argument('destinationId', dest='destination_id', location='args')
self.reqparse.add_argument('active', type=str, default='true', location='args')
self.reqparse.add_argument("metric", type=str, location="args")
self.reqparse.add_argument("range", default=32, type=int, location="args")
self.reqparse.add_argument(
"destinationId", dest="destination_id", location="args"
)
self.reqparse.add_argument("active", type=str, default="true", location="args")
args = self.reqparse.parse_args()
@ -655,12 +670,12 @@ class CertificatePrivateKey(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can():
return dict(message='You are not authorized to view this key'), 403
return dict(message="You are not authorized to view this key"), 403
log_service.create(g.current_user, 'key_view', certificate=cert)
log_service.create(g.current_user, "key_view", certificate=cert)
response = make_response(jsonify(key=cert.private_key), 200)
response.headers['cache-control'] = 'private, max-age=0, no-cache, no-store'
response.headers['pragma'] = 'no-cache'
response.headers["cache-control"] = "private, max-age=0, no-cache, no-store"
response.headers["pragma"] = "no-cache"
return response
@ -850,19 +865,25 @@ class Certificates(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can():
return dict(message='You are not authorized to update this certificate'), 403
return (
dict(message="You are not authorized to update this certificate"),
403,
)
for destination in data['destinations']:
for destination in data["destinations"]:
if destination.plugin.requires_key:
if not cert.private_key:
return dict(
message='Unable to add destination: {0}. Certificate does not have required private key.'.format(
return (
dict(
message="Unable to add destination: {0}. Certificate does not have required private key.".format(
destination.label
)
), 400
),
400,
)
cert = service.update(certificate_id, **data)
log_service.create(g.current_user, 'update_cert', certificate=cert)
log_service.create(g.current_user, "update_cert", certificate=cert)
return cert
def delete(self, certificate_id, data=None):
@ -891,7 +912,7 @@ class Certificates(AuthenticatedResource):
:statuscode 405: certificate deletion is disabled
"""
if not current_app.config.get('ALLOW_CERT_DELETION', False):
if not current_app.config.get("ALLOW_CERT_DELETION", False):
return dict(message="Certificate deletion is disabled"), 405
cert = service.get(certificate_id)
@ -908,11 +929,14 @@ class Certificates(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can():
return dict(message='You are not authorized to delete this certificate'), 403
return (
dict(message="You are not authorized to delete this certificate"),
403,
)
service.update(certificate_id, deleted=True)
log_service.create(g.current_user, 'delete_cert', certificate=cert)
return 'Certificate deleted', 204
log_service.create(g.current_user, "delete_cert", certificate=cert)
return "Certificate deleted", 204
class NotificationCertificatesList(AuthenticatedResource):
@ -1012,17 +1036,19 @@ class NotificationCertificatesList(AuthenticatedResource):
"""
parser = paginated_parser.copy()
parser.add_argument('timeRange', type=int, dest='time_range', location='args')
parser.add_argument('owner', type=inputs.boolean, location='args')
parser.add_argument('id', type=str, location='args')
parser.add_argument('active', type=inputs.boolean, location='args')
parser.add_argument('destinationId', type=int, dest="destination_id", location='args')
parser.add_argument('creator', type=str, location='args')
parser.add_argument('show', type=str, location='args')
parser.add_argument("timeRange", type=int, dest="time_range", location="args")
parser.add_argument("owner", type=inputs.boolean, location="args")
parser.add_argument("id", type=str, location="args")
parser.add_argument("active", type=inputs.boolean, location="args")
parser.add_argument(
"destinationId", type=int, dest="destination_id", location="args"
)
parser.add_argument("creator", type=str, location="args")
parser.add_argument("show", type=str, location="args")
args = parser.parse_args()
args['notification_id'] = notification_id
args['user'] = g.current_user
args["notification_id"] = notification_id
args["user"] = g.current_user
return service.render(args)
@ -1195,30 +1221,48 @@ class CertificateExport(AuthenticatedResource):
if not cert:
return dict(message="Cannot find specified certificate"), 404
plugin = data['plugin']['plugin_object']
plugin = data["plugin"]["plugin_object"]
if plugin.requires_key:
if not cert.private_key:
return dict(
message='Unable to export certificate, plugin: {0} requires a private key but no key was found.'.format(
plugin.slug)), 400
return (
dict(
message="Unable to export certificate, plugin: {0} requires a private key but no key was found.".format(
plugin.slug
)
),
400,
)
else:
# allow creators
if g.current_user != cert.user:
owner_role = role_service.get_by_name(cert.owner)
permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
permission = CertificatePermission(
owner_role, [x.name for x in cert.roles]
)
if not permission.can():
return dict(message='You are not authorized to export this certificate.'), 403
return (
dict(
message="You are not authorized to export this certificate."
),
403,
)
options = data['plugin']['plugin_options']
options = data["plugin"]["plugin_options"]
log_service.create(g.current_user, 'key_view', certificate=cert)
extension, passphrase, data = plugin.export(cert.body, cert.chain, cert.private_key, options)
log_service.create(g.current_user, "key_view", certificate=cert)
extension, passphrase, data = plugin.export(
cert.body, cert.chain, cert.private_key, options
)
# we take a hit in message size when b64 encoding
return dict(extension=extension, passphrase=passphrase, data=base64.b64encode(data).decode('utf-8'))
return dict(
extension=extension,
passphrase=passphrase,
data=base64.b64encode(data).decode("utf-8"),
)
class CertificateRevoke(AuthenticatedResource):
@ -1269,30 +1313,66 @@ class CertificateRevoke(AuthenticatedResource):
permission = CertificatePermission(owner_role, [x.name for x in cert.roles])
if not permission.can():
return dict(message='You are not authorized to revoke this certificate.'), 403
return (
dict(message="You are not authorized to revoke this certificate."),
403,
)
if not cert.external_id:
return dict(message='Cannot revoke certificate. No external id found.'), 400
return dict(message="Cannot revoke certificate. No external id found."), 400
if cert.endpoints:
return dict(message='Cannot revoke certificate. Endpoints are deployed with the given certificate.'), 403
return (
dict(
message="Cannot revoke certificate. Endpoints are deployed with the given certificate."
),
403,
)
plugin = plugins.get(cert.authority.plugin_name)
plugin.revoke_certificate(cert, data)
log_service.create(g.current_user, 'revoke_cert', certificate=cert)
log_service.create(g.current_user, "revoke_cert", certificate=cert)
return dict(id=cert.id)
api.add_resource(CertificateRevoke, '/certificates/<int:certificate_id>/revoke', endpoint='revokeCertificate')
api.add_resource(CertificatesNameQuery, '/certificates/name/<string:certificate_name>', endpoint='certificatesNameQuery')
api.add_resource(CertificatesList, '/certificates', endpoint='certificates')
api.add_resource(CertificatesListValid, '/certificates/valid', endpoint='certificatesListValid')
api.add_resource(Certificates, '/certificates/<int:certificate_id>', endpoint='certificate')
api.add_resource(CertificatesStats, '/certificates/stats', endpoint='certificateStats')
api.add_resource(CertificatesUpload, '/certificates/upload', endpoint='certificateUpload')
api.add_resource(CertificatePrivateKey, '/certificates/<int:certificate_id>/key', endpoint='privateKeyCertificates')
api.add_resource(CertificateExport, '/certificates/<int:certificate_id>/export', endpoint='exportCertificate')
api.add_resource(NotificationCertificatesList, '/notifications/<int:notification_id>/certificates',
endpoint='notificationCertificates')
api.add_resource(CertificatesReplacementsList, '/certificates/<int:certificate_id>/replacements',
endpoint='replacements')
api.add_resource(
CertificateRevoke,
"/certificates/<int:certificate_id>/revoke",
endpoint="revokeCertificate",
)
api.add_resource(
CertificatesNameQuery,
"/certificates/name/<string:certificate_name>",
endpoint="certificatesNameQuery",
)
api.add_resource(CertificatesList, "/certificates", endpoint="certificates")
api.add_resource(
CertificatesListValid, "/certificates/valid", endpoint="certificatesListValid"
)
api.add_resource(
Certificates, "/certificates/<int:certificate_id>", endpoint="certificate"
)
api.add_resource(CertificatesStats, "/certificates/stats", endpoint="certificateStats")
api.add_resource(
CertificatesUpload, "/certificates/upload", endpoint="certificateUpload"
)
api.add_resource(
CertificatePrivateKey,
"/certificates/<int:certificate_id>/key",
endpoint="privateKeyCertificates",
)
api.add_resource(
CertificateExport,
"/certificates/<int:certificate_id>/export",
endpoint="exportCertificate",
)
api.add_resource(
NotificationCertificatesList,
"/notifications/<int:notification_id>/certificates",
endpoint="notificationCertificates",
)
api.add_resource(
CertificatesReplacementsList,
"/certificates/<int:certificate_id>/replacements",
endpoint="replacements",
)

View File

@ -32,8 +32,11 @@ else:
def make_celery(app):
celery = Celery(app.import_name, backend=app.config.get('CELERY_RESULT_BACKEND'),
broker=app.config.get('CELERY_BROKER_URL'))
celery = Celery(
app.import_name,
backend=app.config.get("CELERY_RESULT_BACKEND"),
broker=app.config.get("CELERY_BROKER_URL"),
)
celery.conf.update(app.config)
TaskBase = celery.Task
@ -53,6 +56,7 @@ celery = make_celery(flask_app)
def is_task_active(fun, task_id, args):
from celery.task.control import inspect
i = inspect()
active_tasks = i.active()
for _, tasks in active_tasks.items():
@ -99,7 +103,7 @@ def fetch_acme_cert(id):
# We only care about certs using the acme-issuer plugin
for cert in pending_certs:
cert_authority = get_authority(cert.authority_id)
if cert_authority.plugin_name == 'acme-issuer':
if cert_authority.plugin_name == "acme-issuer":
acme_certs.append(cert)
else:
wrong_issuer += 1
@ -112,20 +116,22 @@ def fetch_acme_cert(id):
# It's necessary to reload the pending cert due to detached instance: http://sqlalche.me/e/bhk3
pending_cert = pending_certificate_service.get(cert.get("pending_cert").id)
if not pending_cert:
log_data["message"] = "Pending certificate doesn't exist anymore. Was it resolved by another process?"
log_data[
"message"
] = "Pending certificate doesn't exist anymore. Was it resolved by another process?"
current_app.logger.error(log_data)
continue
if real_cert:
# If a real certificate was returned from issuer, then create it in Lemur and mark
# the pending certificate as resolved
final_cert = pending_certificate_service.create_certificate(pending_cert, real_cert, pending_cert.user)
pending_certificate_service.update(
cert.get("pending_cert").id,
resolved_cert_id=final_cert.id
final_cert = pending_certificate_service.create_certificate(
pending_cert, real_cert, pending_cert.user
)
pending_certificate_service.update(
cert.get("pending_cert").id,
resolved=True
cert.get("pending_cert").id, resolved_cert_id=final_cert.id
)
pending_certificate_service.update(
cert.get("pending_cert").id, resolved=True
)
# add metrics to metrics extension
new += 1
@ -139,17 +145,17 @@ def fetch_acme_cert(id):
if pending_cert.number_attempts > 4:
error_log["message"] = "Deleting pending certificate"
send_pending_failure_notification(pending_cert, notify_owner=pending_cert.notify)
send_pending_failure_notification(
pending_cert, notify_owner=pending_cert.notify
)
# Mark the pending cert as resolved
pending_certificate_service.update(
cert.get("pending_cert").id,
resolved=True
cert.get("pending_cert").id, resolved=True
)
else:
pending_certificate_service.increment_attempt(pending_cert)
pending_certificate_service.update(
cert.get("pending_cert").id,
status=str(cert.get("last_error"))
cert.get("pending_cert").id, status=str(cert.get("last_error"))
)
# Add failed pending cert task back to queue
fetch_acme_cert.delay(id)
@ -161,9 +167,7 @@ def fetch_acme_cert(id):
current_app.logger.debug(log_data)
print(
"[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format(
new=new,
failed=failed,
wrong_issuer=wrong_issuer
new=new, failed=failed, wrong_issuer=wrong_issuer
)
)
@ -175,7 +179,7 @@ def fetch_all_pending_acme_certs():
log_data = {
"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name),
"message": "Starting job."
"message": "Starting job.",
}
current_app.logger.debug(log_data)
@ -183,7 +187,7 @@ def fetch_all_pending_acme_certs():
# We only care about certs using the acme-issuer plugin
for cert in pending_certs:
cert_authority = get_authority(cert.authority_id)
if cert_authority.plugin_name == 'acme-issuer':
if cert_authority.plugin_name == "acme-issuer":
if datetime.now(timezone.utc) - cert.last_updated > timedelta(minutes=5):
log_data["message"] = "Triggering job for cert {}".format(cert.name)
log_data["cert_name"] = cert.name
@ -195,17 +199,15 @@ def fetch_all_pending_acme_certs():
@celery.task()
def remove_old_acme_certs():
"""Prune old pending acme certificates from the database"""
log_data = {
"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)
}
pending_certs = pending_certificate_service.get_pending_certs('all')
log_data = {"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)}
pending_certs = pending_certificate_service.get_pending_certs("all")
# Delete pending certs more than a week old
for cert in pending_certs:
if datetime.now(timezone.utc) - cert.last_updated > timedelta(days=7):
log_data['pending_cert_id'] = cert.id
log_data['pending_cert_name'] = cert.name
log_data['message'] = "Deleting pending certificate"
log_data["pending_cert_id"] = cert.id
log_data["pending_cert_name"] = cert.name
log_data["message"] = "Deleting pending certificate"
current_app.logger.debug(log_data)
pending_certificate_service.delete(cert)
@ -218,7 +220,9 @@ def clean_all_sources():
"""
sources = validate_sources("all")
for source in sources:
current_app.logger.debug("Creating celery task to clean source {}".format(source.label))
current_app.logger.debug(
"Creating celery task to clean source {}".format(source.label)
)
clean_source.delay(source.label)
@ -242,7 +246,9 @@ def sync_all_sources():
"""
sources = validate_sources("all")
for source in sources:
current_app.logger.debug("Creating celery task to sync source {}".format(source.label))
current_app.logger.debug(
"Creating celery task to sync source {}".format(source.label)
)
sync_source.delay(source.label)
@ -277,7 +283,9 @@ def sync_source(source):
log_data["message"] = "Error syncing source: Time limit exceeded."
current_app.logger.error(log_data)
sentry.captureException()
metrics.send('sync_source_timeout', 'counter', 1, metric_tags={'source': source})
metrics.send(
"sync_source_timeout", "counter", 1, metric_tags={"source": source}
)
return
log_data["message"] = "Done syncing source"

View File

@ -9,18 +9,20 @@ from lemur.extensions import sentry
from lemur.constants import SAN_NAMING_TEMPLATE, DEFAULT_NAMING_TEMPLATE
def text_to_slug(value, joiner='-'):
def text_to_slug(value, joiner="-"):
"""
Normalize a string to a "slug" value, stripping character accents and removing non-alphanum characters.
A series of non-alphanumeric characters is replaced with the joiner character.
"""
# Strip all character accents: decompose Unicode characters and then drop combining chars.
value = ''.join(c for c in unicodedata.normalize('NFKD', value) if not unicodedata.combining(c))
value = "".join(
c for c in unicodedata.normalize("NFKD", value) if not unicodedata.combining(c)
)
# Replace all remaining non-alphanumeric characters with joiner string. Multiple characters get collapsed into a
# single joiner. Except, keep 'xn--' used in IDNA domain names as is.
value = re.sub(r'[^A-Za-z0-9.]+(?<!xn--)', joiner, value)
value = re.sub(r"[^A-Za-z0-9.]+(?<!xn--)", joiner, value)
# '-' in the beginning or end of string looks ugly.
return value.strip(joiner)
@ -48,12 +50,12 @@ def certificate_name(common_name, issuer, not_before, not_after, san):
temp = t.format(
subject=common_name,
issuer=issuer.replace(' ', ''),
not_before=not_before.strftime('%Y%m%d'),
not_after=not_after.strftime('%Y%m%d')
issuer=issuer.replace(" ", ""),
not_before=not_before.strftime("%Y%m%d"),
not_after=not_after.strftime("%Y%m%d"),
)
temp = temp.replace('*', "WILDCARD")
temp = temp.replace("*", "WILDCARD")
return text_to_slug(temp)
@ -69,9 +71,9 @@ def common_name(cert):
:return: Common name or None
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_COMMON_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_COMMON_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get common name! {0}".format(e))
@ -84,9 +86,9 @@ def organization(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_ORGANIZATION_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get organization! {0}".format(e))
@ -99,9 +101,9 @@ def organizational_unit(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_ORGANIZATIONAL_UNIT_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_ORGANIZATIONAL_UNIT_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get organizational unit! {0}".format(e))
@ -114,9 +116,9 @@ def country(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_COUNTRY_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_COUNTRY_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get country! {0}".format(e))
@ -129,9 +131,9 @@ def state(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_STATE_OR_PROVINCE_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_STATE_OR_PROVINCE_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get state! {0}".format(e))
@ -144,9 +146,9 @@ def location(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_LOCALITY_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_LOCALITY_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get location! {0}".format(e))
@ -224,7 +226,7 @@ def bitstrength(cert):
return cert.public_key().key_size
except AttributeError:
sentry.captureException()
current_app.logger.debug('Unable to get bitstrength.')
current_app.logger.debug("Unable to get bitstrength.")
def issuer(cert):
@ -239,16 +241,19 @@ def issuer(cert):
"""
# If certificate is self-signed, we return a special value -- there really is no distinct "issuer" for it
if is_selfsigned(cert):
return '<selfsigned>'
return "<selfsigned>"
# Try Common Name or fall back to Organization name
attrs = (cert.issuer.get_attributes_for_oid(x509.OID_COMMON_NAME) or
cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME))
attrs = cert.issuer.get_attributes_for_oid(
x509.OID_COMMON_NAME
) or cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)
if not attrs:
current_app.logger.error("Unable to get issuer! Cert serial {:x}".format(cert.serial_number))
return '<unknown>'
current_app.logger.error(
"Unable to get issuer! Cert serial {:x}".format(cert.serial_number)
)
return "<unknown>"
return text_to_slug(attrs[0].value, '')
return text_to_slug(attrs[0].value, "")
def not_before(cert):

View File

@ -25,6 +25,7 @@ class Hex(Field):
"""
A hex formatted string.
"""
def _serialize(self, value, attr, obj):
if value:
value = hex(int(value))[2:].upper()
@ -48,25 +49,25 @@ class ArrowDateTime(Field):
"""
DATEFORMAT_SERIALIZATION_FUNCS = {
'iso': utils.isoformat,
'iso8601': utils.isoformat,
'rfc': utils.rfcformat,
'rfc822': utils.rfcformat,
"iso": utils.isoformat,
"iso8601": utils.isoformat,
"rfc": utils.rfcformat,
"rfc822": utils.rfcformat,
}
DATEFORMAT_DESERIALIZATION_FUNCS = {
'iso': utils.from_iso,
'iso8601': utils.from_iso,
'rfc': utils.from_rfc,
'rfc822': utils.from_rfc,
"iso": utils.from_iso,
"iso8601": utils.from_iso,
"rfc": utils.from_rfc,
"rfc822": utils.from_rfc,
}
DEFAULT_FORMAT = 'iso'
DEFAULT_FORMAT = "iso"
localtime = False
default_error_messages = {
'invalid': 'Not a valid datetime.',
'format': '"{input}" cannot be formatted as a datetime.',
"invalid": "Not a valid datetime.",
"format": '"{input}" cannot be formatted as a datetime.',
}
def __init__(self, format=None, **kwargs):
@ -89,34 +90,36 @@ class ArrowDateTime(Field):
try:
return format_func(value, localtime=self.localtime)
except (AttributeError, ValueError) as err:
self.fail('format', input=value)
self.fail("format", input=value)
else:
return value.strftime(self.dateformat)
def _deserialize(self, value, attr, data):
if not value: # Falsy values, e.g. '', None, [] are not valid
raise self.fail('invalid')
raise self.fail("invalid")
self.dateformat = self.dateformat or self.DEFAULT_FORMAT
func = self.DATEFORMAT_DESERIALIZATION_FUNCS.get(self.dateformat)
if func:
try:
return arrow.get(func(value))
except (TypeError, AttributeError, ValueError):
raise self.fail('invalid')
raise self.fail("invalid")
elif self.dateformat:
try:
return dt.datetime.strptime(value, self.dateformat)
except (TypeError, AttributeError, ValueError):
raise self.fail('invalid')
raise self.fail("invalid")
elif utils.dateutil_available:
try:
return arrow.get(utils.from_datestring(value))
except TypeError:
raise self.fail('invalid')
raise self.fail("invalid")
else:
warnings.warn('It is recommended that you install python-dateutil '
'for improved datetime deserialization.')
raise self.fail('invalid')
warnings.warn(
"It is recommended that you install python-dateutil "
"for improved datetime deserialization."
)
raise self.fail("invalid")
class KeyUsageExtension(Field):
@ -131,73 +134,75 @@ class KeyUsageExtension(Field):
def _serialize(self, value, attr, obj):
return {
'useDigitalSignature': value.digital_signature,
'useNonRepudiation': value.content_commitment,
'useKeyEncipherment': value.key_encipherment,
'useDataEncipherment': value.data_encipherment,
'useKeyAgreement': value.key_agreement,
'useKeyCertSign': value.key_cert_sign,
'useCRLSign': value.crl_sign,
'useEncipherOnly': value._encipher_only,
'useDecipherOnly': value._decipher_only
"useDigitalSignature": value.digital_signature,
"useNonRepudiation": value.content_commitment,
"useKeyEncipherment": value.key_encipherment,
"useDataEncipherment": value.data_encipherment,
"useKeyAgreement": value.key_agreement,
"useKeyCertSign": value.key_cert_sign,
"useCRLSign": value.crl_sign,
"useEncipherOnly": value._encipher_only,
"useDecipherOnly": value._decipher_only,
}
def _deserialize(self, value, attr, data):
keyusages = {
'digital_signature': False,
'content_commitment': False,
'key_encipherment': False,
'data_encipherment': False,
'key_agreement': False,
'key_cert_sign': False,
'crl_sign': False,
'encipher_only': False,
'decipher_only': False
"digital_signature": False,
"content_commitment": False,
"key_encipherment": False,
"data_encipherment": False,
"key_agreement": False,
"key_cert_sign": False,
"crl_sign": False,
"encipher_only": False,
"decipher_only": False,
}
for k, v in value.items():
if k == 'useDigitalSignature':
keyusages['digital_signature'] = v
if k == "useDigitalSignature":
keyusages["digital_signature"] = v
elif k == 'useNonRepudiation':
keyusages['content_commitment'] = v
elif k == "useNonRepudiation":
keyusages["content_commitment"] = v
elif k == 'useKeyEncipherment':
keyusages['key_encipherment'] = v
elif k == "useKeyEncipherment":
keyusages["key_encipherment"] = v
elif k == 'useDataEncipherment':
keyusages['data_encipherment'] = v
elif k == "useDataEncipherment":
keyusages["data_encipherment"] = v
elif k == 'useKeyCertSign':
keyusages['key_cert_sign'] = v
elif k == "useKeyCertSign":
keyusages["key_cert_sign"] = v
elif k == 'useCRLSign':
keyusages['crl_sign'] = v
elif k == "useCRLSign":
keyusages["crl_sign"] = v
elif k == 'useKeyAgreement':
keyusages['key_agreement'] = v
elif k == "useKeyAgreement":
keyusages["key_agreement"] = v
elif k == 'useEncipherOnly' and v:
keyusages['encipher_only'] = True
keyusages['key_agreement'] = True
elif k == "useEncipherOnly" and v:
keyusages["encipher_only"] = True
keyusages["key_agreement"] = True
elif k == 'useDecipherOnly' and v:
keyusages['decipher_only'] = True
keyusages['key_agreement'] = True
elif k == "useDecipherOnly" and v:
keyusages["decipher_only"] = True
keyusages["key_agreement"] = True
if keyusages['encipher_only'] and keyusages['decipher_only']:
raise ValidationError('A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages.')
if keyusages["encipher_only"] and keyusages["decipher_only"]:
raise ValidationError(
"A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages."
)
return x509.KeyUsage(
digital_signature=keyusages['digital_signature'],
content_commitment=keyusages['content_commitment'],
key_encipherment=keyusages['key_encipherment'],
data_encipherment=keyusages['data_encipherment'],
key_agreement=keyusages['key_agreement'],
key_cert_sign=keyusages['key_cert_sign'],
crl_sign=keyusages['crl_sign'],
encipher_only=keyusages['encipher_only'],
decipher_only=keyusages['decipher_only']
digital_signature=keyusages["digital_signature"],
content_commitment=keyusages["content_commitment"],
key_encipherment=keyusages["key_encipherment"],
data_encipherment=keyusages["data_encipherment"],
key_agreement=keyusages["key_agreement"],
key_cert_sign=keyusages["key_cert_sign"],
crl_sign=keyusages["crl_sign"],
encipher_only=keyusages["encipher_only"],
decipher_only=keyusages["decipher_only"],
)
@ -216,69 +221,77 @@ class ExtendedKeyUsageExtension(Field):
usage_list = {}
for usage in usages:
if usage == x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH:
usage_list['useClientAuthentication'] = True
usage_list["useClientAuthentication"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.SERVER_AUTH:
usage_list['useServerAuthentication'] = True
usage_list["useServerAuthentication"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.CODE_SIGNING:
usage_list['useCodeSigning'] = True
usage_list["useCodeSigning"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION:
usage_list['useEmailProtection'] = True
usage_list["useEmailProtection"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.TIME_STAMPING:
usage_list['useTimestamping'] = True
usage_list["useTimestamping"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING:
usage_list['useOCSPSigning'] = True
usage_list["useOCSPSigning"] = True
elif usage.dotted_string == '1.3.6.1.5.5.7.3.14':
usage_list['useEapOverLAN'] = True
elif usage.dotted_string == "1.3.6.1.5.5.7.3.14":
usage_list["useEapOverLAN"] = True
elif usage.dotted_string == '1.3.6.1.5.5.7.3.13':
usage_list['useEapOverPPP'] = True
elif usage.dotted_string == "1.3.6.1.5.5.7.3.13":
usage_list["useEapOverPPP"] = True
elif usage.dotted_string == '1.3.6.1.4.1.311.20.2.2':
usage_list['useSmartCardLogon'] = True
elif usage.dotted_string == "1.3.6.1.4.1.311.20.2.2":
usage_list["useSmartCardLogon"] = True
else:
current_app.logger.warning('Unable to serialize ExtendedKeyUsage with OID: {usage}'.format(usage=usage.dotted_string))
current_app.logger.warning(
"Unable to serialize ExtendedKeyUsage with OID: {usage}".format(
usage=usage.dotted_string
)
)
return usage_list
def _deserialize(self, value, attr, data):
usage_oids = []
for k, v in value.items():
if k == 'useClientAuthentication' and v:
if k == "useClientAuthentication" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH)
elif k == 'useServerAuthentication' and v:
elif k == "useServerAuthentication" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.SERVER_AUTH)
elif k == 'useCodeSigning' and v:
elif k == "useCodeSigning" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.CODE_SIGNING)
elif k == 'useEmailProtection' and v:
elif k == "useEmailProtection" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION)
elif k == 'useTimestamping' and v:
elif k == "useTimestamping" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.TIME_STAMPING)
elif k == 'useOCSPSigning' and v:
elif k == "useOCSPSigning" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING)
elif k == 'useEapOverLAN' and v:
elif k == "useEapOverLAN" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.14"))
elif k == 'useEapOverPPP' and v:
elif k == "useEapOverPPP" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.13"))
elif k == 'useSmartCardLogon' and v:
elif k == "useSmartCardLogon" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.4.1.311.20.2.2"))
else:
current_app.logger.warning('Unable to deserialize ExtendedKeyUsage with name: {key}'.format(key=k))
current_app.logger.warning(
"Unable to deserialize ExtendedKeyUsage with name: {key}".format(
key=k
)
)
return x509.ExtendedKeyUsage(usage_oids)
@ -294,15 +307,17 @@ class BasicConstraintsExtension(Field):
"""
def _serialize(self, value, attr, obj):
return {'ca': value.ca, 'path_length': value.path_length}
return {"ca": value.ca, "path_length": value.path_length}
def _deserialize(self, value, attr, data):
ca = value.get('ca', False)
path_length = value.get('path_length', None)
ca = value.get("ca", False)
path_length = value.get("path_length", None)
if ca:
if not isinstance(path_length, (type(None), int)):
raise ValidationError('A CA certificate path_length (for BasicConstraints) must be None or an integer.')
raise ValidationError(
"A CA certificate path_length (for BasicConstraints) must be None or an integer."
)
return x509.BasicConstraints(ca=True, path_length=path_length)
else:
return x509.BasicConstraints(ca=False, path_length=None)
@ -317,6 +332,7 @@ class SubjectAlternativeNameExtension(Field):
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
def _serialize(self, value, attr, obj):
general_names = []
name_type = None
@ -326,53 +342,59 @@ class SubjectAlternativeNameExtension(Field):
value = name.value
if isinstance(name, x509.DNSName):
name_type = 'DNSName'
name_type = "DNSName"
elif isinstance(name, x509.IPAddress):
if isinstance(value, ipaddress.IPv4Network):
name_type = 'IPNetwork'
name_type = "IPNetwork"
else:
name_type = 'IPAddress'
name_type = "IPAddress"
value = str(value)
elif isinstance(name, x509.UniformResourceIdentifier):
name_type = 'uniformResourceIdentifier'
name_type = "uniformResourceIdentifier"
elif isinstance(name, x509.DirectoryName):
name_type = 'directoryName'
name_type = "directoryName"
elif isinstance(name, x509.RFC822Name):
name_type = 'rfc822Name'
name_type = "rfc822Name"
elif isinstance(name, x509.RegisteredID):
name_type = 'registeredID'
name_type = "registeredID"
value = value.dotted_string
else:
current_app.logger.warning('Unknown SubAltName type: {name}'.format(name=name))
current_app.logger.warning(
"Unknown SubAltName type: {name}".format(name=name)
)
continue
general_names.append({'nameType': name_type, 'value': value})
general_names.append({"nameType": name_type, "value": value})
return general_names
def _deserialize(self, value, attr, data):
general_names = []
for name in value:
if name['nameType'] == 'DNSName':
validators.sensitive_domain(name['value'])
general_names.append(x509.DNSName(name['value']))
if name["nameType"] == "DNSName":
validators.sensitive_domain(name["value"])
general_names.append(x509.DNSName(name["value"]))
elif name['nameType'] == 'IPAddress':
general_names.append(x509.IPAddress(ipaddress.ip_address(name['value'])))
elif name["nameType"] == "IPAddress":
general_names.append(
x509.IPAddress(ipaddress.ip_address(name["value"]))
)
elif name['nameType'] == 'IPNetwork':
general_names.append(x509.IPAddress(ipaddress.ip_network(name['value'])))
elif name["nameType"] == "IPNetwork":
general_names.append(
x509.IPAddress(ipaddress.ip_network(name["value"]))
)
elif name['nameType'] == 'uniformResourceIdentifier':
general_names.append(x509.UniformResourceIdentifier(name['value']))
elif name["nameType"] == "uniformResourceIdentifier":
general_names.append(x509.UniformResourceIdentifier(name["value"]))
elif name['nameType'] == 'directoryName':
elif name["nameType"] == "directoryName":
# TODO: Need to parse a string in name['value'] like:
# 'CN=Common Name, O=Org Name, OU=OrgUnit Name, C=US, ST=ST, L=City/emailAddress=person@example.com'
# or
@ -390,26 +412,32 @@ class SubjectAlternativeNameExtension(Field):
# general_names.append(x509.DirectoryName(x509.Name(BLAH))))
pass
elif name['nameType'] == 'rfc822Name':
general_names.append(x509.RFC822Name(name['value']))
elif name["nameType"] == "rfc822Name":
general_names.append(x509.RFC822Name(name["value"]))
elif name['nameType'] == 'registeredID':
general_names.append(x509.RegisteredID(x509.ObjectIdentifier(name['value'])))
elif name["nameType"] == "registeredID":
general_names.append(
x509.RegisteredID(x509.ObjectIdentifier(name["value"]))
)
elif name['nameType'] == 'otherName':
elif name["nameType"] == "otherName":
# This has two inputs (type and value), so it doesn't fit the mold of the rest of these GeneralName entities.
# general_names.append(x509.OtherName(name['type'], bytes(name['value']), 'utf-8'))
pass
elif name['nameType'] == 'x400Address':
elif name["nameType"] == "x400Address":
# The Python Cryptography library doesn't support x400Address types (yet?)
pass
elif name['nameType'] == 'EDIPartyName':
elif name["nameType"] == "EDIPartyName":
# The Python Cryptography library doesn't support EDIPartyName types (yet?)
pass
else:
current_app.logger.warning('Unable to deserialize SubAltName with type: {name_type}'.format(name_type=name['nameType']))
current_app.logger.warning(
"Unable to deserialize SubAltName with type: {name_type}".format(
name_type=name["nameType"]
)
)
return x509.SubjectAlternativeName(general_names)

View File

@ -10,20 +10,20 @@ from flask import Blueprint
from lemur.database import db
from lemur.extensions import sentry
mod = Blueprint('healthCheck', __name__)
mod = Blueprint("healthCheck", __name__)
@mod.route('/healthcheck')
@mod.route("/healthcheck")
def health():
try:
if healthcheck(db):
return 'ok'
return "ok"
except Exception:
sentry.captureException()
return 'db check failed'
return "db check failed"
def healthcheck(db):
with db.engine.connect() as connection:
connection.execute('SELECT 1;')
connection.execute("SELECT 1;")
return True

View File

@ -52,7 +52,7 @@ class InstanceManager(object):
results = []
for cls_path in class_list:
module_name, class_name = cls_path.rsplit('.', 1)
module_name, class_name = cls_path.rsplit(".", 1)
try:
module = __import__(module_name, {}, {}, class_name)
cls = getattr(module, class_name)
@ -62,10 +62,14 @@ class InstanceManager(object):
results.append(cls)
except InvalidConfiguration as e:
current_app.logger.warning("Plugin '{0}' may not work correctly. {1}".format(class_name, e))
current_app.logger.warning(
"Plugin '{0}' may not work correctly. {1}".format(class_name, e)
)
except Exception as e:
current_app.logger.exception("Unable to import {0}. Reason: {1}".format(cls_path, e))
current_app.logger.exception(
"Unable to import {0}. Reason: {1}".format(cls_path, e)
)
continue
self.cache = results

View File

@ -11,15 +11,15 @@ def convert_validity_years(data):
:param data:
:return:
"""
if data.get('validity_years'):
if data.get("validity_years"):
now = arrow.utcnow()
data['validity_start'] = now.isoformat()
data["validity_start"] = now.isoformat()
end = now.replace(years=+int(data['validity_years']))
end = now.replace(years=+int(data["validity_years"]))
if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True):
if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True):
if is_weekend(end):
end = end.replace(days=-2)
data['validity_end'] = end.isoformat()
data["validity_end"] = end.isoformat()
return data

View File

@ -22,27 +22,26 @@ class LemurSchema(Schema):
"""
Base schema from which all grouper schema's inherit
"""
__envelope__ = True
def under(self, data, many=None):
items = []
if many:
for i in data:
items.append(
{underscore(key): value for key, value in i.items()}
)
items.append({underscore(key): value for key, value in i.items()})
return items
return {
underscore(key): value
for key, value in data.items()
}
return {underscore(key): value for key, value in data.items()}
def camel(self, data, many=None):
items = []
if many:
for i in data:
items.append(
{camelize(key, uppercase_first_letter=False): value for key, value in i.items()}
{
camelize(key, uppercase_first_letter=False): value
for key, value in i.items()
}
)
return items
return {
@ -52,16 +51,16 @@ class LemurSchema(Schema):
def wrap_with_envelope(self, data, many):
if many:
if 'total' in self.context.keys():
return dict(total=self.context['total'], items=data)
if "total" in self.context.keys():
return dict(total=self.context["total"], items=data)
return data
class LemurInputSchema(LemurSchema):
@pre_load(pass_many=True)
def preprocess(self, data, many):
if isinstance(data, dict) and data.get('owner'):
data['owner'] = data['owner'].lower()
if isinstance(data, dict) and data.get("owner"):
data["owner"] = data["owner"].lower()
return self.under(data, many=many)
@ -74,17 +73,17 @@ class LemurOutputSchema(LemurSchema):
def unwrap_envelope(self, data, many):
if many:
if data['items']:
if data["items"]:
if isinstance(data, InstrumentedList) or isinstance(data, list):
self.context['total'] = len(data)
self.context["total"] = len(data)
return data
else:
self.context['total'] = data['total']
self.context["total"] = data["total"]
else:
self.context['total'] = 0
data = {'items': []}
self.context["total"] = 0
data = {"items": []}
return data['items']
return data["items"]
return data
@ -110,11 +109,11 @@ def format_errors(messages):
def wrap_errors(messages):
errors = dict(message='Validation Error.')
if messages.get('_schema'):
errors['reasons'] = {'Schema': {'rule': messages['_schema']}}
errors = dict(message="Validation Error.")
if messages.get("_schema"):
errors["reasons"] = {"Schema": {"rule": messages["_schema"]}}
else:
errors['reasons'] = format_errors(messages)
errors["reasons"] = format_errors(messages)
return errors
@ -123,19 +122,19 @@ def unwrap_pagination(data, output_schema):
return data
if isinstance(data, dict):
if 'total' in data.keys():
if data.get('total') == 0:
if "total" in data.keys():
if data.get("total") == 0:
return data
marshaled_data = {'total': data['total']}
marshaled_data['items'] = output_schema.dump(data['items'], many=True).data
marshaled_data = {"total": data["total"]}
marshaled_data["items"] = output_schema.dump(data["items"], many=True).data
return marshaled_data
return output_schema.dump(data).data
elif isinstance(data, list):
marshaled_data = {'total': len(data)}
marshaled_data['items'] = output_schema.dump(data, many=True).data
marshaled_data = {"total": len(data)}
marshaled_data["items"] = output_schema.dump(data, many=True).data
return marshaled_data
return output_schema.dump(data).data
@ -155,7 +154,7 @@ def validate_schema(input_schema, output_schema):
if errors:
return wrap_errors(errors), 400
kwargs['data'] = data
kwargs["data"] = data
try:
resp = f(*args, **kwargs)
@ -173,4 +172,5 @@ def validate_schema(input_schema, output_schema):
return unwrap_pagination(resp, output_schema), 200
return decorated_function
return decorator

View File

@ -25,22 +25,22 @@ from lemur.exceptions import InvalidConfiguration
paginated_parser = RequestParser()
paginated_parser.add_argument('count', type=int, default=10, location='args')
paginated_parser.add_argument('page', type=int, default=1, location='args')
paginated_parser.add_argument('sortDir', type=str, dest='sort_dir', location='args')
paginated_parser.add_argument('sortBy', type=str, dest='sort_by', location='args')
paginated_parser.add_argument('filter', type=str, location='args')
paginated_parser.add_argument('owner', type=str, location='args')
paginated_parser.add_argument("count", type=int, default=10, location="args")
paginated_parser.add_argument("page", type=int, default=1, location="args")
paginated_parser.add_argument("sortDir", type=str, dest="sort_dir", location="args")
paginated_parser.add_argument("sortBy", type=str, dest="sort_by", location="args")
paginated_parser.add_argument("filter", type=str, location="args")
paginated_parser.add_argument("owner", type=str, location="args")
def get_psuedo_random_string():
"""
Create a random and strongish challenge.
"""
challenge = ''.join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa
challenge += ''.join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa
challenge += ''.join(random.choice(string.ascii_lowercase) for x in range(6))
challenge += ''.join(random.choice(string.digits) for x in range(6)) # noqa
challenge = "".join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa
challenge += "".join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa
challenge += "".join(random.choice(string.ascii_lowercase) for x in range(6))
challenge += "".join(random.choice(string.digits) for x in range(6)) # noqa
return challenge
@ -53,7 +53,7 @@ def parse_certificate(body):
"""
assert isinstance(body, str)
return x509.load_pem_x509_certificate(body.encode('utf-8'), default_backend())
return x509.load_pem_x509_certificate(body.encode("utf-8"), default_backend())
def parse_private_key(private_key):
@ -66,7 +66,9 @@ def parse_private_key(private_key):
"""
assert isinstance(private_key, str)
return load_pem_private_key(private_key.encode('utf8'), password=None, backend=default_backend())
return load_pem_private_key(
private_key.encode("utf8"), password=None, backend=default_backend()
)
def split_pem(data):
@ -100,14 +102,15 @@ def parse_csr(csr):
"""
assert isinstance(csr, str)
return x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend())
return x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend())
def get_authority_key(body):
"""Returns the authority key for a given certificate in hex format"""
parsed_cert = parse_certificate(body)
authority_key = parsed_cert.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier).value.key_identifier
x509.AuthorityKeyIdentifier
).value.key_identifier
return authority_key.hex()
@ -127,20 +130,17 @@ def generate_private_key(key_type):
_CURVE_TYPES = {
"ECCPRIME192V1": ec.SECP192R1(),
"ECCPRIME256V1": ec.SECP256R1(),
"ECCSECP192R1": ec.SECP192R1(),
"ECCSECP224R1": ec.SECP224R1(),
"ECCSECP256R1": ec.SECP256R1(),
"ECCSECP384R1": ec.SECP384R1(),
"ECCSECP521R1": ec.SECP521R1(),
"ECCSECP256K1": ec.SECP256K1(),
"ECCSECT163K1": ec.SECT163K1(),
"ECCSECT233K1": ec.SECT233K1(),
"ECCSECT283K1": ec.SECT283K1(),
"ECCSECT409K1": ec.SECT409K1(),
"ECCSECT571K1": ec.SECT571K1(),
"ECCSECT163R2": ec.SECT163R2(),
"ECCSECT233R1": ec.SECT233R1(),
"ECCSECT283R1": ec.SECT283R1(),
@ -149,22 +149,20 @@ def generate_private_key(key_type):
}
if key_type not in CERTIFICATE_KEY_TYPES:
raise Exception("Invalid key type: {key_type}. Supported key types: {choices}".format(
key_type=key_type,
choices=",".join(CERTIFICATE_KEY_TYPES)
))
raise Exception(
"Invalid key type: {key_type}. Supported key types: {choices}".format(
key_type=key_type, choices=",".join(CERTIFICATE_KEY_TYPES)
)
)
if 'RSA' in key_type:
if "RSA" in key_type:
key_size = int(key_type[3:])
return rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend()
public_exponent=65537, key_size=key_size, backend=default_backend()
)
elif 'ECC' in key_type:
elif "ECC" in key_type:
return ec.generate_private_key(
curve=_CURVE_TYPES[key_type],
backend=default_backend()
curve=_CURVE_TYPES[key_type], backend=default_backend()
)
@ -184,11 +182,26 @@ def check_cert_signature(cert, issuer_public_key):
raise UnsupportedAlgorithm("RSASSA-PSS not supported")
else:
padder = padding.PKCS1v15()
issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, padder, cert.signature_hash_algorithm)
elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance(ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA):
issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, ec.ECDSA(cert.signature_hash_algorithm))
issuer_public_key.verify(
cert.signature,
cert.tbs_certificate_bytes,
padder,
cert.signature_hash_algorithm,
)
elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance(
ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA
):
issuer_public_key.verify(
cert.signature,
cert.tbs_certificate_bytes,
ec.ECDSA(cert.signature_hash_algorithm),
)
else:
raise UnsupportedAlgorithm("Unsupported Algorithm '{var}'.".format(var=cert.signature_algorithm_oid._name))
raise UnsupportedAlgorithm(
"Unsupported Algorithm '{var}'.".format(
var=cert.signature_algorithm_oid._name
)
)
def is_selfsigned(cert):
@ -224,7 +237,9 @@ def validate_conf(app, required_vars):
"""
for var in required_vars:
if var not in app.config:
raise InvalidConfiguration("Required variable '{var}' is not set in Lemur's conf.".format(var=var))
raise InvalidConfiguration(
"Required variable '{var}' is not set in Lemur's conf.".format(var=var)
)
# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/WindowedRangeQuery
@ -243,18 +258,15 @@ def column_windows(session, column, windowsize):
be computed.
"""
def int_for_range(start_id, end_id):
if end_id:
return and_(
column >= start_id,
column < end_id
)
return and_(column >= start_id, column < end_id)
else:
return column >= start_id
q = session.query(
column,
func.row_number().over(order_by=column).label('rownum')
column, func.row_number().over(order_by=column).label("rownum")
).from_self(column)
if windowsize > 1:
@ -274,9 +286,7 @@ def column_windows(session, column, windowsize):
def windowed_query(q, column, windowsize):
""""Break a Query into windows on a given column."""
for whereclause in column_windows(
q.session,
column, windowsize):
for whereclause in column_windows(q.session, column, windowsize):
for row in q.filter(whereclause).order_by(column):
yield row
@ -284,7 +294,7 @@ def windowed_query(q, column, windowsize):
def truthiness(s):
"""If input string resembles something truthy then return True, else False."""
return s.lower() in ('true', 'yes', 'on', 't', '1')
return s.lower() in ("true", "yes", "on", "t", "1")
def find_matching_certificates_by_hash(cert, matching_certs):
@ -292,6 +302,8 @@ def find_matching_certificates_by_hash(cert, matching_certs):
determine if any of the certificate hashes match and return the matches."""
matching = []
for c in matching_certs:
if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint(hashes.SHA256()):
if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint(
hashes.SHA256()
):
matching.append(c)
return matching

View File

@ -16,7 +16,7 @@ def common_name(value):
# Common name could be a domain name, or a human-readable name of the subject (often used in CA names or client
# certificates). As a simple heuristic, we assume that human-readable names always include a space.
# However, to avoid confusion for humans, we also don't count spaces at the beginning or end of the string.
if ' ' not in value.strip():
if " " not in value.strip():
return sensitive_domain(value)
@ -30,17 +30,21 @@ def sensitive_domain(domain):
# User has permission, no need to check anything
return
whitelist = current_app.config.get('LEMUR_WHITELISTED_DOMAINS', [])
whitelist = current_app.config.get("LEMUR_WHITELISTED_DOMAINS", [])
if whitelist and not any(re.match(pattern, domain) for pattern in whitelist):
raise ValidationError('Domain {0} does not match whitelisted domain patterns. '
'Contact an administrator to issue the certificate.'.format(domain))
raise ValidationError(
"Domain {0} does not match whitelisted domain patterns. "
"Contact an administrator to issue the certificate.".format(domain)
)
# Avoid circular import.
from lemur.domains import service as domain_service
if any(d.sensitive for d in domain_service.get_by_name(domain)):
raise ValidationError('Domain {0} has been marked as sensitive. '
'Contact an administrator to issue the certificate.'.format(domain))
raise ValidationError(
"Domain {0} has been marked as sensitive. "
"Contact an administrator to issue the certificate.".format(domain)
)
def encoding(oid_encoding):
@ -49,9 +53,13 @@ def encoding(oid_encoding):
:param oid_encoding:
:return:
"""
valid_types = ['b64asn1', 'string', 'ia5string']
valid_types = ["b64asn1", "string", "ia5string"]
if oid_encoding.lower() not in [o_type.lower() for o_type in valid_types]:
raise ValidationError('Invalid Oid Encoding: {0} choose from {1}'.format(oid_encoding, ",".join(valid_types)))
raise ValidationError(
"Invalid Oid Encoding: {0} choose from {1}".format(
oid_encoding, ",".join(valid_types)
)
)
def sub_alt_type(alt_type):
@ -60,10 +68,23 @@ def sub_alt_type(alt_type):
:param alt_type:
:return:
"""
valid_types = ['DNSName', 'IPAddress', 'uniFormResourceIdentifier', 'directoryName', 'rfc822Name', 'registrationID',
'otherName', 'x400Address', 'EDIPartyName']
valid_types = [
"DNSName",
"IPAddress",
"uniFormResourceIdentifier",
"directoryName",
"rfc822Name",
"registrationID",
"otherName",
"x400Address",
"EDIPartyName",
]
if alt_type.lower() not in [a_type.lower() for a_type in valid_types]:
raise ValidationError('Invalid SubAltName Type: {0} choose from {1}'.format(type, ",".join(valid_types)))
raise ValidationError(
"Invalid SubAltName Type: {0} choose from {1}".format(
type, ",".join(valid_types)
)
)
def csr(data):
@ -73,16 +94,18 @@ def csr(data):
:return:
"""
try:
request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend())
request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend())
except Exception:
raise ValidationError('CSR presented is not valid.')
raise ValidationError("CSR presented is not valid.")
# Validate common name and SubjectAltNames
for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME):
common_name(name.value)
try:
alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName)
alt_names = request.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
for name in alt_names.value.get_values_for_type(x509.DNSName):
sensitive_domain(name)
@ -91,26 +114,40 @@ def csr(data):
def dates(data):
if not data.get('validity_start') and data.get('validity_end'):
raise ValidationError('If validity start is specified so must validity end.')
if not data.get("validity_start") and data.get("validity_end"):
raise ValidationError("If validity start is specified so must validity end.")
if not data.get('validity_end') and data.get('validity_start'):
raise ValidationError('If validity end is specified so must validity start.')
if not data.get("validity_end") and data.get("validity_start"):
raise ValidationError("If validity end is specified so must validity start.")
if data.get('validity_start') and data.get('validity_end'):
if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True):
if is_weekend(data.get('validity_end')):
raise ValidationError('Validity end must not land on a weekend.')
if data.get("validity_start") and data.get("validity_end"):
if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True):
if is_weekend(data.get("validity_end")):
raise ValidationError("Validity end must not land on a weekend.")
if not data['validity_start'] < data['validity_end']:
raise ValidationError('Validity start must be before validity end.')
if not data["validity_start"] < data["validity_end"]:
raise ValidationError("Validity start must be before validity end.")
if data.get('authority'):
if data.get('validity_start').date() < data['authority'].authority_certificate.not_before.date():
raise ValidationError('Validity start must not be before {0}'.format(data['authority'].authority_certificate.not_before))
if data.get("authority"):
if (
data.get("validity_start").date()
< data["authority"].authority_certificate.not_before.date()
):
raise ValidationError(
"Validity start must not be before {0}".format(
data["authority"].authority_certificate.not_before
)
)
if data.get('validity_end').date() > data['authority'].authority_certificate.not_after.date():
raise ValidationError('Validity end must not be after {0}'.format(data['authority'].authority_certificate.not_after))
if (
data.get("validity_end").date()
> data["authority"].authority_certificate.not_after.date()
):
raise ValidationError(
"Validity end must not be after {0}".format(
data["authority"].authority_certificate.not_after
)
)
return data
@ -148,8 +185,13 @@ def verify_cert_chain(certs, error_class=ValidationError):
# Avoid circular import.
from lemur.common import defaults
raise error_class("Incorrect chain certificate(s) provided: '%s' is not signed by '%s'"
% (defaults.common_name(cert) or 'Unknown', defaults.common_name(issuer)))
raise error_class(
"Incorrect chain certificate(s) provided: '%s' is not signed by '%s'"
% (
defaults.common_name(cert) or "Unknown",
defaults.common_name(issuer),
)
)
except UnsupportedAlgorithm as err:
current_app.logger.warning("Skipping chain validation: %s", err)

View File

@ -7,28 +7,28 @@ SAN_NAMING_TEMPLATE = "SAN-{subject}-{issuer}-{not_before}-{not_after}"
DEFAULT_NAMING_TEMPLATE = "{subject}-{issuer}-{not_before}-{not_after}"
NONSTANDARD_NAMING_TEMPLATE = "{issuer}-{not_before}-{not_after}"
SUCCESS_METRIC_STATUS = 'success'
FAILURE_METRIC_STATUS = 'failure'
SUCCESS_METRIC_STATUS = "success"
FAILURE_METRIC_STATUS = "failure"
CERTIFICATE_KEY_TYPES = [
'RSA2048',
'RSA4096',
'ECCPRIME192V1',
'ECCPRIME256V1',
'ECCSECP192R1',
'ECCSECP224R1',
'ECCSECP256R1',
'ECCSECP384R1',
'ECCSECP521R1',
'ECCSECP256K1',
'ECCSECT163K1',
'ECCSECT233K1',
'ECCSECT283K1',
'ECCSECT409K1',
'ECCSECT571K1',
'ECCSECT163R2',
'ECCSECT233R1',
'ECCSECT283R1',
'ECCSECT409R1',
'ECCSECT571R2'
"RSA2048",
"RSA4096",
"ECCPRIME192V1",
"ECCPRIME256V1",
"ECCSECP192R1",
"ECCSECP224R1",
"ECCSECP256R1",
"ECCSECP384R1",
"ECCSECP521R1",
"ECCSECP256K1",
"ECCSECT163K1",
"ECCSECT233K1",
"ECCSECT283K1",
"ECCSECT409K1",
"ECCSECT571K1",
"ECCSECT163R2",
"ECCSECT233R1",
"ECCSECT283R1",
"ECCSECT409R1",
"ECCSECT571R2",
]

View File

@ -43,7 +43,7 @@ def session_query(model):
:param model: sqlalchemy model
:return: query object for model
"""
return model.query if hasattr(model, 'query') else db.session.query(model)
return model.query if hasattr(model, "query") else db.session.query(model)
def create_query(model, kwargs):
@ -77,7 +77,7 @@ def add(model):
def get_model_column(model, field):
if field in getattr(model, 'sensitive_fields', ()):
if field in getattr(model, "sensitive_fields", ()):
raise AttrNotFound(field)
column = model.__table__.columns._data.get(field, None)
if column is None:
@ -100,7 +100,7 @@ def find_all(query, model, kwargs):
kwargs = filter_none(kwargs)
for attr, value in kwargs.items():
if not isinstance(value, list):
value = value.split(',')
value = value.split(",")
conditions.append(get_model_column(model, attr).in_(value))
@ -200,7 +200,7 @@ def filter(query, model, terms):
:return:
"""
column = get_model_column(model, underscore(terms[0]))
return query.filter(column.ilike('%{}%'.format(terms[1])))
return query.filter(column.ilike("%{}%".format(terms[1])))
def sort(query, model, field, direction):
@ -214,7 +214,7 @@ def sort(query, model, field, direction):
:param direction:
"""
column = get_model_column(model, underscore(field))
return query.order_by(column.desc() if direction == 'desc' else column.asc())
return query.order_by(column.desc() if direction == "desc" else column.asc())
def paginate(query, page, count):
@ -247,10 +247,10 @@ def update_list(model, model_attr, item_model, items):
for i in items:
for item in getattr(model, model_attr):
if item.id == i['id']:
if item.id == i["id"]:
break
else:
getattr(model, model_attr).append(get(item_model, i['id']))
getattr(model, model_attr).append(get(item_model, i["id"]))
return model
@ -276,9 +276,9 @@ def get_count(q):
disable_group_by = False
if len(q._entities) > 1:
# currently support only one entity
raise Exception('only one entity is supported for get_count, got: %s' % q)
raise Exception("only one entity is supported for get_count, got: %s" % q)
entity = q._entities[0]
if hasattr(entity, 'column'):
if hasattr(entity, "column"):
# _ColumnEntity has column attr - on case: query(Model.column)...
col = entity.column
if q._group_by and q._distinct:
@ -295,7 +295,11 @@ def get_count(q):
count_func = func.count()
if q._group_by and not disable_group_by:
count_func = count_func.over(None)
count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None)
count_q = (
q.options(lazyload("*"))
.statement.with_only_columns([count_func])
.order_by(None)
)
if disable_group_by:
count_q = count_q.group_by(None)
count = q.session.execute(count_q).scalar()
@ -311,13 +315,13 @@ def sort_and_page(query, model, args):
:param args:
:return:
"""
sort_by = args.pop('sort_by')
sort_dir = args.pop('sort_dir')
page = args.pop('page')
count = args.pop('count')
sort_by = args.pop("sort_by")
sort_dir = args.pop("sort_dir")
page = args.pop("page")
count = args.pop("count")
if args.get('user'):
user = args.pop('user')
if args.get("user"):
user = args.pop("user")
query = find_all(query, model, args)

View File

@ -1,6 +1,7 @@
# This is just Python which means you can inherit and tweak settings
import os
_basedir = os.path.abspath(os.path.dirname(__file__))
THREADS_PER_PAGE = 8

View File

@ -13,12 +13,13 @@ from lemur.auth.service import AuthenticatedResource
from lemur.defaults.schemas import default_output_schema
mod = Blueprint('default', __name__)
mod = Blueprint("default", __name__)
api = Api(mod)
class LemurDefaults(AuthenticatedResource):
""" Defines the 'defaults' endpoint """
def __init__(self):
super(LemurDefaults)
@ -59,17 +60,21 @@ class LemurDefaults(AuthenticatedResource):
:statuscode 403: unauthenticated
"""
default_authority = get_by_name(current_app.config.get('LEMUR_DEFAULT_AUTHORITY'))
default_authority = get_by_name(
current_app.config.get("LEMUR_DEFAULT_AUTHORITY")
)
return dict(
country=current_app.config.get('LEMUR_DEFAULT_COUNTRY'),
state=current_app.config.get('LEMUR_DEFAULT_STATE'),
location=current_app.config.get('LEMUR_DEFAULT_LOCATION'),
organization=current_app.config.get('LEMUR_DEFAULT_ORGANIZATION'),
organizational_unit=current_app.config.get('LEMUR_DEFAULT_ORGANIZATIONAL_UNIT'),
issuer_plugin=current_app.config.get('LEMUR_DEFAULT_ISSUER_PLUGIN'),
country=current_app.config.get("LEMUR_DEFAULT_COUNTRY"),
state=current_app.config.get("LEMUR_DEFAULT_STATE"),
location=current_app.config.get("LEMUR_DEFAULT_LOCATION"),
organization=current_app.config.get("LEMUR_DEFAULT_ORGANIZATION"),
organizational_unit=current_app.config.get(
"LEMUR_DEFAULT_ORGANIZATIONAL_UNIT"
),
issuer_plugin=current_app.config.get("LEMUR_DEFAULT_ISSUER_PLUGIN"),
authority=default_authority,
)
api.add_resource(LemurDefaults, '/defaults', endpoint='default')
api.add_resource(LemurDefaults, "/defaults", endpoint="default")

View File

@ -13,7 +13,7 @@ from lemur.plugins.base import plugins
class Destination(db.Model):
__tablename__ = 'destinations'
__tablename__ = "destinations"
id = Column(Integer, primary_key=True)
label = Column(String(32))
options = Column(JSONType)

View File

@ -30,7 +30,7 @@ class DestinationOutputSchema(LemurOutputSchema):
@post_dump
def fill_object(self, data):
if data:
data['plugin']['pluginOptions'] = data['options']
data["plugin"]["pluginOptions"] = data["options"]
return data

View File

@ -26,10 +26,12 @@ def create(label, plugin_name, options, description=None):
"""
# remove any sub-plugin objects before try to save the json options
for option in options:
if 'plugin' in option['type']:
del option['value']['plugin_object']
if "plugin" in option["type"]:
del option["value"]["plugin_object"]
destination = Destination(label=label, options=options, plugin_name=plugin_name, description=description)
destination = Destination(
label=label, options=options, plugin_name=plugin_name, description=description
)
current_app.logger.info("Destination: %s created", label)
# add the destination as source, to avoid new destinations that are not in source, as long as an AWS destination
@ -85,7 +87,7 @@ def get_by_label(label):
:param label:
:return:
"""
return database.get(Destination, label, field='label')
return database.get(Destination, label, field="label")
def get_all():
@ -99,17 +101,19 @@ def get_all():
def render(args):
filt = args.pop('filter')
certificate_id = args.pop('certificate_id', None)
filt = args.pop("filter")
certificate_id = args.pop("certificate_id", None)
if certificate_id:
query = database.session_query(Destination).join(Certificate, Destination.certificate)
query = database.session_query(Destination).join(
Certificate, Destination.certificate
)
query = query.filter(Certificate.id == certificate_id)
else:
query = database.session_query(Destination)
if filt:
terms = filt.split(';')
terms = filt.split(";")
query = database.filter(query, Destination, terms)
return database.sort_and_page(query, Destination, args)
@ -122,9 +126,15 @@ def stats(**kwargs):
:param kwargs:
:return:
"""
items = database.db.session.query(Destination.label, func.count(certificate_destination_associations.c.certificate_id))\
.join(certificate_destination_associations)\
.group_by(Destination.label).all()
items = (
database.db.session.query(
Destination.label,
func.count(certificate_destination_associations.c.certificate_id),
)
.join(certificate_destination_associations)
.group_by(Destination.label)
.all()
)
keys = []
values = []
@ -132,4 +142,4 @@ def stats(**kwargs):
keys.append(key)
values.append(count)
return {'labels': keys, 'values': values}
return {"labels": keys, "values": values}

View File

@ -15,15 +15,20 @@ from lemur.auth.permissions import admin_permission
from lemur.common.utils import paginated_parser
from lemur.common.schema import validate_schema
from lemur.destinations.schemas import destinations_output_schema, destination_input_schema, destination_output_schema
from lemur.destinations.schemas import (
destinations_output_schema,
destination_input_schema,
destination_output_schema,
)
mod = Blueprint('destinations', __name__)
mod = Blueprint("destinations", __name__)
api = Api(mod)
class DestinationsList(AuthenticatedResource):
""" Defines the 'destinations' endpoint """
def __init__(self):
self.reqparse = reqparse.RequestParser()
super(DestinationsList, self).__init__()
@ -176,7 +181,12 @@ class DestinationsList(AuthenticatedResource):
:reqheader Authorization: OAuth token to authenticate
:statuscode 200: no error
"""
return service.create(data['label'], data['plugin']['slug'], data['plugin']['plugin_options'], data['description'])
return service.create(
data["label"],
data["plugin"]["slug"],
data["plugin"]["plugin_options"],
data["description"],
)
class Destinations(AuthenticatedResource):
@ -325,16 +335,22 @@ class Destinations(AuthenticatedResource):
:reqheader Authorization: OAuth token to authenticate
:statuscode 200: no error
"""
return service.update(destination_id, data['label'], data['plugin']['plugin_options'], data['description'])
return service.update(
destination_id,
data["label"],
data["plugin"]["plugin_options"],
data["description"],
)
@admin_permission.require(http_exception=403)
def delete(self, destination_id):
service.delete(destination_id)
return {'result': True}
return {"result": True}
class CertificateDestinations(AuthenticatedResource):
""" Defines the 'certificate/<int:certificate_id/destinations'' endpoint """
def __init__(self):
super(CertificateDestinations, self).__init__()
@ -401,25 +417,31 @@ class CertificateDestinations(AuthenticatedResource):
"""
parser = paginated_parser.copy()
args = parser.parse_args()
args['certificate_id'] = certificate_id
args["certificate_id"] = certificate_id
return service.render(args)
class DestinationsStats(AuthenticatedResource):
""" Defines the 'certificates' stats endpoint """
def __init__(self):
self.reqparse = reqparse.RequestParser()
super(DestinationsStats, self).__init__()
def get(self):
self.reqparse.add_argument('metric', type=str, location='args')
self.reqparse.add_argument("metric", type=str, location="args")
args = self.reqparse.parse_args()
items = service.stats(**args)
return dict(items=items, total=len(items))
api.add_resource(DestinationsList, '/destinations', endpoint='destinations')
api.add_resource(Destinations, '/destinations/<int:destination_id>', endpoint='destination')
api.add_resource(CertificateDestinations, '/certificates/<int:certificate_id>/destinations',
endpoint='certificateDestinations')
api.add_resource(DestinationsStats, '/destinations/stats', endpoint='destinationStats')
api.add_resource(DestinationsList, "/destinations", endpoint="destinations")
api.add_resource(
Destinations, "/destinations/<int:destination_id>", endpoint="destination"
)
api.add_resource(
CertificateDestinations,
"/certificates/<int:certificate_id>/destinations",
endpoint="certificateDestinations",
)
api.add_resource(DestinationsStats, "/destinations/stats", endpoint="destinationStats")

View File

@ -5,7 +5,9 @@ from lemur.dns_providers.service import get_all_dns_providers, set_domains
from lemur.extensions import metrics
from lemur.plugins.base import plugins
manager = Manager(usage="Iterates through all DNS providers and sets DNS zones in the database.")
manager = Manager(
usage="Iterates through all DNS providers and sets DNS zones in the database."
)
@manager.command
@ -27,5 +29,5 @@ def get_all_zones():
status = SUCCESS_METRIC_STATUS
metrics.send('get_all_zones', 'counter', 1, metric_tags={'status': status})
metrics.send("get_all_zones", "counter", 1, metric_tags={"status": status})
print("[+] Done with dns provider zone lookup and configuration.")

View File

@ -9,22 +9,23 @@ from lemur.utils import Vault
class DnsProvider(db.Model):
__tablename__ = 'dns_providers'
id = Column(
Integer(),
primary_key=True,
)
__tablename__ = "dns_providers"
id = Column(Integer(), primary_key=True)
name = Column(String(length=256), unique=True, nullable=True)
description = Column(Text(), nullable=True)
provider_type = Column(String(length=256), nullable=True)
credentials = Column(Vault, nullable=True)
api_endpoint = Column(String(length=256), nullable=True)
date_created = Column(ArrowType(), server_default=text('now()'), nullable=False)
date_created = Column(ArrowType(), server_default=text("now()"), nullable=False)
status = Column(String(length=128), nullable=True)
options = Column(JSON, nullable=True)
domains = Column(JSON, nullable=True)
certificates = relationship("Certificate", backref='dns_provider', foreign_keys='Certificate.dns_provider_id',
lazy='dynamic')
certificates = relationship(
"Certificate",
backref="dns_provider",
foreign_keys="Certificate.dns_provider_id",
lazy="dynamic",
)
def __init__(self, name, description, provider_type, credentials):
self.name = name

View File

@ -49,7 +49,9 @@ def get_friendly(dns_provider_id):
}
if dns_provider.provider_type == "route53":
dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get("account_id")
dns_provider_friendly["account_id"] = json.loads(dns_provider.credentials).get(
"account_id"
)
return dns_provider_friendly
@ -64,40 +66,40 @@ def delete(dns_provider_id):
def get_types():
provider_config = current_app.config.get(
'ACME_DNS_PROVIDER_TYPES',
{"items": [
"ACME_DNS_PROVIDER_TYPES",
{
'name': 'route53',
'requirements': [
"items": [
{
'name': 'account_id',
'type': 'int',
'required': True,
'helpMessage': 'AWS Account number'
"name": "route53",
"requirements": [
{
"name": "account_id",
"type": "int",
"required": True,
"helpMessage": "AWS Account number",
}
],
},
{
"name": "cloudflare",
"requirements": [
{
"name": "email",
"type": "str",
"required": True,
"helpMessage": "Cloudflare Email",
},
{
"name": "key",
"type": "str",
"required": True,
"helpMessage": "Cloudflare Key",
},
],
},
{"name": "dyn"},
]
},
{
'name': 'cloudflare',
'requirements': [
{
'name': 'email',
'type': 'str',
'required': True,
'helpMessage': 'Cloudflare Email'
},
{
'name': 'key',
'type': 'str',
'required': True,
'helpMessage': 'Cloudflare Key'
},
]
},
{
'name': 'dyn',
},
]}
)
if not provider_config:
raise Exception("No DNS Provider configuration specified.")

View File

@ -13,9 +13,12 @@ from lemur.auth.service import AuthenticatedResource
from lemur.common.schema import validate_schema
from lemur.common.utils import paginated_parser
from lemur.dns_providers import service
from lemur.dns_providers.schemas import dns_provider_output_schema, dns_provider_input_schema
from lemur.dns_providers.schemas import (
dns_provider_output_schema,
dns_provider_input_schema,
)
mod = Blueprint('dns_providers', __name__)
mod = Blueprint("dns_providers", __name__)
api = Api(mod)
@ -71,12 +74,12 @@ class DnsProvidersList(AuthenticatedResource):
"""
parser = paginated_parser.copy()
parser.add_argument('dns_provider_id', type=int, location='args')
parser.add_argument('name', type=str, location='args')
parser.add_argument('type', type=str, location='args')
parser.add_argument("dns_provider_id", type=int, location="args")
parser.add_argument("name", type=str, location="args")
parser.add_argument("type", type=str, location="args")
args = parser.parse_args()
args['user'] = g.user
args["user"] = g.user
return service.render(args)
@validate_schema(dns_provider_input_schema, None)
@ -152,7 +155,7 @@ class DnsProviders(AuthenticatedResource):
@admin_permission.require(http_exception=403)
def delete(self, dns_provider_id):
service.delete(dns_provider_id)
return {'result': True}
return {"result": True}
class DnsProviderOptions(AuthenticatedResource):
@ -166,6 +169,10 @@ class DnsProviderOptions(AuthenticatedResource):
return service.get_types()
api.add_resource(DnsProvidersList, '/dns_providers', endpoint='dns_providers')
api.add_resource(DnsProviders, '/dns_providers/<int:dns_provider_id>', endpoint='dns_provider')
api.add_resource(DnsProviderOptions, '/dns_provider_options', endpoint='dns_provider_options')
api.add_resource(DnsProvidersList, "/dns_providers", endpoint="dns_providers")
api.add_resource(
DnsProviders, "/dns_providers/<int:dns_provider_id>", endpoint="dns_provider"
)
api.add_resource(
DnsProviderOptions, "/dns_provider_options", endpoint="dns_provider_options"
)

View File

@ -13,11 +13,14 @@ from lemur.database import db
class Domain(db.Model):
__tablename__ = 'domains'
__tablename__ = "domains"
__table_args__ = (
Index('ix_domains_name_gin', "name",
Index(
"ix_domains_name_gin",
"name",
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using='gin'),
postgresql_using="gin",
),
)
id = Column(Integer, primary_key=True)
name = Column(String(256), index=True)

View File

@ -77,11 +77,11 @@ def render(args):
:return:
"""
query = database.session_query(Domain)
filt = args.pop('filter')
certificate_id = args.pop('certificate_id', None)
filt = args.pop("filter")
certificate_id = args.pop("certificate_id", None)
if filt:
terms = filt.split(';')
terms = filt.split(";")
query = database.filter(query, Domain, terms)
if certificate_id:

View File

@ -17,14 +17,19 @@ from lemur.auth.permissions import SensitiveDomainPermission
from lemur.common.schema import validate_schema
from lemur.common.utils import paginated_parser
from lemur.domains.schemas import domain_input_schema, domain_output_schema, domains_output_schema
from lemur.domains.schemas import (
domain_input_schema,
domain_output_schema,
domains_output_schema,
)
mod = Blueprint('domains', __name__)
mod = Blueprint("domains", __name__)
api = Api(mod)
class DomainsList(AuthenticatedResource):
""" Defines the 'domains' endpoint """
def __init__(self):
super(DomainsList, self).__init__()
@ -123,7 +128,7 @@ class DomainsList(AuthenticatedResource):
:statuscode 200: no error
:statuscode 403: unauthenticated
"""
return service.create(data['name'], data['sensitive'])
return service.create(data["name"], data["sensitive"])
class Domains(AuthenticatedResource):
@ -205,13 +210,14 @@ class Domains(AuthenticatedResource):
:statuscode 403: unauthenticated
"""
if SensitiveDomainPermission().can():
return service.update(domain_id, data['name'], data['sensitive'])
return service.update(domain_id, data["name"], data["sensitive"])
return dict(message='You are not authorized to modify this domain'), 403
return dict(message="You are not authorized to modify this domain"), 403
class CertificateDomains(AuthenticatedResource):
""" Defines the 'domains' endpoint """
def __init__(self):
super(CertificateDomains, self).__init__()
@ -265,10 +271,14 @@ class CertificateDomains(AuthenticatedResource):
"""
parser = paginated_parser.copy()
args = parser.parse_args()
args['certificate_id'] = certificate_id
args["certificate_id"] = certificate_id
return service.render(args)
api.add_resource(DomainsList, '/domains', endpoint='domains')
api.add_resource(Domains, '/domains/<int:domain_id>', endpoint='domain')
api.add_resource(CertificateDomains, '/certificates/<int:certificate_id>/domains', endpoint='certificateDomains')
api.add_resource(DomainsList, "/domains", endpoint="domains")
api.add_resource(Domains, "/domains/<int:domain_id>", endpoint="domain")
api.add_resource(
CertificateDomains,
"/certificates/<int:certificate_id>/domains",
endpoint="certificateDomains",
)

View File

@ -21,7 +21,14 @@ from lemur.endpoints.models import Endpoint
manager = Manager(usage="Handles all endpoint related tasks.")
@manager.option('-ttl', '--time-to-live', type=int, dest='ttl', default=2, help='Time in hours, which endpoint has not been refreshed to remove the endpoint.')
@manager.option(
"-ttl",
"--time-to-live",
type=int,
dest="ttl",
default=2,
help="Time in hours, which endpoint has not been refreshed to remove the endpoint.",
)
def expire(ttl):
"""
Removed all endpoints that have not been recently updated.
@ -31,12 +38,18 @@ def expire(ttl):
try:
now = arrow.utcnow()
expiration = now - timedelta(hours=ttl)
endpoints = database.session_query(Endpoint).filter(cast(Endpoint.last_updated, ArrowType) <= expiration)
endpoints = database.session_query(Endpoint).filter(
cast(Endpoint.last_updated, ArrowType) <= expiration
)
for endpoint in endpoints:
print("[!] Expiring endpoint: {name} Last Updated: {last_updated}".format(name=endpoint.name, last_updated=endpoint.last_updated))
print(
"[!] Expiring endpoint: {name} Last Updated: {last_updated}".format(
name=endpoint.name, last_updated=endpoint.last_updated
)
)
database.delete(endpoint)
metrics.send('endpoint_expired', 'counter', 1)
metrics.send("endpoint_expired", "counter", 1)
print("[+] Finished expiration.")
except Exception as e:

View File

@ -20,15 +20,11 @@ from lemur.database import db
from lemur.models import policies_ciphers
BAD_CIPHERS = [
'Protocol-SSLv3',
'Protocol-SSLv2',
'Protocol-TLSv1'
]
BAD_CIPHERS = ["Protocol-SSLv3", "Protocol-SSLv2", "Protocol-TLSv1"]
class Cipher(db.Model):
__tablename__ = 'ciphers'
__tablename__ = "ciphers"
id = Column(Integer, primary_key=True)
name = Column(String(128), nullable=False)
@ -38,23 +34,18 @@ class Cipher(db.Model):
@deprecated.expression
def deprecated(cls):
return case(
[
(cls.name in BAD_CIPHERS, True)
],
else_=False
)
return case([(cls.name in BAD_CIPHERS, True)], else_=False)
class Policy(db.Model):
___tablename__ = 'policies'
___tablename__ = "policies"
id = Column(Integer, primary_key=True)
name = Column(String(128), nullable=True)
ciphers = relationship('Cipher', secondary=policies_ciphers, backref='policy')
ciphers = relationship("Cipher", secondary=policies_ciphers, backref="policy")
class Endpoint(db.Model):
__tablename__ = 'endpoints'
__tablename__ = "endpoints"
id = Column(Integer, primary_key=True)
owner = Column(String(128))
name = Column(String(128))
@ -62,16 +53,18 @@ class Endpoint(db.Model):
type = Column(String(128))
active = Column(Boolean, default=True)
port = Column(Integer)
policy_id = Column(Integer, ForeignKey('policy.id'))
policy = relationship('Policy', backref='endpoint')
certificate_id = Column(Integer, ForeignKey('certificates.id'))
source_id = Column(Integer, ForeignKey('sources.id'))
policy_id = Column(Integer, ForeignKey("policy.id"))
policy = relationship("Policy", backref="endpoint")
certificate_id = Column(Integer, ForeignKey("certificates.id"))
source_id = Column(Integer, ForeignKey("sources.id"))
sensitive = Column(Boolean, default=False)
source = relationship('Source', back_populates='endpoints')
source = relationship("Source", back_populates="endpoints")
last_updated = Column(ArrowType, default=arrow.utcnow, nullable=False)
date_created = Column(ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False)
date_created = Column(
ArrowType, default=arrow.utcnow, onupdate=arrow.utcnow, nullable=False
)
replaced = association_proxy('certificate', 'replaced')
replaced = association_proxy("certificate", "replaced")
@property
def issues(self):
@ -79,13 +72,30 @@ class Endpoint(db.Model):
for cipher in self.policy.ciphers:
if cipher.deprecated:
issues.append({'name': 'deprecated cipher', 'value': '{0} has been deprecated consider removing it.'.format(cipher.name)})
issues.append(
{
"name": "deprecated cipher",
"value": "{0} has been deprecated consider removing it.".format(
cipher.name
),
}
)
if self.certificate.expired:
issues.append({'name': 'expired certificate', 'value': 'There is an expired certificate attached to this endpoint consider replacing it.'})
issues.append(
{
"name": "expired certificate",
"value": "There is an expired certificate attached to this endpoint consider replacing it.",
}
)
if self.certificate.revoked:
issues.append({'name': 'revoked', 'value': 'There is a revoked certificate attached to this endpoint consider replacing it.'})
issues.append(
{
"name": "revoked",
"value": "There is a revoked certificate attached to this endpoint consider replacing it.",
}
)
return issues

View File

@ -46,7 +46,7 @@ def get_by_name(name):
:param name:
:return:
"""
return database.get(Endpoint, name, field='name')
return database.get(Endpoint, name, field="name")
def get_by_dnsname(dnsname):
@ -56,7 +56,7 @@ def get_by_dnsname(dnsname):
:param dnsname:
:return:
"""
return database.get(Endpoint, dnsname, field='dnsname')
return database.get(Endpoint, dnsname, field="dnsname")
def get_by_dnsname_and_port(dnsname, port):
@ -66,7 +66,11 @@ def get_by_dnsname_and_port(dnsname, port):
:param port:
:return:
"""
return Endpoint.query.filter(Endpoint.dnsname == dnsname).filter(Endpoint.port == port).scalar()
return (
Endpoint.query.filter(Endpoint.dnsname == dnsname)
.filter(Endpoint.port == port)
.scalar()
)
def get_by_source(source_label):
@ -95,12 +99,14 @@ def create(**kwargs):
"""
endpoint = Endpoint(**kwargs)
database.create(endpoint)
metrics.send('endpoint_added', 'counter', 1, metric_tags={'source': endpoint.source.label})
metrics.send(
"endpoint_added", "counter", 1, metric_tags={"source": endpoint.source.label}
)
return endpoint
def get_or_create_policy(**kwargs):
policy = database.get(Policy, kwargs['name'], field='name')
policy = database.get(Policy, kwargs["name"], field="name")
if not policy:
policy = Policy(**kwargs)
@ -110,7 +116,7 @@ def get_or_create_policy(**kwargs):
def get_or_create_cipher(**kwargs):
cipher = database.get(Cipher, kwargs['name'], field='name')
cipher = database.get(Cipher, kwargs["name"], field="name")
if not cipher:
cipher = Cipher(**kwargs)
@ -122,11 +128,13 @@ def get_or_create_cipher(**kwargs):
def update(endpoint_id, **kwargs):
endpoint = database.get(Endpoint, endpoint_id)
endpoint.policy = kwargs['policy']
endpoint.certificate = kwargs['certificate']
endpoint.source = kwargs['source']
endpoint.policy = kwargs["policy"]
endpoint.certificate = kwargs["certificate"]
endpoint.source = kwargs["source"]
endpoint.last_updated = arrow.utcnow()
metrics.send('endpoint_updated', 'counter', 1, metric_tags={'source': endpoint.source.label})
metrics.send(
"endpoint_updated", "counter", 1, metric_tags={"source": endpoint.source.label}
)
database.update(endpoint)
return endpoint
@ -138,19 +146,17 @@ def render(args):
:return:
"""
query = database.session_query(Endpoint)
filt = args.pop('filter')
filt = args.pop("filter")
if filt:
terms = filt.split(';')
if 'active' in filt: # this is really weird but strcmp seems to not work here??
terms = filt.split(";")
if "active" in filt: # this is really weird but strcmp seems to not work here??
query = query.filter(Endpoint.active == truthiness(terms[1]))
elif 'port' in filt:
if terms[1] != 'null': # ng-table adds 'null' if a number is removed
elif "port" in filt:
if terms[1] != "null": # ng-table adds 'null' if a number is removed
query = query.filter(Endpoint.port == terms[1])
elif 'ciphers' in filt:
query = query.filter(
Cipher.name == terms[1]
)
elif "ciphers" in filt:
query = query.filter(Cipher.name == terms[1])
else:
query = database.filter(query, Endpoint, terms)
@ -164,7 +170,7 @@ def stats(**kwargs):
:param kwargs:
:return:
"""
attr = getattr(Endpoint, kwargs.get('metric'))
attr = getattr(Endpoint, kwargs.get("metric"))
query = database.db.session.query(attr, func.count(attr))
items = query.group_by(attr).all()
@ -175,4 +181,4 @@ def stats(**kwargs):
keys.append(key)
values.append(count)
return {'labels': keys, 'values': values}
return {"labels": keys, "values": values}

View File

@ -16,12 +16,13 @@ from lemur.endpoints import service
from lemur.endpoints.schemas import endpoint_output_schema, endpoints_output_schema
mod = Blueprint('endpoints', __name__)
mod = Blueprint("endpoints", __name__)
api = Api(mod)
class EndpointsList(AuthenticatedResource):
""" Defines the 'endpoints' endpoint """
def __init__(self):
self.reqparse = reqparse.RequestParser()
super(EndpointsList, self).__init__()
@ -63,7 +64,7 @@ class EndpointsList(AuthenticatedResource):
"""
parser = paginated_parser.copy()
args = parser.parse_args()
args['user'] = g.current_user
args["user"] = g.current_user
return service.render(args)
@ -103,5 +104,5 @@ class Endpoints(AuthenticatedResource):
return service.get(endpoint_id)
api.add_resource(EndpointsList, '/endpoints', endpoint='endpoints')
api.add_resource(Endpoints, '/endpoints/<int:endpoint_id>', endpoint='endpoint')
api.add_resource(EndpointsList, "/endpoints", endpoint="endpoints")
api.add_resource(Endpoints, "/endpoints/<int:endpoint_id>", endpoint="endpoint")

View File

@ -21,7 +21,9 @@ class DuplicateError(LemurException):
class InvalidListener(LemurException):
def __str__(self):
return repr("Invalid listener, ensure you select a certificate if you are using a secure protocol")
return repr(
"Invalid listener, ensure you select a certificate if you are using a secure protocol"
)
class AttrNotFound(LemurException):

View File

@ -15,25 +15,33 @@ class SQLAlchemy(SA):
db = SQLAlchemy()
from flask_migrate import Migrate
migrate = Migrate()
from flask_bcrypt import Bcrypt
bcrypt = Bcrypt()
from flask_principal import Principal
principal = Principal(use_sessions=False)
from flask_mail import Mail
smtp_mail = Mail()
from lemur.metrics import Metrics
metrics = Metrics()
from raven.contrib.flask import Sentry
sentry = Sentry()
from blinker import Namespace
signals = Namespace()
from flask_cors import CORS
cors = CORS()

View File

@ -24,9 +24,7 @@ from lemur.common.health import mod as health
from lemur.extensions import db, migrate, principal, smtp_mail, metrics, sentry, cors
DEFAULT_BLUEPRINTS = (
health,
)
DEFAULT_BLUEPRINTS = (health,)
API_VERSION = 1
@ -71,16 +69,17 @@ def from_file(file_path, silent=False):
:param file_path:
:param silent:
"""
d = imp.new_module('config')
d = imp.new_module("config")
d.__file__ = file_path
try:
with open(file_path) as config_file:
exec(compile(config_file.read(), # nosec: config file safe
file_path, 'exec'), d.__dict__)
exec( # nosec: config file safe
compile(config_file.read(), file_path, "exec"), d.__dict__
)
except IOError as e:
if silent and e.errno in (errno.ENOENT, errno.EISDIR):
return False
e.strerror = 'Unable to load configuration file (%s)' % e.strerror
e.strerror = "Unable to load configuration file (%s)" % e.strerror
raise
return d
@ -94,8 +93,8 @@ def configure_app(app, config=None):
:return:
"""
# respect the config first
if config and config != 'None':
app.config['CONFIG_PATH'] = config
if config and config != "None":
app.config["CONFIG_PATH"] = config
app.config.from_object(from_file(config))
else:
try:
@ -103,12 +102,21 @@ def configure_app(app, config=None):
except RuntimeError:
# look in default paths
if os.path.isfile(os.path.expanduser("~/.lemur/lemur.conf.py")):
app.config.from_object(from_file(os.path.expanduser("~/.lemur/lemur.conf.py")))
app.config.from_object(
from_file(os.path.expanduser("~/.lemur/lemur.conf.py"))
)
else:
app.config.from_object(from_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'default.conf.py')))
app.config.from_object(
from_file(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"default.conf.py",
)
)
)
# we don't use this
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
def configure_extensions(app):
@ -125,9 +133,15 @@ def configure_extensions(app):
metrics.init_app(app)
sentry.init_app(app)
if app.config['CORS']:
app.config['CORS_HEADERS'] = 'Content-Type'
cors.init_app(app, resources=r'/api/*', headers='Content-Type', origin='*', supports_credentials=True)
if app.config["CORS"]:
app.config["CORS_HEADERS"] = "Content-Type"
cors.init_app(
app,
resources=r"/api/*",
headers="Content-Type",
origin="*",
supports_credentials=True,
)
def configure_blueprints(app, blueprints):
@ -148,22 +162,25 @@ def configure_logging(app):
:param app:
"""
handler = RotatingFileHandler(app.config.get('LOG_FILE', 'lemur.log'), maxBytes=10000000, backupCount=100)
handler = RotatingFileHandler(
app.config.get("LOG_FILE", "lemur.log"), maxBytes=10000000, backupCount=100
)
handler.setFormatter(Formatter(
'%(asctime)s %(levelname)s: %(message)s '
'[in %(pathname)s:%(lineno)d]'
))
handler.setFormatter(
Formatter(
"%(asctime)s %(levelname)s: %(message)s " "[in %(pathname)s:%(lineno)d]"
)
)
handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG'))
app.logger.setLevel(app.config.get('LOG_LEVEL', 'DEBUG'))
handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG"))
app.logger.setLevel(app.config.get("LOG_LEVEL", "DEBUG"))
app.logger.addHandler(handler)
stream_handler = StreamHandler()
stream_handler.setLevel(app.config.get('LOG_LEVEL', 'DEBUG'))
stream_handler.setLevel(app.config.get("LOG_LEVEL", "DEBUG"))
app.logger.addHandler(stream_handler)
if app.config.get('DEBUG_DUMP', False):
if app.config.get("DEBUG_DUMP", False):
activate_debug_dump()
@ -176,17 +193,21 @@ def install_plugins(app):
"""
from lemur.plugins import plugins
from lemur.plugins.base import register
# entry_points={
# 'lemur.plugins': [
# 'verisign = lemur_verisign.plugin:VerisignPlugin'
# ],
# },
for ep in pkg_resources.iter_entry_points('lemur.plugins'):
for ep in pkg_resources.iter_entry_points("lemur.plugins"):
try:
plugin = ep.load()
except Exception:
import traceback
app.logger.error("Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc()))
app.logger.error(
"Failed to load plugin %r:\n%s\n" % (ep.name, traceback.format_exc())
)
else:
register(plugin)
@ -196,6 +217,9 @@ def install_plugins(app):
try:
plugins.get(slug)
except KeyError:
raise Exception("Unable to location notification plugin: {slug}. Ensure that "
"LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin."
.format(slug=slug))
raise Exception(
"Unable to location notification plugin: {slug}. Ensure that "
"LEMUR_DEFAULT_NOTIFICATION_PLUGIN is set to a valid and installed notification plugin.".format(
slug=slug
)
)

View File

@ -15,9 +15,19 @@ from lemur.database import db
class Log(db.Model):
__tablename__ = 'logs'
__tablename__ = "logs"
id = Column(Integer, primary_key=True)
certificate_id = Column(Integer, ForeignKey('certificates.id'))
log_type = Column(Enum('key_view', 'create_cert', 'update_cert', 'revoke_cert', 'delete_cert', name='log_type'), nullable=False)
certificate_id = Column(Integer, ForeignKey("certificates.id"))
log_type = Column(
Enum(
"key_view",
"create_cert",
"update_cert",
"revoke_cert",
"delete_cert",
name="log_type",
),
nullable=False,
)
logged_at = Column(ArrowType(), PassiveDefault(func.now()), nullable=False)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)

View File

@ -24,7 +24,11 @@ def create(user, type, certificate=None):
:param certificate:
:return:
"""
current_app.logger.info("[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format(type, user.email, certificate.name))
current_app.logger.info(
"[lemur-audit] action: {0}, user: {1}, certificate: {2}.".format(
type, user.email, certificate.name
)
)
view = Log(user_id=user.id, log_type=type, certificate_id=certificate.id)
database.add(view)
database.commit()
@ -50,20 +54,22 @@ def render(args):
"""
query = database.session_query(Log)
filt = args.pop('filter')
filt = args.pop("filter")
if filt:
terms = filt.split(';')
terms = filt.split(";")
if 'certificate.name' in terms:
sub_query = database.session_query(Certificate.id)\
.filter(Certificate.name.ilike('%{0}%'.format(terms[1])))
if "certificate.name" in terms:
sub_query = database.session_query(Certificate.id).filter(
Certificate.name.ilike("%{0}%".format(terms[1]))
)
query = query.filter(Log.certificate_id.in_(sub_query))
elif 'user.email' in terms:
sub_query = database.session_query(User.id)\
.filter(User.email.ilike('%{0}%'.format(terms[1])))
elif "user.email" in terms:
sub_query = database.session_query(User.id).filter(
User.email.ilike("%{0}%".format(terms[1]))
)
query = query.filter(Log.user_id.in_(sub_query))

View File

@ -17,12 +17,13 @@ from lemur.logs.schemas import logs_output_schema
from lemur.logs import service
mod = Blueprint('logs', __name__)
mod = Blueprint("logs", __name__)
api = Api(mod)
class LogsList(AuthenticatedResource):
""" Defines the 'logs' endpoint """
def __init__(self):
self.reqparse = reqparse.RequestParser()
super(LogsList, self).__init__()
@ -65,10 +66,10 @@ class LogsList(AuthenticatedResource):
:statuscode 200: no error
"""
parser = paginated_parser.copy()
parser.add_argument('owner', type=str, location='args')
parser.add_argument('id', type=str, location='args')
parser.add_argument("owner", type=str, location="args")
parser.add_argument("id", type=str, location="args")
args = parser.parse_args()
return service.render(args)
api.add_resource(LogsList, '/logs', endpoint='logs')
api.add_resource(LogsList, "/logs", endpoint="logs")

View File

@ -52,24 +52,24 @@ from lemur.dns_providers.models import DnsProvider # noqa
from sqlalchemy.sql import text
manager = Manager(create_app)
manager.add_option('-c', '--config', dest='config_path', required=False)
manager.add_option("-c", "--config", dest="config_path", required=False)
migrate = Migrate(create_app)
REQUIRED_VARIABLES = [
'LEMUR_SECURITY_TEAM_EMAIL',
'LEMUR_DEFAULT_ORGANIZATIONAL_UNIT',
'LEMUR_DEFAULT_ORGANIZATION',
'LEMUR_DEFAULT_LOCATION',
'LEMUR_DEFAULT_COUNTRY',
'LEMUR_DEFAULT_STATE',
'SQLALCHEMY_DATABASE_URI'
"LEMUR_SECURITY_TEAM_EMAIL",
"LEMUR_DEFAULT_ORGANIZATIONAL_UNIT",
"LEMUR_DEFAULT_ORGANIZATION",
"LEMUR_DEFAULT_LOCATION",
"LEMUR_DEFAULT_COUNTRY",
"LEMUR_DEFAULT_STATE",
"SQLALCHEMY_DATABASE_URI",
]
KEY_LENGTH = 40
DEFAULT_CONFIG_PATH = '~/.lemur/lemur.conf.py'
DEFAULT_SETTINGS = 'lemur.conf.server'
SETTINGS_ENVVAR = 'LEMUR_CONF'
DEFAULT_CONFIG_PATH = "~/.lemur/lemur.conf.py"
DEFAULT_SETTINGS = "lemur.conf.server"
SETTINGS_ENVVAR = "LEMUR_CONF"
CONFIG_TEMPLATE = """
# This is just Python which means you can inherit and tweak settings
@ -144,9 +144,9 @@ SQLALCHEMY_DATABASE_URI = 'postgresql://lemur:lemur@localhost:5432/lemur'
@MigrateCommand.command
def create():
database.db.engine.execute(text('CREATE EXTENSION IF NOT EXISTS pg_trgm'))
database.db.engine.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
database.db.create_all()
stamp(revision='head')
stamp(revision="head")
@MigrateCommand.command
@ -174,9 +174,9 @@ def generate_settings():
output = CONFIG_TEMPLATE.format(
# we use Fernet.generate_key to make sure that the key length is
# compatible with Fernet
encryption_key=Fernet.generate_key().decode('utf-8'),
secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'),
flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode('utf-8'),
encryption_key=Fernet.generate_key().decode("utf-8"),
secret_token=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"),
flask_secret_key=base64.b64encode(os.urandom(KEY_LENGTH)).decode("utf-8"),
)
return output
@ -190,39 +190,44 @@ class InitializeApp(Command):
Additionally a Lemur user will be created as a default user
and be used when certificates are discovered by Lemur.
"""
option_list = (
Option('-p', '--password', dest='password'),
)
option_list = (Option("-p", "--password", dest="password"),)
def run(self, password):
create()
user = user_service.get_by_username("lemur")
admin_role = role_service.get_by_name('admin')
admin_role = role_service.get_by_name("admin")
if admin_role:
sys.stdout.write("[-] Admin role already created, skipping...!\n")
else:
# we create an admin role
admin_role = role_service.create('admin', description='This is the Lemur administrator role.')
admin_role = role_service.create(
"admin", description="This is the Lemur administrator role."
)
sys.stdout.write("[+] Created 'admin' role\n")
operator_role = role_service.get_by_name('operator')
operator_role = role_service.get_by_name("operator")
if operator_role:
sys.stdout.write("[-] Operator role already created, skipping...!\n")
else:
# we create an operator role
operator_role = role_service.create('operator', description='This is the Lemur operator role.')
operator_role = role_service.create(
"operator", description="This is the Lemur operator role."
)
sys.stdout.write("[+] Created 'operator' role\n")
read_only_role = role_service.get_by_name('read-only')
read_only_role = role_service.get_by_name("read-only")
if read_only_role:
sys.stdout.write("[-] Read only role already created, skipping...!\n")
else:
# we create an read only role
read_only_role = role_service.create('read-only', description='This is the Lemur read only role.')
read_only_role = role_service.create(
"read-only", description="This is the Lemur read only role."
)
sys.stdout.write("[+] Created 'read-only' role\n")
if not user:
@ -235,34 +240,54 @@ class InitializeApp(Command):
sys.stderr.write("[!] Passwords do not match!\n")
sys.exit(1)
user_service.create("lemur", password, 'lemur@nobody.com', True, None, [admin_role])
sys.stdout.write("[+] Created the user 'lemur' and granted it the 'admin' role!\n")
user_service.create(
"lemur", password, "lemur@nobody.com", True, None, [admin_role]
)
sys.stdout.write(
"[+] Created the user 'lemur' and granted it the 'admin' role!\n"
)
else:
sys.stdout.write("[-] Default user has already been created, skipping...!\n")
sys.stdout.write(
"[-] Default user has already been created, skipping...!\n"
)
intervals = current_app.config.get("LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", [])
intervals = current_app.config.get(
"LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS", []
)
sys.stdout.write(
"[!] Creating {num} notifications for {intervals} days as specified by LEMUR_DEFAULT_EXPIRATION_NOTIFICATION_INTERVALS\n".format(
num=len(intervals),
intervals=",".join([str(x) for x in intervals])
num=len(intervals), intervals=",".join([str(x) for x in intervals])
)
)
recipients = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL')
recipients = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")
sys.stdout.write("[+] Creating expiration email notifications!\n")
sys.stdout.write("[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format(recipients))
notification_service.create_default_expiration_notifications("DEFAULT_SECURITY", recipients=recipients)
sys.stdout.write(
"[!] Using {0} as specified by LEMUR_SECURITY_TEAM_EMAIL for notifications\n".format(
recipients
)
)
notification_service.create_default_expiration_notifications(
"DEFAULT_SECURITY", recipients=recipients
)
_DEFAULT_ROTATION_INTERVAL = 'default'
default_rotation_interval = policy_service.get_by_name(_DEFAULT_ROTATION_INTERVAL)
_DEFAULT_ROTATION_INTERVAL = "default"
default_rotation_interval = policy_service.get_by_name(
_DEFAULT_ROTATION_INTERVAL
)
if default_rotation_interval:
sys.stdout.write("[-] Default rotation interval policy already created, skipping...!\n")
sys.stdout.write(
"[-] Default rotation interval policy already created, skipping...!\n"
)
else:
days = current_app.config.get("LEMUR_DEFAULT_ROTATION_INTERVAL", 30)
sys.stdout.write("[+] Creating default certificate rotation policy of {days} days before issuance.\n".format(
days=days))
sys.stdout.write(
"[+] Creating default certificate rotation policy of {days} days before issuance.\n".format(
days=days
)
)
policy_service.create(days=days, name=_DEFAULT_ROTATION_INTERVAL)
sys.stdout.write("[/] Done!\n")
@ -272,12 +297,13 @@ class CreateUser(Command):
"""
This command allows for the creation of a new user within Lemur.
"""
option_list = (
Option('-u', '--username', dest='username', required=True),
Option('-e', '--email', dest='email', required=True),
Option('-a', '--active', dest='active', default=True),
Option('-r', '--roles', dest='roles', action='append', default=[]),
Option('-p', '--password', dest='password', default=None)
Option("-u", "--username", dest="username", required=True),
Option("-e", "--email", dest="email", required=True),
Option("-a", "--active", dest="active", default=True),
Option("-r", "--roles", dest="roles", action="append", default=[]),
Option("-p", "--password", dest="password", default=None),
)
def run(self, username, email, active, roles, password):
@ -307,9 +333,8 @@ class ResetPassword(Command):
"""
This command allows you to reset a user's password.
"""
option_list = (
Option('-u', '--username', dest='username', required=True),
)
option_list = (Option("-u", "--username", dest="username", required=True),)
def run(self, username):
user = user_service.get_by_username(username)
@ -335,10 +360,11 @@ class CreateRole(Command):
"""
This command allows for the creation of a new role within Lemur
"""
option_list = (
Option('-n', '--name', dest='name', required=True),
Option('-u', '--users', dest='users', default=[]),
Option('-d', '--description', dest='description', required=True)
Option("-n", "--name", dest="name", required=True),
Option("-u", "--users", dest="users", default=[]),
Option("-d", "--description", dest="description", required=True),
)
def run(self, name, users, description):
@ -369,7 +395,8 @@ class LemurServer(Command):
Will start gunicorn with 4 workers bound to 127.0.0.0:8002
"""
description = 'Run the app within Gunicorn'
description = "Run the app within Gunicorn"
def get_options(self):
settings = make_settings()
@ -377,8 +404,10 @@ class LemurServer(Command):
for setting, klass in settings.items():
if klass.cli:
if klass.action:
if klass.action == 'store_const':
options.append(Option(*klass.cli, const=klass.const, action=klass.action))
if klass.action == "store_const":
options.append(
Option(*klass.cli, const=klass.const, action=klass.action)
)
else:
options.append(Option(*klass.cli, action=klass.action))
else:
@ -394,7 +423,9 @@ class LemurServer(Command):
# run startup tasks on an app like object
validate_conf(current_app, REQUIRED_VARIABLES)
app.app_uri = 'lemur:create_app(config_path="{0}")'.format(current_app.config.get('CONFIG_PATH'))
app.app_uri = 'lemur:create_app(config_path="{0}")'.format(
current_app.config.get("CONFIG_PATH")
)
return app.run()
@ -414,7 +445,7 @@ def create_config(config_path=None):
os.makedirs(dir)
config = generate_settings()
with open(config_path, 'w') as f:
with open(config_path, "w") as f:
f.write(config)
sys.stdout.write("[+] Created a new configuration file {0}\n".format(config_path))
@ -436,7 +467,7 @@ def lock(path=None):
:param: path
"""
if not path:
path = os.path.expanduser('~/.lemur/keys')
path = os.path.expanduser("~/.lemur/keys")
dest_dir = os.path.join(path, "encrypted")
sys.stdout.write("[!] Generating a new key...\n")
@ -447,15 +478,17 @@ def lock(path=None):
sys.stdout.write("[+] Creating encryption directory: {0}\n".format(dest_dir))
os.makedirs(dest_dir)
for root, dirs, files in os.walk(os.path.join(path, 'decrypted')):
for root, dirs, files in os.walk(os.path.join(path, "decrypted")):
for f in files:
source = os.path.join(root, f)
dest = os.path.join(dest_dir, f + ".enc")
with open(source, 'rb') as in_file, open(dest, 'wb') as out_file:
with open(source, "rb") as in_file, open(dest, "wb") as out_file:
f = Fernet(key)
data = f.encrypt(in_file.read())
out_file.write(data)
sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source))
sys.stdout.write(
"[+] Writing file: {0} Source: {1}\n".format(dest, source)
)
sys.stdout.write("[+] Keys have been encrypted with key {0}\n".format(key))
@ -475,7 +508,7 @@ def unlock(path=None):
key = prompt_pass("[!] Please enter the encryption password")
if not path:
path = os.path.expanduser('~/.lemur/keys')
path = os.path.expanduser("~/.lemur/keys")
dest_dir = os.path.join(path, "decrypted")
source_dir = os.path.join(path, "encrypted")
@ -488,11 +521,13 @@ def unlock(path=None):
for f in files:
source = os.path.join(source_dir, f)
dest = os.path.join(dest_dir, ".".join(f.split(".")[:-1]))
with open(source, 'rb') as in_file, open(dest, 'wb') as out_file:
with open(source, "rb") as in_file, open(dest, "wb") as out_file:
f = Fernet(key)
data = f.decrypt(in_file.read())
out_file.write(data)
sys.stdout.write("[+] Writing file: {0} Source: {1}\n".format(dest, source))
sys.stdout.write(
"[+] Writing file: {0} Source: {1}\n".format(dest, source)
)
sys.stdout.write("[+] Keys have been unencrypted!\n")
@ -505,15 +540,16 @@ def publish_verisign_units():
:return:
"""
from lemur.plugins import plugins
v = plugins.get('verisign-issuer')
v = plugins.get("verisign-issuer")
units = v.get_available_units()
metrics = {}
for item in units:
if item['@type'] in metrics.keys():
metrics[item['@type']] += int(item['@remaining'])
if item["@type"] in metrics.keys():
metrics[item["@type"]] += int(item["@remaining"])
else:
metrics.update({item['@type']: int(item['@remaining'])})
metrics.update({item["@type"]: int(item["@remaining"])})
for name, value in metrics.items():
metric = [
@ -522,16 +558,16 @@ def publish_verisign_units():
"type": "GAUGE",
"name": "Symantec {0} Unit Count".format(name),
"tags": {},
"value": value
"value": value,
}
]
requests.post('http://localhost:8078/metrics', data=json.dumps(metric))
requests.post("http://localhost:8078/metrics", data=json.dumps(metric))
def main():
manager.add_command("start", LemurServer())
manager.add_command("runserver", Server(host='127.0.0.1', threaded=True))
manager.add_command("runserver", Server(host="127.0.0.1", threaded=True))
manager.add_command("clean", Clean())
manager.add_command("show_urls", ShowUrls())
manager.add_command("db", MigrateCommand)

View File

@ -11,6 +11,7 @@ class Metrics(object):
"""
:param app: The Flask application object. Defaults to None.
"""
_providers = []
def __init__(self, app=None):
@ -22,11 +23,14 @@ class Metrics(object):
:param app: The Flask application object.
"""
self._providers = app.config.get('METRIC_PROVIDERS', [])
self._providers = app.config.get("METRIC_PROVIDERS", [])
def send(self, metric_name, metric_type, metric_value, *args, **kwargs):
for provider in self._providers:
current_app.logger.debug(
"Sending metric '{metric}' to the {provider} provider.".format(metric=metric_name, provider=provider))
"Sending metric '{metric}' to the {provider} provider.".format(
metric=metric_name, provider=provider
)
)
p = plugins.get(provider)
p.submit(metric_name, metric_type, metric_value, *args, **kwargs)

View File

@ -19,8 +19,11 @@ fileConfig(config.config_file_name)
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from flask import current_app
config.set_main_option('sqlalchemy.url', current_app.config.get('SQLALCHEMY_DATABASE_URI'))
target_metadata = current_app.extensions['migrate'].db.metadata
config.set_main_option(
"sqlalchemy.url", current_app.config.get("SQLALCHEMY_DATABASE_URI")
)
target_metadata = current_app.extensions["migrate"].db.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
@ -54,14 +57,18 @@ def run_migrations_online():
and associate a connection with the context.
"""
engine = engine_from_config(config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
engine = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
connection = engine.connect()
context.configure(connection=connection,
context.configure(
connection=connection,
target_metadata=target_metadata,
**current_app.extensions['migrate'].configure_args)
**current_app.extensions["migrate"].configure_args
)
try:
with context.begin_transaction():
@ -69,8 +76,8 @@ def run_migrations_online():
finally:
connection.close()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -7,8 +7,8 @@ Create Date: 2016-12-07 17:29:42.049986
"""
# revision identifiers, used by Alembic.
revision = '131ec6accff5'
down_revision = 'e3691fc396e9'
revision = "131ec6accff5"
down_revision = "e3691fc396e9"
from alembic import op
import sqlalchemy as sa
@ -16,13 +16,24 @@ import sqlalchemy as sa
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('certificates', sa.Column('rotation', sa.Boolean(), nullable=False, server_default=sa.false()))
op.add_column('endpoints', sa.Column('last_updated', sa.DateTime(), server_default=sa.text('now()'), nullable=False))
op.add_column(
"certificates",
sa.Column("rotation", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"endpoints",
sa.Column(
"last_updated",
sa.DateTime(),
server_default=sa.text("now()"),
nullable=False,
),
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('endpoints', 'last_updated')
op.drop_column('certificates', 'rotation')
op.drop_column("endpoints", "last_updated")
op.drop_column("certificates", "rotation")
# ### end Alembic commands ###

View File

@ -7,15 +7,19 @@ Create Date: 2017-07-13 12:32:09.162800
"""
# revision identifiers, used by Alembic.
revision = '1ae8e3104db8'
down_revision = 'a02a678ddc25'
revision = "1ae8e3104db8"
down_revision = "a02a678ddc25"
from alembic import op
def upgrade():
op.sync_enum_values('public', 'log_type', ['key_view'], ['create_cert', 'key_view', 'update_cert'])
op.sync_enum_values(
"public", "log_type", ["key_view"], ["create_cert", "key_view", "update_cert"]
)
def downgrade():
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['key_view'])
op.sync_enum_values(
"public", "log_type", ["create_cert", "key_view", "update_cert"], ["key_view"]
)

View File

@ -7,8 +7,8 @@ Create Date: 2018-08-03 12:56:44.565230
"""
# revision identifiers, used by Alembic.
revision = '1db4f82bc780'
down_revision = '3adfdd6598df'
revision = "1db4f82bc780"
down_revision = "3adfdd6598df"
import logging
@ -20,12 +20,14 @@ log = logging.getLogger(__name__)
def upgrade():
connection = op.get_bind()
result = connection.execute("""\
result = connection.execute(
"""\
UPDATE certificates
SET rotation_policy_id=(SELECT id FROM rotation_policies WHERE name='default')
WHERE rotation_policy_id IS NULL
RETURNING id
""")
"""
)
log.info("Filled rotation_policy for %d certificates" % result.rowcount)

View File

@ -7,8 +7,8 @@ Create Date: 2016-06-28 16:05:25.720213
"""
# revision identifiers, used by Alembic.
revision = '29d8c8455c86'
down_revision = '3307381f3b88'
revision = "29d8c8455c86"
down_revision = "3307381f3b88"
from alembic import op
import sqlalchemy as sa
@ -17,46 +17,60 @@ from sqlalchemy.dialects import postgresql
def upgrade():
### commands auto generated by Alembic - please adjust! ###
op.create_table('ciphers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=128), nullable=False),
sa.PrimaryKeyConstraint('id')
op.create_table(
"ciphers",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=128), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('policy',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=128), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"policy",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=128), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('policies_ciphers',
sa.Column('cipher_id', sa.Integer(), nullable=True),
sa.Column('policy_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['cipher_id'], ['ciphers.id'], ),
sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], )
op.create_table(
"policies_ciphers",
sa.Column("cipher_id", sa.Integer(), nullable=True),
sa.Column("policy_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["cipher_id"], ["ciphers.id"]),
sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]),
)
op.create_index('policies_ciphers_ix', 'policies_ciphers', ['cipher_id', 'policy_id'], unique=False)
op.create_table('endpoints',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('owner', sa.String(length=128), nullable=True),
sa.Column('name', sa.String(length=128), nullable=True),
sa.Column('dnsname', sa.String(length=256), nullable=True),
sa.Column('type', sa.String(length=128), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True),
sa.Column('port', sa.Integer(), nullable=True),
sa.Column('date_created', sa.DateTime(), server_default=sa.text(u'now()'), nullable=False),
sa.Column('policy_id', sa.Integer(), nullable=True),
sa.Column('certificate_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ),
sa.ForeignKeyConstraint(['policy_id'], ['policy.id'], ),
sa.PrimaryKeyConstraint('id')
op.create_index(
"policies_ciphers_ix",
"policies_ciphers",
["cipher_id", "policy_id"],
unique=False,
)
op.create_table(
"endpoints",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("owner", sa.String(length=128), nullable=True),
sa.Column("name", sa.String(length=128), nullable=True),
sa.Column("dnsname", sa.String(length=256), nullable=True),
sa.Column("type", sa.String(length=128), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.Column("port", sa.Integer(), nullable=True),
sa.Column(
"date_created",
sa.DateTime(),
server_default=sa.text(u"now()"),
nullable=False,
),
sa.Column("policy_id", sa.Integer(), nullable=True),
sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]),
sa.ForeignKeyConstraint(["policy_id"], ["policy.id"]),
sa.PrimaryKeyConstraint("id"),
)
### end Alembic commands ###
def downgrade():
### commands auto generated by Alembic - please adjust! ###
op.drop_table('endpoints')
op.drop_index('policies_ciphers_ix', table_name='policies_ciphers')
op.drop_table('policies_ciphers')
op.drop_table('policy')
op.drop_table('ciphers')
op.drop_table("endpoints")
op.drop_index("policies_ciphers_ix", table_name="policies_ciphers")
op.drop_table("policies_ciphers")
op.drop_table("policy")
op.drop_table("ciphers")
### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2019-02-05 15:42:25.477587
"""
# revision identifiers, used by Alembic.
revision = '318b66568358'
down_revision = '9f79024fe67b'
revision = "318b66568358"
down_revision = "9f79024fe67b"
from alembic import op
@ -16,7 +16,7 @@ from alembic import op
def upgrade():
connection = op.get_bind()
# Delete duplicate entries
connection.execute('UPDATE certificates SET deleted = false WHERE deleted IS NULL')
connection.execute("UPDATE certificates SET deleted = false WHERE deleted IS NULL")
def downgrade():

View File

@ -12,8 +12,8 @@ Create Date: 2016-05-20 17:33:04.360687
"""
# revision identifiers, used by Alembic.
revision = '3307381f3b88'
down_revision = '412b22cb656a'
revision = "3307381f3b88"
down_revision = "412b22cb656a"
from alembic import op
import sqlalchemy as sa
@ -23,109 +23,165 @@ from sqlalchemy.dialects import postgresql
def upgrade():
### commands auto generated by Alembic - please adjust! ###
op.alter_column('authorities', 'owner',
existing_type=sa.VARCHAR(length=128),
nullable=True)
op.drop_column('authorities', 'not_after')
op.drop_column('authorities', 'bits')
op.drop_column('authorities', 'cn')
op.drop_column('authorities', 'not_before')
op.add_column('certificates', sa.Column('root_authority_id', sa.Integer(), nullable=True))
op.alter_column('certificates', 'body',
existing_type=sa.TEXT(),
nullable=False)
op.alter_column('certificates', 'owner',
existing_type=sa.VARCHAR(length=128),
nullable=True)
op.drop_constraint(u'certificates_authority_id_fkey', 'certificates', type_='foreignkey')
op.create_foreign_key(None, 'certificates', 'authorities', ['authority_id'], ['id'], ondelete='CASCADE')
op.create_foreign_key(None, 'certificates', 'authorities', ['root_authority_id'], ['id'], ondelete='CASCADE')
op.alter_column(
"authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
)
op.drop_column("authorities", "not_after")
op.drop_column("authorities", "bits")
op.drop_column("authorities", "cn")
op.drop_column("authorities", "not_before")
op.add_column(
"certificates", sa.Column("root_authority_id", sa.Integer(), nullable=True)
)
op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=False)
op.alter_column(
"certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
)
op.drop_constraint(
u"certificates_authority_id_fkey", "certificates", type_="foreignkey"
)
op.create_foreign_key(
None,
"certificates",
"authorities",
["authority_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
None,
"certificates",
"authorities",
["root_authority_id"],
["id"],
ondelete="CASCADE",
)
### end Alembic commands ###
# link existing certificate to their authority certificates
conn = op.get_bind()
for id, body, owner in conn.execute(text('select id, body, owner from authorities')):
for id, body, owner in conn.execute(
text("select id, body, owner from authorities")
):
if not owner:
owner = "lemur@nobody"
# look up certificate by body, if duplications are found, pick one
stmt = text('select id from certificates where body=:body')
stmt = text("select id from certificates where body=:body")
stmt = stmt.bindparams(body=body)
root_certificate = conn.execute(stmt).fetchone()
if root_certificate:
stmt = text('update certificates set root_authority_id=:root_authority_id where id=:id')
stmt = text(
"update certificates set root_authority_id=:root_authority_id where id=:id"
)
stmt = stmt.bindparams(root_authority_id=id, id=root_certificate[0])
op.execute(stmt)
# link owner roles to their authorities
stmt = text('select id from roles where name=:name')
stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone()
if not owner_role:
stmt = text('insert into roles (name, description) values (:name, :description)')
stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.')
stmt = text(
"insert into roles (name, description) values (:name, :description)"
)
stmt = stmt.bindparams(
name=owner, description="Lemur generated role or existing owner."
)
op.execute(stmt)
stmt = text('select id from roles where name=:name')
stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone()
stmt = text('select * from roles_authorities where role_id=:role_id and authority_id=:authority_id')
stmt = text(
"select * from roles_authorities where role_id=:role_id and authority_id=:authority_id"
)
stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id)
exists = conn.execute(stmt).fetchone()
if not exists:
stmt = text('insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)')
stmt = text(
"insert into roles_authorities (role_id, authority_id) values (:role_id, :authority_id)"
)
stmt = stmt.bindparams(role_id=owner_role[0], authority_id=id)
op.execute(stmt)
# link owner roles to their certificates
for id, owner in conn.execute(text('select id, owner from certificates')):
for id, owner in conn.execute(text("select id, owner from certificates")):
if not owner:
owner = "lemur@nobody"
stmt = text('select id from roles where name=:name')
stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone()
if not owner_role:
stmt = text('insert into roles (name, description) values (:name, :description)')
stmt = stmt.bindparams(name=owner, description='Lemur generated role or existing owner.')
stmt = text(
"insert into roles (name, description) values (:name, :description)"
)
stmt = stmt.bindparams(
name=owner, description="Lemur generated role or existing owner."
)
op.execute(stmt)
# link owner roles to their authorities
stmt = text('select id from roles where name=:name')
stmt = text("select id from roles where name=:name")
stmt = stmt.bindparams(name=owner)
owner_role = conn.execute(stmt).fetchone()
stmt = text('select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id')
stmt = text(
"select * from roles_certificates where role_id=:role_id and certificate_id=:certificate_id"
)
stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id)
exists = conn.execute(stmt).fetchone()
if not exists:
stmt = text('insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)')
stmt = text(
"insert into roles_certificates (role_id, certificate_id) values (:role_id, :certificate_id)"
)
stmt = stmt.bindparams(role_id=owner_role[0], certificate_id=id)
op.execute(stmt)
def downgrade():
### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'certificates', type_='foreignkey')
op.drop_constraint(None, 'certificates', type_='foreignkey')
op.create_foreign_key(u'certificates_authority_id_fkey', 'certificates', 'authorities', ['authority_id'], ['id'])
op.alter_column('certificates', 'owner',
existing_type=sa.VARCHAR(length=128),
nullable=True)
op.alter_column('certificates', 'body',
existing_type=sa.TEXT(),
nullable=True)
op.drop_column('certificates', 'root_authority_id')
op.add_column('authorities', sa.Column('not_before', postgresql.TIMESTAMP(), autoincrement=False, nullable=True))
op.add_column('authorities', sa.Column('cn', sa.VARCHAR(length=128), autoincrement=False, nullable=True))
op.add_column('authorities', sa.Column('bits', sa.INTEGER(), autoincrement=False, nullable=True))
op.add_column('authorities', sa.Column('not_after', postgresql.TIMESTAMP(), autoincrement=False, nullable=True))
op.alter_column('authorities', 'owner',
existing_type=sa.VARCHAR(length=128),
nullable=True)
op.drop_constraint(None, "certificates", type_="foreignkey")
op.drop_constraint(None, "certificates", type_="foreignkey")
op.create_foreign_key(
u"certificates_authority_id_fkey",
"certificates",
"authorities",
["authority_id"],
["id"],
)
op.alter_column(
"certificates", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
)
op.alter_column("certificates", "body", existing_type=sa.TEXT(), nullable=True)
op.drop_column("certificates", "root_authority_id")
op.add_column(
"authorities",
sa.Column(
"not_before", postgresql.TIMESTAMP(), autoincrement=False, nullable=True
),
)
op.add_column(
"authorities",
sa.Column("cn", sa.VARCHAR(length=128), autoincrement=False, nullable=True),
)
op.add_column(
"authorities",
sa.Column("bits", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.add_column(
"authorities",
sa.Column(
"not_after", postgresql.TIMESTAMP(), autoincrement=False, nullable=True
),
)
op.alter_column(
"authorities", "owner", existing_type=sa.VARCHAR(length=128), nullable=True
)
### end Alembic commands ###

View File

@ -7,25 +7,31 @@ Create Date: 2015-11-30 15:40:19.827272
"""
# revision identifiers, used by Alembic.
revision = '33de094da890'
revision = "33de094da890"
down_revision = None
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
def upgrade():
### commands auto generated by Alembic - please adjust! ###
op.create_table('certificate_replacement_associations',
sa.Column('replaced_certificate_id', sa.Integer(), nullable=True),
sa.Column('certificate_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ondelete='cascade'),
sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade')
op.create_table(
"certificate_replacement_associations",
sa.Column("replaced_certificate_id", sa.Integer(), nullable=True),
sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["certificate_id"], ["certificates.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["replaced_certificate_id"], ["certificates.id"], ondelete="cascade"
),
)
### end Alembic commands ###
def downgrade():
### commands auto generated by Alembic - please adjust! ###
op.drop_table('certificate_replacement_associations')
op.drop_table("certificate_replacement_associations")
### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-04-10 13:25:47.007556
"""
# revision identifiers, used by Alembic.
revision = '3adfdd6598df'
down_revision = '556ceb3e3c3e'
revision = "3adfdd6598df"
down_revision = "556ceb3e3c3e"
import sqlalchemy as sa
from alembic import op
@ -22,84 +22,90 @@ def upgrade():
# create provider table
print("Creating dns_providers table")
op.create_table(
'dns_providers',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=256), nullable=True),
sa.Column('description', sa.String(length=1024), nullable=True),
sa.Column('provider_type', sa.String(length=256), nullable=True),
sa.Column('credentials', Vault(), nullable=True),
sa.Column('api_endpoint', sa.String(length=256), nullable=True),
sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False),
sa.Column('status', sa.String(length=128), nullable=True),
sa.Column('options', JSON),
sa.Column('domains', sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
"dns_providers",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=256), nullable=True),
sa.Column("description", sa.String(length=1024), nullable=True),
sa.Column("provider_type", sa.String(length=256), nullable=True),
sa.Column("credentials", Vault(), nullable=True),
sa.Column("api_endpoint", sa.String(length=256), nullable=True),
sa.Column(
"date_created", ArrowType(), server_default=sa.text("now()"), nullable=False
),
sa.Column("status", sa.String(length=128), nullable=True),
sa.Column("options", JSON),
sa.Column("domains", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
print("Adding dns_provider_id column to certificates")
op.add_column('certificates', sa.Column('dns_provider_id', sa.Integer(), nullable=True))
op.add_column(
"certificates", sa.Column("dns_provider_id", sa.Integer(), nullable=True)
)
print("Adding dns_provider_id column to pending_certs")
op.add_column('pending_certs', sa.Column('dns_provider_id', sa.Integer(), nullable=True))
op.add_column(
"pending_certs", sa.Column("dns_provider_id", sa.Integer(), nullable=True)
)
print("Adding options column to pending_certs")
op.add_column('pending_certs', sa.Column('options', JSON))
op.add_column("pending_certs", sa.Column("options", JSON))
print("Creating pending_dns_authorizations table")
op.create_table(
'pending_dns_authorizations',
sa.Column('id', sa.Integer(), primary_key=True, autoincrement=True),
sa.Column('account_number', sa.String(length=128), nullable=True),
sa.Column('domains', JSON, nullable=True),
sa.Column('dns_provider_type', sa.String(length=128), nullable=True),
sa.Column('options', JSON, nullable=True),
"pending_dns_authorizations",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("account_number", sa.String(length=128), nullable=True),
sa.Column("domains", JSON, nullable=True),
sa.Column("dns_provider_type", sa.String(length=128), nullable=True),
sa.Column("options", JSON, nullable=True),
)
print("Creating certificates_dns_providers_fk foreign key")
op.create_foreign_key('certificates_dns_providers_fk', 'certificates', 'dns_providers', ['dns_provider_id'], ['id'],
ondelete='cascade')
op.create_foreign_key(
"certificates_dns_providers_fk",
"certificates",
"dns_providers",
["dns_provider_id"],
["id"],
ondelete="cascade",
)
print("Altering column types in the api_keys table")
op.alter_column('api_keys', 'issued_at',
existing_type=sa.BIGINT(),
nullable=True)
op.alter_column('api_keys', 'revoked',
existing_type=sa.BOOLEAN(),
nullable=True)
op.alter_column('api_keys', 'ttl',
existing_type=sa.BIGINT(),
nullable=True)
op.alter_column('api_keys', 'user_id',
existing_type=sa.INTEGER(),
nullable=True)
op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=True)
op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=True)
op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=True)
op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=True)
print("Creating dns_providers_id foreign key on pending_certs table")
op.create_foreign_key(None, 'pending_certs', 'dns_providers', ['dns_provider_id'], ['id'], ondelete='CASCADE')
op.create_foreign_key(
None,
"pending_certs",
"dns_providers",
["dns_provider_id"],
["id"],
ondelete="CASCADE",
)
def downgrade():
print("Removing dns_providers_id foreign key on pending_certs table")
op.drop_constraint(None, 'pending_certs', type_='foreignkey')
op.drop_constraint(None, "pending_certs", type_="foreignkey")
print("Reverting column types in the api_keys table")
op.alter_column('api_keys', 'user_id',
existing_type=sa.INTEGER(),
nullable=False)
op.alter_column('api_keys', 'ttl',
existing_type=sa.BIGINT(),
nullable=False)
op.alter_column('api_keys', 'revoked',
existing_type=sa.BOOLEAN(),
nullable=False)
op.alter_column('api_keys', 'issued_at',
existing_type=sa.BIGINT(),
nullable=False)
op.alter_column("api_keys", "user_id", existing_type=sa.INTEGER(), nullable=False)
op.alter_column("api_keys", "ttl", existing_type=sa.BIGINT(), nullable=False)
op.alter_column("api_keys", "revoked", existing_type=sa.BOOLEAN(), nullable=False)
op.alter_column("api_keys", "issued_at", existing_type=sa.BIGINT(), nullable=False)
print("Reverting certificates_dns_providers_fk foreign key")
op.drop_constraint('certificates_dns_providers_fk', 'certificates', type_='foreignkey')
op.drop_constraint(
"certificates_dns_providers_fk", "certificates", type_="foreignkey"
)
print("Dropping pending_dns_authorizations table")
op.drop_table('pending_dns_authorizations')
op.drop_table("pending_dns_authorizations")
print("Undoing modifications to pending_certs table")
op.drop_column('pending_certs', 'options')
op.drop_column('pending_certs', 'dns_provider_id')
op.drop_column("pending_certs", "options")
op.drop_column("pending_certs", "dns_provider_id")
print("Undoing modifications to certificates table")
op.drop_column('certificates', 'dns_provider_id')
op.drop_column("certificates", "dns_provider_id")
print("Deleting dns_providers table")
op.drop_table('dns_providers')
op.drop_table("dns_providers")

View File

@ -7,8 +7,8 @@ Create Date: 2016-05-17 17:37:41.210232
"""
# revision identifiers, used by Alembic.
revision = '412b22cb656a'
down_revision = '4c50b903d1ae'
revision = "412b22cb656a"
down_revision = "4c50b903d1ae"
from alembic import op
import sqlalchemy as sa
@ -17,47 +17,102 @@ from sqlalchemy.sql import text
def upgrade():
### commands auto generated by Alembic - please adjust! ###
op.create_table('roles_authorities',
sa.Column('authority_id', sa.Integer(), nullable=True),
sa.Column('role_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ),
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], )
op.create_table(
"roles_authorities",
sa.Column("authority_id", sa.Integer(), nullable=True),
sa.Column("role_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["authority_id"], ["authorities.id"]),
sa.ForeignKeyConstraint(["role_id"], ["roles.id"]),
)
op.create_index('roles_authorities_ix', 'roles_authorities', ['authority_id', 'role_id'], unique=True)
op.create_table('roles_certificates',
sa.Column('certificate_id', sa.Integer(), nullable=True),
sa.Column('role_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ),
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], )
op.create_index(
"roles_authorities_ix",
"roles_authorities",
["authority_id", "role_id"],
unique=True,
)
op.create_table(
"roles_certificates",
sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.Column("role_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]),
sa.ForeignKeyConstraint(["role_id"], ["roles.id"]),
)
op.create_index(
"roles_certificates_ix",
"roles_certificates",
["certificate_id", "role_id"],
unique=True,
)
op.create_index(
"certificate_associations_ix",
"certificate_associations",
["domain_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_destination_associations_ix",
"certificate_destination_associations",
["destination_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_notification_associations_ix",
"certificate_notification_associations",
["notification_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["certificate_id", "certificate_id"],
unique=True,
)
op.create_index(
"certificate_source_associations_ix",
"certificate_source_associations",
["source_id", "certificate_id"],
unique=True,
)
op.create_index(
"roles_users_ix", "roles_users", ["user_id", "role_id"], unique=True
)
op.create_index('roles_certificates_ix', 'roles_certificates', ['certificate_id', 'role_id'], unique=True)
op.create_index('certificate_associations_ix', 'certificate_associations', ['domain_id', 'certificate_id'], unique=True)
op.create_index('certificate_destination_associations_ix', 'certificate_destination_associations', ['destination_id', 'certificate_id'], unique=True)
op.create_index('certificate_notification_associations_ix', 'certificate_notification_associations', ['notification_id', 'certificate_id'], unique=True)
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True)
op.create_index('certificate_source_associations_ix', 'certificate_source_associations', ['source_id', 'certificate_id'], unique=True)
op.create_index('roles_users_ix', 'roles_users', ['user_id', 'role_id'], unique=True)
### end Alembic commands ###
# migrate existing authority_id relationship to many_to_many
conn = op.get_bind()
for id, authority_id in conn.execute(text('select id, authority_id from roles where authority_id is not null')):
stmt = text('insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)')
for id, authority_id in conn.execute(
text("select id, authority_id from roles where authority_id is not null")
):
stmt = text(
"insert into roles_authoritties (role_id, authority_id) values (:role_id, :authority_id)"
)
stmt = stmt.bindparams(role_id=id, authority_id=authority_id)
op.execute(stmt)
def downgrade():
### commands auto generated by Alembic - please adjust! ###
op.drop_index('roles_users_ix', table_name='roles_users')
op.drop_index('certificate_source_associations_ix', table_name='certificate_source_associations')
op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations')
op.drop_index('certificate_notification_associations_ix', table_name='certificate_notification_associations')
op.drop_index('certificate_destination_associations_ix', table_name='certificate_destination_associations')
op.drop_index('certificate_associations_ix', table_name='certificate_associations')
op.drop_index('roles_certificates_ix', table_name='roles_certificates')
op.drop_table('roles_certificates')
op.drop_index('roles_authorities_ix', table_name='roles_authorities')
op.drop_table('roles_authorities')
op.drop_index("roles_users_ix", table_name="roles_users")
op.drop_index(
"certificate_source_associations_ix",
table_name="certificate_source_associations",
)
op.drop_index(
"certificate_replacement_associations_ix",
table_name="certificate_replacement_associations",
)
op.drop_index(
"certificate_notification_associations_ix",
table_name="certificate_notification_associations",
)
op.drop_index(
"certificate_destination_associations_ix",
table_name="certificate_destination_associations",
)
op.drop_index("certificate_associations_ix", table_name="certificate_associations")
op.drop_index("roles_certificates_ix", table_name="roles_certificates")
op.drop_table("roles_certificates")
op.drop_index("roles_authorities_ix", table_name="roles_authorities")
op.drop_table("roles_authorities")
### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-02-24 22:51:35.369229
"""
# revision identifiers, used by Alembic.
revision = '449c3d5c7299'
down_revision = '5770674184de'
revision = "449c3d5c7299"
down_revision = "5770674184de"
from alembic import op
from flask_sqlalchemy import SQLAlchemy
@ -23,12 +23,14 @@ COLUMNS = ["notification_id", "certificate_id"]
def upgrade():
connection = op.get_bind()
# Delete duplicate entries
connection.execute("""\
connection.execute(
"""\
DELETE FROM certificate_notification_associations WHERE ctid NOT IN (
-- Select the first tuple ID for each (notification_id, certificate_id) combination and keep that
SELECT min(ctid) FROM certificate_notification_associations GROUP BY notification_id, certificate_id
)
""")
"""
)
op.create_unique_constraint(CONSTRAINT_NAME, TABLE, COLUMNS)

View File

@ -7,20 +7,21 @@ Create Date: 2015-12-30 10:19:30.057791
"""
# revision identifiers, used by Alembic.
revision = '4c50b903d1ae'
down_revision = '33de094da890'
revision = "4c50b903d1ae"
down_revision = "33de094da890"
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
def upgrade():
### commands auto generated by Alembic - please adjust! ###
op.add_column('domains', sa.Column('sensitive', sa.Boolean(), nullable=True))
op.add_column("domains", sa.Column("sensitive", sa.Boolean(), nullable=True))
### end Alembic commands ###
def downgrade():
### commands auto generated by Alembic - please adjust! ###
op.drop_column('domains', 'sensitive')
op.drop_column("domains", "sensitive")
### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-01-05 01:18:45.571595
"""
# revision identifiers, used by Alembic.
revision = '556ceb3e3c3e'
down_revision = '449c3d5c7299'
revision = "556ceb3e3c3e"
down_revision = "449c3d5c7299"
from alembic import op
import sqlalchemy as sa
@ -16,84 +16,150 @@ from lemur.utils import Vault
from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import ArrowType
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('pending_certs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('external_id', sa.String(length=128), nullable=True),
sa.Column('owner', sa.String(length=128), nullable=False),
sa.Column('name', sa.String(length=256), nullable=True),
sa.Column('description', sa.String(length=1024), nullable=True),
sa.Column('notify', sa.Boolean(), nullable=True),
sa.Column('number_attempts', sa.Integer(), nullable=True),
sa.Column('rename', sa.Boolean(), nullable=True),
sa.Column('cn', sa.String(length=128), nullable=True),
sa.Column('csr', sa.Text(), nullable=False),
sa.Column('chain', sa.Text(), nullable=True),
sa.Column('private_key', Vault(), nullable=True),
sa.Column('date_created', ArrowType(), server_default=sa.text('now()'), nullable=False),
sa.Column('status', sa.String(length=128), nullable=True),
sa.Column('rotation', sa.Boolean(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('authority_id', sa.Integer(), nullable=True),
sa.Column('root_authority_id', sa.Integer(), nullable=True),
sa.Column('rotation_policy_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['authority_id'], ['authorities.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['root_authority_id'], ['authorities.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['rotation_policy_id'], ['rotation_policies.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name')
op.create_table(
"pending_certs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(length=128), nullable=True),
sa.Column("owner", sa.String(length=128), nullable=False),
sa.Column("name", sa.String(length=256), nullable=True),
sa.Column("description", sa.String(length=1024), nullable=True),
sa.Column("notify", sa.Boolean(), nullable=True),
sa.Column("number_attempts", sa.Integer(), nullable=True),
sa.Column("rename", sa.Boolean(), nullable=True),
sa.Column("cn", sa.String(length=128), nullable=True),
sa.Column("csr", sa.Text(), nullable=False),
sa.Column("chain", sa.Text(), nullable=True),
sa.Column("private_key", Vault(), nullable=True),
sa.Column(
"date_created", ArrowType(), server_default=sa.text("now()"), nullable=False
),
sa.Column("status", sa.String(length=128), nullable=True),
sa.Column("rotation", sa.Boolean(), nullable=True),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("authority_id", sa.Integer(), nullable=True),
sa.Column("root_authority_id", sa.Integer(), nullable=True),
sa.Column("rotation_policy_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["authority_id"], ["authorities.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["root_authority_id"], ["authorities.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["rotation_policy_id"], ["rotation_policies.id"]),
sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_table('pending_cert_destination_associations',
sa.Column('destination_id', sa.Integer(), nullable=True),
sa.Column('pending_cert_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['destination_id'], ['destinations.id'], ondelete='cascade'),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade')
op.create_table(
"pending_cert_destination_associations",
sa.Column("destination_id", sa.Integer(), nullable=True),
sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["destination_id"], ["destinations.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
)
op.create_index('pending_cert_destination_associations_ix', 'pending_cert_destination_associations', ['destination_id', 'pending_cert_id'], unique=False)
op.create_table('pending_cert_notification_associations',
sa.Column('notification_id', sa.Integer(), nullable=True),
sa.Column('pending_cert_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['notification_id'], ['notifications.id'], ondelete='cascade'),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade')
op.create_index(
"pending_cert_destination_associations_ix",
"pending_cert_destination_associations",
["destination_id", "pending_cert_id"],
unique=False,
)
op.create_index('pending_cert_notification_associations_ix', 'pending_cert_notification_associations', ['notification_id', 'pending_cert_id'], unique=False)
op.create_table('pending_cert_replacement_associations',
sa.Column('replaced_certificate_id', sa.Integer(), nullable=True),
sa.Column('pending_cert_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'),
sa.ForeignKeyConstraint(['replaced_certificate_id'], ['certificates.id'], ondelete='cascade')
op.create_table(
"pending_cert_notification_associations",
sa.Column("notification_id", sa.Integer(), nullable=True),
sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["notification_id"], ["notifications.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
)
op.create_index('pending_cert_replacement_associations_ix', 'pending_cert_replacement_associations', ['replaced_certificate_id', 'pending_cert_id'], unique=False)
op.create_table('pending_cert_role_associations',
sa.Column('pending_cert_id', sa.Integer(), nullable=True),
sa.Column('role_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ),
sa.ForeignKeyConstraint(['role_id'], ['roles.id'], )
op.create_index(
"pending_cert_notification_associations_ix",
"pending_cert_notification_associations",
["notification_id", "pending_cert_id"],
unique=False,
)
op.create_index('pending_cert_role_associations_ix', 'pending_cert_role_associations', ['pending_cert_id', 'role_id'], unique=False)
op.create_table('pending_cert_source_associations',
sa.Column('source_id', sa.Integer(), nullable=True),
sa.Column('pending_cert_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['pending_cert_id'], ['pending_certs.id'], ondelete='cascade'),
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='cascade')
op.create_table(
"pending_cert_replacement_associations",
sa.Column("replaced_certificate_id", sa.Integer(), nullable=True),
sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(
["replaced_certificate_id"], ["certificates.id"], ondelete="cascade"
),
)
op.create_index(
"pending_cert_replacement_associations_ix",
"pending_cert_replacement_associations",
["replaced_certificate_id", "pending_cert_id"],
unique=False,
)
op.create_table(
"pending_cert_role_associations",
sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.Column("role_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["pending_cert_id"], ["pending_certs.id"]),
sa.ForeignKeyConstraint(["role_id"], ["roles.id"]),
)
op.create_index(
"pending_cert_role_associations_ix",
"pending_cert_role_associations",
["pending_cert_id", "role_id"],
unique=False,
)
op.create_table(
"pending_cert_source_associations",
sa.Column("source_id", sa.Integer(), nullable=True),
sa.Column("pending_cert_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["pending_cert_id"], ["pending_certs.id"], ondelete="cascade"
),
sa.ForeignKeyConstraint(["source_id"], ["sources.id"], ondelete="cascade"),
)
op.create_index(
"pending_cert_source_associations_ix",
"pending_cert_source_associations",
["source_id", "pending_cert_id"],
unique=False,
)
op.create_index('pending_cert_source_associations_ix', 'pending_cert_source_associations', ['source_id', 'pending_cert_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('pending_cert_source_associations_ix', table_name='pending_cert_source_associations')
op.drop_table('pending_cert_source_associations')
op.drop_index('pending_cert_role_associations_ix', table_name='pending_cert_role_associations')
op.drop_table('pending_cert_role_associations')
op.drop_index('pending_cert_replacement_associations_ix', table_name='pending_cert_replacement_associations')
op.drop_table('pending_cert_replacement_associations')
op.drop_index('pending_cert_notification_associations_ix', table_name='pending_cert_notification_associations')
op.drop_table('pending_cert_notification_associations')
op.drop_index('pending_cert_destination_associations_ix', table_name='pending_cert_destination_associations')
op.drop_table('pending_cert_destination_associations')
op.drop_table('pending_certs')
op.drop_index(
"pending_cert_source_associations_ix",
table_name="pending_cert_source_associations",
)
op.drop_table("pending_cert_source_associations")
op.drop_index(
"pending_cert_role_associations_ix", table_name="pending_cert_role_associations"
)
op.drop_table("pending_cert_role_associations")
op.drop_index(
"pending_cert_replacement_associations_ix",
table_name="pending_cert_replacement_associations",
)
op.drop_table("pending_cert_replacement_associations")
op.drop_index(
"pending_cert_notification_associations_ix",
table_name="pending_cert_notification_associations",
)
op.drop_table("pending_cert_notification_associations")
op.drop_index(
"pending_cert_destination_associations_ix",
table_name="pending_cert_destination_associations",
)
op.drop_table("pending_cert_destination_associations")
op.drop_table("pending_certs")
# ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2018-02-23 15:27:30.335435
"""
# revision identifiers, used by Alembic.
revision = '5770674184de'
down_revision = 'ce547319f7be'
revision = "5770674184de"
down_revision = "ce547319f7be"
from flask_sqlalchemy import SQLAlchemy
from lemur.models import certificate_notification_associations
@ -32,7 +32,9 @@ def upgrade():
# If we've seen a pair already, delete the duplicates
if seen.get("{}-{}".format(x.certificate_id, x.notification_id)):
print("Deleting duplicate: {}".format(x))
d = session.query(certificate_notification_associations).filter(certificate_notification_associations.c.id==x.id)
d = session.query(certificate_notification_associations).filter(
certificate_notification_associations.c.id == x.id
)
d.delete(synchronize_session=False)
seen["{}-{}".format(x.certificate_id, x.notification_id)] = True
db.session.commit()

View File

@ -7,8 +7,8 @@ Create Date: 2018-08-14 08:16:43.329316
"""
# revision identifiers, used by Alembic.
revision = '5ae0ecefb01f'
down_revision = '1db4f82bc780'
revision = "5ae0ecefb01f"
down_revision = "1db4f82bc780"
from alembic import op
import sqlalchemy as sa
@ -16,17 +16,14 @@ import sqlalchemy as sa
def upgrade():
op.alter_column(
table_name='pending_certs',
column_name='status',
nullable=True,
type_=sa.TEXT()
table_name="pending_certs", column_name="status", nullable=True, type_=sa.TEXT()
)
def downgrade():
op.alter_column(
table_name='pending_certs',
column_name='status',
table_name="pending_certs",
column_name="status",
nullable=True,
type_=sa.VARCHAR(128)
type_=sa.VARCHAR(128),
)

View File

@ -7,16 +7,18 @@ Create Date: 2017-12-08 14:19:11.903864
"""
# revision identifiers, used by Alembic.
revision = '5bc47fa7cac4'
down_revision = 'c05a8998b371'
revision = "5bc47fa7cac4"
down_revision = "c05a8998b371"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('roles', sa.Column('third_party', sa.Boolean(), nullable=True, default=False))
op.add_column(
"roles", sa.Column("third_party", sa.Boolean(), nullable=True, default=False)
)
def downgrade():
op.drop_column('roles', 'third_party')
op.drop_column("roles", "third_party")

View File

@ -7,20 +7,20 @@ Create Date: 2017-01-26 05:05:25.168125
"""
# revision identifiers, used by Alembic.
revision = '5e680529b666'
down_revision = '131ec6accff5'
revision = "5e680529b666"
down_revision = "131ec6accff5"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('endpoints', sa.Column('sensitive', sa.Boolean(), nullable=True))
op.add_column('endpoints', sa.Column('source_id', sa.Integer(), nullable=True))
op.create_foreign_key(None, 'endpoints', 'sources', ['source_id'], ['id'])
op.add_column("endpoints", sa.Column("sensitive", sa.Boolean(), nullable=True))
op.add_column("endpoints", sa.Column("source_id", sa.Integer(), nullable=True))
op.create_foreign_key(None, "endpoints", "sources", ["source_id"], ["id"])
def downgrade():
op.drop_constraint(None, 'endpoints', type_='foreignkey')
op.drop_column('endpoints', 'source_id')
op.drop_column('endpoints', 'sensitive')
op.drop_constraint(None, "endpoints", type_="foreignkey")
op.drop_column("endpoints", "source_id")
op.drop_column("endpoints", "sensitive")

View File

@ -7,15 +7,15 @@ Create Date: 2018-10-19 15:23:06.750510
"""
# revision identifiers, used by Alembic.
revision = '6006c79b6011'
down_revision = '984178255c83'
revision = "6006c79b6011"
down_revision = "984178255c83"
from alembic import op
def upgrade():
op.create_unique_constraint("uq_label", 'sources', ['label'])
op.create_unique_constraint("uq_label", "sources", ["label"])
def downgrade():
op.drop_constraint("uq_label", 'sources', type_='unique')
op.drop_constraint("uq_label", "sources", type_="unique")

View File

@ -7,15 +7,16 @@ Create Date: 2018-10-21 22:06:23.056906
"""
# revision identifiers, used by Alembic.
revision = '7ead443ba911'
down_revision = '6006c79b6011'
revision = "7ead443ba911"
down_revision = "6006c79b6011"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('certificates', sa.Column('csr', sa.TEXT(), nullable=True))
op.add_column("certificates", sa.Column("csr", sa.TEXT(), nullable=True))
def downgrade():
op.drop_column('certificates', 'csr')
op.drop_column("certificates", "csr")

View File

@ -9,8 +9,8 @@ Create Date: 2016-07-28 09:39:12.736506
"""
# revision identifiers, used by Alembic.
revision = '7f71c0cea31a'
down_revision = '29d8c8455c86'
revision = "7f71c0cea31a"
down_revision = "29d8c8455c86"
from alembic import op
import sqlalchemy as sa
@ -19,17 +19,25 @@ from sqlalchemy.sql import text
def upgrade():
conn = op.get_bind()
for name in conn.execute(text('select name from certificates group by name having count(*) > 1')):
for idx, id in enumerate(conn.execute(text("select id from certificates where certificates.name like :name order by id ASC").bindparams(name=name[0]))):
for name in conn.execute(
text("select name from certificates group by name having count(*) > 1")
):
for idx, id in enumerate(
conn.execute(
text(
"select id from certificates where certificates.name like :name order by id ASC"
).bindparams(name=name[0])
)
):
if not idx:
continue
new_name = name[0] + '-' + str(idx)
stmt = text('update certificates set name=:name where id=:id')
new_name = name[0] + "-" + str(idx)
stmt = text("update certificates set name=:name where id=:id")
stmt = stmt.bindparams(name=new_name, id=id[0])
op.execute(stmt)
op.create_unique_constraint(None, 'certificates', ['name'])
op.create_unique_constraint(None, "certificates", ["name"])
def downgrade():
op.drop_constraint(None, 'certificates', type_='unique')
op.drop_constraint(None, "certificates", type_="unique")

View File

@ -7,18 +7,28 @@ Create Date: 2017-05-10 11:56:13.999332
"""
# revision identifiers, used by Alembic.
revision = '8ae67285ff14'
down_revision = '5e680529b666'
revision = "8ae67285ff14"
down_revision = "5e680529b666"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.drop_index('certificate_replacement_associations_ix')
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True)
op.drop_index("certificate_replacement_associations_ix")
op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["replaced_certificate_id", "certificate_id"],
unique=True,
)
def downgrade():
op.drop_index('certificate_replacement_associations_ix')
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['certificate_id', 'certificate_id'], unique=True)
op.drop_index("certificate_replacement_associations_ix")
op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["certificate_id", "certificate_id"],
unique=True,
)

View File

@ -7,15 +7,15 @@ Create Date: 2016-10-13 20:14:33.928029
"""
# revision identifiers, used by Alembic.
revision = '932525b82f1a'
down_revision = '7f71c0cea31a'
revision = "932525b82f1a"
down_revision = "7f71c0cea31a"
from alembic import op
def upgrade():
op.alter_column('certificates', 'active', new_column_name='notify')
op.alter_column("certificates", "active", new_column_name="notify")
def downgrade():
op.alter_column('certificates', 'notify', new_column_name='active')
op.alter_column("certificates", "notify", new_column_name="active")

View File

@ -6,8 +6,8 @@ Create Date: 2018-09-17 08:33:37.087488
"""
# revision identifiers, used by Alembic.
revision = '9392b9f9a805'
down_revision = '5ae0ecefb01f'
revision = "9392b9f9a805"
down_revision = "5ae0ecefb01f"
from alembic import op
from sqlalchemy_utils import ArrowType
@ -15,10 +15,17 @@ import sqlalchemy as sa
def upgrade():
op.add_column('pending_certs', sa.Column('last_updated', ArrowType, server_default=sa.text('now()'), onupdate=sa.text('now()'),
nullable=False))
op.add_column(
"pending_certs",
sa.Column(
"last_updated",
ArrowType,
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
)
def downgrade():
op.drop_column('pending_certs', 'last_updated')
op.drop_column("pending_certs", "last_updated")

View File

@ -7,18 +7,20 @@ Create Date: 2018-10-11 20:49:12.704563
"""
# revision identifiers, used by Alembic.
revision = '984178255c83'
down_revision = 'f2383bf08fbc'
revision = "984178255c83"
down_revision = "f2383bf08fbc"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column('pending_certs', sa.Column('resolved', sa.Boolean(), nullable=True))
op.add_column('pending_certs', sa.Column('resolved_cert_id', sa.Integer(), nullable=True))
op.add_column("pending_certs", sa.Column("resolved", sa.Boolean(), nullable=True))
op.add_column(
"pending_certs", sa.Column("resolved_cert_id", sa.Integer(), nullable=True)
)
def downgrade():
op.drop_column('pending_certs', 'resolved_cert_id')
op.drop_column('pending_certs', 'resolved')
op.drop_column("pending_certs", "resolved_cert_id")
op.drop_column("pending_certs", "resolved")

View File

@ -7,16 +7,26 @@ Create Date: 2019-01-03 15:36:59.181911
"""
# revision identifiers, used by Alembic.
revision = '9f79024fe67b'
down_revision = 'ee827d1e1974'
revision = "9f79024fe67b"
down_revision = "ee827d1e1974"
from alembic import op
import sqlalchemy as sa
def upgrade():
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'delete_cert', 'key_view', 'revoke_cert', 'update_cert'])
op.sync_enum_values(
"public",
"log_type",
["create_cert", "key_view", "revoke_cert", "update_cert"],
["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"],
)
def downgrade():
op.sync_enum_values('public', 'log_type', ['create_cert', 'delete_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'key_view', 'revoke_cert', 'update_cert'])
op.sync_enum_values(
"public",
"log_type",
["create_cert", "delete_cert", "key_view", "revoke_cert", "update_cert"],
["create_cert", "key_view", "revoke_cert", "update_cert"],
)

View File

@ -10,8 +10,8 @@ Create Date: 2017-07-12 11:45:49.257927
"""
# revision identifiers, used by Alembic.
revision = 'a02a678ddc25'
down_revision = '8ae67285ff14'
revision = "a02a678ddc25"
down_revision = "8ae67285ff14"
from alembic import op
import sqlalchemy as sa
@ -20,25 +20,30 @@ from sqlalchemy.sql import text
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('rotation_policies',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('days', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"rotation_policies",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("days", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column(
"certificates", sa.Column("rotation_policy_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
None, "certificates", "rotation_policies", ["rotation_policy_id"], ["id"]
)
op.add_column('certificates', sa.Column('rotation_policy_id', sa.Integer(), nullable=True))
op.create_foreign_key(None, 'certificates', 'rotation_policies', ['rotation_policy_id'], ['id'])
conn = op.get_bind()
stmt = text('insert into rotation_policies (days, name) values (:days, :name)')
stmt = stmt.bindparams(days=30, name='default')
stmt = text("insert into rotation_policies (days, name) values (:days, :name)")
stmt = stmt.bindparams(days=30, name="default")
conn.execute(stmt)
stmt = text('select id from rotation_policies where name=:name')
stmt = stmt.bindparams(name='default')
stmt = text("select id from rotation_policies where name=:name")
stmt = stmt.bindparams(name="default")
rotation_policy_id = conn.execute(stmt).fetchone()[0]
stmt = text('update certificates set rotation_policy_id=:rotation_policy_id')
stmt = text("update certificates set rotation_policy_id=:rotation_policy_id")
stmt = stmt.bindparams(rotation_policy_id=rotation_policy_id)
conn.execute(stmt)
# ### end Alembic commands ###
@ -46,9 +51,17 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'certificates', type_='foreignkey')
op.drop_column('certificates', 'rotation_policy_id')
op.drop_index('certificate_replacement_associations_ix', table_name='certificate_replacement_associations')
op.create_index('certificate_replacement_associations_ix', 'certificate_replacement_associations', ['replaced_certificate_id', 'certificate_id'], unique=True)
op.drop_table('rotation_policies')
op.drop_constraint(None, "certificates", type_="foreignkey")
op.drop_column("certificates", "rotation_policy_id")
op.drop_index(
"certificate_replacement_associations_ix",
table_name="certificate_replacement_associations",
)
op.create_index(
"certificate_replacement_associations_ix",
"certificate_replacement_associations",
["replaced_certificate_id", "certificate_id"],
unique=True,
)
op.drop_table("rotation_policies")
# ### end Alembic commands ###

View File

@ -7,8 +7,8 @@ Create Date: 2017-10-11 10:16:39.682591
"""
# revision identifiers, used by Alembic.
revision = 'ac483cfeb230'
down_revision = 'b29e2c4bf8c9'
revision = "ac483cfeb230"
down_revision = "b29e2c4bf8c9"
from alembic import op
import sqlalchemy as sa
@ -16,12 +16,18 @@ from sqlalchemy.dialects import postgresql
def upgrade():
op.alter_column('certificates', 'name',
op.alter_column(
"certificates",
"name",
existing_type=sa.VARCHAR(length=128),
type_=sa.String(length=256))
type_=sa.String(length=256),
)
def downgrade():
op.alter_column('certificates', 'name',
op.alter_column(
"certificates",
"name",
existing_type=sa.VARCHAR(length=256),
type_=sa.String(length=128))
type_=sa.String(length=128),
)

View File

@ -7,8 +7,8 @@ Create Date: 2017-09-26 10:50:35.740367
"""
# revision identifiers, used by Alembic.
revision = 'b29e2c4bf8c9'
down_revision = '1ae8e3104db8'
revision = "b29e2c4bf8c9"
down_revision = "1ae8e3104db8"
from alembic import op
import sqlalchemy as sa
@ -16,13 +16,25 @@ import sqlalchemy as sa
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('certificates', sa.Column('external_id', sa.String(128), nullable=True))
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'update_cert'], ['create_cert', 'key_view', 'revoke_cert', 'update_cert'])
op.add_column(
"certificates", sa.Column("external_id", sa.String(128), nullable=True)
)
op.sync_enum_values(
"public",
"log_type",
["create_cert", "key_view", "update_cert"],
["create_cert", "key_view", "revoke_cert", "update_cert"],
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'log_type', ['create_cert', 'key_view', 'revoke_cert', 'update_cert'], ['create_cert', 'key_view', 'update_cert'])
op.drop_column('certificates', 'external_id')
op.sync_enum_values(
"public",
"log_type",
["create_cert", "key_view", "revoke_cert", "update_cert"],
["create_cert", "key_view", "update_cert"],
)
op.drop_column("certificates", "external_id")
# ### end Alembic commands ###

View File

@ -7,25 +7,27 @@ Create Date: 2017-11-10 14:51:28.975927
"""
# revision identifiers, used by Alembic.
revision = 'c05a8998b371'
down_revision = 'ac483cfeb230'
revision = "c05a8998b371"
down_revision = "ac483cfeb230"
from alembic import op
import sqlalchemy as sa
import sqlalchemy_utils
def upgrade():
op.create_table('api_keys',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=128), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('ttl', sa.BigInteger(), nullable=False),
sa.Column('issued_at', sa.BigInteger(), nullable=False),
sa.Column('revoked', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
op.create_table(
"api_keys",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=128), nullable=True),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("ttl", sa.BigInteger(), nullable=False),
sa.Column("issued_at", sa.BigInteger(), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.PrimaryKeyConstraint("id"),
)
def downgrade():
op.drop_table('api_keys')
op.drop_table("api_keys")

View File

@ -5,15 +5,15 @@ Create Date: 2018-10-11 09:44:57.099854
"""
revision = 'c87cb989af04'
down_revision = '9392b9f9a805'
revision = "c87cb989af04"
down_revision = "9392b9f9a805"
from alembic import op
def upgrade():
op.create_index(op.f('ix_domains_name'), 'domains', ['name'], unique=False)
op.create_index(op.f("ix_domains_name"), "domains", ["name"], unique=False)
def downgrade():
op.drop_index(op.f('ix_domains_name'), table_name='domains')
op.drop_index(op.f("ix_domains_name"), table_name="domains")

View File

@ -7,8 +7,8 @@ Create Date: 2018-02-23 11:00:02.150561
"""
# revision identifiers, used by Alembic.
revision = 'ce547319f7be'
down_revision = '5bc47fa7cac4'
revision = "ce547319f7be"
down_revision = "5bc47fa7cac4"
import sqlalchemy as sa
@ -24,12 +24,12 @@ TABLE = "certificate_notification_associations"
def upgrade():
print("Adding id column")
op.add_column(
TABLE,
sa.Column('id', sa.Integer, primary_key=True, autoincrement=True)
TABLE, sa.Column("id", sa.Integer, primary_key=True, autoincrement=True)
)
db.session.commit()
db.session.flush()
def downgrade():
op.drop_column(TABLE, "id")
db.session.commit()

View File

@ -7,29 +7,36 @@ Create Date: 2016-11-28 13:15:46.995219
"""
# revision identifiers, used by Alembic.
revision = 'e3691fc396e9'
down_revision = '932525b82f1a'
revision = "e3691fc396e9"
down_revision = "932525b82f1a"
from alembic import op
import sqlalchemy as sa
import sqlalchemy_utils
def upgrade():
### commands auto generated by Alembic - please adjust! ###
op.create_table('logs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('certificate_id', sa.Integer(), nullable=True),
sa.Column('log_type', sa.Enum('key_view', name='log_type'), nullable=False),
sa.Column('logged_at', sqlalchemy_utils.types.arrow.ArrowType(), server_default=sa.text('now()'), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['certificate_id'], ['certificates.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
op.create_table(
"logs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("certificate_id", sa.Integer(), nullable=True),
sa.Column("log_type", sa.Enum("key_view", name="log_type"), nullable=False),
sa.Column(
"logged_at",
sqlalchemy_utils.types.arrow.ArrowType(),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["certificate_id"], ["certificates.id"]),
sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.PrimaryKeyConstraint("id"),
)
### end Alembic commands ###
def downgrade():
### commands auto generated by Alembic - please adjust! ###
op.drop_table('logs')
op.drop_table("logs")
### end Alembic commands ###

View File

@ -7,25 +7,44 @@ Create Date: 2018-11-05 09:49:40.226368
"""
# revision identifiers, used by Alembic.
revision = 'ee827d1e1974'
down_revision = '7ead443ba911'
revision = "ee827d1e1974"
down_revision = "7ead443ba911"
from alembic import op
from sqlalchemy.exc import ProgrammingError
def upgrade():
connection = op.get_bind()
connection.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
op.create_index('ix_certificates_cn', 'certificates', ['cn'], unique=False, postgresql_ops={'cn': 'gin_trgm_ops'},
postgresql_using='gin')
op.create_index('ix_certificates_name', 'certificates', ['name'], unique=False,
postgresql_ops={'name': 'gin_trgm_ops'}, postgresql_using='gin')
op.create_index('ix_domains_name_gin', 'domains', ['name'], unique=False, postgresql_ops={'name': 'gin_trgm_ops'},
postgresql_using='gin')
op.create_index(
"ix_certificates_cn",
"certificates",
["cn"],
unique=False,
postgresql_ops={"cn": "gin_trgm_ops"},
postgresql_using="gin",
)
op.create_index(
"ix_certificates_name",
"certificates",
["name"],
unique=False,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
)
op.create_index(
"ix_domains_name_gin",
"domains",
["name"],
unique=False,
postgresql_ops={"name": "gin_trgm_ops"},
postgresql_using="gin",
)
def downgrade():
op.drop_index('ix_domains_name', table_name='domains')
op.drop_index('ix_certificates_name', table_name='certificates')
op.drop_index('ix_certificates_cn', table_name='certificates')
op.drop_index("ix_domains_name", table_name="domains")
op.drop_index("ix_certificates_name", table_name="certificates")
op.drop_index("ix_certificates_cn", table_name="certificates")

View File

@ -7,17 +7,22 @@ Create Date: 2018-10-11 11:23:31.195471
"""
revision = 'f2383bf08fbc'
down_revision = 'c87cb989af04'
revision = "f2383bf08fbc"
down_revision = "c87cb989af04"
import sqlalchemy as sa
from alembic import op
def upgrade():
op.create_index('ix_certificates_id_desc', 'certificates', [sa.text('id DESC')], unique=True,
postgresql_using='btree')
op.create_index(
"ix_certificates_id_desc",
"certificates",
[sa.text("id DESC")],
unique=True,
postgresql_using="btree",
)
def downgrade():
op.drop_index('ix_certificates_id_desc', table_name='certificates')
op.drop_index("ix_certificates_id_desc", table_name="certificates")

View File

@ -12,121 +12,201 @@ from sqlalchemy import Column, Integer, ForeignKey, Index, UniqueConstraint
from lemur.database import db
certificate_associations = db.Table('certificate_associations',
Column('domain_id', Integer, ForeignKey('domains.id')),
Column('certificate_id', Integer, ForeignKey('certificates.id'))
)
certificate_associations = db.Table(
"certificate_associations",
Column("domain_id", Integer, ForeignKey("domains.id")),
Column("certificate_id", Integer, ForeignKey("certificates.id")),
)
Index('certificate_associations_ix', certificate_associations.c.domain_id, certificate_associations.c.certificate_id)
Index(
"certificate_associations_ix",
certificate_associations.c.domain_id,
certificate_associations.c.certificate_id,
)
certificate_destination_associations = db.Table('certificate_destination_associations',
Column('destination_id', Integer,
ForeignKey('destinations.id', ondelete='cascade')),
Column('certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade'))
)
certificate_destination_associations = db.Table(
"certificate_destination_associations",
Column(
"destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade")
),
Column(
"certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
),
)
Index('certificate_destination_associations_ix', certificate_destination_associations.c.destination_id, certificate_destination_associations.c.certificate_id)
Index(
"certificate_destination_associations_ix",
certificate_destination_associations.c.destination_id,
certificate_destination_associations.c.certificate_id,
)
certificate_source_associations = db.Table('certificate_source_associations',
Column('source_id', Integer,
ForeignKey('sources.id', ondelete='cascade')),
Column('certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade'))
)
certificate_source_associations = db.Table(
"certificate_source_associations",
Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")),
Column(
"certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
),
)
Index('certificate_source_associations_ix', certificate_source_associations.c.source_id, certificate_source_associations.c.certificate_id)
Index(
"certificate_source_associations_ix",
certificate_source_associations.c.source_id,
certificate_source_associations.c.certificate_id,
)
certificate_notification_associations = db.Table('certificate_notification_associations',
Column('notification_id', Integer,
ForeignKey('notifications.id', ondelete='cascade')),
Column('certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade')),
Column('id', Integer, primary_key=True, autoincrement=True),
UniqueConstraint('notification_id', 'certificate_id', name='uq_dest_not_ids')
)
certificate_notification_associations = db.Table(
"certificate_notification_associations",
Column(
"notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade")
),
Column(
"certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
),
Column("id", Integer, primary_key=True, autoincrement=True),
UniqueConstraint("notification_id", "certificate_id", name="uq_dest_not_ids"),
)
Index('certificate_notification_associations_ix', certificate_notification_associations.c.notification_id, certificate_notification_associations.c.certificate_id)
Index(
"certificate_notification_associations_ix",
certificate_notification_associations.c.notification_id,
certificate_notification_associations.c.certificate_id,
)
certificate_replacement_associations = db.Table('certificate_replacement_associations',
Column('replaced_certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade')),
Column('certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade'))
)
certificate_replacement_associations = db.Table(
"certificate_replacement_associations",
Column(
"replaced_certificate_id",
Integer,
ForeignKey("certificates.id", ondelete="cascade"),
),
Column(
"certificate_id", Integer, ForeignKey("certificates.id", ondelete="cascade")
),
)
Index('certificate_replacement_associations_ix', certificate_replacement_associations.c.replaced_certificate_id, certificate_replacement_associations.c.certificate_id, unique=True)
Index(
"certificate_replacement_associations_ix",
certificate_replacement_associations.c.replaced_certificate_id,
certificate_replacement_associations.c.certificate_id,
unique=True,
)
roles_authorities = db.Table('roles_authorities',
Column('authority_id', Integer, ForeignKey('authorities.id')),
Column('role_id', Integer, ForeignKey('roles.id'))
)
roles_authorities = db.Table(
"roles_authorities",
Column("authority_id", Integer, ForeignKey("authorities.id")),
Column("role_id", Integer, ForeignKey("roles.id")),
)
Index('roles_authorities_ix', roles_authorities.c.authority_id, roles_authorities.c.role_id)
Index(
"roles_authorities_ix",
roles_authorities.c.authority_id,
roles_authorities.c.role_id,
)
roles_certificates = db.Table('roles_certificates',
Column('certificate_id', Integer, ForeignKey('certificates.id')),
Column('role_id', Integer, ForeignKey('roles.id'))
)
roles_certificates = db.Table(
"roles_certificates",
Column("certificate_id", Integer, ForeignKey("certificates.id")),
Column("role_id", Integer, ForeignKey("roles.id")),
)
Index('roles_certificates_ix', roles_certificates.c.certificate_id, roles_certificates.c.role_id)
Index(
"roles_certificates_ix",
roles_certificates.c.certificate_id,
roles_certificates.c.role_id,
)
roles_users = db.Table('roles_users',
Column('user_id', Integer, ForeignKey('users.id')),
Column('role_id', Integer, ForeignKey('roles.id'))
)
roles_users = db.Table(
"roles_users",
Column("user_id", Integer, ForeignKey("users.id")),
Column("role_id", Integer, ForeignKey("roles.id")),
)
Index('roles_users_ix', roles_users.c.user_id, roles_users.c.role_id)
Index("roles_users_ix", roles_users.c.user_id, roles_users.c.role_id)
policies_ciphers = db.Table('policies_ciphers',
Column('cipher_id', Integer, ForeignKey('ciphers.id')),
Column('policy_id', Integer, ForeignKey('policy.id')))
policies_ciphers = db.Table(
"policies_ciphers",
Column("cipher_id", Integer, ForeignKey("ciphers.id")),
Column("policy_id", Integer, ForeignKey("policy.id")),
)
Index('policies_ciphers_ix', policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id)
Index("policies_ciphers_ix", policies_ciphers.c.cipher_id, policies_ciphers.c.policy_id)
pending_cert_destination_associations = db.Table('pending_cert_destination_associations',
Column('destination_id', Integer,
ForeignKey('destinations.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
)
pending_cert_destination_associations = db.Table(
"pending_cert_destination_associations",
Column(
"destination_id", Integer, ForeignKey("destinations.id", ondelete="cascade")
),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index('pending_cert_destination_associations_ix', pending_cert_destination_associations.c.destination_id, pending_cert_destination_associations.c.pending_cert_id)
Index(
"pending_cert_destination_associations_ix",
pending_cert_destination_associations.c.destination_id,
pending_cert_destination_associations.c.pending_cert_id,
)
pending_cert_notification_associations = db.Table('pending_cert_notification_associations',
Column('notification_id', Integer,
ForeignKey('notifications.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
)
pending_cert_notification_associations = db.Table(
"pending_cert_notification_associations",
Column(
"notification_id", Integer, ForeignKey("notifications.id", ondelete="cascade")
),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index('pending_cert_notification_associations_ix', pending_cert_notification_associations.c.notification_id, pending_cert_notification_associations.c.pending_cert_id)
Index(
"pending_cert_notification_associations_ix",
pending_cert_notification_associations.c.notification_id,
pending_cert_notification_associations.c.pending_cert_id,
)
pending_cert_source_associations = db.Table('pending_cert_source_associations',
Column('source_id', Integer,
ForeignKey('sources.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
)
pending_cert_source_associations = db.Table(
"pending_cert_source_associations",
Column("source_id", Integer, ForeignKey("sources.id", ondelete="cascade")),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index('pending_cert_source_associations_ix', pending_cert_source_associations.c.source_id, pending_cert_source_associations.c.pending_cert_id)
Index(
"pending_cert_source_associations_ix",
pending_cert_source_associations.c.source_id,
pending_cert_source_associations.c.pending_cert_id,
)
pending_cert_replacement_associations = db.Table('pending_cert_replacement_associations',
Column('replaced_certificate_id', Integer,
ForeignKey('certificates.id', ondelete='cascade')),
Column('pending_cert_id', Integer,
ForeignKey('pending_certs.id', ondelete='cascade'))
)
pending_cert_replacement_associations = db.Table(
"pending_cert_replacement_associations",
Column(
"replaced_certificate_id",
Integer,
ForeignKey("certificates.id", ondelete="cascade"),
),
Column(
"pending_cert_id", Integer, ForeignKey("pending_certs.id", ondelete="cascade")
),
)
Index('pending_cert_replacement_associations_ix', pending_cert_replacement_associations.c.replaced_certificate_id, pending_cert_replacement_associations.c.pending_cert_id)
Index(
"pending_cert_replacement_associations_ix",
pending_cert_replacement_associations.c.replaced_certificate_id,
pending_cert_replacement_associations.c.pending_cert_id,
)
pending_cert_role_associations = db.Table('pending_cert_role_associations',
Column('pending_cert_id', Integer, ForeignKey('pending_certs.id')),
Column('role_id', Integer, ForeignKey('roles.id'))
)
pending_cert_role_associations = db.Table(
"pending_cert_role_associations",
Column("pending_cert_id", Integer, ForeignKey("pending_certs.id")),
Column("role_id", Integer, ForeignKey("roles.id")),
)
Index('pending_cert_role_associations_ix', pending_cert_role_associations.c.pending_cert_id, pending_cert_role_associations.c.role_id)
Index(
"pending_cert_role_associations_ix",
pending_cert_role_associations.c.pending_cert_id,
pending_cert_role_associations.c.role_id,
)

View File

@ -14,7 +14,14 @@ from lemur.notifications.messaging import send_expiration_notifications
manager = Manager(usage="Handles notification related tasks.")
@manager.option('-e', '--exclude', dest='exclude', action='append', default=[], help='Common name matching of certificates that should be excluded from notification')
@manager.option(
"-e",
"--exclude",
dest="exclude",
action="append",
default=[],
help="Common name matching of certificates that should be excluded from notification",
)
def expirations(exclude):
"""
Runs Lemur's notification engine, that looks for expired certificates and sends
@ -33,12 +40,13 @@ def expirations(exclude):
success, failed = send_expiration_notifications(exclude)
print(
"Finished notifying subscribers about expiring certificates! Sent: {success} Failed: {failed}".format(
success=success,
failed=failed
success=success, failed=failed
)
)
status = SUCCESS_METRIC_STATUS
except Exception as e:
sentry.captureException()
metrics.send('expiration_notification_job', 'counter', 1, metric_tags={'status': status})
metrics.send(
"expiration_notification_job", "counter", 1, metric_tags={"status": status}
)

View File

@ -36,15 +36,17 @@ def get_certificates(exclude=None):
now = arrow.utcnow()
max = now + timedelta(days=90)
q = database.db.session.query(Certificate) \
.filter(Certificate.not_after <= max) \
.filter(Certificate.notify == True) \
.filter(Certificate.expired == False) # noqa
q = (
database.db.session.query(Certificate)
.filter(Certificate.not_after <= max)
.filter(Certificate.notify == True)
.filter(Certificate.expired == False)
) # noqa
exclude_conditions = []
if exclude:
for e in exclude:
exclude_conditions.append(~Certificate.name.ilike('%{}%'.format(e)))
exclude_conditions.append(~Certificate.name.ilike("%{}%".format(e)))
q = q.filter(and_(*exclude_conditions))
@ -101,7 +103,12 @@ def send_notification(event_type, data, targets, notification):
except Exception as e:
sentry.captureException()
metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': event_type})
metrics.send(
"notification",
"counter",
1,
metric_tags={"status": status, "event_type": event_type},
)
if status == SUCCESS_METRIC_STATUS:
return True
@ -115,7 +122,7 @@ def send_expiration_notifications(exclude):
success = failure = 0
# security team gets all
security_email = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL')
security_email = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")
security_data = []
for owner, notification_group in get_eligible_certificates(exclude=exclude).items():
@ -127,26 +134,43 @@ def send_expiration_notifications(exclude):
for data in certificates:
n, certificate = data
cert_data = certificate_notification_output_schema.dump(certificate).data
cert_data = certificate_notification_output_schema.dump(
certificate
).data
notification_data.append(cert_data)
security_data.append(cert_data)
notification_recipient = get_plugin_option('recipients', notification.options)
notification_recipient = get_plugin_option(
"recipients", notification.options
)
if notification_recipient:
notification_recipient = notification_recipient.split(",")
if send_notification('expiration', notification_data, [owner], notification):
if send_notification(
"expiration", notification_data, [owner], notification
):
success += 1
else:
failure += 1
if notification_recipient and owner != notification_recipient and security_email != notification_recipient:
if send_notification('expiration', notification_data, notification_recipient, notification):
if (
notification_recipient
and owner != notification_recipient
and security_email != notification_recipient
):
if send_notification(
"expiration",
notification_data,
notification_recipient,
notification,
):
success += 1
else:
failure += 1
if send_notification('expiration', security_data, security_email, notification):
if send_notification(
"expiration", security_data, security_email, notification
):
success += 1
else:
failure += 1
@ -165,24 +189,35 @@ def send_rotation_notification(certificate, notification_plugin=None):
"""
status = FAILURE_METRIC_STATUS
if not notification_plugin:
notification_plugin = plugins.get(current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN'))
notification_plugin = plugins.get(
current_app.config.get("LEMUR_DEFAULT_NOTIFICATION_PLUGIN")
)
data = certificate_notification_output_schema.dump(certificate).data
try:
notification_plugin.send('rotation', data, [data['owner']])
notification_plugin.send("rotation", data, [data["owner"]])
status = SUCCESS_METRIC_STATUS
except Exception as e:
current_app.logger.error('Unable to send notification to {}.'.format(data['owner']), exc_info=True)
current_app.logger.error(
"Unable to send notification to {}.".format(data["owner"]), exc_info=True
)
sentry.captureException()
metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'})
metrics.send(
"notification",
"counter",
1,
metric_tags={"status": status, "event_type": "rotation"},
)
if status == SUCCESS_METRIC_STATUS:
return True
def send_pending_failure_notification(pending_cert, notify_owner=True, notify_security=True, notification_plugin=None):
def send_pending_failure_notification(
pending_cert, notify_owner=True, notify_security=True, notification_plugin=None
):
"""
Sends a report to certificate owners when their pending certificate failed to be created.
@ -194,32 +229,47 @@ def send_pending_failure_notification(pending_cert, notify_owner=True, notify_se
if not notification_plugin:
notification_plugin = plugins.get(
current_app.config.get('LEMUR_DEFAULT_NOTIFICATION_PLUGIN', 'email-notification')
current_app.config.get(
"LEMUR_DEFAULT_NOTIFICATION_PLUGIN", "email-notification"
)
)
data = pending_certificate_output_schema.dump(pending_cert).data
data["security_email"] = current_app.config.get('LEMUR_SECURITY_TEAM_EMAIL')
data["security_email"] = current_app.config.get("LEMUR_SECURITY_TEAM_EMAIL")
if notify_owner:
try:
notification_plugin.send('failed', data, [data['owner']], pending_cert)
notification_plugin.send("failed", data, [data["owner"]], pending_cert)
status = SUCCESS_METRIC_STATUS
except Exception as e:
current_app.logger.error('Unable to send pending failure notification to {}.'.format(data['owner']),
exc_info=True)
current_app.logger.error(
"Unable to send pending failure notification to {}.".format(
data["owner"]
),
exc_info=True,
)
sentry.captureException()
if notify_security:
try:
notification_plugin.send('failed', data, data["security_email"], pending_cert)
notification_plugin.send(
"failed", data, data["security_email"], pending_cert
)
status = SUCCESS_METRIC_STATUS
except Exception as e:
current_app.logger.error('Unable to send pending failure notification to '
'{}.'.format(data['security_email']),
exc_info=True)
current_app.logger.error(
"Unable to send pending failure notification to "
"{}.".format(data["security_email"]),
exc_info=True,
)
sentry.captureException()
metrics.send('notification', 'counter', 1, metric_tags={'status': status, 'event_type': 'rotation'})
metrics.send(
"notification",
"counter",
1,
metric_tags={"status": status, "event_type": "rotation"},
)
if status == SUCCESS_METRIC_STATUS:
return True
@ -242,20 +292,22 @@ def needs_notification(certificate):
if not notification.active or not notification.options:
return
interval = get_plugin_option('interval', notification.options)
unit = get_plugin_option('unit', notification.options)
interval = get_plugin_option("interval", notification.options)
unit = get_plugin_option("unit", notification.options)
if unit == 'weeks':
if unit == "weeks":
interval *= 7
elif unit == 'months':
elif unit == "months":
interval *= 30
elif unit == 'days': # it's nice to be explicit about the base unit
elif unit == "days": # it's nice to be explicit about the base unit
pass
else:
raise Exception("Invalid base unit for expiration interval: {0}".format(unit))
raise Exception(
"Invalid base unit for expiration interval: {0}".format(unit)
)
if days == interval:
notifications.append(notification)

View File

@ -11,12 +11,14 @@ from sqlalchemy_utils import JSONType
from lemur.database import db
from lemur.plugins.base import plugins
from lemur.models import certificate_notification_associations, \
pending_cert_notification_associations
from lemur.models import (
certificate_notification_associations,
pending_cert_notification_associations,
)
class Notification(db.Model):
__tablename__ = 'notifications'
__tablename__ = "notifications"
id = Column(Integer, primary_key=True)
label = Column(String(128), unique=True)
description = Column(Text())
@ -28,14 +30,14 @@ class Notification(db.Model):
secondary=certificate_notification_associations,
passive_deletes=True,
backref="notification",
cascade='all,delete'
cascade="all,delete",
)
pending_certificates = relationship(
"PendingCertificate",
secondary=pending_cert_notification_associations,
passive_deletes=True,
backref="notification",
cascade='all,delete'
cascade="all,delete",
)
@property

View File

@ -7,7 +7,11 @@
"""
from marshmallow import fields, post_dump
from lemur.common.schema import LemurInputSchema, LemurOutputSchema
from lemur.schemas import PluginInputSchema, PluginOutputSchema, AssociatedCertificateSchema
from lemur.schemas import (
PluginInputSchema,
PluginOutputSchema,
AssociatedCertificateSchema,
)
class NotificationInputSchema(LemurInputSchema):
@ -30,7 +34,7 @@ class NotificationOutputSchema(LemurOutputSchema):
@post_dump
def fill_object(self, data):
if data:
data['plugin']['pluginOptions'] = data['options']
data["plugin"]["pluginOptions"] = data["options"]
return data

Some files were not shown because too many files have changed in this diff Show More