updating variables based on feedback

This commit is contained in:
csine-nflx 2020-10-29 13:51:22 -07:00
parent 33a006bbeb
commit 2b91077d92
2 changed files with 29 additions and 30 deletions

View File

@ -37,9 +37,9 @@ from retrying import retry
class AuthorizationRecord(object): class AuthorizationRecord(object):
def __init__(self, domain, host, authz, dns_challenge, change_id): def __init__(self, domain, target_domain, authz, dns_challenge, change_id):
self.domain = domain self.domain = domain
self.host = host self.target_domain = target_domain
self.authz = authz self.authz = authz
self.dns_challenge = dns_challenge self.dns_challenge = dns_challenge
self.change_id = change_id self.change_id = change_id
@ -93,16 +93,16 @@ class AcmeHandler(object):
acme_client, acme_client,
account_number, account_number,
domain, domain,
host, target_domain,
dns_provider, dns_provider,
order, order,
dns_provider_options, dns_provider_options,
): ):
current_app.logger.debug("Starting DNS challenge for {0}".format(host)) current_app.logger.debug("Starting DNS challenge for {0}".format(target_domain))
change_ids = [] change_ids = []
dns_challenges = self.get_dns_challenges(domain, order.authorizations) dns_challenges = self.get_dns_challenges(domain, order.authorizations)
host_to_validate, _ = self.strip_wildcard(host) host_to_validate, _ = self.strip_wildcard(target_domain)
host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options)
if not dns_challenges: if not dns_challenges:
@ -113,7 +113,7 @@ class AcmeHandler(object):
for dns_challenge in dns_challenges: for dns_challenge in dns_challenges:
# Only prepend '_acme-challenge' if not using CNAME redirection # Only prepend '_acme-challenge' if not using CNAME redirection
if domain == host: if domain == target_domain:
host_to_validate = dns_challenge.validation_domain_name(host_to_validate) host_to_validate = dns_challenge.validation_domain_name(host_to_validate)
change_id = dns_provider.create_txt_record( change_id = dns_provider.create_txt_record(
@ -124,7 +124,7 @@ class AcmeHandler(object):
change_ids.append(change_id) change_ids.append(change_id)
return AuthorizationRecord( return AuthorizationRecord(
domain, host, order.authorizations, dns_challenges, change_ids domain, target_domain, order.authorizations, dns_challenges, change_ids
) )
def complete_dns_challenge(self, acme_client, authz_record): def complete_dns_challenge(self, acme_client, authz_record):
@ -133,11 +133,11 @@ class AcmeHandler(object):
authz_record.authz[0].body.identifier.value authz_record.authz[0].body.identifier.value
) )
) )
dns_providers = self.dns_providers_for_domain.get(authz_record.host) dns_providers = self.dns_providers_for_domain.get(authz_record.target_domain)
if not dns_providers: if not dns_providers:
metrics.send("complete_dns_challenge_error_no_dnsproviders", "counter", 1) metrics.send("complete_dns_challenge_error_no_dnsproviders", "counter", 1)
raise Exception( raise Exception(
"No DNS providers found for domain: {}".format(authz_record.host) "No DNS providers found for domain: {}".format(authz_record.target_domain)
) )
for dns_provider in dns_providers: for dns_provider in dns_providers:
@ -165,7 +165,7 @@ class AcmeHandler(object):
verified = response.simple_verify( verified = response.simple_verify(
dns_challenge.chall, dns_challenge.chall,
authz_record.host, authz_record.target_domain,
acme_client.client.net.key.public_key(), acme_client.client.net.key.public_key(),
) )
@ -318,22 +318,22 @@ class AcmeHandler(object):
for domain in order_info.domains: for domain in order_info.domains:
# If CNAME exists, set host to the target address # If CNAME exists, set host to the target address
host = domain target_domain = domain
if current_app.config.get("ACME_ENABLE_DELEGATED_CNAME", False): if current_app.config.get("ACME_ENABLE_DELEGATED_CNAME", False):
val_domain, _ = self.strip_wildcard(domain) cname_result, _ = self.strip_wildcard(domain)
val_domain = challenges.DNS01().validation_domain_name(val_domain) cname_result = challenges.DNS01().validation_domain_name(cname_result)
cname_res = self.get_cname(val_domain) cname_result = self.get_cname(cname_result)
if cname_res: if cname_result:
host = cname_res target_domain = cname_result
self.autodetect_dns_providers(host) self.autodetect_dns_providers(target_domain)
if not self.dns_providers_for_domain.get(host): if not self.dns_providers_for_domain.get(target_domain):
metrics.send( metrics.send(
"get_authorizations_no_dns_provider_for_domain", "counter", 1 "get_authorizations_no_dns_provider_for_domain", "counter", 1
) )
raise Exception("No DNS providers found for domain: {}".format(host)) raise Exception("No DNS providers found for domain: {}".format(target_domain))
for dns_provider in self.dns_providers_for_domain[host]: for dns_provider in self.dns_providers_for_domain[target_domain]:
dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type)
dns_provider_options = json.loads(dns_provider.credentials) dns_provider_options = json.loads(dns_provider.credentials)
account_number = dns_provider_options.get("account_id") account_number = dns_provider_options.get("account_id")
@ -341,7 +341,7 @@ class AcmeHandler(object):
acme_client, acme_client,
account_number, account_number,
domain, domain,
host, target_domain,
dns_provider_plugin, dns_provider_plugin,
order, order,
dns_provider.options, dns_provider.options,
@ -376,7 +376,7 @@ class AcmeHandler(object):
for authz_record in authorizations: for authz_record in authorizations:
dns_challenges = authz_record.dns_challenge dns_challenges = authz_record.dns_challenge
for dns_challenge in dns_challenges: for dns_challenge in dns_challenges:
dns_providers = self.dns_providers_for_domain.get(authz_record.host) dns_providers = self.dns_providers_for_domain.get(authz_record.target_domain)
for dns_provider in dns_providers: for dns_provider in dns_providers:
# Grab account number (For Route53) # Grab account number (For Route53)
dns_provider_plugin = self.get_dns_provider( dns_provider_plugin = self.get_dns_provider(
@ -384,9 +384,9 @@ class AcmeHandler(object):
) )
dns_provider_options = json.loads(dns_provider.credentials) dns_provider_options = json.loads(dns_provider.credentials)
account_number = dns_provider_options.get("account_id") account_number = dns_provider_options.get("account_id")
host_to_validate, _ = self.strip_wildcard(authz_record.host) host_to_validate, _ = self.strip_wildcard(authz_record.target_domain)
host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options) host_to_validate = self.maybe_add_extension(host_to_validate, dns_provider_options)
if authz_record.domain == authz_record.host: if authz_record.domain == authz_record.target_domain:
host_to_validate = challenges.DNS01().validation_domain_name(host_to_validate) host_to_validate = challenges.DNS01().validation_domain_name(host_to_validate)
dns_provider_plugin.delete_txt_record( dns_provider_plugin.delete_txt_record(
authz_record.change_id, authz_record.change_id,
@ -410,20 +410,20 @@ class AcmeHandler(object):
:return: :return:
""" """
for authz_record in authorizations: for authz_record in authorizations:
dns_providers = self.dns_providers_for_domain.get(authz_record.host) dns_providers = self.dns_providers_for_domain.get(authz_record.target_domain)
for dns_provider in dns_providers: for dns_provider in dns_providers:
# Grab account number (For Route53) # Grab account number (For Route53)
dns_provider_options = json.loads(dns_provider.credentials) dns_provider_options = json.loads(dns_provider.credentials)
account_number = dns_provider_options.get("account_id") account_number = dns_provider_options.get("account_id")
dns_challenges = authz_record.dns_challenge dns_challenges = authz_record.dns_challenge
host_to_validate, _ = self.strip_wildcard(authz_record.host) host_to_validate, _ = self.strip_wildcard(authz_record.target_domain)
host_to_validate = self.maybe_add_extension( host_to_validate = self.maybe_add_extension(
host_to_validate, dns_provider_options host_to_validate, dns_provider_options
) )
dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type) dns_provider_plugin = self.get_dns_provider(dns_provider.provider_type)
for dns_challenge in dns_challenges: for dns_challenge in dns_challenges:
if authz_record.domain == authz_record.host: if authz_record.domain == authz_record.target_domain:
host_to_validate = dns_challenge.validation_domain_name(host_to_validate), host_to_validate = dns_challenge.validation_domain_name(host_to_validate),
try: try:
dns_provider_plugin.delete_txt_record( dns_provider_plugin.delete_txt_record(
@ -455,7 +455,6 @@ class AcmeHandler(object):
def get_cname(self, domain): def get_cname(self, domain):
""" """
:param domain: Domain name to look up a CNAME for. :param domain: Domain name to look up a CNAME for.
:param record_type: Type of DNS record to lookup.
:return: First CNAME target or False if no CNAME record exists. :return: First CNAME target or False if no CNAME record exists.
""" """
try: try:

View File

@ -97,7 +97,7 @@ class TestAcme(unittest.TestCase):
mock_authz.dns_challenge.response = Mock() mock_authz.dns_challenge.response = Mock()
mock_authz.dns_challenge.response.simple_verify = Mock(return_value=True) mock_authz.dns_challenge.response.simple_verify = Mock(return_value=True)
mock_authz.authz = [] mock_authz.authz = []
mock_authz.host = "www.test.com" mock_authz.target_domain = "www.test.com"
mock_authz_record = Mock() mock_authz_record = Mock()
mock_authz_record.body.identifier.value = "test" mock_authz_record.body.identifier.value = "test"
mock_authz.authz.append(mock_authz_record) mock_authz.authz.append(mock_authz_record)
@ -121,7 +121,7 @@ class TestAcme(unittest.TestCase):
mock_authz.dns_challenge.response = Mock() mock_authz.dns_challenge.response = Mock()
mock_authz.dns_challenge.response.simple_verify = Mock(return_value=False) mock_authz.dns_challenge.response.simple_verify = Mock(return_value=False)
mock_authz.authz = [] mock_authz.authz = []
mock_authz.host = "www.test.com" mock_authz.target_domain = "www.test.com"
mock_authz_record = Mock() mock_authz_record = Mock()
mock_authz_record.body.identifier.value = "test" mock_authz_record.body.identifier.value = "test"
mock_authz.authz.append(mock_authz_record) mock_authz.authz.append(mock_authz_record)