Black lint all the things

This commit is contained in:
Curtis Castrapel
2019-05-16 07:57:02 -07:00
parent 3680d523d4
commit 68fd1556b2
226 changed files with 9340 additions and 5940 deletions

View File

@ -32,8 +32,11 @@ else:
def make_celery(app):
celery = Celery(app.import_name, backend=app.config.get('CELERY_RESULT_BACKEND'),
broker=app.config.get('CELERY_BROKER_URL'))
celery = Celery(
app.import_name,
backend=app.config.get("CELERY_RESULT_BACKEND"),
broker=app.config.get("CELERY_BROKER_URL"),
)
celery.conf.update(app.config)
TaskBase = celery.Task
@ -53,6 +56,7 @@ celery = make_celery(flask_app)
def is_task_active(fun, task_id, args):
from celery.task.control import inspect
i = inspect()
active_tasks = i.active()
for _, tasks in active_tasks.items():
@ -99,7 +103,7 @@ def fetch_acme_cert(id):
# We only care about certs using the acme-issuer plugin
for cert in pending_certs:
cert_authority = get_authority(cert.authority_id)
if cert_authority.plugin_name == 'acme-issuer':
if cert_authority.plugin_name == "acme-issuer":
acme_certs.append(cert)
else:
wrong_issuer += 1
@ -112,20 +116,22 @@ def fetch_acme_cert(id):
# It's necessary to reload the pending cert due to detached instance: http://sqlalche.me/e/bhk3
pending_cert = pending_certificate_service.get(cert.get("pending_cert").id)
if not pending_cert:
log_data["message"] = "Pending certificate doesn't exist anymore. Was it resolved by another process?"
log_data[
"message"
] = "Pending certificate doesn't exist anymore. Was it resolved by another process?"
current_app.logger.error(log_data)
continue
if real_cert:
# If a real certificate was returned from issuer, then create it in Lemur and mark
# the pending certificate as resolved
final_cert = pending_certificate_service.create_certificate(pending_cert, real_cert, pending_cert.user)
pending_certificate_service.update(
cert.get("pending_cert").id,
resolved_cert_id=final_cert.id
final_cert = pending_certificate_service.create_certificate(
pending_cert, real_cert, pending_cert.user
)
pending_certificate_service.update(
cert.get("pending_cert").id,
resolved=True
cert.get("pending_cert").id, resolved_cert_id=final_cert.id
)
pending_certificate_service.update(
cert.get("pending_cert").id, resolved=True
)
# add metrics to metrics extension
new += 1
@ -139,17 +145,17 @@ def fetch_acme_cert(id):
if pending_cert.number_attempts > 4:
error_log["message"] = "Deleting pending certificate"
send_pending_failure_notification(pending_cert, notify_owner=pending_cert.notify)
send_pending_failure_notification(
pending_cert, notify_owner=pending_cert.notify
)
# Mark the pending cert as resolved
pending_certificate_service.update(
cert.get("pending_cert").id,
resolved=True
cert.get("pending_cert").id, resolved=True
)
else:
pending_certificate_service.increment_attempt(pending_cert)
pending_certificate_service.update(
cert.get("pending_cert").id,
status=str(cert.get("last_error"))
cert.get("pending_cert").id, status=str(cert.get("last_error"))
)
# Add failed pending cert task back to queue
fetch_acme_cert.delay(id)
@ -161,9 +167,7 @@ def fetch_acme_cert(id):
current_app.logger.debug(log_data)
print(
"[+] Certificates: New: {new} Failed: {failed} Not using ACME: {wrong_issuer}".format(
new=new,
failed=failed,
wrong_issuer=wrong_issuer
new=new, failed=failed, wrong_issuer=wrong_issuer
)
)
@ -175,7 +179,7 @@ def fetch_all_pending_acme_certs():
log_data = {
"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name),
"message": "Starting job."
"message": "Starting job.",
}
current_app.logger.debug(log_data)
@ -183,7 +187,7 @@ def fetch_all_pending_acme_certs():
# We only care about certs using the acme-issuer plugin
for cert in pending_certs:
cert_authority = get_authority(cert.authority_id)
if cert_authority.plugin_name == 'acme-issuer':
if cert_authority.plugin_name == "acme-issuer":
if datetime.now(timezone.utc) - cert.last_updated > timedelta(minutes=5):
log_data["message"] = "Triggering job for cert {}".format(cert.name)
log_data["cert_name"] = cert.name
@ -195,17 +199,15 @@ def fetch_all_pending_acme_certs():
@celery.task()
def remove_old_acme_certs():
"""Prune old pending acme certificates from the database"""
log_data = {
"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)
}
pending_certs = pending_certificate_service.get_pending_certs('all')
log_data = {"function": "{}.{}".format(__name__, sys._getframe().f_code.co_name)}
pending_certs = pending_certificate_service.get_pending_certs("all")
# Delete pending certs more than a week old
for cert in pending_certs:
if datetime.now(timezone.utc) - cert.last_updated > timedelta(days=7):
log_data['pending_cert_id'] = cert.id
log_data['pending_cert_name'] = cert.name
log_data['message'] = "Deleting pending certificate"
log_data["pending_cert_id"] = cert.id
log_data["pending_cert_name"] = cert.name
log_data["message"] = "Deleting pending certificate"
current_app.logger.debug(log_data)
pending_certificate_service.delete(cert)
@ -218,7 +220,9 @@ def clean_all_sources():
"""
sources = validate_sources("all")
for source in sources:
current_app.logger.debug("Creating celery task to clean source {}".format(source.label))
current_app.logger.debug(
"Creating celery task to clean source {}".format(source.label)
)
clean_source.delay(source.label)
@ -242,7 +246,9 @@ def sync_all_sources():
"""
sources = validate_sources("all")
for source in sources:
current_app.logger.debug("Creating celery task to sync source {}".format(source.label))
current_app.logger.debug(
"Creating celery task to sync source {}".format(source.label)
)
sync_source.delay(source.label)
@ -277,7 +283,9 @@ def sync_source(source):
log_data["message"] = "Error syncing source: Time limit exceeded."
current_app.logger.error(log_data)
sentry.captureException()
metrics.send('sync_source_timeout', 'counter', 1, metric_tags={'source': source})
metrics.send(
"sync_source_timeout", "counter", 1, metric_tags={"source": source}
)
return
log_data["message"] = "Done syncing source"

View File

@ -9,18 +9,20 @@ from lemur.extensions import sentry
from lemur.constants import SAN_NAMING_TEMPLATE, DEFAULT_NAMING_TEMPLATE
def text_to_slug(value, joiner='-'):
def text_to_slug(value, joiner="-"):
"""
Normalize a string to a "slug" value, stripping character accents and removing non-alphanum characters.
A series of non-alphanumeric characters is replaced with the joiner character.
"""
# Strip all character accents: decompose Unicode characters and then drop combining chars.
value = ''.join(c for c in unicodedata.normalize('NFKD', value) if not unicodedata.combining(c))
value = "".join(
c for c in unicodedata.normalize("NFKD", value) if not unicodedata.combining(c)
)
# Replace all remaining non-alphanumeric characters with joiner string. Multiple characters get collapsed into a
# single joiner. Except, keep 'xn--' used in IDNA domain names as is.
value = re.sub(r'[^A-Za-z0-9.]+(?<!xn--)', joiner, value)
value = re.sub(r"[^A-Za-z0-9.]+(?<!xn--)", joiner, value)
# '-' in the beginning or end of string looks ugly.
return value.strip(joiner)
@ -48,12 +50,12 @@ def certificate_name(common_name, issuer, not_before, not_after, san):
temp = t.format(
subject=common_name,
issuer=issuer.replace(' ', ''),
not_before=not_before.strftime('%Y%m%d'),
not_after=not_after.strftime('%Y%m%d')
issuer=issuer.replace(" ", ""),
not_before=not_before.strftime("%Y%m%d"),
not_after=not_after.strftime("%Y%m%d"),
)
temp = temp.replace('*', "WILDCARD")
temp = temp.replace("*", "WILDCARD")
return text_to_slug(temp)
@ -69,9 +71,9 @@ def common_name(cert):
:return: Common name or None
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_COMMON_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_COMMON_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get common name! {0}".format(e))
@ -84,9 +86,9 @@ def organization(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_ORGANIZATION_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get organization! {0}".format(e))
@ -99,9 +101,9 @@ def organizational_unit(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_ORGANIZATIONAL_UNIT_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_ORGANIZATIONAL_UNIT_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get organizational unit! {0}".format(e))
@ -114,9 +116,9 @@ def country(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_COUNTRY_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_COUNTRY_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get country! {0}".format(e))
@ -129,9 +131,9 @@ def state(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_STATE_OR_PROVINCE_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_STATE_OR_PROVINCE_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get state! {0}".format(e))
@ -144,9 +146,9 @@ def location(cert):
:return:
"""
try:
return cert.subject.get_attributes_for_oid(
x509.OID_LOCALITY_NAME
)[0].value.strip()
return cert.subject.get_attributes_for_oid(x509.OID_LOCALITY_NAME)[
0
].value.strip()
except Exception as e:
sentry.captureException()
current_app.logger.error("Unable to get location! {0}".format(e))
@ -224,7 +226,7 @@ def bitstrength(cert):
return cert.public_key().key_size
except AttributeError:
sentry.captureException()
current_app.logger.debug('Unable to get bitstrength.')
current_app.logger.debug("Unable to get bitstrength.")
def issuer(cert):
@ -239,16 +241,19 @@ def issuer(cert):
"""
# If certificate is self-signed, we return a special value -- there really is no distinct "issuer" for it
if is_selfsigned(cert):
return '<selfsigned>'
return "<selfsigned>"
# Try Common Name or fall back to Organization name
attrs = (cert.issuer.get_attributes_for_oid(x509.OID_COMMON_NAME) or
cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME))
attrs = cert.issuer.get_attributes_for_oid(
x509.OID_COMMON_NAME
) or cert.issuer.get_attributes_for_oid(x509.OID_ORGANIZATION_NAME)
if not attrs:
current_app.logger.error("Unable to get issuer! Cert serial {:x}".format(cert.serial_number))
return '<unknown>'
current_app.logger.error(
"Unable to get issuer! Cert serial {:x}".format(cert.serial_number)
)
return "<unknown>"
return text_to_slug(attrs[0].value, '')
return text_to_slug(attrs[0].value, "")
def not_before(cert):

View File

@ -25,6 +25,7 @@ class Hex(Field):
"""
A hex formatted string.
"""
def _serialize(self, value, attr, obj):
if value:
value = hex(int(value))[2:].upper()
@ -48,25 +49,25 @@ class ArrowDateTime(Field):
"""
DATEFORMAT_SERIALIZATION_FUNCS = {
'iso': utils.isoformat,
'iso8601': utils.isoformat,
'rfc': utils.rfcformat,
'rfc822': utils.rfcformat,
"iso": utils.isoformat,
"iso8601": utils.isoformat,
"rfc": utils.rfcformat,
"rfc822": utils.rfcformat,
}
DATEFORMAT_DESERIALIZATION_FUNCS = {
'iso': utils.from_iso,
'iso8601': utils.from_iso,
'rfc': utils.from_rfc,
'rfc822': utils.from_rfc,
"iso": utils.from_iso,
"iso8601": utils.from_iso,
"rfc": utils.from_rfc,
"rfc822": utils.from_rfc,
}
DEFAULT_FORMAT = 'iso'
DEFAULT_FORMAT = "iso"
localtime = False
default_error_messages = {
'invalid': 'Not a valid datetime.',
'format': '"{input}" cannot be formatted as a datetime.',
"invalid": "Not a valid datetime.",
"format": '"{input}" cannot be formatted as a datetime.',
}
def __init__(self, format=None, **kwargs):
@ -89,34 +90,36 @@ class ArrowDateTime(Field):
try:
return format_func(value, localtime=self.localtime)
except (AttributeError, ValueError) as err:
self.fail('format', input=value)
self.fail("format", input=value)
else:
return value.strftime(self.dateformat)
def _deserialize(self, value, attr, data):
if not value: # Falsy values, e.g. '', None, [] are not valid
raise self.fail('invalid')
raise self.fail("invalid")
self.dateformat = self.dateformat or self.DEFAULT_FORMAT
func = self.DATEFORMAT_DESERIALIZATION_FUNCS.get(self.dateformat)
if func:
try:
return arrow.get(func(value))
except (TypeError, AttributeError, ValueError):
raise self.fail('invalid')
raise self.fail("invalid")
elif self.dateformat:
try:
return dt.datetime.strptime(value, self.dateformat)
except (TypeError, AttributeError, ValueError):
raise self.fail('invalid')
raise self.fail("invalid")
elif utils.dateutil_available:
try:
return arrow.get(utils.from_datestring(value))
except TypeError:
raise self.fail('invalid')
raise self.fail("invalid")
else:
warnings.warn('It is recommended that you install python-dateutil '
'for improved datetime deserialization.')
raise self.fail('invalid')
warnings.warn(
"It is recommended that you install python-dateutil "
"for improved datetime deserialization."
)
raise self.fail("invalid")
class KeyUsageExtension(Field):
@ -131,73 +134,75 @@ class KeyUsageExtension(Field):
def _serialize(self, value, attr, obj):
return {
'useDigitalSignature': value.digital_signature,
'useNonRepudiation': value.content_commitment,
'useKeyEncipherment': value.key_encipherment,
'useDataEncipherment': value.data_encipherment,
'useKeyAgreement': value.key_agreement,
'useKeyCertSign': value.key_cert_sign,
'useCRLSign': value.crl_sign,
'useEncipherOnly': value._encipher_only,
'useDecipherOnly': value._decipher_only
"useDigitalSignature": value.digital_signature,
"useNonRepudiation": value.content_commitment,
"useKeyEncipherment": value.key_encipherment,
"useDataEncipherment": value.data_encipherment,
"useKeyAgreement": value.key_agreement,
"useKeyCertSign": value.key_cert_sign,
"useCRLSign": value.crl_sign,
"useEncipherOnly": value._encipher_only,
"useDecipherOnly": value._decipher_only,
}
def _deserialize(self, value, attr, data):
keyusages = {
'digital_signature': False,
'content_commitment': False,
'key_encipherment': False,
'data_encipherment': False,
'key_agreement': False,
'key_cert_sign': False,
'crl_sign': False,
'encipher_only': False,
'decipher_only': False
"digital_signature": False,
"content_commitment": False,
"key_encipherment": False,
"data_encipherment": False,
"key_agreement": False,
"key_cert_sign": False,
"crl_sign": False,
"encipher_only": False,
"decipher_only": False,
}
for k, v in value.items():
if k == 'useDigitalSignature':
keyusages['digital_signature'] = v
if k == "useDigitalSignature":
keyusages["digital_signature"] = v
elif k == 'useNonRepudiation':
keyusages['content_commitment'] = v
elif k == "useNonRepudiation":
keyusages["content_commitment"] = v
elif k == 'useKeyEncipherment':
keyusages['key_encipherment'] = v
elif k == "useKeyEncipherment":
keyusages["key_encipherment"] = v
elif k == 'useDataEncipherment':
keyusages['data_encipherment'] = v
elif k == "useDataEncipherment":
keyusages["data_encipherment"] = v
elif k == 'useKeyCertSign':
keyusages['key_cert_sign'] = v
elif k == "useKeyCertSign":
keyusages["key_cert_sign"] = v
elif k == 'useCRLSign':
keyusages['crl_sign'] = v
elif k == "useCRLSign":
keyusages["crl_sign"] = v
elif k == 'useKeyAgreement':
keyusages['key_agreement'] = v
elif k == "useKeyAgreement":
keyusages["key_agreement"] = v
elif k == 'useEncipherOnly' and v:
keyusages['encipher_only'] = True
keyusages['key_agreement'] = True
elif k == "useEncipherOnly" and v:
keyusages["encipher_only"] = True
keyusages["key_agreement"] = True
elif k == 'useDecipherOnly' and v:
keyusages['decipher_only'] = True
keyusages['key_agreement'] = True
elif k == "useDecipherOnly" and v:
keyusages["decipher_only"] = True
keyusages["key_agreement"] = True
if keyusages['encipher_only'] and keyusages['decipher_only']:
raise ValidationError('A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages.')
if keyusages["encipher_only"] and keyusages["decipher_only"]:
raise ValidationError(
"A certificate cannot have both Encipher Only and Decipher Only Extended Key Usages."
)
return x509.KeyUsage(
digital_signature=keyusages['digital_signature'],
content_commitment=keyusages['content_commitment'],
key_encipherment=keyusages['key_encipherment'],
data_encipherment=keyusages['data_encipherment'],
key_agreement=keyusages['key_agreement'],
key_cert_sign=keyusages['key_cert_sign'],
crl_sign=keyusages['crl_sign'],
encipher_only=keyusages['encipher_only'],
decipher_only=keyusages['decipher_only']
digital_signature=keyusages["digital_signature"],
content_commitment=keyusages["content_commitment"],
key_encipherment=keyusages["key_encipherment"],
data_encipherment=keyusages["data_encipherment"],
key_agreement=keyusages["key_agreement"],
key_cert_sign=keyusages["key_cert_sign"],
crl_sign=keyusages["crl_sign"],
encipher_only=keyusages["encipher_only"],
decipher_only=keyusages["decipher_only"],
)
@ -216,69 +221,77 @@ class ExtendedKeyUsageExtension(Field):
usage_list = {}
for usage in usages:
if usage == x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH:
usage_list['useClientAuthentication'] = True
usage_list["useClientAuthentication"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.SERVER_AUTH:
usage_list['useServerAuthentication'] = True
usage_list["useServerAuthentication"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.CODE_SIGNING:
usage_list['useCodeSigning'] = True
usage_list["useCodeSigning"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION:
usage_list['useEmailProtection'] = True
usage_list["useEmailProtection"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.TIME_STAMPING:
usage_list['useTimestamping'] = True
usage_list["useTimestamping"] = True
elif usage == x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING:
usage_list['useOCSPSigning'] = True
usage_list["useOCSPSigning"] = True
elif usage.dotted_string == '1.3.6.1.5.5.7.3.14':
usage_list['useEapOverLAN'] = True
elif usage.dotted_string == "1.3.6.1.5.5.7.3.14":
usage_list["useEapOverLAN"] = True
elif usage.dotted_string == '1.3.6.1.5.5.7.3.13':
usage_list['useEapOverPPP'] = True
elif usage.dotted_string == "1.3.6.1.5.5.7.3.13":
usage_list["useEapOverPPP"] = True
elif usage.dotted_string == '1.3.6.1.4.1.311.20.2.2':
usage_list['useSmartCardLogon'] = True
elif usage.dotted_string == "1.3.6.1.4.1.311.20.2.2":
usage_list["useSmartCardLogon"] = True
else:
current_app.logger.warning('Unable to serialize ExtendedKeyUsage with OID: {usage}'.format(usage=usage.dotted_string))
current_app.logger.warning(
"Unable to serialize ExtendedKeyUsage with OID: {usage}".format(
usage=usage.dotted_string
)
)
return usage_list
def _deserialize(self, value, attr, data):
usage_oids = []
for k, v in value.items():
if k == 'useClientAuthentication' and v:
if k == "useClientAuthentication" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH)
elif k == 'useServerAuthentication' and v:
elif k == "useServerAuthentication" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.SERVER_AUTH)
elif k == 'useCodeSigning' and v:
elif k == "useCodeSigning" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.CODE_SIGNING)
elif k == 'useEmailProtection' and v:
elif k == "useEmailProtection" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION)
elif k == 'useTimestamping' and v:
elif k == "useTimestamping" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.TIME_STAMPING)
elif k == 'useOCSPSigning' and v:
elif k == "useOCSPSigning" and v:
usage_oids.append(x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING)
elif k == 'useEapOverLAN' and v:
elif k == "useEapOverLAN" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.14"))
elif k == 'useEapOverPPP' and v:
elif k == "useEapOverPPP" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.5.5.7.3.13"))
elif k == 'useSmartCardLogon' and v:
elif k == "useSmartCardLogon" and v:
usage_oids.append(x509.oid.ObjectIdentifier("1.3.6.1.4.1.311.20.2.2"))
else:
current_app.logger.warning('Unable to deserialize ExtendedKeyUsage with name: {key}'.format(key=k))
current_app.logger.warning(
"Unable to deserialize ExtendedKeyUsage with name: {key}".format(
key=k
)
)
return x509.ExtendedKeyUsage(usage_oids)
@ -294,15 +307,17 @@ class BasicConstraintsExtension(Field):
"""
def _serialize(self, value, attr, obj):
return {'ca': value.ca, 'path_length': value.path_length}
return {"ca": value.ca, "path_length": value.path_length}
def _deserialize(self, value, attr, data):
ca = value.get('ca', False)
path_length = value.get('path_length', None)
ca = value.get("ca", False)
path_length = value.get("path_length", None)
if ca:
if not isinstance(path_length, (type(None), int)):
raise ValidationError('A CA certificate path_length (for BasicConstraints) must be None or an integer.')
raise ValidationError(
"A CA certificate path_length (for BasicConstraints) must be None or an integer."
)
return x509.BasicConstraints(ca=True, path_length=path_length)
else:
return x509.BasicConstraints(ca=False, path_length=None)
@ -317,6 +332,7 @@ class SubjectAlternativeNameExtension(Field):
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""
def _serialize(self, value, attr, obj):
general_names = []
name_type = None
@ -326,53 +342,59 @@ class SubjectAlternativeNameExtension(Field):
value = name.value
if isinstance(name, x509.DNSName):
name_type = 'DNSName'
name_type = "DNSName"
elif isinstance(name, x509.IPAddress):
if isinstance(value, ipaddress.IPv4Network):
name_type = 'IPNetwork'
name_type = "IPNetwork"
else:
name_type = 'IPAddress'
name_type = "IPAddress"
value = str(value)
elif isinstance(name, x509.UniformResourceIdentifier):
name_type = 'uniformResourceIdentifier'
name_type = "uniformResourceIdentifier"
elif isinstance(name, x509.DirectoryName):
name_type = 'directoryName'
name_type = "directoryName"
elif isinstance(name, x509.RFC822Name):
name_type = 'rfc822Name'
name_type = "rfc822Name"
elif isinstance(name, x509.RegisteredID):
name_type = 'registeredID'
name_type = "registeredID"
value = value.dotted_string
else:
current_app.logger.warning('Unknown SubAltName type: {name}'.format(name=name))
current_app.logger.warning(
"Unknown SubAltName type: {name}".format(name=name)
)
continue
general_names.append({'nameType': name_type, 'value': value})
general_names.append({"nameType": name_type, "value": value})
return general_names
def _deserialize(self, value, attr, data):
general_names = []
for name in value:
if name['nameType'] == 'DNSName':
validators.sensitive_domain(name['value'])
general_names.append(x509.DNSName(name['value']))
if name["nameType"] == "DNSName":
validators.sensitive_domain(name["value"])
general_names.append(x509.DNSName(name["value"]))
elif name['nameType'] == 'IPAddress':
general_names.append(x509.IPAddress(ipaddress.ip_address(name['value'])))
elif name["nameType"] == "IPAddress":
general_names.append(
x509.IPAddress(ipaddress.ip_address(name["value"]))
)
elif name['nameType'] == 'IPNetwork':
general_names.append(x509.IPAddress(ipaddress.ip_network(name['value'])))
elif name["nameType"] == "IPNetwork":
general_names.append(
x509.IPAddress(ipaddress.ip_network(name["value"]))
)
elif name['nameType'] == 'uniformResourceIdentifier':
general_names.append(x509.UniformResourceIdentifier(name['value']))
elif name["nameType"] == "uniformResourceIdentifier":
general_names.append(x509.UniformResourceIdentifier(name["value"]))
elif name['nameType'] == 'directoryName':
elif name["nameType"] == "directoryName":
# TODO: Need to parse a string in name['value'] like:
# 'CN=Common Name, O=Org Name, OU=OrgUnit Name, C=US, ST=ST, L=City/emailAddress=person@example.com'
# or
@ -390,26 +412,32 @@ class SubjectAlternativeNameExtension(Field):
# general_names.append(x509.DirectoryName(x509.Name(BLAH))))
pass
elif name['nameType'] == 'rfc822Name':
general_names.append(x509.RFC822Name(name['value']))
elif name["nameType"] == "rfc822Name":
general_names.append(x509.RFC822Name(name["value"]))
elif name['nameType'] == 'registeredID':
general_names.append(x509.RegisteredID(x509.ObjectIdentifier(name['value'])))
elif name["nameType"] == "registeredID":
general_names.append(
x509.RegisteredID(x509.ObjectIdentifier(name["value"]))
)
elif name['nameType'] == 'otherName':
elif name["nameType"] == "otherName":
# This has two inputs (type and value), so it doesn't fit the mold of the rest of these GeneralName entities.
# general_names.append(x509.OtherName(name['type'], bytes(name['value']), 'utf-8'))
pass
elif name['nameType'] == 'x400Address':
elif name["nameType"] == "x400Address":
# The Python Cryptography library doesn't support x400Address types (yet?)
pass
elif name['nameType'] == 'EDIPartyName':
elif name["nameType"] == "EDIPartyName":
# The Python Cryptography library doesn't support EDIPartyName types (yet?)
pass
else:
current_app.logger.warning('Unable to deserialize SubAltName with type: {name_type}'.format(name_type=name['nameType']))
current_app.logger.warning(
"Unable to deserialize SubAltName with type: {name_type}".format(
name_type=name["nameType"]
)
)
return x509.SubjectAlternativeName(general_names)

View File

@ -10,20 +10,20 @@ from flask import Blueprint
from lemur.database import db
from lemur.extensions import sentry
mod = Blueprint('healthCheck', __name__)
mod = Blueprint("healthCheck", __name__)
@mod.route('/healthcheck')
@mod.route("/healthcheck")
def health():
try:
if healthcheck(db):
return 'ok'
return "ok"
except Exception:
sentry.captureException()
return 'db check failed'
return "db check failed"
def healthcheck(db):
with db.engine.connect() as connection:
connection.execute('SELECT 1;')
connection.execute("SELECT 1;")
return True

View File

@ -52,7 +52,7 @@ class InstanceManager(object):
results = []
for cls_path in class_list:
module_name, class_name = cls_path.rsplit('.', 1)
module_name, class_name = cls_path.rsplit(".", 1)
try:
module = __import__(module_name, {}, {}, class_name)
cls = getattr(module, class_name)
@ -62,10 +62,14 @@ class InstanceManager(object):
results.append(cls)
except InvalidConfiguration as e:
current_app.logger.warning("Plugin '{0}' may not work correctly. {1}".format(class_name, e))
current_app.logger.warning(
"Plugin '{0}' may not work correctly. {1}".format(class_name, e)
)
except Exception as e:
current_app.logger.exception("Unable to import {0}. Reason: {1}".format(cls_path, e))
current_app.logger.exception(
"Unable to import {0}. Reason: {1}".format(cls_path, e)
)
continue
self.cache = results

View File

@ -11,15 +11,15 @@ def convert_validity_years(data):
:param data:
:return:
"""
if data.get('validity_years'):
if data.get("validity_years"):
now = arrow.utcnow()
data['validity_start'] = now.isoformat()
data["validity_start"] = now.isoformat()
end = now.replace(years=+int(data['validity_years']))
end = now.replace(years=+int(data["validity_years"]))
if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True):
if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True):
if is_weekend(end):
end = end.replace(days=-2)
data['validity_end'] = end.isoformat()
data["validity_end"] = end.isoformat()
return data

View File

@ -22,27 +22,26 @@ class LemurSchema(Schema):
"""
Base schema from which all grouper schema's inherit
"""
__envelope__ = True
def under(self, data, many=None):
items = []
if many:
for i in data:
items.append(
{underscore(key): value for key, value in i.items()}
)
items.append({underscore(key): value for key, value in i.items()})
return items
return {
underscore(key): value
for key, value in data.items()
}
return {underscore(key): value for key, value in data.items()}
def camel(self, data, many=None):
items = []
if many:
for i in data:
items.append(
{camelize(key, uppercase_first_letter=False): value for key, value in i.items()}
{
camelize(key, uppercase_first_letter=False): value
for key, value in i.items()
}
)
return items
return {
@ -52,16 +51,16 @@ class LemurSchema(Schema):
def wrap_with_envelope(self, data, many):
if many:
if 'total' in self.context.keys():
return dict(total=self.context['total'], items=data)
if "total" in self.context.keys():
return dict(total=self.context["total"], items=data)
return data
class LemurInputSchema(LemurSchema):
@pre_load(pass_many=True)
def preprocess(self, data, many):
if isinstance(data, dict) and data.get('owner'):
data['owner'] = data['owner'].lower()
if isinstance(data, dict) and data.get("owner"):
data["owner"] = data["owner"].lower()
return self.under(data, many=many)
@ -74,17 +73,17 @@ class LemurOutputSchema(LemurSchema):
def unwrap_envelope(self, data, many):
if many:
if data['items']:
if data["items"]:
if isinstance(data, InstrumentedList) or isinstance(data, list):
self.context['total'] = len(data)
self.context["total"] = len(data)
return data
else:
self.context['total'] = data['total']
self.context["total"] = data["total"]
else:
self.context['total'] = 0
data = {'items': []}
self.context["total"] = 0
data = {"items": []}
return data['items']
return data["items"]
return data
@ -110,11 +109,11 @@ def format_errors(messages):
def wrap_errors(messages):
errors = dict(message='Validation Error.')
if messages.get('_schema'):
errors['reasons'] = {'Schema': {'rule': messages['_schema']}}
errors = dict(message="Validation Error.")
if messages.get("_schema"):
errors["reasons"] = {"Schema": {"rule": messages["_schema"]}}
else:
errors['reasons'] = format_errors(messages)
errors["reasons"] = format_errors(messages)
return errors
@ -123,19 +122,19 @@ def unwrap_pagination(data, output_schema):
return data
if isinstance(data, dict):
if 'total' in data.keys():
if data.get('total') == 0:
if "total" in data.keys():
if data.get("total") == 0:
return data
marshaled_data = {'total': data['total']}
marshaled_data['items'] = output_schema.dump(data['items'], many=True).data
marshaled_data = {"total": data["total"]}
marshaled_data["items"] = output_schema.dump(data["items"], many=True).data
return marshaled_data
return output_schema.dump(data).data
elif isinstance(data, list):
marshaled_data = {'total': len(data)}
marshaled_data['items'] = output_schema.dump(data, many=True).data
marshaled_data = {"total": len(data)}
marshaled_data["items"] = output_schema.dump(data, many=True).data
return marshaled_data
return output_schema.dump(data).data
@ -155,7 +154,7 @@ def validate_schema(input_schema, output_schema):
if errors:
return wrap_errors(errors), 400
kwargs['data'] = data
kwargs["data"] = data
try:
resp = f(*args, **kwargs)
@ -173,4 +172,5 @@ def validate_schema(input_schema, output_schema):
return unwrap_pagination(resp, output_schema), 200
return decorated_function
return decorator

View File

@ -25,22 +25,22 @@ from lemur.exceptions import InvalidConfiguration
paginated_parser = RequestParser()
paginated_parser.add_argument('count', type=int, default=10, location='args')
paginated_parser.add_argument('page', type=int, default=1, location='args')
paginated_parser.add_argument('sortDir', type=str, dest='sort_dir', location='args')
paginated_parser.add_argument('sortBy', type=str, dest='sort_by', location='args')
paginated_parser.add_argument('filter', type=str, location='args')
paginated_parser.add_argument('owner', type=str, location='args')
paginated_parser.add_argument("count", type=int, default=10, location="args")
paginated_parser.add_argument("page", type=int, default=1, location="args")
paginated_parser.add_argument("sortDir", type=str, dest="sort_dir", location="args")
paginated_parser.add_argument("sortBy", type=str, dest="sort_by", location="args")
paginated_parser.add_argument("filter", type=str, location="args")
paginated_parser.add_argument("owner", type=str, location="args")
def get_psuedo_random_string():
"""
Create a random and strongish challenge.
"""
challenge = ''.join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa
challenge += ''.join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa
challenge += ''.join(random.choice(string.ascii_lowercase) for x in range(6))
challenge += ''.join(random.choice(string.digits) for x in range(6)) # noqa
challenge = "".join(random.choice(string.ascii_uppercase) for x in range(6)) # noqa
challenge += "".join(random.choice("~!@#$%^&*()_+") for x in range(6)) # noqa
challenge += "".join(random.choice(string.ascii_lowercase) for x in range(6))
challenge += "".join(random.choice(string.digits) for x in range(6)) # noqa
return challenge
@ -53,7 +53,7 @@ def parse_certificate(body):
"""
assert isinstance(body, str)
return x509.load_pem_x509_certificate(body.encode('utf-8'), default_backend())
return x509.load_pem_x509_certificate(body.encode("utf-8"), default_backend())
def parse_private_key(private_key):
@ -66,7 +66,9 @@ def parse_private_key(private_key):
"""
assert isinstance(private_key, str)
return load_pem_private_key(private_key.encode('utf8'), password=None, backend=default_backend())
return load_pem_private_key(
private_key.encode("utf8"), password=None, backend=default_backend()
)
def split_pem(data):
@ -100,14 +102,15 @@ def parse_csr(csr):
"""
assert isinstance(csr, str)
return x509.load_pem_x509_csr(csr.encode('utf-8'), default_backend())
return x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend())
def get_authority_key(body):
"""Returns the authority key for a given certificate in hex format"""
parsed_cert = parse_certificate(body)
authority_key = parsed_cert.extensions.get_extension_for_class(
x509.AuthorityKeyIdentifier).value.key_identifier
x509.AuthorityKeyIdentifier
).value.key_identifier
return authority_key.hex()
@ -127,20 +130,17 @@ def generate_private_key(key_type):
_CURVE_TYPES = {
"ECCPRIME192V1": ec.SECP192R1(),
"ECCPRIME256V1": ec.SECP256R1(),
"ECCSECP192R1": ec.SECP192R1(),
"ECCSECP224R1": ec.SECP224R1(),
"ECCSECP256R1": ec.SECP256R1(),
"ECCSECP384R1": ec.SECP384R1(),
"ECCSECP521R1": ec.SECP521R1(),
"ECCSECP256K1": ec.SECP256K1(),
"ECCSECT163K1": ec.SECT163K1(),
"ECCSECT233K1": ec.SECT233K1(),
"ECCSECT283K1": ec.SECT283K1(),
"ECCSECT409K1": ec.SECT409K1(),
"ECCSECT571K1": ec.SECT571K1(),
"ECCSECT163R2": ec.SECT163R2(),
"ECCSECT233R1": ec.SECT233R1(),
"ECCSECT283R1": ec.SECT283R1(),
@ -149,22 +149,20 @@ def generate_private_key(key_type):
}
if key_type not in CERTIFICATE_KEY_TYPES:
raise Exception("Invalid key type: {key_type}. Supported key types: {choices}".format(
key_type=key_type,
choices=",".join(CERTIFICATE_KEY_TYPES)
))
raise Exception(
"Invalid key type: {key_type}. Supported key types: {choices}".format(
key_type=key_type, choices=",".join(CERTIFICATE_KEY_TYPES)
)
)
if 'RSA' in key_type:
if "RSA" in key_type:
key_size = int(key_type[3:])
return rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend()
public_exponent=65537, key_size=key_size, backend=default_backend()
)
elif 'ECC' in key_type:
elif "ECC" in key_type:
return ec.generate_private_key(
curve=_CURVE_TYPES[key_type],
backend=default_backend()
curve=_CURVE_TYPES[key_type], backend=default_backend()
)
@ -184,11 +182,26 @@ def check_cert_signature(cert, issuer_public_key):
raise UnsupportedAlgorithm("RSASSA-PSS not supported")
else:
padder = padding.PKCS1v15()
issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, padder, cert.signature_hash_algorithm)
elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance(ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA):
issuer_public_key.verify(cert.signature, cert.tbs_certificate_bytes, ec.ECDSA(cert.signature_hash_algorithm))
issuer_public_key.verify(
cert.signature,
cert.tbs_certificate_bytes,
padder,
cert.signature_hash_algorithm,
)
elif isinstance(issuer_public_key, ec.EllipticCurvePublicKey) and isinstance(
ec.ECDSA(cert.signature_hash_algorithm), ec.ECDSA
):
issuer_public_key.verify(
cert.signature,
cert.tbs_certificate_bytes,
ec.ECDSA(cert.signature_hash_algorithm),
)
else:
raise UnsupportedAlgorithm("Unsupported Algorithm '{var}'.".format(var=cert.signature_algorithm_oid._name))
raise UnsupportedAlgorithm(
"Unsupported Algorithm '{var}'.".format(
var=cert.signature_algorithm_oid._name
)
)
def is_selfsigned(cert):
@ -224,7 +237,9 @@ def validate_conf(app, required_vars):
"""
for var in required_vars:
if var not in app.config:
raise InvalidConfiguration("Required variable '{var}' is not set in Lemur's conf.".format(var=var))
raise InvalidConfiguration(
"Required variable '{var}' is not set in Lemur's conf.".format(var=var)
)
# https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/WindowedRangeQuery
@ -243,18 +258,15 @@ def column_windows(session, column, windowsize):
be computed.
"""
def int_for_range(start_id, end_id):
if end_id:
return and_(
column >= start_id,
column < end_id
)
return and_(column >= start_id, column < end_id)
else:
return column >= start_id
q = session.query(
column,
func.row_number().over(order_by=column).label('rownum')
column, func.row_number().over(order_by=column).label("rownum")
).from_self(column)
if windowsize > 1:
@ -274,9 +286,7 @@ def column_windows(session, column, windowsize):
def windowed_query(q, column, windowsize):
""""Break a Query into windows on a given column."""
for whereclause in column_windows(
q.session,
column, windowsize):
for whereclause in column_windows(q.session, column, windowsize):
for row in q.filter(whereclause).order_by(column):
yield row
@ -284,7 +294,7 @@ def windowed_query(q, column, windowsize):
def truthiness(s):
"""If input string resembles something truthy then return True, else False."""
return s.lower() in ('true', 'yes', 'on', 't', '1')
return s.lower() in ("true", "yes", "on", "t", "1")
def find_matching_certificates_by_hash(cert, matching_certs):
@ -292,6 +302,8 @@ def find_matching_certificates_by_hash(cert, matching_certs):
determine if any of the certificate hashes match and return the matches."""
matching = []
for c in matching_certs:
if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint(hashes.SHA256()):
if parse_certificate(c.body).fingerprint(hashes.SHA256()) == cert.fingerprint(
hashes.SHA256()
):
matching.append(c)
return matching

View File

@ -16,7 +16,7 @@ def common_name(value):
# Common name could be a domain name, or a human-readable name of the subject (often used in CA names or client
# certificates). As a simple heuristic, we assume that human-readable names always include a space.
# However, to avoid confusion for humans, we also don't count spaces at the beginning or end of the string.
if ' ' not in value.strip():
if " " not in value.strip():
return sensitive_domain(value)
@ -30,17 +30,21 @@ def sensitive_domain(domain):
# User has permission, no need to check anything
return
whitelist = current_app.config.get('LEMUR_WHITELISTED_DOMAINS', [])
whitelist = current_app.config.get("LEMUR_WHITELISTED_DOMAINS", [])
if whitelist and not any(re.match(pattern, domain) for pattern in whitelist):
raise ValidationError('Domain {0} does not match whitelisted domain patterns. '
'Contact an administrator to issue the certificate.'.format(domain))
raise ValidationError(
"Domain {0} does not match whitelisted domain patterns. "
"Contact an administrator to issue the certificate.".format(domain)
)
# Avoid circular import.
from lemur.domains import service as domain_service
if any(d.sensitive for d in domain_service.get_by_name(domain)):
raise ValidationError('Domain {0} has been marked as sensitive. '
'Contact an administrator to issue the certificate.'.format(domain))
raise ValidationError(
"Domain {0} has been marked as sensitive. "
"Contact an administrator to issue the certificate.".format(domain)
)
def encoding(oid_encoding):
@ -49,9 +53,13 @@ def encoding(oid_encoding):
:param oid_encoding:
:return:
"""
valid_types = ['b64asn1', 'string', 'ia5string']
valid_types = ["b64asn1", "string", "ia5string"]
if oid_encoding.lower() not in [o_type.lower() for o_type in valid_types]:
raise ValidationError('Invalid Oid Encoding: {0} choose from {1}'.format(oid_encoding, ",".join(valid_types)))
raise ValidationError(
"Invalid Oid Encoding: {0} choose from {1}".format(
oid_encoding, ",".join(valid_types)
)
)
def sub_alt_type(alt_type):
@ -60,10 +68,23 @@ def sub_alt_type(alt_type):
:param alt_type:
:return:
"""
valid_types = ['DNSName', 'IPAddress', 'uniFormResourceIdentifier', 'directoryName', 'rfc822Name', 'registrationID',
'otherName', 'x400Address', 'EDIPartyName']
valid_types = [
"DNSName",
"IPAddress",
"uniFormResourceIdentifier",
"directoryName",
"rfc822Name",
"registrationID",
"otherName",
"x400Address",
"EDIPartyName",
]
if alt_type.lower() not in [a_type.lower() for a_type in valid_types]:
raise ValidationError('Invalid SubAltName Type: {0} choose from {1}'.format(type, ",".join(valid_types)))
raise ValidationError(
"Invalid SubAltName Type: {0} choose from {1}".format(
type, ",".join(valid_types)
)
)
def csr(data):
@ -73,16 +94,18 @@ def csr(data):
:return:
"""
try:
request = x509.load_pem_x509_csr(data.encode('utf-8'), default_backend())
request = x509.load_pem_x509_csr(data.encode("utf-8"), default_backend())
except Exception:
raise ValidationError('CSR presented is not valid.')
raise ValidationError("CSR presented is not valid.")
# Validate common name and SubjectAltNames
for name in request.subject.get_attributes_for_oid(NameOID.COMMON_NAME):
common_name(name.value)
try:
alt_names = request.extensions.get_extension_for_class(x509.SubjectAlternativeName)
alt_names = request.extensions.get_extension_for_class(
x509.SubjectAlternativeName
)
for name in alt_names.value.get_values_for_type(x509.DNSName):
sensitive_domain(name)
@ -91,26 +114,40 @@ def csr(data):
def dates(data):
if not data.get('validity_start') and data.get('validity_end'):
raise ValidationError('If validity start is specified so must validity end.')
if not data.get("validity_start") and data.get("validity_end"):
raise ValidationError("If validity start is specified so must validity end.")
if not data.get('validity_end') and data.get('validity_start'):
raise ValidationError('If validity end is specified so must validity start.')
if not data.get("validity_end") and data.get("validity_start"):
raise ValidationError("If validity end is specified so must validity start.")
if data.get('validity_start') and data.get('validity_end'):
if not current_app.config.get('LEMUR_ALLOW_WEEKEND_EXPIRATION', True):
if is_weekend(data.get('validity_end')):
raise ValidationError('Validity end must not land on a weekend.')
if data.get("validity_start") and data.get("validity_end"):
if not current_app.config.get("LEMUR_ALLOW_WEEKEND_EXPIRATION", True):
if is_weekend(data.get("validity_end")):
raise ValidationError("Validity end must not land on a weekend.")
if not data['validity_start'] < data['validity_end']:
raise ValidationError('Validity start must be before validity end.')
if not data["validity_start"] < data["validity_end"]:
raise ValidationError("Validity start must be before validity end.")
if data.get('authority'):
if data.get('validity_start').date() < data['authority'].authority_certificate.not_before.date():
raise ValidationError('Validity start must not be before {0}'.format(data['authority'].authority_certificate.not_before))
if data.get("authority"):
if (
data.get("validity_start").date()
< data["authority"].authority_certificate.not_before.date()
):
raise ValidationError(
"Validity start must not be before {0}".format(
data["authority"].authority_certificate.not_before
)
)
if data.get('validity_end').date() > data['authority'].authority_certificate.not_after.date():
raise ValidationError('Validity end must not be after {0}'.format(data['authority'].authority_certificate.not_after))
if (
data.get("validity_end").date()
> data["authority"].authority_certificate.not_after.date()
):
raise ValidationError(
"Validity end must not be after {0}".format(
data["authority"].authority_certificate.not_after
)
)
return data
@ -148,8 +185,13 @@ def verify_cert_chain(certs, error_class=ValidationError):
# Avoid circular import.
from lemur.common import defaults
raise error_class("Incorrect chain certificate(s) provided: '%s' is not signed by '%s'"
% (defaults.common_name(cert) or 'Unknown', defaults.common_name(issuer)))
raise error_class(
"Incorrect chain certificate(s) provided: '%s' is not signed by '%s'"
% (
defaults.common_name(cert) or "Unknown",
defaults.common_name(issuer),
)
)
except UnsupportedAlgorithm as err:
current_app.logger.warning("Skipping chain validation: %s", err)