Merge branch 'master' into soft_time_outs

This commit is contained in:
Hossein Shafagh 2019-08-13 19:42:22 -07:00 committed by GitHub
commit 296a315a3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 733 additions and 18 deletions

View File

@ -66,6 +66,9 @@ celery = make_celery(flask_app)
def is_task_active(fun, task_id, args): def is_task_active(fun, task_id, args):
from celery.task.control import inspect from celery.task.control import inspect
if not args:
args = '()' # empty args
i = inspect() i = inspect()
active_tasks = i.active() active_tasks = i.active()
for _, tasks in active_tasks.items(): for _, tasks in active_tasks.items():
@ -89,6 +92,21 @@ def report_celery_last_success_metrics():
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = {
"function": function,
"message": "recurrent task",
"task_id": task_id,
}
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_time = int(time.time()) current_time = int(time.time())
schedule = current_app.config.get('CELERYBEAT_SCHEDULE') schedule = current_app.config.get('CELERYBEAT_SCHEDULE')
for _, t in schedule.items(): for _, t in schedule.items():
@ -213,15 +231,25 @@ def fetch_acme_cert(id):
@celery.task() @celery.task()
def fetch_all_pending_acme_certs(): def fetch_all_pending_acme_certs():
"""Instantiate celery workers to resolve all pending Acme certificates""" """Instantiate celery workers to resolve all pending Acme certificates"""
pending_certs = pending_certificate_service.get_unresolved_pending_certs()
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "Starting job.", "message": "Starting job.",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
pending_certs = pending_certificate_service.get_unresolved_pending_certs()
# We only care about certs using the acme-issuer plugin # We only care about certs using the acme-issuer plugin
for cert in pending_certs: for cert in pending_certs:
@ -242,10 +270,21 @@ def fetch_all_pending_acme_certs():
def remove_old_acme_certs(): def remove_old_acme_certs():
"""Prune old pending acme certificates from the database""" """Prune old pending acme certificates from the database"""
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "Starting job.", "message": "Starting job.",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
pending_certs = pending_certificate_service.get_pending_certs("all") pending_certs = pending_certificate_service.get_pending_certs("all")
# Delete pending certs more than a week old # Delete pending certs more than a week old
@ -268,10 +307,21 @@ def clean_all_sources():
be ran periodically. This function triggers one celery task per source. be ran periodically. This function triggers one celery task per source.
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "Creating celery task to clean source", "message": "Creating celery task to clean source",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
sources = validate_sources("all") sources = validate_sources("all")
for source in sources: for source in sources:
log_data["source"] = source.label log_data["source"] = source.label
@ -292,11 +342,22 @@ def clean_source(source):
:return: :return:
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "Cleaning source", "message": "Cleaning source",
"source": source, "source": source,
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, (source,)):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
try: try:
clean([source], True) clean([source], True)
@ -313,10 +374,21 @@ def sync_all_sources():
This function will sync certificates from all sources. This function triggers one celery task per source. This function will sync certificates from all sources. This function triggers one celery task per source.
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "creating celery task to sync source", "message": "creating celery task to sync source",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
sources = validate_sources("all") sources = validate_sources("all")
for source in sources: for source in sources:
log_data["source"] = source.label log_data["source"] = source.label
@ -340,21 +412,23 @@ def sync_source(source):
task_id = None task_id = None
if celery.current_task: if celery.current_task:
task_id = celery.current_task.request.id task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "Syncing source", "message": "Syncing source",
"source": source, "source": source,
"task_id": task_id, "task_id": task_id,
} }
current_app.logger.debug(log_data)
if task_id and is_task_active(function, task_id, (source,)): if task_id and is_task_active(function, task_id, (source,)):
log_data["message"] = "Skipping task: Task is already active" log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
return return
current_app.logger.debug(log_data)
try: try:
sync([source]) sync([source])
metrics.send(f"{function}.success", 'counter', '1', metric_tags={"source": source}) metrics.send(f"{function}.success", 'counter', 1, metric_tags={"source": source})
except SoftTimeLimitExceeded: except SoftTimeLimitExceeded:
log_data["message"] = "Error syncing source: Time limit exceeded." log_data["message"] = "Error syncing source: Time limit exceeded."
current_app.logger.error(log_data) current_app.logger.error(log_data)
@ -379,10 +453,21 @@ def sync_source_destination():
We rely on account numbers to avoid duplicates. We rely on account numbers to avoid duplicates.
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "syncing AWS destinations and sources", "message": "syncing AWS destinations and sources",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
for dst in destinations_service.get_all(): for dst in destinations_service.get_all():
if add_aws_destination_to_sources(dst): if add_aws_destination_to_sources(dst):
@ -403,10 +488,21 @@ def certificate_reissue():
:return: :return:
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "reissuing certificates", "message": "reissuing certificates",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
try: try:
cli_certificate.reissue(None, True) cli_certificate.reissue(None, True)
@ -430,10 +526,22 @@ def certificate_rotate():
:return: :return:
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "rotating certificates", "message": "rotating certificates",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
try: try:
cli_certificate.rotate(None, None, None, None, True) cli_certificate.rotate(None, None, None, None, True)
@ -457,10 +565,21 @@ def endpoints_expire():
:return: :return:
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "endpoints expire", "message": "endpoints expire",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
try: try:
cli_endpoints.expire(2) # Time in hours cli_endpoints.expire(2) # Time in hours
@ -482,10 +601,21 @@ def get_all_zones():
:return: :return:
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "refresh all zones from available DNS providers", "message": "refresh all zones from available DNS providers",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
try: try:
cli_dns_providers.get_all_zones() cli_dns_providers.get_all_zones()
@ -543,10 +673,21 @@ def notify_expirations():
:return: :return:
""" """
function = f"{__name__}.{sys._getframe().f_code.co_name}" function = f"{__name__}.{sys._getframe().f_code.co_name}"
task_id = None
if celery.current_task:
task_id = celery.current_task.request.id
log_data = { log_data = {
"function": function, "function": function,
"message": "notify for cert expiration", "message": "notify for cert expiration",
"task_id": task_id,
} }
if task_id and is_task_active(function, task_id, None):
log_data["message"] = "Skipping task: Task is already active"
current_app.logger.debug(log_data)
return
current_app.logger.debug(log_data) current_app.logger.debug(log_data)
try: try:
cli_notification.expirations(current_app.config.get("EXCLUDE_CN_FROM_NOTIFICATION", [])) cli_notification.expirations(current_app.config.get("EXCLUDE_CN_FROM_NOTIFICATION", []))

