Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions api/admin/router.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from fastapi import APIRouter, Depends, HTTPException, Query

from common.api_key_manager import key_manager
from common.auth import admin_only
from common.redis_manager import redis_manager

router = APIRouter(prefix="/admin", tags=["Admin"], dependencies=[Depends(admin_only)])


@router.post("/keys", operation_id="keys_create")
def create(name: str = Query(..., min_length=3, max_length=50, pattern=r"^[a-zA-Z0-9 _-]+$")):
try:
token = key_manager.create_key(name)
token = redis_manager.create_key(name)
return {"api_key": token, "name": name}
except ValueError as e:
raise HTTPException(400, str(e))


@router.get("/keys", operation_id="keys_list")
def list():
return key_manager.list_keys()
return redis_manager.list_keys()


@router.delete("/keys/{key_id}", operation_id="keys_delete")
def delete(key_id: str):
if key_manager.delete_key(key_id):
if redis_manager.delete_key(key_id):
return {"deleted": True}

raise HTTPException(404, "Key not found")
16 changes: 9 additions & 7 deletions api/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from fastapi import Depends, HTTPException, Request
from fastapi.security import APIKeyHeader

from common.api_key_manager import key_manager
from common.config import settings
from common.logger import log_request
from common.redis_manager import redis_manager
from common.schemas import Identity


