From bf47f87c215f9c6042374ddf3a43f5f4bbc24d43 Mon Sep 17 00:00:00 2001 From: Hossein Shafagh Date: Mon, 12 Aug 2019 13:52:01 -0700 Subject: [PATCH] preventing celery duplicate tasks --- lemur/common/celery.py | 91 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/lemur/common/celery.py b/lemur/common/celery.py index b19a9607..a79ec838 100644 --- a/lemur/common/celery.py +++ b/lemur/common/celery.py @@ -248,6 +248,15 @@ def remove_old_acme_certs(): } pending_certs = pending_certificate_service.get_pending_certs("all") + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + # Delete pending certs more than a week old for cert in pending_certs: if datetime.now(timezone.utc) - cert.last_updated > timedelta(days=7): @@ -311,6 +320,17 @@ def sync_all_sources(): "function": function, "message": "creating celery task to sync source", } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + sources = validate_sources("all") for source in sources: log_data["source"] = source.label @@ -340,6 +360,17 @@ def sync_source(source): "source": source, "task_id": task_id, } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + + current_app.logger.debug(log_data) if task_id and is_task_active(function, task_id, (source,)): @@ -378,6 +409,16 @@ def sync_source_destination(): "function": function, "message": "syncing AWS destinations and sources", } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + current_app.logger.debug(log_data) for dst in destinations_service.get_all(): if add_aws_destination_to_sources(dst): @@ -402,6 +443,16 @@ def certificate_reissue(): "function": function, "message": "reissuing certificates", } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + current_app.logger.debug(log_data) cli_certificate.reissue(None, True) log_data["message"] = "reissuance completed" @@ -421,6 +472,16 @@ def certificate_rotate(): "function": function, "message": "rotating certificates", } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + current_app.logger.debug(log_data) cli_certificate.rotate(None, None, None, None, True) log_data["message"] = "rotation completed" @@ -440,6 +501,16 @@ def endpoints_expire(): "function": function, "message": "endpoints expire", } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + current_app.logger.debug(log_data) cli_endpoints.expire(2) # Time in hours red.set(f'{function}.last_success', int(time.time())) @@ -457,6 +528,16 @@ def get_all_zones(): "function": function, "message": "refresh all zones from available DNS providers", } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + current_app.logger.debug(log_data) cli_dns_providers.get_all_zones() red.set(f'{function}.last_success', int(time.time())) @@ -491,6 +572,16 @@ def notify_expirations(): "function": function, "message": "notify for cert expiration", } + + task_id = None + if celery.current_task: + task_id = celery.current_task.request.id + + if task_id and is_task_active(function, task_id, (id,)): + log_data["message"] = "Skipping task: Task is already active" + current_app.logger.debug(log_data) + return + current_app.logger.debug(log_data) cli_notification.expirations(current_app.config.get("EXCLUDE_CN_FROM_NOTIFICATION", [])) red.set(f'{function}.last_success', int(time.time()))