View File

@ -98,6 +98,7 @@ def get_types():
], ],
}, },
{"name": "dyn"}, {"name": "dyn"},
{"name": "ultradns"},
] ]
}, },
) )

View File

@ -31,7 +31,7 @@ from lemur.exceptions import InvalidAuthority, InvalidConfiguration, UnknownProv
from lemur.extensions import metrics, sentry from lemur.extensions import metrics, sentry
from lemur.plugins import lemur_acme as acme from lemur.plugins import lemur_acme as acme
from lemur.plugins.bases import IssuerPlugin from lemur.plugins.bases import IssuerPlugin
from lemur.plugins.lemur_acme import cloudflare, dyn, route53 from lemur.plugins.lemur_acme import cloudflare, dyn, route53, ultradns
class AuthorizationRecord(object): class AuthorizationRecord(object):
@ -370,7 +370,7 @@ class AcmeHandler(object):
pass pass
def get_dns_provider(self, type): def get_dns_provider(self, type):
provider_types = {"cloudflare": cloudflare, "dyn": dyn, "route53": route53} provider_types = {"cloudflare": cloudflare, "dyn": dyn, "route53": route53, "ultradns": ultradns}
provider = provider_types.get(type) provider = provider_types.get(type)
if not provider: if not provider:
raise UnknownProvider("No such DNS provider: {}".format(type)) raise UnknownProvider("No such DNS provider: {}".format(type))
@ -424,7 +424,7 @@ class ACMEIssuerPlugin(IssuerPlugin):
def get_dns_provider(self, type): def get_dns_provider(self, type):
self.acme = AcmeHandler() self.acme = AcmeHandler()
provider_types = {"cloudflare": cloudflare, "dyn": dyn, "route53": route53} provider_types = {"cloudflare": cloudflare, "dyn": dyn, "route53": route53, "ultradns": ultradns}
provider = provider_types.get(type) provider = provider_types.get(type)
if not provider: if not provider:
raise UnknownProvider("No such DNS provider: {}".format(type)) raise UnknownProvider("No such DNS provider: {}".format(type))