Expand All @@ -32,7 +32,7 @@ async def verify_token(

token = authorization.replace("Bearer ", "")

key_data = key_manager.verify_token(token)
key_data = redis_manager.verify_token(token)
if not key_data:
raise HTTPException(status_code=403, detail="Invalid or revoked token")

Expand All @@ -45,12 +45,14 @@ async def verify_token(
)
await log_request(request, identity)

# Only rate limit POST requests (task creation) so polling doesn't consume quota
# Only limit POST requests (task creation)
if request.method == "POST":
limit = settings.creates_per_minute

if not key_manager.check_rate_limit(key_data.key_id, limit=limit, window=60):
raise HTTPException(status_code=429, detail="Rate limit exceeded")
waiting_tasks = redis_manager.waiting_tasks()
if waiting_tasks >= settings.task_backlog_limit:
raise HTTPException(
status_code=429,
detail=f"Too many waiting tasks {waiting_tasks} / {settings.task_backlog_limit}, please try later",
)

return identity

Expand Down
4 changes: 2 additions & 2 deletions api/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class Settings(BaseSettings):
ddiffusion_storage_address: str = "http://127.0.0.1:5000"
ddiffusion_storage_directory: str = "/STORAGE"
flower_url: str = "http://flower:5555"
signed_url_expiry_seconds: int = 3600 # 1 hour
creates_per_minute: int = 30
signed_url_expiry_seconds: int = 3600 * 1 # 1 hour
task_backlog_limit: int = 100 # Max number of waiting tasks allowed before rejecting new ones
enable_mcp: bool = True

@property
Expand Down
44 changes: 33 additions & 11 deletions api/common/api_key_manager.py → api/common/redis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,36 @@
import hashlib
import hmac
import secrets
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

import redis
from redis import Redis

from common.config import settings
from common.logger import logger
from common.schemas import APIKeyPublic
from common.schemas import APIKeyPublic, QueuePosition

_redis_client = redis.from_url(settings.celery_broker_url, decode_responses=True)


class APIKeyManager:
class RedisManager:
def __init__(self):
self.client: Redis = _redis_client
self.prefix = "DDIFFUSION_API_KEY"
# Register once at startup - see get_queue_position
self._pos_script = self.client.register_script(
"""
local tasks = redis.call('LRANGE', KEYS[1], 0, -1)
local total = #tasks
for i, task in ipairs(tasks) do
if string.find(task, ARGV[1], 1, true) then
-- FIFO correction: The tail of the list is position 1
return {total - i + 1, total}
end
end
return nil
"""
)

def _get_redis_key(self, key_id: str) -> str:
return f"{self.prefix}:{key_id}"
Expand Down Expand Up @@ -114,17 +128,25 @@ def delete_key(self, key_id: str) -> bool:
key = self._get_redis_key(key_id)
return bool(self.client.delete(key))

def check_rate_limit(self, key_id: str, limit: int = 60, window: int = 60) -> bool:
def waiting_tasks(self, queues=["gpu", "cpu", "comfy"]) -> int:
"""
Checks if the key_id has exceeded the rate limit.
Returns the number of waiting tasks
"""
key = f"{self.prefix}_RATE_LIMIT:{key_id}"
current_count = cast(int, self.client.incr(key))

if current_count == 1:
self.client.expire(key, window)
waiting = sum(cast(int, self.client.llen(q)) for q in queues)
return waiting

return current_count <= limit
def get_queue_position(self, task_id: str, queues=["gpu", "cpu", "comfy"]) -> Optional[QueuePosition]:
"""
Finds the 1-based position of a task logic inside Redis using Lua.
This is MUCH faster because it avoids pulling large task payloads (Base64 images)
over the network to the API.
"""
for q in queues:
result = cast(list, self._pos_script(keys=[q], args=[task_id]))
if result:
return QueuePosition(position=result[0], queue=q, total=result[1])
return None


key_manager = APIKeyManager()
redis_manager = RedisManager()
6 changes: 6 additions & 0 deletions api/common/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,9 @@ class APIKeyPublic(BaseModel):
key_id: str
name: str
created_at: str


class QueuePosition(BaseModel):
position: int = Field(description="1-based position in the queue")
queue: str = Field(description="Name of the queue")
total: int = Field(description="Total tasks waiting in this queue")
12 changes: 12 additions & 0 deletions api/common/task_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from common.config import settings
from common.logger import logger
from common.redis_manager import redis_manager
from common.schemas import DeleteResponse, TaskStatus


Expand Down Expand Up @@ -53,6 +54,17 @@ def get_task_info(task_id: str) -> Dict[str, Any]:
return result


def get_queue_position_logs(task_id: str) -> list[str]:
"""
Returns a list containing a log string with the task's queue position.
"""
pos_data = redis_manager.get_queue_position(task_id)
if pos_data:
return [f"Queue {pos_data.queue} position: {pos_data.position} / {pos_data.total}"]

return [f"Task not found"]


def cancel_task(id: UUID, celery_app) -> DeleteResponse:
result = AsyncResult(str(id), app=celery_app)

Expand Down
6 changes: 5 additions & 1 deletion api/images/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from common.auth import verify_token
from common.schemas import DeleteResponse, Identity
from common.storage import signed_url_for_file
from common.task_helpers import cancel_task, get_task_info
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
from images.schemas import (
MODEL_META,
ImageCreateResponse,
Expand Down Expand Up @@ -56,6 +56,10 @@ def get(id: UUID):
# Initialize response with common fields
response = ImageResponse(id=id, status=result.status, task_info=get_task_info(str(id)))

# Use the helper to inject queue position into logs if still pending
if result.status == "PENDING":
response.logs = get_queue_position_logs(str(id))

if result.info:
if isinstance(result.info, dict):
response.logs = result.info.get("logs", [])
Expand Down
59 changes: 30 additions & 29 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,6 @@
fastapi_app = FastAPI(title="API")


@fastapi_app.exception_handler(RequestValidationError)
async def validation_handler(request: Request, exc: RequestValidationError):
cleaned = []
for err in exc.errors():
d = dict(err)
cleaned.append(
{
"loc": d.get("loc", []),
"msg": d.get("msg", ""),
"type": d.get("type", ""),
}
)

logger.warning(f"Validation error on {request.url.path}: {cleaned}")
return JSONResponse(status_code=422, content={"detail": cleaned})


@fastapi_app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={
"message": "Internal server error",
"detail": truncate_strings(str(exc), 1000),
"path": request.url.path,
},
)


fastapi_app.include_router(images.router, prefix="/api")
fastapi_app.include_router(texts.router, prefix="/api")
fastapi_app.include_router(videos.router, prefix="/api")
Expand Down Expand Up @@ -90,6 +61,36 @@ def health():
else:
app = fastapi_app


@app.exception_handler(RequestValidationError)
async def validation_handler(request: Request, exc: RequestValidationError):
cleaned = []
for err in exc.errors():
d = dict(err)
cleaned.append(
{
"loc": d.get("loc", []),
"msg": d.get("msg", ""),
"type": d.get("type", ""),
}
)

logger.warning(f"Validation error on {request.url.path}: {cleaned}")
return JSONResponse(status_code=422, content={"detail": cleaned})


@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={
"message": "Internal server error",
"detail": truncate_strings(str(exc), 1000),
"path": request.url.path,
},
)


