|
1 | 1 | from datetime import datetime, timezone |
2 | | -from typing import Any, Dict |
| 2 | +from typing import Any, Dict, Optional |
3 | 3 | from uuid import UUID |
4 | 4 |
|
5 | 5 | import httpx |
|
10 | 10 | from common.config import settings |
11 | 11 | from common.logger import logger |
12 | 12 | from common.redis_manager import redis_manager |
13 | | -from common.schemas import DeleteResponse, TaskStatus |
| 13 | +from common.schemas import DeleteResponse, Identity, TaskStatus |
| 14 | +from worker import celery_app |
14 | 15 |
|
15 | 16 |
|
16 | 17 | @cached(cache=TTLCache(maxsize=128, ttl=5)) |
17 | | -def get_task_info(task_id: str) -> Dict[str, Any]: |
| 18 | +def _get_task_info(task_id: str) -> Dict[str, Any]: |
18 | 19 | """ |
19 | 20 | Fetch task information from Flower API. |
20 | 21 | So we can use this to provide more detailed task status in the API responses. |
@@ -54,18 +55,63 @@ def get_task_info(task_id: str) -> Dict[str, Any]: |
54 | 55 | return result |
55 | 56 |
|
56 | 57 |
|
57 | | -def get_queue_position_logs(task_id: str) -> list[str]: |
| 58 | +def create_task(task_name: str, task_queue: str, payload: dict, identity: Identity) -> AsyncResult: |
58 | 59 | """ |
59 | | - Returns a list containing a log string with the task's queue position. |
| 60 | + Unified helper to create a task in Celery. |
60 | 61 | """ |
61 | | - pos_data = redis_manager.get_queue_position(task_id) |
62 | | - if pos_data: |
63 | | - return [f"Queue {pos_data.queue} position: {pos_data.position} / {pos_data.total}"] |
| 62 | + try: |
| 63 | + return celery_app.send_task( |
| 64 | + task_name, |
| 65 | + queue=task_queue, |
| 66 | + args=[payload], |
| 67 | + kwargs=identity.model_dump(), |
| 68 | + ) |
| 69 | + except Exception as e: |
| 70 | + raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}") |
| 71 | + |
| 72 | + |
| 73 | +def get_task_detailed(id: UUID) -> tuple[AsyncResult, dict, list[str]]: |
| 74 | + """ |
| 75 | + Fetches the task across current Redis storage (Broker and Result Backend). |
| 76 | + Returns (AsyncResult, task_info, initial_logs). |
| 77 | + Raises 404 if the task is not in Redis (either never existed or has expired). |
| 78 | + """ |
| 79 | + |
| 80 | + def get_queue_position(task_id: str) -> Optional[str]: |
| 81 | + """ |
| 82 | + Inner helper to check the broker and format the queue position log. |
| 83 | + """ |
| 84 | + pos_data = redis_manager.get_queue_position(task_id) |
| 85 | + if pos_data: |
| 86 | + return f"Queue {pos_data.queue} position: {pos_data.position} / {pos_data.total}" |
| 87 | + |
| 88 | + return None |
| 89 | + |
| 90 | + result = AsyncResult(str(id), app=celery_app) |
| 91 | + logs = [] |
| 92 | + |
| 93 | + # Celery reports waiting tasks as PENDING and also unknown tasks as PENDING. |
| 94 | + if result.status == TaskStatus.PENDING: |
| 95 | + queue_position = get_queue_position(str(id)) |
| 96 | + if queue_position is None: |
| 97 | + # Truly not found |
| 98 | + raise HTTPException(status_code=404, detail="Task not found or has expired") |
| 99 | + |
| 100 | + # Keep the queue position logs to return to the user |
| 101 | + logs = [queue_position] |
| 102 | + else: |
| 103 | + # get the running logs of the task if available |
| 104 | + if result.info: |
| 105 | + if isinstance(result.info, dict): |
| 106 | + logs = result.info.get("logs", []) |
| 107 | + |
| 108 | + # Enrich with Flower metadata if available (metrics, worker info, etc) |
| 109 | + task_info = _get_task_info(str(id)) |
64 | 110 |
|
65 | | - return [f"Task not found"] |
| 111 | + return result, task_info, logs |
66 | 112 |
|
67 | 113 |
|
68 | | -def cancel_task(id: UUID, celery_app) -> DeleteResponse: |
| 114 | +def cancel_task(id: UUID) -> DeleteResponse: |
69 | 115 | result = AsyncResult(str(id), app=celery_app) |
70 | 116 |
|
71 | 117 | if result.status in ["SUCCESS", "FAILURE", "REVOKED"]: |
|
0 commit comments