Skip to content

Commit dd838dc

Browse files
authored
feature: centralizes Celery task handling in shared helpers and configures result expiration (#70)
* refactor: consolidated task info and in progress logs * refactor: consolidated create task into helpers
1 parent 19d6851 commit dd838dc

9 files changed

Lines changed: 118 additions & 124 deletions

File tree

api/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class Settings(BaseSettings):
2121
signed_url_expiry_seconds: int = 3600 * 1 # 1 hour
2222
task_backlog_limit: int = 100 # Max number of waiting tasks allowed before rejecting new ones
2323
enable_mcp: bool = True
24+
result_expires_days: int = 30 # Number of days to keep task results
2425

2526
@property
2627
def encoded_storage_key(self) -> bytes:

api/common/task_helpers.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, timezone
2-
from typing import Any, Dict
2+
from typing import Any, Dict, Optional
33
from uuid import UUID
44

55
import httpx
@@ -10,11 +10,12 @@
1010
from common.config import settings
1111
from common.logger import logger
1212
from common.redis_manager import redis_manager
13-
from common.schemas import DeleteResponse, TaskStatus
13+
from common.schemas import DeleteResponse, Identity, TaskStatus
14+
from worker import celery_app
1415

1516

1617
@cached(cache=TTLCache(maxsize=128, ttl=5))
17-
def get_task_info(task_id: str) -> Dict[str, Any]:
18+
def _get_task_info(task_id: str) -> Dict[str, Any]:
1819
"""
1920
Fetch task information from Flower API.
2021
So we can use this to provide more detailed task status in the API responses.
@@ -54,18 +55,63 @@ def get_task_info(task_id: str) -> Dict[str, Any]:
5455
return result
5556

5657

57-
def get_queue_position_logs(task_id: str) -> list[str]:
58+
def create_task(task_name: str, task_queue: str, payload: dict, identity: Identity) -> AsyncResult:
5859
"""
59-
Returns a list containing a log string with the task's queue position.
60+
Unified helper to create a task in Celery.
6061
"""
61-
pos_data = redis_manager.get_queue_position(task_id)
62-
if pos_data:
63-
return [f"Queue {pos_data.queue} position: {pos_data.position} / {pos_data.total}"]
62+
try:
63+
return celery_app.send_task(
64+
task_name,
65+
queue=task_queue,
66+
args=[payload],
67+
kwargs=identity.model_dump(),
68+
)
69+
except Exception as e:
70+
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")
71+
72+
73+
def get_task_detailed(id: UUID) -> tuple[AsyncResult, dict, list[str]]:
74+
"""
75+
Fetches the task across current Redis storage (Broker and Result Backend).
76+
Returns (AsyncResult, task_info, initial_logs).
77+
Raises 404 if the task is not in Redis (either never existed or has expired).
78+
"""
79+
80+
def get_queue_position(task_id: str) -> Optional[str]:
81+
"""
82+
Inner helper to check the broker and format the queue position log.
83+
"""
84+
pos_data = redis_manager.get_queue_position(task_id)
85+
if pos_data:
86+
return f"Queue {pos_data.queue} position: {pos_data.position} / {pos_data.total}"
87+
88+
return None
89+
90+
result = AsyncResult(str(id), app=celery_app)
91+
logs = []
92+
93+
# Celery reports waiting tasks as PENDING and also unknown tasks as PENDING.
94+
if result.status == TaskStatus.PENDING:
95+
queue_position = get_queue_position(str(id))
96+
if queue_position is None:
97+
# Truly not found
98+
raise HTTPException(status_code=404, detail="Task not found or has expired")
99+
100+
# Keep the queue position logs to return to the user
101+
logs = [queue_position]
102+
else:
103+
# get the running logs of the task if available
104+
if result.info:
105+
if isinstance(result.info, dict):
106+
logs = result.info.get("logs", [])
107+
108+
# Enrich with Flower metadata if available (metrics, worker info, etc)
109+
task_info = _get_task_info(str(id))
64110

65-
return [f"Task not found"]
111+
return result, task_info, logs
66112

67113

68-
def cancel_task(id: UUID, celery_app) -> DeleteResponse:
114+
def cancel_task(id: UUID) -> DeleteResponse:
69115
result = AsyncResult(str(id), app=celery_app)
70116

71117
if result.status in ["SUCCESS", "FAILURE", "REVOKED"]:

api/images/router.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from uuid import UUID
22

3-
from celery.result import AsyncResult
4-
from fastapi import APIRouter, Depends, HTTPException, Response
3+
from fastapi import APIRouter, Depends
54

65
from common.auth import verify_token
76
from common.schemas import DeleteResponse, Identity
87
from common.storage import signed_url_for_file
9-
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
8+
from common.task_helpers import cancel_task, create_task, get_task_detailed
109
from images.schemas import (
1110
MODEL_META,
1211
ImageCreateResponse,
@@ -16,30 +15,21 @@
1615
ImageWorkerResponse,
1716
generate_model_docs,
1817
)
19-
from worker import celery_app
2018

2119
router = APIRouter(
2220
prefix="/images", tags=["Images"], dependencies=[Depends(verify_token)] # This will apply to all routes
2321
)
2422

2523

2624
@router.post("", response_model=ImageCreateResponse, description=generate_model_docs(), operation_id="images_create")
27-
def create(
28-
image_request: ImageRequest,
29-
response: Response,
30-
identity: Identity = Depends(verify_token),
31-
):
32-
try:
33-
result = celery_app.send_task(
34-
image_request.task_name,
35-
queue=image_request.task_queue,
36-
args=[image_request.model_dump()],
37-
kwargs=identity.model_dump(),
38-
)
39-
response.headers["Location"] = f"/images/{result.id}"
40-
return ImageCreateResponse(id=result.id, status=result.status)
41-
except Exception as e:
42-
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")
25+
def create(image_request: ImageRequest, identity: Identity = Depends(verify_token)):
26+
result = create_task(
27+
image_request.task_name,
28+
image_request.task_queue,
29+
image_request.model_dump(),
30+
identity,
31+
)
32+
return ImageCreateResponse(id=UUID(str(result.id)), status=result.status)
4333

4434

4535
@router.get(
@@ -51,18 +41,8 @@ def models():
5141

5242
@router.get("/{id}", response_model=ImageResponse, operation_id="images_get")
5343
def get(id: UUID):
54-
result = AsyncResult(str(id), app=celery_app)
55-
56-
# Initialize response with common fields
57-
response = ImageResponse(id=id, status=result.status, task_info=get_task_info(str(id)))
58-
59-
# Use the helper to inject queue position into logs if still pending
60-
if result.status == "PENDING":
61-
response.logs = get_queue_position_logs(str(id))
62-
63-
if result.info:
64-
if isinstance(result.info, dict):
65-
response.logs = result.info.get("logs", [])
44+
result, task_info, logs = get_task_detailed(id)
45+
response = ImageResponse(id=id, status=result.status, task_info=task_info, logs=logs)
6646

6747
# Add appropriate fields based on status
6848
if result.successful():
@@ -79,4 +59,4 @@ def get(id: UUID):
7959

8060
@router.delete("/{id}", response_model=DeleteResponse, operation_id="images_delete")
8161
def delete(id: UUID):
82-
return cancel_task(id, celery_app)
62+
return cancel_task(id)

api/texts/router.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from uuid import UUID
22

3-
from celery.result import AsyncResult
4-
from fastapi import APIRouter, Depends, HTTPException, Response
3+
from fastapi import APIRouter, Depends
54

65
from common.auth import verify_token
7-
from common.schemas import DeleteResponse, Identity, TaskStatus
8-
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
6+
from common.schemas import DeleteResponse, Identity
7+
from common.task_helpers import cancel_task, create_task, get_task_detailed
98
from texts.schemas import (
109
MODEL_META,
1110
TextCreateResponse,
@@ -15,24 +14,19 @@
1514
TextWorkerResponse,
1615
generate_model_docs,
1716
)
18-
from worker import celery_app
1917

2018
router = APIRouter(prefix="/texts", tags=["Texts"], dependencies=[Depends(verify_token)])
2119

2220

2321
@router.post("", response_model=TextCreateResponse, operation_id="texts_create", description=generate_model_docs())
24-
def create(text_request: TextRequest, response: Response, identity: Identity = Depends(verify_token)):
25-
try:
26-
result = celery_app.send_task(
27-
text_request.task_name,
28-
queue=text_request.task_queue,
29-
args=[text_request.model_dump()],
30-
kwargs=identity.model_dump(),
31-
)
32-
response.headers["Location"] = f"/texts/{result.id}"
33-
return TextCreateResponse(id=result.id, status=result.status)
34-
except Exception as e:
35-
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")
22+
def create(text_request: TextRequest, identity: Identity = Depends(verify_token)):
23+
result = create_task(
24+
text_request.task_name,
25+
text_request.task_queue,
26+
text_request.model_dump(),
27+
identity,
28+
)
29+
return TextCreateResponse(id=UUID(str(result.id)), status=result.status)
3630

3731

3832
@router.get("/models", response_model=TextModelsResponse, summary="List text models", operation_id="texts_list_models")
@@ -42,19 +36,16 @@ def models():
4236

4337
@router.get("/{id}", response_model=TextResponse, operation_id="texts_get")
4438
def get(id: UUID):
45-
result = AsyncResult(str(id), app=celery_app)
39+
result, task_info, logs = get_task_detailed(id)
4640

4741
# Initialize response with common fields
4842
response = TextResponse(
4943
id=id,
5044
status=result.status,
51-
task_info=get_task_info(str(id)),
45+
task_info=task_info,
46+
logs=logs,
5247
)
5348

54-
# Use the helper to inject queue position into logs if still pending
55-
if result.status == "PENDING":
56-
response.logs = get_queue_position_logs(str(id))
57-
5849
# Add appropriate fields based on status
5950
if result.successful():
6051
result_data = TextWorkerResponse.model_validate(result.result)
@@ -67,4 +58,4 @@ def get(id: UUID):
6758

6859
@router.delete("/{id}", response_model=DeleteResponse, operation_id="texts_delete")
6960
def delete(id: UUID):
70-
return cancel_task(id, celery_app)
61+
return cancel_task(id)

api/videos/router.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from uuid import UUID
22

3-
from celery.result import AsyncResult
4-
from fastapi import APIRouter, Depends, HTTPException, Response
3+
from fastapi import APIRouter, Depends
54

65
from common.auth import verify_token
76
from common.schemas import DeleteResponse, Identity
87
from common.storage import signed_url_for_file
9-
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
8+
from common.task_helpers import cancel_task, create_task, get_task_detailed
109
from videos.schemas import (
1110
MODEL_META,
1211
VideoCreateResponse,
@@ -16,24 +15,19 @@
1615
VideoWorkerResponse,
1716
generate_model_docs,
1817
)
19-
from worker import celery_app
2018

2119
router = APIRouter(prefix="/videos", tags=["Videos"], dependencies=[Depends(verify_token)])
2220

2321

2422
@router.post("", response_model=VideoCreateResponse, operation_id="videos_create", description=generate_model_docs())
25-
def create(video_request: VideoRequest, response: Response, identity: Identity = Depends(verify_token)):
26-
try:
27-
result = celery_app.send_task(
28-
video_request.task_name,
29-
queue=video_request.task_queue,
30-
args=[video_request.model_dump()],
31-
kwargs=identity.model_dump(),
32-
)
33-
response.headers["Location"] = f"/videos/{result.id}"
34-
return VideoCreateResponse(id=result.id, status=result.status)
35-
except Exception as e:
36-
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")
23+
def create(video_request: VideoRequest, identity: Identity = Depends(verify_token)):
24+
result = create_task(
25+
video_request.task_name,
26+
video_request.task_queue,
27+
video_request.model_dump(),
28+
identity,
29+
)
30+
return VideoCreateResponse(id=UUID(str(result.id)), status=result.status)
3731

3832

3933
@router.get(
@@ -45,18 +39,10 @@ def models():
4539

4640
@router.get("/{id}", response_model=VideoResponse, operation_id="videos_get")
4741
def get(id: UUID):
48-
result = AsyncResult(str(id), app=celery_app)
42+
result, task_info, logs = get_task_detailed(id)
4943

5044
# Initialize response with common fields
51-
response = VideoResponse(id=id, status=result.status, task_info=get_task_info(str(id)))
52-
53-
# Use the helper to inject queue position into logs if still pending
54-
if result.status == "PENDING":
55-
response.logs = get_queue_position_logs(str(id))
56-
57-
if result.info:
58-
if isinstance(result.info, dict):
59-
response.logs = result.info.get("logs", [])
45+
response = VideoResponse(id=id, status=result.status, task_info=task_info, logs=logs)
6046

6147
# Add appropriate fields based on status
6248
if result.successful():
@@ -73,4 +59,4 @@ def get(id: UUID):
7359

7460
@router.delete("/{id}", response_model=DeleteResponse, operation_id="videos_delete")
7561
def delete(id: UUID):
76-
return cancel_task(id, celery_app)
62+
return cancel_task(id)

api/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from datetime import timedelta
2+
13
from celery import Celery
24

35
from common.config import settings
@@ -13,3 +15,4 @@
1315
result_backend_always_retry=False, # Do not always retry result backend operations
1416
result_backend_max_retries=2, # Number of retries for result backend operations
1517
)
18+
celery_app.conf.result_expires = timedelta(days=settings.result_expires_days)

0 commit comments

Comments
 (0)