View File

@ -1,8 +1,9 @@
import unittest import unittest
from requests.models import Response
from mock import MagicMock, Mock, patch from mock import MagicMock, Mock, patch
from lemur.plugins.lemur_acme import plugin from lemur.plugins.lemur_acme import plugin, ultradns
class TestAcme(unittest.TestCase): class TestAcme(unittest.TestCase):
@ -360,3 +361,121 @@ class TestAcme(unittest.TestCase):
mock_request_certificate.return_value = ("pem_certificate", "chain") mock_request_certificate.return_value = ("pem_certificate", "chain")
result = provider.create_certificate(csr, issuer_options) result = provider.create_certificate(csr, issuer_options)
assert result assert result
@patch("lemur.plugins.lemur_acme.ultradns.requests")
@patch("lemur.plugins.lemur_acme.ultradns.current_app")
def test_get_ultradns_token(self, mock_current_app, mock_requests):
# ret_val = json.dumps({"access_token": "access"})
the_response = Response()
the_response._content = b'{"access_token": "access"}'
mock_requests.post = Mock(return_value=the_response)
mock_current_app.config.get = Mock(return_value="Test")
result = ultradns.get_ultradns_token()
self.assertTrue(len(result) > 0)
@patch("lemur.plugins.lemur_acme.ultradns.current_app")
def test_create_txt_record(self, mock_current_app):
domain = "_acme_challenge.test.example.com"
zone = "test.example.com"
token = "ABCDEFGHIJ"
account_number = "1234567890"
change_id = (domain, token)
ultradns.get_zone_name = Mock(return_value=zone)
mock_current_app.logger.debug = Mock()
ultradns._post = Mock()
log_data = {
"function": "create_txt_record",
"fqdn": domain,
"token": token,
"message": "TXT record created"
}
result = ultradns.create_txt_record(domain, token, account_number)
mock_current_app.logger.debug.assert_called_with(log_data)
self.assertEqual(result, change_id)
@patch("lemur.plugins.lemur_acme.ultradns.current_app")
@patch("lemur.extensions.metrics")
def test_delete_txt_record(self, mock_metrics, mock_current_app):
domain = "_acme_challenge.test.example.com"
zone = "test.example.com"
token = "ABCDEFGHIJ"
account_number = "1234567890"
change_id = (domain, token)
mock_current_app.logger.debug = Mock()
ultradns.get_zone_name = Mock(return_value=zone)
ultradns._post = Mock()
ultradns._get = Mock()
ultradns._get.return_value = {'zoneName': 'test.example.com.com',
'rrSets': [{'ownerName': '_acme-challenge.test.example.com.',
'rrtype': 'TXT (16)', 'ttl': 5, 'rdata': ['ABCDEFGHIJ']}],
'queryInfo': {'sort': 'OWNER', 'reverse': False, 'limit': 100},
'resultInfo': {'totalCount': 1, 'offset': 0, 'returnedCount': 1}}
ultradns._delete = Mock()
mock_metrics.send = Mock()
ultradns.delete_txt_record(change_id, account_number, domain, token)
mock_current_app.logger.debug.assert_not_called()
mock_metrics.send.assert_not_called()
@patch("lemur.plugins.lemur_acme.ultradns.current_app")
@patch("lemur.extensions.metrics")
def test_wait_for_dns_change(self, mock_metrics, mock_current_app):
ultradns._has_dns_propagated = Mock(return_value=True)
nameserver = "1.1.1.1"
ultradns.get_authoritative_nameserver = Mock(return_value=nameserver)
mock_metrics.send = Mock()
domain = "_acme-challenge.test.example.com"
token = "ABCDEFGHIJ"
change_id = (domain, token)
mock_current_app.logger.debug = Mock()
ultradns.wait_for_dns_change(change_id)
# mock_metrics.send.assert_not_called()
log_data = {
"function": "wait_for_dns_change",
"fqdn": domain,
"status": True,
"message": "Record status on Public DNS"
}
mock_current_app.logger.debug.assert_called_with(log_data)
def test_get_zone_name(self):
zones = ['example.com', 'test.example.com']
zone = "test.example.com"
domain = "_acme-challenge.test.example.com"
account_number = "1234567890"
ultradns.get_zones = Mock(return_value=zones)
result = ultradns.get_zone_name(domain, account_number)
self.assertEqual(result, zone)
def test_get_zones(self):
account_number = "1234567890"
path = "a/b/c"
zones = ['example.com', 'test.example.com']
paginate_response = [{
'properties': {
'name': 'example.com.', 'accountName': 'example', 'type': 'PRIMARY',
'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9,
'lastModifiedDateTime': '2017-06-14T06:45Z'},
'registrarInfo': {
'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.',
'example.ultradns.biz.', 'example.ultradns.org.']}},
'inherit': 'ALL'}, {
'properties': {
'name': 'test.example.com.', 'accountName': 'example', 'type': 'PRIMARY',
'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9,
'lastModifiedDateTime': '2017-06-14T06:45Z'},
'registrarInfo': {
'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.',
'example.ultradns.biz.', 'example.ultradns.org.']}},
'inherit': 'ALL'}, {
'properties': {
'name': 'example2.com.', 'accountName': 'example', 'type': 'SECONDARY',
'dnssecStatus': 'UNSIGNED', 'status': 'ACTIVE', 'resourceRecordCount': 9,
'lastModifiedDateTime': '2017-06-14T06:45Z'},
'registrarInfo': {
'nameServers': {'missing': ['example.ultradns.com.', 'example.ultradns.net.',
'example.ultradns.biz.', 'example.ultradns.org.']}},
'inherit': 'ALL'}]
ultradns._paginate = Mock(path, "zones")
ultradns._paginate.side_effect = [[paginate_response]]
result = ultradns.get_zones(account_number)
self.assertEqual(result, zones)

