Merge branch 'master' into master

This commit is contained in:
e11it 2020-03-20 20:44:46 +03:00 committed by GitHub
commit 6049e8f117
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 242 additions and 64 deletions

View File

@ -102,12 +102,13 @@ def get_all_certs():
return Certificate.query.all() return Certificate.query.all()
def get_all_pending_cleaning(source): def get_all_pending_cleaning_expired(source):
""" """
Retrieves all certificates that are available for cleaning. Retrieves all certificates that are available for cleaning. These are certificates which are expired and are not
attached to any endpoints.
:param source: :param source: the source to search for certificates
:return: :return: list of pending certificates
""" """
return ( return (
Certificate.query.filter(Certificate.sources.any(id=source.id)) Certificate.query.filter(Certificate.sources.any(id=source.id))
@ -117,6 +118,41 @@ def get_all_pending_cleaning(source):
) )
def get_all_pending_cleaning_expiring_in_days(source, days_to_expire):
"""
Retrieves all certificates that are available for cleaning, not attached to endpoint,
and within X days from expiration.
:param days_to_expire: defines how many days till the certificate is expired
:param source: the source to search for certificates
:return: list of pending certificates
"""
expiration_window = arrow.now().shift(days=+days_to_expire).format("YYYY-MM-DD")
return (
Certificate.query.filter(Certificate.sources.any(id=source.id))
.filter(not_(Certificate.endpoints.any()))
.filter(Certificate.not_after < expiration_window)
.all()
)
def get_all_pending_cleaning_issued_since_days(source, days_since_issuance):
"""
Retrieves all certificates that are available for cleaning: not attached to endpoint, and X days since issuance.
:param days_since_issuance: defines how many days since the certificate is issued
:param source: the source to search for certificates
:return: list of pending certificates
"""
not_in_use_window = arrow.now().shift(days=-days_since_issuance).format("YYYY-MM-DD")
return (
Certificate.query.filter(Certificate.sources.any(id=source.id))
.filter(not_(Certificate.endpoints.any()))
.filter(Certificate.date_created > not_in_use_window)
.all()
)
def get_all_pending_reissue(): def get_all_pending_reissue():
""" """
Retrieves all certificates that need to be rotated. Retrieves all certificates that need to be rotated.

View File