app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
14 changes: 7 additions & 7 deletions api/texts/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from common.auth import verify_token
from common.schemas import DeleteResponse, Identity, TaskStatus
from common.task_helpers import cancel_task, get_task_info
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
from texts.schemas import (
MODEL_META,
TextCreateResponse,
Expand Down Expand Up @@ -51,14 +51,14 @@ def get(id: UUID):
task_info=get_task_info(str(id)),
)

# Use the helper to inject queue position into logs if still pending
if result.status == "PENDING":
response.logs = get_queue_position_logs(str(id))

# Add appropriate fields based on status
if result.successful():
try:
result_data = TextWorkerResponse.model_validate(result.result)
response.output = result_data.response
except Exception as e:
response.status = TaskStatus.FAILURE
response.error_message = f"Error parsing result: {str(e)}"
result_data = TextWorkerResponse.model_validate(result.result)
response.output = result_data.response
elif result.failed():
response.error_message = f"Task failed with error: {str(result.result)}"

Expand Down
2 changes: 2 additions & 0 deletions api/texts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class TextResponse(BaseModel):
status: TaskStatus
output: str = ""
error_message: Optional[str] = None
logs: List[str] = []
task_info: dict = Field(default_factory=dict)
model_config = ConfigDict(
json_schema_extra={
Expand All @@ -157,6 +158,7 @@ class TextResponse(BaseModel):
"status": "SUCCESS",
"output": "This is the generated text response from the model.",
"error_message": None,
"logs": ["Processing..."],
}
}
)
Expand Down
6 changes: 5 additions & 1 deletion api/videos/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from common.auth import verify_token
from common.schemas import DeleteResponse, Identity
from common.storage import signed_url_for_file
from common.task_helpers import cancel_task, get_task_info
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
from videos.schemas import (
MODEL_META,
VideoCreateResponse,
Expand Down Expand Up @@ -50,6 +50,10 @@ def get(id: UUID):
# Initialize response with common fields
response = VideoResponse(id=id, status=result.status, task_info=get_task_info(str(id)))

# Use the helper to inject queue position into logs if still pending
if result.status == "PENDING":
response.logs = get_queue_position_logs(str(id))

if result.info:
if isinstance(result.info, dict):
response.logs = result.info.get("logs", [])
Expand Down
6 changes: 5 additions & 1 deletion api/workflows/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from common.auth import verify_token
from common.schemas import DeleteResponse, Identity
from common.storage import signed_url_for_file
from common.task_helpers import cancel_task, get_task_info
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
from worker import celery_app
from workflows.schemas import (
WorkflowCreateResponse,
Expand Down Expand Up @@ -47,6 +47,10 @@ def get(id: UUID):
# Initialize response with common fields
response = WorkflowResponse(id=id, status=result.status, task_info=get_task_info(str(id)))

# Use the helper to inject queue position into logs if still pending
if result.status == "PENDING":
response.logs = get_queue_position_logs(str(id))

if result.info:
if isinstance(result.info, dict):
response.logs = result.info.get("logs", [])
Expand Down
2 changes: 1 addition & 1 deletion clients/openapi.json

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions workers/common/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,9 @@ class APIKeyPublic(BaseModel):
key_id: str
name: str
created_at: str


class QueuePosition(BaseModel):
position: int = Field(description="1-based position in the queue")
queue: str = Field(description="Name of the queue")
total: int = Field(description="Total tasks waiting in this queue")
2 changes: 2 additions & 0 deletions workers/texts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class TextResponse(BaseModel):
status: TaskStatus
output: str = ""
error_message: Optional[str] = None
logs: List[str] = []
task_info: dict = Field(default_factory=dict)
model_config = ConfigDict(
json_schema_extra={
Expand All @@ -157,6 +158,7 @@ class TextResponse(BaseModel):
"status": "SUCCESS",
"output": "This is the generated text response from the model.",
"error_message": None,
"logs": ["Processing..."],
}
}
)
Expand Down
Loading
Loading