View File

@ -0,0 +1,445 @@
import time
import requests
import json
import sys
import dns
import dns.exception
import dns.name
import dns.query
import dns.resolver
from flask import current_app
from lemur.extensions import metrics, sentry
class Record:
"""
This class implements an Ultra DNS record.
Accepts the response from the API call as the argument.
"""
def __init__(self, _data):
# Since we are dealing with only TXT records for Lemur, we expect only 1 RRSet in the response.
# Thus we default to picking up the first entry (_data["rrsets"][0]) from the response.
self._data = _data["rrSets"][0]
@property
def name(self):
return self._data["ownerName"]
@property
def rrtype(self):
return self._data["rrtype"]
@property
def rdata(self):
return self._data["rdata"]
@property
def ttl(self):
return self._data["ttl"]
class Zone:
"""
This class implements an Ultra DNS zone.
"""
def __init__(self, _data, _client="Client"):
self._data = _data
self._client = _client
@property
def name(self):
"""
Zone name, has a trailing "." at the end, which we manually remove.
"""
return self._data["properties"]["name"][:-1]
@property
def authoritative_type(self):
"""
Indicates whether the zone is setup as a PRIMARY or SECONDARY
"""
return self._data["properties"]["type"]
@property
def record_count(self):
return self._data["properties"]["resourceRecordCount"]
@property
def status(self):
"""
Returns the status of the zone - ACTIVE, SUSPENDED, etc
"""
return self._data["properties"]["status"]
def get_ultradns_token():
"""
Function to call the UltraDNS Authorization API.
Returns the Authorization access_token which is valid for 1 hour.
Each request calls this function and we generate a new token every time.
"""
path = "/v2/authorization/token"
data = {
"grant_type": "password",
"username": current_app.config.get("ACME_ULTRADNS_USERNAME", ""),
"password": current_app.config.get("ACME_ULTRADNS_PASSWORD", ""),
}
base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "")
resp = requests.post(f"{base_uri}{path}", data=data, verify=True)
return resp.json()["access_token"]
def _generate_header():
"""
Function to generate the header for a request.
Contains the Authorization access_key obtained from the get_ultradns_token() function.
"""
access_token = get_ultradns_token()
return {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}
def _paginate(path, key):
limit = 100
params = {"offset": 0, "limit": 1}
resp = _get(path, params)
for index in range(0, resp["resultInfo"]["totalCount"], limit):
params["offset"] = index
params["limit"] = limit
resp = _get(path, params)
yield resp[key]
def _get(path, params=None):
"""Function to execute a GET request on the given URL (base_uri + path) with given params"""
base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "")
resp = requests.get(
f"{base_uri}{path}",
headers=_generate_header(),
params=params,
verify=True,
)
resp.raise_for_status()
return resp.json()
def _delete(path):
"""Function to execute a DELETE request on the given URL"""
base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "")
resp = requests.delete(
f"{base_uri}{path}",
headers=_generate_header(),
verify=True,
)
resp.raise_for_status()
def _post(path, params):
"""Executes a POST request on given URL. Body is sent in JSON format"""
base_uri = current_app.config.get("ACME_ULTRADNS_DOMAIN", "")
resp = requests.post(
f"{base_uri}{path}",
headers=_generate_header(),
data=json.dumps(params),
verify=True,
)
resp.raise_for_status()
def _has_dns_propagated(name, token, domain):
"""
Check whether the DNS change made by Lemur have propagated to the public DNS or not.
Invoked by wait_for_dns_change() function
"""
txt_records = []
try:
dns_resolver = dns.resolver.Resolver()
dns_resolver.nameservers = [domain]
dns_response = dns_resolver.query(name, "TXT")
for rdata in dns_response:
for txt_record in rdata.strings:
txt_records.append(txt_record.decode("utf-8"))
except dns.exception.DNSException:
function = sys._getframe().f_code.co_name
metrics.send(f"{function}.fail", "counter", 1)
return False
for txt_record in txt_records:
if txt_record == token:
function = sys._getframe().f_code.co_name
metrics.send(f"{function}.success", "counter", 1)
return True
return False
def wait_for_dns_change(change_id, account_number=None):
"""
Waits and checks if the DNS changes have propagated or not.
First check the domains authoritative server. Once this succeeds,
we ask a public DNS server (Google <8.8.8.8> in our case).
"""
fqdn, token = change_id
number_of_attempts = 20
nameserver = get_authoritative_nameserver(fqdn)
for attempts in range(0, number_of_attempts):
status = _has_dns_propagated(fqdn, token, nameserver)
function = sys._getframe().f_code.co_name
log_data = {
"function": function,
"fqdn": fqdn,
"status": status,
"message": "Record status on ultraDNS authoritative server"
}
current_app.logger.debug(log_data)
if status:
time.sleep(10)
break
time.sleep(10)
if status:
nameserver = get_public_authoritative_nameserver()
for attempts in range(0, number_of_attempts):
status = _has_dns_propagated(fqdn, token, nameserver)
log_data = {
"function": function,
"fqdn": fqdn,
"status": status,
"message": "Record status on Public DNS"
}
current_app.logger.debug(log_data)
if status:
metrics.send(f"{function}.success", "counter", 1)
break
time.sleep(10)
if not status:
metrics.send(f"{function}.fail", "counter", 1, metric_tags={"fqdn": fqdn, "txt_record": token})
sentry.captureException(extra={"fqdn": str(fqdn), "txt_record": str(token)})
return
def get_zones(account_number):
"""Get zones from the UltraDNS"""
path = "/v2/zones"
zones = []
for page in _paginate(path, "zones"):
for elem in page:
# UltraDNS zone names end with a "." - Example - lemur.example.com.
# We pick out the names minus the "." at the end while returning the list
zone = Zone(elem)
if zone.authoritative_type == "PRIMARY" and zone.status == "ACTIVE":
zones.append(zone.name)
return zones
def get_zone_name(domain, account_number):
"""Get the matching zone for the given domain"""
zones = get_zones(account_number)
zone_name = ""
for z in zones:
if domain.endswith(z):
# Find the most specific zone possible for the domain
# Ex: If fqdn is a.b.c.com, there is a zone for c.com,
# and a zone for b.c.com, we want to use b.c.com.
if z.count(".") > zone_name.count("."):
zone_name = z
if not zone_name:
function = sys._getframe().f_code.co_name
metrics.send(f"{function}.fail", "counter", 1)
raise Exception(f"No UltraDNS zone found for domain: {domain}")
return zone_name
def create_txt_record(domain, token, account_number):
"""
Create a TXT record for the given domain.
The part of the domain that matches with the zone becomes the zone name.
The remainder becomes the owner name (referred to as node name here)
Example: Let's say we have a zone named "exmaple.com" in UltraDNS and we
get a request to create a cert for lemur.example.com
Domain - _acme-challenge.lemur.example.com
Matching zone - example.com
Owner name - _acme-challenge.lemur
"""
zone_name = get_zone_name(domain, account_number)
zone_parts = len(zone_name.split("."))
node_name = ".".join(domain.split(".")[:-zone_parts])
fqdn = f"{node_name}.{zone_name}"
path = f"/v2/zones/{zone_name}/rrsets/TXT/{node_name}"
params = {
"ttl": 5,
"rdata": [
f"{token}"
],
}
try:
_post(path, params)
function = sys._getframe().f_code.co_name
log_data = {
"function": function,
"fqdn": fqdn,
"token": token,
"message": "TXT record created"
}
current_app.logger.debug(log_data)
except Exception as e:
function = sys._getframe().f_code.co_name
log_data = {
"function": function,
"domain": domain,
"token": token,
"Exception": e,
"message": "Unable to add record. Record already exists."
}
current_app.logger.debug(log_data)
change_id = (fqdn, token)
return change_id
def delete_txt_record(change_id, account_number, domain, token):
"""
Delete the TXT record that was created in the create_txt_record() function.
UltraDNS handles records differently compared to Dyn. It creates an RRSet
which is a set of records of the same type and owner. This means
that while deleting the record, we cannot delete any individual record from
the RRSet. Instead, we have to delete the entire RRSet. If multiple certs are
being created for the same domain at the same time, the challenge TXT records
that are created will be added under the same RRSet. If the RRSet had more
than 1 record, then we create a new RRSet on UltraDNS minus the record that
has to be deleted.
"""
if not domain:
function = sys._getframe().f_code.co_name
log_data = {
"function": function,
"message": "No domain passed"
}
current_app.logger.debug(log_data)
return
zone_name = get_zone_name(domain, account_number)
zone_parts = len(zone_name.split("."))
node_name = ".".join(domain.split(".")[:-zone_parts])
path = f"/v2/zones/{zone_name}/rrsets/16/{node_name}"
try:
rrsets = _get(path)
record = Record(rrsets)
except Exception as e:
function = sys._getframe().f_code.co_name
metrics.send(f"{function}.geterror", "counter", 1)
# No Text Records remain or host is not in the zone anymore because all records have been deleted.
return
try:
# Remove the record from the RRSet locally
record.rdata.remove(f"{token}")
except ValueError:
function = sys._getframe().f_code.co_name
log_data = {
"function": function,
"token": token,
"message": "Token not found"
}
current_app.logger.debug(log_data)
return
# Delete the RRSet from UltraDNS
_delete(path)
# Check if the RRSet has more records. If yes, add the modified RRSet back to UltraDNS
if len(record.rdata) > 0:
params = {
"ttl": 5,
"rdata": record.rdata,
}
_post(path, params)
def delete_acme_txt_records(domain):
if not domain:
function = sys._getframe().f_code.co_name
log_data = {
"function": function,
"message": "No domain passed"
}
current_app.logger.debug(log_data)
return
acme_challenge_string = "_acme-challenge"
if not domain.startswith(acme_challenge_string):
function = sys._getframe().f_code.co_name
log_data = {
"function": function,
"domain": domain,
"acme_challenge_string": acme_challenge_string,
"message": "Domain does not start with the acme challenge string"
}
current_app.logger.debug(log_data)
return
zone_name = get_zone_name(domain)
zone_parts = len(zone_name.split("."))
node_name = ".".join(domain.split(".")[:-zone_parts])
path = f"/v2/zones/{zone_name}/rrsets/16/{node_name}"
_delete(path)
def get_authoritative_nameserver(domain):
"""Get the authoritative nameserver for the given domain"""
n = dns.name.from_text(domain)
depth = 2
default = dns.resolver.get_default_resolver()
nameserver = default.nameservers[0]
last = False
while not last:
s = n.split(depth)
last = s[0].to_unicode() == u"@"
sub = s[1]
query = dns.message.make_query(sub, dns.rdatatype.NS)
response = dns.query.udp(query, nameserver)
rcode = response.rcode()
if rcode != dns.rcode.NOERROR:
function = sys._getframe().f_code.co_name
metrics.send(f"{function}.error", "counter", 1)
if rcode == dns.rcode.NXDOMAIN:
raise Exception("%s does not exist." % sub)
else:
raise Exception("Error %s" % dns.rcode.to_text(rcode))
if len(response.authority) > 0:
rrset = response.authority[0]
else:
rrset = response.answer[0]
rr = rrset[0]
if rr.rdtype != dns.rdatatype.SOA:
authority = rr.target
nameserver = default.query(authority).rrset[0].to_text()
depth += 1
return nameserver
def get_public_authoritative_nameserver():
return "8.8.8.8"