@ -54,18 +54,30 @@ class AcmeHandler(object):
current_app.logger.error(f"Unable to fetch DNS Providers: {e}") current_app.logger.error(f"Unable to fetch DNS Providers: {e}")
self.all_dns_providers = [] self.all_dns_providers = []
def find_dns_challenge(self, host, authorizations): def get_dns_challenges(self, host, authorizations):
"""Get dns challenges for provided domain"""
domain_to_validate, is_wildcard = self.strip_wildcard(host)
dns_challenges = [] dns_challenges = []
for authz in authorizations: for authz in authorizations:
if not authz.body.identifier.value.lower() == host.lower(): if not authz.body.identifier.value.lower() == domain_to_validate.lower():
continue
if is_wildcard and not authz.body.wildcard:
continue
if not is_wildcard and authz.body.wildcard:
continue continue
for combo in authz.body.challenges: for combo in authz.body.challenges:
if isinstance(combo.chall, challenges.DNS01): if isinstance(combo.chall, challenges.DNS01):
dns_challenges.append(combo) dns_challenges.append(combo)
return dns_challenges return dns_challenges
def maybe_remove_wildcard(self, host): def strip_wildcard(self, host):
return host.replace("*.", "") """Removes the leading *. and returns Host and whether it was removed or not (True/False)"""
prefix = "*."
if host.startswith(prefix):
return host[len(prefix):], True
return host, False
def maybe_add_extension(self, host, dns_provider_options): def maybe_add_extension(self, host, dns_provider_options):
if dns_provider_options and dns_provider_options.get( if dns_provider_options and dns_provider_options.get(
@ -86,9 +98,8 @@ class AcmeHandler(object):
current_app.logger.debug("Starting DNS challenge for {0}".format(host)) current_app.logger.debug("Starting DNS challenge for {0}".format(host))
change_ids = [] change_ids = []
dns_challenges = self.get_dns_challenges(host, order.authorizations)
host_to_validate = self.maybe_remove_wildcard(host) host_to_validate, _ = self.strip_wildcard(host)
dns_challenges = self.find_dns_challenge(host_to_validate, order.authorizations)
host_to_validate = self.maybe_add_extension( host_to_validate = self.maybe_add_extension(
host_to_validate, dns_provider_options host_to_validate, dns_provider_options
) )
@ -325,7 +336,7 @@ class AcmeHandler(object):
) )
dns_provider_options = json.loads(dns_provider.credentials) dns_provider_options = json.loads(dns_provider.credentials)
account_number = dns_provider_options.get("account_id") account_number = dns_provider_options.get("account_id")
host_to_validate = self.maybe_remove_wildcard(authz_record.host) host_to_validate, _ = self.strip_wildcard(authz_record.host)
host_to_validate = self.maybe_add_extension( host_to_validate = self.maybe_add_extension(
host_to_validate, dns_provider_options host_to_validate, dns_provider_options
) )
@ -357,7 +368,7 @@ class AcmeHandler(object):
dns_provider_options = json.loads(dns_provider.credentials) dns_provider_options = json.loads(dns_provider.credentials)
account_number = dns_provider_options.get("account_id") account_number = dns_provider_options.get("account_id")
dns_challenges = authz_record.dns_challenge dns_challenges = authz_record.dns_challenge
host_to_validate = self.maybe_remove_wildcard(authz_record.host) host_to_validate, _ = self.strip_wildcard(authz_record.host)
host_to_validate = self.maybe_add_extension( host_to_validate = self.maybe_add_extension(
host_to_validate, dns_provider_options host_to_validate, dns_provider_options
) )

View File

