diff --git a/api/CHANGELOG.md b/api/CHANGELOG.md index 7cd4fce52f..7822024b70 100644 --- a/api/CHANGELOG.md +++ b/api/CHANGELOG.md @@ -2,6 +2,14 @@ All notable changes to the **Prowler API** are documented in this file. +## [1.23.0] (Prowler UNRELEASED) + +### 🔄 Changed + +- Attack Paths: Periodic cleanup of stale scans with dead-worker detection via Celery inspect, marking orphaned `EXECUTING` scans as `FAILED` and recovering `graph_data_ready` [(#10387)](https://github.com/prowler-cloud/prowler/pull/10387) + +--- + ## [1.22.0] (Prowler v5.21.0) ### 🚀 Added diff --git a/api/docker-entrypoint.sh b/api/docker-entrypoint.sh index eea024a4a2..a980595af2 100755 --- a/api/docker-entrypoint.sh +++ b/api/docker-entrypoint.sh @@ -30,9 +30,28 @@ start_prod_server() { poetry run gunicorn -c config/guniconf.py config.wsgi:application } +resolve_worker_hostname() { + TASK_ID="" + + if [ -n "$ECS_CONTAINER_METADATA_URI_V4" ]; then + TASK_ID=$(wget -qO- --timeout=2 "${ECS_CONTAINER_METADATA_URI_V4}/task" | \ + python3 -c "import sys,json; print(json.load(sys.stdin)['TaskARN'].split('/')[-1])" 2>/dev/null) + fi + + if [ -z "$TASK_ID" ]; then + TASK_ID=$(python3 -c "import uuid; print(uuid.uuid4().hex)") + fi + + echo "${TASK_ID}@$(hostname)" +} + start_worker() { echo "Starting the worker..." - poetry run python -m celery -A config.celery worker -l "${DJANGO_LOGGING_LEVEL:-info}" -Q celery,scans,scan-reports,deletion,backfill,overview,integrations,compliance,attack-paths-scans -E --max-tasks-per-child 1 + poetry run python -m celery -A config.celery worker \ + -n "$(resolve_worker_hostname)" \ + -l "${DJANGO_LOGGING_LEVEL:-info}" \ + -Q celery,scans,scan-reports,deletion,backfill,overview,integrations,compliance,attack-paths-scans \ + -E --max-tasks-per-child 1 } start_worker_beat() { diff --git a/api/src/backend/api/migrations/0085_attack_paths_cleanup_periodic_task.py b/api/src/backend/api/migrations/0085_attack_paths_cleanup_periodic_task.py new file mode 100644 index 0000000000..beb77d0867 --- /dev/null +++ b/api/src/backend/api/migrations/0085_attack_paths_cleanup_periodic_task.py @@ -0,0 +1,49 @@ +from django.db import migrations + + +TASK_NAME = "attack-paths-cleanup-stale-scans" +INTERVAL_HOURS = 1 + + +def create_periodic_task(apps, schema_editor): + IntervalSchedule = apps.get_model("django_celery_beat", "IntervalSchedule") + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") + + schedule, _ = IntervalSchedule.objects.get_or_create( + every=INTERVAL_HOURS, + period="hours", + ) + + PeriodicTask.objects.update_or_create( + name=TASK_NAME, + defaults={ + "task": TASK_NAME, + "interval": schedule, + "enabled": True, + }, + ) + + +def delete_periodic_task(apps, schema_editor): + IntervalSchedule = apps.get_model("django_celery_beat", "IntervalSchedule") + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") + + PeriodicTask.objects.filter(name=TASK_NAME).delete() + + # Clean up the schedule if no other task references it + IntervalSchedule.objects.filter( + every=INTERVAL_HOURS, + period="hours", + periodictask__isnull=True, + ).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ("api", "0084_googleworkspace_provider"), + ("django_celery_beat", "0019_alter_periodictasks_options"), + ] + + operations = [ + migrations.RunPython(create_periodic_task, delete_periodic_task), + ] diff --git a/api/src/backend/config/django/base.py b/api/src/backend/config/django/base.py index 9304c12938..238825591f 100644 --- a/api/src/backend/config/django/base.py +++ b/api/src/backend/config/django/base.py @@ -299,3 +299,8 @@ # SAML requirement CSRF_COOKIE_SECURE = True SESSION_COOKIE_SECURE = True + +# Attack Paths +ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES = env.int( + "ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES", 2880 +) # 48h diff --git a/api/src/backend/tasks/jobs/attack_paths/cleanup.py b/api/src/backend/tasks/jobs/attack_paths/cleanup.py new file mode 100644 index 0000000000..c88ddaaadd --- /dev/null +++ b/api/src/backend/tasks/jobs/attack_paths/cleanup.py @@ -0,0 +1,156 @@ +from datetime import datetime, timedelta, timezone + +from celery import current_app, states +from celery.utils.log import get_task_logger +from config.django.base import ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES +from tasks.jobs.attack_paths.db_utils import ( + finish_attack_paths_scan, + recover_graph_data_ready, +) + +from api.attack_paths import database as graph_database +from api.db_router import MainRouter +from api.db_utils import rls_transaction +from api.models import AttackPathsScan, StateChoices + +logger = get_task_logger(__name__) + + +def cleanup_stale_attack_paths_scans() -> dict: + """ + Find `EXECUTING` `AttackPathsScan` scans whose workers are dead or that have + exceeded the stale threshold, and mark them as `FAILED`. + + Two-pass detection: + 1. If `TaskResult.worker` exists, ping the worker. + - Dead worker: cleanup immediately (any age). + - Alive + past threshold: revoke the task, then cleanup. + - Alive + within threshold: skip. + 2. If no worker field: fall back to time-based heuristic only. + """ + threshold = timedelta(minutes=ATTACK_PATHS_SCAN_STALE_THRESHOLD_MINUTES) + now = datetime.now(tz=timezone.utc) + cutoff = now - threshold + + executing_scans = ( + AttackPathsScan.all_objects.using(MainRouter.admin_db) + .filter(state=StateChoices.EXECUTING) + .select_related("task__task_runner_task") + ) + + # Cache worker liveness so each worker is pinged at most once + executing_scans = list(executing_scans) + workers = { + tr.worker + for scan in executing_scans + if (tr := getattr(scan.task, "task_runner_task", None) if scan.task else None) + and tr.worker + } + worker_alive = {w: _is_worker_alive(w) for w in workers} + + cleaned_up = [] + + for scan in executing_scans: + task_result = ( + getattr(scan.task, "task_runner_task", None) if scan.task else None + ) + worker = task_result.worker if task_result else None + + if worker: + alive = worker_alive.get(worker, True) + + if alive: + if scan.started_at and scan.started_at >= cutoff: + continue + + # Alive but stale — revoke before cleanup + _revoke_task(task_result) + reason = ( + "Scan exceeded stale threshold — " "cleaned up by periodic task" + ) + else: + reason = "Worker dead — cleaned up by periodic task" + else: + # No worker recorded — time-based heuristic only + if scan.started_at and scan.started_at >= cutoff: + continue + reason = ( + "No worker recorded, scan exceeded stale threshold — " + "cleaned up by periodic task" + ) + + if _cleanup_scan(scan, task_result, reason): + cleaned_up.append(str(scan.id)) + + logger.info( + f"Stale `AttackPathsScan` cleanup: {len(cleaned_up)} scan(s) cleaned up" + ) + return {"cleaned_up_count": len(cleaned_up), "scan_ids": cleaned_up} + + +def _is_worker_alive(worker: str) -> bool: + """Ping a specific Celery worker. Returns `True` if it responds or on error.""" + try: + response = current_app.control.inspect(destination=[worker], timeout=1.0).ping() + return response is not None and worker in response + except Exception: + logger.exception(f"Failed to ping worker {worker}, treating as alive") + return True + + +def _revoke_task(task_result) -> None: + """Send `SIGTERM` to a hung Celery task. Non-fatal on failure.""" + try: + current_app.control.revoke( + task_result.task_id, terminate=True, signal="SIGTERM" + ) + logger.info(f"Revoked task {task_result.task_id}") + except Exception: + logger.exception(f"Failed to revoke task {task_result.task_id}") + + +def _cleanup_scan(scan, task_result, reason: str) -> bool: + """ + Clean up a single stale `AttackPathsScan`: + drop temp DB, mark `FAILED`, update `TaskResult`, recover `graph_data_ready`. + + Returns `True` if the scan was actually cleaned up, `False` if skipped. + """ + scan_id_str = str(scan.id) + + # 1. Drop temp Neo4j database + tmp_db_name = graph_database.get_database_name(scan.id, temporary=True) + try: + graph_database.drop_database(tmp_db_name) + except Exception: + logger.exception(f"Failed to drop temp database {tmp_db_name}") + + # 2. Re-fetch within RLS (race guard against normal completion) + with rls_transaction(str(scan.tenant_id)): + try: + fresh_scan = AttackPathsScan.objects.get(id=scan.id) + except AttackPathsScan.DoesNotExist: + logger.warning(f"Scan {scan_id_str} no longer exists, skipping") + return False + + if fresh_scan.state != StateChoices.EXECUTING: + logger.info(f"Scan {scan_id_str} is now {fresh_scan.state}, skipping") + return False + + # 3. Mark `AttackPathsScan` as `FAILED` + finish_attack_paths_scan( + fresh_scan, + StateChoices.FAILED, + {"global_error": reason}, + ) + + # 4. Mark `TaskResult` as `FAILURE` + if task_result: + task_result.status = states.FAILURE + task_result.save(update_fields=["status", "date_done"]) + + # 5. Recover graph_data_ready if provider data still exists + recover_graph_data_ready(fresh_scan) + + logger.info(f"Cleaned up stale scan {scan_id_str}: {reason}") + return True diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index 2e31ebc0f0..c230e7a9c2 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -13,6 +13,7 @@ can_provider_run_attack_paths_scan, db_utils as attack_paths_db_utils, ) +from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans from tasks.jobs.backfill import ( backfill_compliance_summaries, backfill_daily_severity_summaries, @@ -406,6 +407,11 @@ def perform_attack_paths_scan_task(self, tenant_id: str, scan_id: str): ) +@shared_task(name="attack-paths-cleanup-stale-scans", queue="attack-paths-scans") +def cleanup_stale_attack_paths_scans_task(): + return cleanup_stale_attack_paths_scans() + + @shared_task(name="tenant-deletion", queue="deletion", autoretry_for=(Exception,)) def delete_tenant_task(tenant_id: str): return delete_tenant(pk=tenant_id) diff --git a/api/src/backend/tasks/tests/test_attack_paths_scan.py b/api/src/backend/tasks/tests/test_attack_paths_scan.py index dc621f173b..114448011e 100644 --- a/api/src/backend/tasks/tests/test_attack_paths_scan.py +++ b/api/src/backend/tasks/tests/test_attack_paths_scan.py @@ -1,8 +1,10 @@ from contextlib import nullcontext +from datetime import datetime, timedelta, timezone from types import SimpleNamespace from unittest.mock import MagicMock, call, patch import pytest +from django_celery_results.models import TaskResult from tasks.jobs.attack_paths import findings as findings_module from tasks.jobs.attack_paths import internet as internet_module from tasks.jobs.attack_paths import sync as sync_module @@ -17,6 +19,7 @@ Scan, StateChoices, StatusChoices, + Task, ) from prowler.lib.check.models import Severity @@ -2317,3 +2320,342 @@ def test_set_provider_graph_data_ready_does_not_affect_other_providers( ap_scan_b.refresh_from_db() assert ap_scan_a.graph_data_ready is False assert ap_scan_b.graph_data_ready is True + + +@pytest.mark.django_db +class TestCleanupStaleAttackPathsScans: + def _create_executing_scan( + self, tenant, provider, scan=None, started_at=None, worker=None + ): + """Helper to create an EXECUTING AttackPathsScan with optional Task+TaskResult.""" + ap_scan = AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + scan=scan, + state=StateChoices.EXECUTING, + started_at=started_at or datetime.now(tz=timezone.utc), + ) + + task_result = None + if worker is not None: + task_result = TaskResult.objects.create( + task_id=str(ap_scan.id), + task_name="attack-paths-scan-perform", + status="STARTED", + worker=worker, + ) + task = Task.objects.create( + id=task_result.task_id, + task_runner_task=task_result, + tenant_id=tenant.id, + ) + ap_scan.task = task + ap_scan.save(update_fields=["task_id"]) + + return ap_scan, task_result + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database") + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + @patch("tasks.jobs.attack_paths.cleanup._is_worker_alive", return_value=False) + def test_cleans_up_scan_with_dead_worker( + self, + mock_alive, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + # Recent scan — should still be cleaned up because worker is dead + ap_scan, task_result = self._create_executing_scan( + tenant, provider, worker="dead-worker@host" + ) + + result = cleanup_stale_attack_paths_scans() + + assert result["cleaned_up_count"] == 1 + assert str(ap_scan.id) in result["scan_ids"] + mock_drop_db.assert_called_once() + mock_finish.assert_called_once() + assert mock_finish.call_args[0][1] == StateChoices.FAILED + mock_recover.assert_called_once() + + task_result.refresh_from_db() + assert task_result.status == "FAILURE" + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database") + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + @patch("tasks.jobs.attack_paths.cleanup._revoke_task") + @patch("tasks.jobs.attack_paths.cleanup._is_worker_alive", return_value=True) + def test_revokes_and_cleans_scan_exceeding_threshold_on_live_worker( + self, + mock_alive, + mock_revoke, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + old_start = datetime.now(tz=timezone.utc) - timedelta(hours=49) + ap_scan, task_result = self._create_executing_scan( + tenant, provider, started_at=old_start, worker="live-worker@host" + ) + + result = cleanup_stale_attack_paths_scans() + + assert result["cleaned_up_count"] == 1 + mock_revoke.assert_called_once_with(task_result) + mock_finish.assert_called_once() + mock_recover.assert_called_once() + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database") + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + @patch("tasks.jobs.attack_paths.cleanup._is_worker_alive", return_value=True) + def test_ignores_recent_executing_scans_on_live_worker( + self, + mock_alive, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + # Recent scan on live worker — should be skipped + self._create_executing_scan(tenant, provider, worker="live-worker@host") + + result = cleanup_stale_attack_paths_scans() + + assert result["cleaned_up_count"] == 0 + mock_drop_db.assert_not_called() + mock_finish.assert_not_called() + mock_recover.assert_not_called() + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database") + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + def test_ignores_completed_and_failed_scans( + self, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + state=StateChoices.COMPLETED, + ) + AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + state=StateChoices.FAILED, + ) + + result = cleanup_stale_attack_paths_scans() + + assert result["cleaned_up_count"] == 0 + mock_drop_db.assert_not_called() + mock_finish.assert_not_called() + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch( + "tasks.jobs.attack_paths.cleanup.graph_database.drop_database", + side_effect=Exception("Neo4j unreachable"), + ) + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + @patch("tasks.jobs.attack_paths.cleanup._is_worker_alive", return_value=False) + def test_handles_drop_database_failure_gracefully( + self, + mock_alive, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + self._create_executing_scan(tenant, provider, worker="dead-worker@host") + + result = cleanup_stale_attack_paths_scans() + + assert result["cleaned_up_count"] == 1 + mock_drop_db.assert_called_once() + mock_finish.assert_called_once() + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database") + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + @patch("tasks.jobs.attack_paths.cleanup._is_worker_alive", return_value=False) + def test_cross_tenant_cleanup( + self, + mock_alive, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant1 = tenants_fixture[0] + tenant2 = tenants_fixture[1] + provider1 = providers_fixture[0] + provider1.provider = Provider.ProviderChoices.AWS + provider1.save() + + provider2 = Provider.objects.create( + provider="aws", + uid="999888777666", + alias="aws_tenant2", + tenant_id=tenant2.id, + ) + + self._create_executing_scan(tenant1, provider1, worker="dead-worker-1@host") + self._create_executing_scan(tenant2, provider2, worker="dead-worker-2@host") + + result = cleanup_stale_attack_paths_scans() + + assert result["cleaned_up_count"] == 2 + assert mock_finish.call_count == 2 + assert mock_recover.call_count == 2 + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database") + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + @patch("tasks.jobs.attack_paths.cleanup._is_worker_alive", return_value=False) + def test_recovers_graph_data_ready_for_stale_scan( + self, + mock_alive, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + ap_scan, _ = self._create_executing_scan( + tenant, provider, worker="dead-worker@host" + ) + + cleanup_stale_attack_paths_scans() + + mock_recover.assert_called_once() + recovered_scan = mock_recover.call_args[0][0] + assert recovered_scan.id == ap_scan.id + + @patch("tasks.jobs.attack_paths.cleanup.recover_graph_data_ready") + @patch("tasks.jobs.attack_paths.cleanup.finish_attack_paths_scan") + @patch("tasks.jobs.attack_paths.cleanup.graph_database.drop_database") + @patch( + "tasks.jobs.attack_paths.cleanup.rls_transaction", + new=lambda *args, **kwargs: nullcontext(), + ) + def test_fallback_to_time_heuristic_when_no_worker_field( + self, + mock_drop_db, + mock_finish, + mock_recover, + tenants_fixture, + providers_fixture, + scans_fixture, + ): + from tasks.jobs.attack_paths.cleanup import cleanup_stale_attack_paths_scans + + tenant = tenants_fixture[0] + provider = providers_fixture[0] + provider.provider = Provider.ProviderChoices.AWS + provider.save() + + # Old scan with no Task/TaskResult + old_start = datetime.now(tz=timezone.utc) - timedelta(hours=49) + AttackPathsScan.objects.create( + tenant_id=tenant.id, + provider=provider, + state=StateChoices.EXECUTING, + started_at=old_start, + ) + + result = cleanup_stale_attack_paths_scans() + + assert result["cleaned_up_count"] == 1 + mock_finish.assert_called_once() + assert mock_finish.call_args[0][1] == StateChoices.FAILED