View File

@ -158,7 +158,7 @@ def map_cis_fields(options, csr):
) )
data = { data = {
"profile_name": current_app.config.get("DIGICERT_CIS_PROFILE_NAME"), "profile_name": current_app.config.get("DIGICERT_CIS_PROFILE_NAMES", {}).get(options['authority'].name),
"common_name": options["common_name"], "common_name": options["common_name"],
"additional_dns_names": get_additional_names(options), "additional_dns_names": get_additional_names(options),
"csr": csr, "csr": csr,
@ -423,9 +423,9 @@ class DigiCertCISSourcePlugin(SourcePlugin):
required_vars = [ required_vars = [
"DIGICERT_CIS_API_KEY", "DIGICERT_CIS_API_KEY",
"DIGICERT_CIS_URL", "DIGICERT_CIS_URL",
"DIGICERT_CIS_ROOT", "DIGICERT_CIS_ROOTS",
"DIGICERT_CIS_INTERMEDIATE", "DIGICERT_CIS_INTERMEDIATES",
"DIGICERT_CIS_PROFILE_NAME", "DIGICERT_CIS_PROFILE_NAMES",
] ]
validate_conf(current_app, required_vars) validate_conf(current_app, required_vars)
@ -498,9 +498,9 @@ class DigiCertCISIssuerPlugin(IssuerPlugin):
required_vars = [ required_vars = [
"DIGICERT_CIS_API_KEY", "DIGICERT_CIS_API_KEY",
"DIGICERT_CIS_URL", "DIGICERT_CIS_URL",
"DIGICERT_CIS_ROOT", "DIGICERT_CIS_ROOTS",
"DIGICERT_CIS_INTERMEDIATE", "DIGICERT_CIS_INTERMEDIATES",
"DIGICERT_CIS_PROFILE_NAME", "DIGICERT_CIS_PROFILE_NAMES",
] ]
validate_conf(current_app, required_vars) validate_conf(current_app, required_vars)
@ -537,14 +537,14 @@ class DigiCertCISIssuerPlugin(IssuerPlugin):
if "ECC" in issuer_options["key_type"]: if "ECC" in issuer_options["key_type"]:
return ( return (
"\n".join(str(end_entity).splitlines()), "\n".join(str(end_entity).splitlines()),
current_app.config.get("DIGICERT_ECC_CIS_INTERMEDIATE"), current_app.config.get("DIGICERT_ECC_CIS_INTERMEDIATES", {}).get(issuer_options['authority'].name),
data["id"], data["id"],
) )
# By default return RSA # By default return RSA
return ( return (
"\n".join(str(end_entity).splitlines()), "\n".join(str(end_entity).splitlines()),
current_app.config.get("DIGICERT_CIS_INTERMEDIATE"), current_app.config.get("DIGICERT_CIS_INTERMEDIATES", {}).get(issuer_options['authority'].name),
data["id"], data["id"],
) )
@ -577,4 +577,4 @@ class DigiCertCISIssuerPlugin(IssuerPlugin):
:return: :return:
""" """
role = {"username": "", "password": "", "name": "digicert"} role = {"username": "", "password": "", "name": "digicert"}
return current_app.config.get("DIGICERT_CIS_ROOT"), "", [role] return current_app.config.get("DIGICERT_CIS_ROOTS", {}).get(options['authority'].name), "", [role]

View File

@ -66,7 +66,7 @@ def test_map_fields_with_validity_years(app):
} }
def test_map_cis_fields(app): def test_map_cis_fields(app, authority):
from lemur.plugins.lemur_digicert.plugin import map_cis_fields from lemur.plugins.lemur_digicert.plugin import map_cis_fields
names = [u"one.example.com", u"two.example.com", u"three.example.com"] names = [u"one.example.com", u"two.example.com", u"three.example.com"]
@ -80,6 +80,7 @@ def test_map_cis_fields(app):
"organizational_unit": "Example Org", "organizational_unit": "Example Org",
"validity_end": arrow.get(2017, 5, 7), "validity_end": arrow.get(2017, 5, 7),
"validity_start": arrow.get(2016, 10, 30), "validity_start": arrow.get(2016, 10, 30),
"authority": authority,
} }
data = map_cis_fields(options, CSR_STR) data = map_cis_fields(options, CSR_STR)
@ -104,6 +105,7 @@ def test_map_cis_fields(app):
"organization": "Example, Inc.", "organization": "Example, Inc.",
"organizational_unit": "Example Org", "organizational_unit": "Example Org",
"validity_years": 2, "validity_years": 2,
"authority": authority,
} }
with freeze_time(time_to_freeze=arrow.get(2016, 11, 3).datetime): with freeze_time(time_to_freeze=arrow.get(2016, 11, 3).datetime):

View File

@ -80,6 +80,13 @@ DIGICERT_API_KEY = "api-key"
DIGICERT_ORG_ID = 111111 DIGICERT_ORG_ID = 111111
DIGICERT_ROOT = "ROOT" DIGICERT_ROOT = "ROOT"
DIGICERT_CIS_URL = "mock://www.digicert.com"
DIGICERT_CIS_PROFILE_NAMES = {"sha2-rsa-ecc-root": "ssl_plus"}
DIGICERT_CIS_API_KEY = "api-key"
DIGICERT_CIS_ROOTS = {"root": "ROOT"}
DIGICERT_CIS_INTERMEDIATES = {"inter": "INTERMEDIATE_CA_CERT"}
VERISIGN_URL = "http://example.com" VERISIGN_URL = "http://example.com"
VERISIGN_PEM_PATH = "~/" VERISIGN_PEM_PATH = "~/"
VERISIGN_FIRST_NAME = "Jim" VERISIGN_FIRST_NAME = "Jim"