@ -23,11 +23,12 @@ class TestAcme(unittest.TestCase):
} }
@patch("lemur.plugins.lemur_acme.plugin.len", return_value=1) @patch("lemur.plugins.lemur_acme.plugin.len", return_value=1)
def test_find_dns_challenge(self, mock_len): def test_get_dns_challenges(self, mock_len):
assert mock_len assert mock_len
from acme import challenges from acme import challenges
host = "example.com"
c = challenges.DNS01() c = challenges.DNS01()
mock_authz = Mock() mock_authz = Mock()
@ -35,9 +36,18 @@ class TestAcme(unittest.TestCase):
mock_entry = Mock() mock_entry = Mock()
mock_entry.chall = c mock_entry.chall = c
mock_authz.body.resolved_combinations.append(mock_entry) mock_authz.body.resolved_combinations.append(mock_entry)
result = yield self.acme.find_dns_challenge(mock_authz) result = yield self.acme.get_dns_challenges(host, mock_authz)
self.assertEqual(result, mock_entry) self.assertEqual(result, mock_entry)
def test_strip_wildcard(self):
expected = ("example.com", False)
result = self.acme.strip_wildcard("example.com")
self.assertEqual(expected, result)
expected = ("example.com", True)
result = self.acme.strip_wildcard("*.example.com")
self.assertEqual(expected, result)
def test_authz_record(self): def test_authz_record(self):
a = plugin.AuthorizationRecord("host", "authz", "challenge", "id") a = plugin.AuthorizationRecord("host", "authz", "challenge", "id")
self.assertEqual(type(a), plugin.AuthorizationRecord) self.assertEqual(type(a), plugin.AuthorizationRecord)
@ -45,9 +55,9 @@ class TestAcme(unittest.TestCase):
@patch("acme.client.Client") @patch("acme.client.Client")
@patch("lemur.plugins.lemur_acme.plugin.current_app") @patch("lemur.plugins.lemur_acme.plugin.current_app")
@patch("lemur.plugins.lemur_acme.plugin.len", return_value=1) @patch("lemur.plugins.lemur_acme.plugin.len", return_value=1)
@patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge") @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_dns_challenges")
def test_start_dns_challenge( def test_start_dns_challenge(
self, mock_find_dns_challenge, mock_len, mock_app, mock_acme self, mock_get_dns_challenges, mock_len, mock_app, mock_acme
): ):
assert mock_len assert mock_len
mock_order = Mock() mock_order = Mock()
@ -65,7 +75,7 @@ class TestAcme(unittest.TestCase):
mock_dns_provider.create_txt_record = Mock(return_value=1) mock_dns_provider.create_txt_record = Mock(return_value=1)
values = [mock_entry] values = [mock_entry]
iterable = mock_find_dns_challenge.return_value iterable = mock_get_dns_challenges.return_value
iterator = iter(values) iterator = iter(values)
iterable.__iter__.return_value = iterator iterable.__iter__.return_value = iterator
result = self.acme.start_dns_challenge( result = self.acme.start_dns_challenge(
@ -127,12 +137,12 @@ class TestAcme(unittest.TestCase):
@patch("acme.client.Client") @patch("acme.client.Client")
@patch("OpenSSL.crypto", return_value="mock_cert") @patch("OpenSSL.crypto", return_value="mock_cert")
@patch("josepy.util.ComparableX509") @patch("josepy.util.ComparableX509")
@patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.find_dns_challenge") @patch("lemur.plugins.lemur_acme.plugin.AcmeHandler.get_dns_challenges")
@patch("lemur.plugins.lemur_acme.plugin.current_app") @patch("lemur.plugins.lemur_acme.plugin.current_app")
def test_request_certificate( def test_request_certificate(
self, self,
mock_current_app, mock_current_app,
mock_find_dns_challenge, mock_get_dns_challenges,
mock_jose, mock_jose,
mock_crypto, mock_crypto,
mock_acme, mock_acme,

View File

@ -54,6 +54,17 @@ def validate_sources(source_strings):
return sources return sources
def execute_clean(plugin, certificate, source):
try:
plugin.clean(certificate, source.options)
certificate.sources.remove(source)
certificate_service.database.update(certificate)
return SUCCESS_METRIC_STATUS
except Exception as e:
current_app.logger.exception(e)
sentry.captureException()
@manager.option( @manager.option(
"-s", "-s",
"--sources", "--sources",
@ -132,11 +143,9 @@ def clean(source_strings, commit):
s = plugins.get(source.plugin_name) s = plugins.get(source.plugin_name)
if not hasattr(s, "clean"): if not hasattr(s, "clean"):
print( info_text = f"Cannot clean source: {source.label}, source plugin does not implement 'clean()'"
"Cannot clean source: {0}, source plugin does not implement 'clean()'".format( current_app.logger.warning(info_text)
source.label print(info_text)
)
)
continue continue
start_time = time.time() start_time = time.time()
@ -144,35 +153,147 @@ def clean(source_strings, commit):
print("[+] Staring to clean source: {label}!\n".format(label=source.label)) print("[+] Staring to clean source: {label}!\n".format(label=source.label))
cleaned = 0 cleaned = 0
for certificate in certificate_service.get_all_pending_cleaning(source): certificates = certificate_service.get_all_pending_cleaning_expired(source)
for certificate in certificates:
status = FAILURE_METRIC_STATUS status = FAILURE_METRIC_STATUS
if commit: if commit:
try: status = execute_clean(s, certificate, source)
s.clean(certificate, source.options)
certificate.sources.remove(source)
certificate_service.database.update(certificate)
status = SUCCESS_METRIC_STATUS
except Exception as e:
current_app.logger.exception(e)
sentry.captureException()
metrics.send( metrics.send(
"clean", "certificate_clean",
"counter", "counter",
1, 1,
metric_tags={"source": source.label, "status": status}, metric_tags={"status": status, "source": source.label, "certificate": certificate.name},
) )
current_app.logger.warning(f"Removed {certificate.name} from source {source.label} during cleaning")
current_app.logger.warning(
"Removed {0} from source {1} during cleaning".format(
certificate.name, source.label
)
)
cleaned += 1 cleaned += 1
print( info_text = f"[+] Finished cleaning source: {source.label}. " \
"[+] Finished cleaning source: {label}. Removed {cleaned} certificates from source. Run Time: {time}\n".format( f"Removed {cleaned} certificates from source. " \
label=source.label, time=(time.time() - start_time), cleaned=cleaned f"Run Time: {(time.time() - start_time)}\n"
print(info_text)
current_app.logger.warning(info_text)
@manager.option(
"-s",
"--sources",
dest="source_strings",
action="append",
help="Sources to operate on.",
) )
@manager.option(
"-d",
"--days",
dest="days_to_expire",
type=int,
action="store",
required=True,
help="The expiry range within days.",
) )
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def clean_unused_and_expiring_within_days(source_strings, days_to_expire, commit):
sources = validate_sources(source_strings)
for source in sources:
s = plugins.get(source.plugin_name)
if not hasattr(s, "clean"):
info_text = f"Cannot clean source: {source.label}, source plugin does not implement 'clean()'"
current_app.logger.warning(info_text)
print(info_text)
continue
start_time = time.time()
print("[+] Staring to clean source: {label}!\n".format(label=source.label))
cleaned = 0
certificates = certificate_service.get_all_pending_cleaning_expiring_in_days(source, days_to_expire)
for certificate in certificates:
status = FAILURE_METRIC_STATUS
if commit:
status = execute_clean(s, certificate, source)
metrics.send(
"certificate_clean",
"counter",
1,
metric_tags={"status": status, "source": source.label, "certificate": certificate.name},
)
current_app.logger.warning(f"Removed {certificate.name} from source {source.label} during cleaning")
cleaned += 1
info_text = f"[+] Finished cleaning source: {source.label}. " \
f"Removed {cleaned} certificates from source. " \
f"Run Time: {(time.time() - start_time)}\n"
print(info_text)
current_app.logger.warning(info_text)
@manager.option(
"-s",
"--sources",
dest="source_strings",
action="append",
help="Sources to operate on.",
)
@manager.option(
"-d",
"--days",
dest="days_since_issuance",
type=int,
action="store",
required=True,
help="Days since issuance.",
)
@manager.option(
"-c",
"--commit",
dest="commit",
action="store_true",
default=False,
help="Persist changes.",
)
def clean_unused_and_issued_since_days(source_strings, days_since_issuance, commit):
sources = validate_sources(source_strings)
for source in sources:
s = plugins.get(source.plugin_name)
if not hasattr(s, "clean"):
info_text = f"Cannot clean source: {source.label}, source plugin does not implement 'clean()'"
current_app.logger.warning(info_text)
print(info_text)
continue
start_time = time.time()
print("[+] Staring to clean source: {label}!\n".format(label=source.label))
cleaned = 0
certificates = certificate_service.get_all_pending_cleaning_issued_since_days(source, days_since_issuance)
for certificate in certificates:
status = FAILURE_METRIC_STATUS
if commit:
status = execute_clean(s, certificate, source)
metrics.send(
"certificate_clean",
"counter",
1,
metric_tags={"status": status, "source": source.label, "certificate": certificate.name},
)
current_app.logger.warning(f"Removed {certificate.name} from source {source.label} during cleaning")
cleaned += 1
info_text = f"[+] Finished cleaning source: {source.label}. " \
f"Removed {cleaned} certificates from source. " \
f"Run Time: {(time.time() - start_time)}\n"
print(info_text)
current_app.logger.warning(info_text)