Skip to content

Commit 3761c0f

Browse files
authored
Support native advanced queries in LightningStore (#318)
1 parent d433418 commit 3761c0f

File tree

28 files changed

+2251
-605
lines changed

28 files changed

+2251
-605
lines changed

agentlightning/adapter/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3-
from typing import Generic, List, TypeVar
3+
from typing import Generic, Sequence, TypeVar
44

55
from opentelemetry.sdk.trace import ReadableSpan
66

@@ -66,7 +66,7 @@ def adapt(self, source: T_from, /) -> T_to:
6666
raise NotImplementedError("Adapter.adapt() is not implemented")
6767

6868

69-
class OtelTraceAdapter(Adapter[List[ReadableSpan], T_to], Generic[T_to]):
69+
class OtelTraceAdapter(Adapter[Sequence[ReadableSpan], T_to], Generic[T_to]):
7070
"""Base class for adapters that convert OpenTelemetry trace spans into other formats.
7171
7272
This specialization of [`Adapter`][agentlightning.Adapter] expects a list of
@@ -84,7 +84,7 @@ class OtelTraceAdapter(Adapter[List[ReadableSpan], T_to], Generic[T_to]):
8484
"""
8585

8686

87-
class TraceAdapter(Adapter[List[Span], T_to], Generic[T_to]):
87+
class TraceAdapter(Adapter[Sequence[Span], T_to], Generic[T_to]):
8888
"""Base class for adapters that convert trace spans into other formats.
8989
9090
This class specializes [`Adapter`][agentlightning.Adapter] for working with

agentlightning/adapter/messages.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import json
66
from collections import defaultdict
7-
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, TypedDict, Union, cast
7+
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, Sequence, TypedDict, Union, cast
88

99
from pydantic import TypeAdapter
1010

@@ -208,7 +208,7 @@ class TraceToMessages(TraceAdapter[List[OpenAIMessages]]):
208208
children of the associated completion span.
209209
"""
210210

211-
def get_tool_calls(self, completion: Span, all_spans: List[Span], /) -> Iterable[Dict[str, Any]]:
211+
def get_tool_calls(self, completion: Span, all_spans: Sequence[Span], /) -> Iterable[Dict[str, Any]]:
212212
"""Yield tool call payloads for a completion span.
213213
214214
Args:
@@ -231,7 +231,7 @@ def get_tool_calls(self, completion: Span, all_spans: List[Span], /) -> Iterable
231231
if tool_call:
232232
yield tool_call
233233

234-
def adapt(self, source: List[Span], /) -> List[OpenAIMessages]:
234+
def adapt(self, source: Sequence[Span], /) -> List[OpenAIMessages]:
235235
"""Transform trace spans into OpenAI chat payloads.
236236
237237
Args:

agentlightning/adapter/triplet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import re
88
from enum import Enum
9-
from typing import Any, Dict, List, Optional, Tuple, Union, cast
9+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
1010

1111
from opentelemetry.sdk.trace import ReadableSpan
1212
from pydantic import BaseModel
@@ -670,7 +670,7 @@ def visualize(
670670
trace_tree.visualize(filename, interested_span_match=interested_span_match)
671671
return trace_tree
672672

673-
def adapt(self, source: Union[List[Span], List[ReadableSpan]], /) -> List[Triplet]: # type: ignore
673+
def adapt(self, source: Union[Sequence[Span], Sequence[ReadableSpan]], /) -> List[Triplet]: # type: ignore
674674
"""Convert tracer spans into [`Triplet`][agentlightning.Triplet] trajectories.
675675
676676
Args:
@@ -800,7 +800,7 @@ def _request_id_from_attrs(self, attrs: Dict[str, Any]) -> Optional[str]:
800800
rid = attrs.get("gen_ai.response.id") or attrs.get("llm.hosted_vllm.id")
801801
return str(rid) if isinstance(rid, str) and rid else None
802802

803-
def adapt(self, source: List[Span], /) -> List[Triplet]: # type: ignore
803+
def adapt(self, source: Sequence[Span], /) -> List[Triplet]: # type: ignore
804804
"""Convert LLM Proxy spans into [`Triplet`][agentlightning.Triplet] trajectories.
805805
806806
Args:

agentlightning/algorithm/fast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ async def _enqueue_rollouts(
143143
store = self.get_store()
144144

145145
for index in train_indices + val_indices:
146-
queuing_rollouts = await store.query_rollouts(status=["queuing", "requeuing"])
146+
queuing_rollouts = await store.query_rollouts(status_in=["queuing", "requeuing"])
147147
if len(queuing_rollouts) <= 1:
148148
# Only enqueue a new rollout when there is at most 1 rollout in the queue.
149149
sample = dataset[index]
@@ -222,7 +222,7 @@ async def run(
222222
f"Processing index {index}. {len(train_indices)} train indices and {len(val_indices)} val indices in total."
223223
)
224224
while True:
225-
queuing_rollouts = await store.query_rollouts(status=["queuing", "requeuing"])
225+
queuing_rollouts = await store.query_rollouts(status_in=["queuing", "requeuing"])
226226
if len(queuing_rollouts) <= self.max_queue_length:
227227
# Only enqueue a new rollout when there is at most "max_queue_length" rollout in the queue.
228228
sample = concatenated_dataset[index]

agentlightning/store/base.py

Lines changed: 146 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Span,
1919
TaskInput,
2020
Worker,
21+
WorkerStatus,
2122
)
2223

2324

@@ -292,30 +293,77 @@ async def add_otel_span(
292293
raise NotImplementedError()
293294

294295
async def query_rollouts(
295-
self, *, status: Optional[Sequence[RolloutStatus]] = None, rollout_ids: Optional[Sequence[str]] = None
296-
) -> List[Rollout]:
296+
self,
297+
*,
298+
status_in: Optional[Sequence[RolloutStatus]] = None,
299+
rollout_id_in: Optional[Sequence[str]] = None,
300+
rollout_id_contains: Optional[str] = None,
301+
filter_logic: Literal["and", "or"] = "and",
302+
sort_by: Optional[str] = None,
303+
sort_order: Literal["asc", "desc"] = "asc",
304+
limit: int = -1,
305+
offset: int = 0,
306+
# Deprecated fields
307+
status: Optional[Sequence[RolloutStatus]] = None,
308+
rollout_ids: Optional[Sequence[str]] = None,
309+
) -> Sequence[Rollout]:
297310
"""Retrieve rollouts filtered by status and/or explicit identifiers.
298311
312+
This interface supports structured filtering, sorting, and pagination so
313+
callers can build simple dashboards without copying data out of the
314+
store. The legacy parameters `status` and `rollout_ids` remain valid and
315+
are treated as aliases for `status_in` and `rollout_id_in`
316+
respectively—when both the new and deprecated parameters are supplied
317+
the new parameters take precedence.
318+
299319
Args:
300-
status: Optional whitelist of [`RolloutStatus`][agentlightning.RolloutStatus] values.
301-
rollout_ids: Optional whitelist of rollout identifiers to include.
320+
status_in: Optional whitelist of [`RolloutStatus`][agentlightning.RolloutStatus] values.
321+
rollout_id_in: Optional whitelist of rollout identifiers to include.
322+
rollout_id_contains: Optional substring match for rollout identifiers.
323+
filter_logic: Logical operator to combine filters.
324+
sort_by: Optional field to sort by. Must reference a numeric or string
325+
field on [`Rollout`][agentlightning.Rollout].
326+
sort_order: Direction to sort when `sort_by` is provided.
327+
limit: Maximum number of rows to return. Use `-1` for "no limit".
328+
offset: Number of rows to skip before returning results.
329+
status: Deprecated field. Use `status_in` instead.
330+
rollout_ids: Deprecated field. Use `rollout_id_in` instead.
302331
303332
Returns:
304-
A list of matching rollouts. Ordering is backend-defined but must be deterministic.
333+
A sequence of matching rollouts (or [`AttemptedRollout`][agentlightning.AttemptedRollout]
334+
when attempts exist). Ordering is deterministic when `sort_by` is set.
335+
The return value is not guaranteed to be a list.
305336
306337
Raises:
307338
NotImplementedError: Subclasses must implement the query.
308339
"""
309340
raise NotImplementedError()
310341

311-
async def query_attempts(self, rollout_id: str) -> List[Attempt]:
342+
async def query_attempts(
343+
self,
344+
rollout_id: str,
345+
*,
346+
sort_by: Optional[str] = "sequence_id",
347+
sort_order: Literal["asc", "desc"] = "asc",
348+
limit: int = -1,
349+
offset: int = 0,
350+
) -> Sequence[Attempt]:
312351
"""Return every attempt ever created for `rollout_id` in ascending sequence order.
313352
353+
The parameters allow callers to re-order or paginate the attempts so that
354+
large retry histories can be streamed lazily.
355+
314356
Args:
315357
rollout_id: Identifier of the rollout being inspected.
358+
sort_by: Field to sort by. Must be a numeric or string field of
359+
[`Attempt`][agentlightning.Attempt]. Defaults to `sequence_id` (oldest first).
360+
sort_order: Order to sort by.
361+
limit: Limit on the number of results. `-1` for unlimited.
362+
offset: Offset into the results.
316363
317364
Returns:
318-
Attempts sorted by `sequence_id` (oldest first). Returns an empty list when none exist.
365+
Sequence of Attempts. Returns an empty sequence when none exist.
366+
The return value is not guaranteed to be a list.
319367
320368
Raises:
321369
NotImplementedError: Subclasses must implement the query.
@@ -352,11 +400,35 @@ async def get_latest_attempt(self, rollout_id: str) -> Optional[Attempt]:
352400
"""
353401
raise NotImplementedError()
354402

355-
async def query_resources(self) -> List[ResourcesUpdate]:
403+
async def query_resources(
404+
self,
405+
*,
406+
resources_id: Optional[str] = None,
407+
resources_id_contains: Optional[str] = None,
408+
# Filter logic is not supported here because I can't see why it's needed.
409+
sort_by: Optional[str] = None,
410+
sort_order: Literal["asc", "desc"] = "asc",
411+
limit: int = -1,
412+
offset: int = 0,
413+
) -> Sequence[ResourcesUpdate]:
356414
"""List every stored resource snapshot in insertion order.
357415
416+
Supports lightweight filtering, sorting, and pagination for embedding in
417+
dashboards.
418+
419+
Args:
420+
resources_id: Optional identifier of the resources to include.
421+
resources_id_contains: Optional substring match for resources identifiers.
422+
sort_by: Optional field to sort by (must be numeric or string on
423+
[`ResourcesUpdate`][agentlightning.ResourcesUpdate]).
424+
sort_order: Order to sort by.
425+
limit: Limit on the number of results. `-1` for unlimited.
426+
offset: Offset into the results.
427+
358428
Returns:
359-
A chronological list of [`ResourcesUpdate`][agentlightning.ResourcesUpdate] objects.
429+
[`ResourcesUpdate`][agentlightning.ResourcesUpdate] objects.
430+
By default, resources are sorted in a deterministic but undefined order.
431+
The return value is not guaranteed to be a list.
360432
361433
Raises:
362434
NotImplementedError: Subclasses must implement retrieval.
@@ -439,19 +511,61 @@ async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[f
439511
"""
440512
raise NotImplementedError()
441513

442-
async def query_spans(self, rollout_id: str, attempt_id: str | Literal["latest"] | None = None) -> List[Span]:
514+
async def query_spans(
515+
self,
516+
rollout_id: str,
517+
attempt_id: str | Literal["latest"] | None = None,
518+
*,
519+
# Filtering
520+
trace_id: Optional[str] = None,
521+
trace_id_contains: Optional[str] = None,
522+
span_id: Optional[str] = None,
523+
span_id_contains: Optional[str] = None,
524+
parent_id: Optional[str] = None,
525+
parent_id_contains: Optional[str] = None,
526+
name: Optional[str] = None,
527+
name_contains: Optional[str] = None,
528+
filter_logic: Literal["and", "or"] = "and",
529+
# Pagination
530+
limit: int = -1,
531+
offset: int = 0,
532+
# Sorting
533+
sort_by: Optional[str] = "sequence_id",
534+
sort_order: Literal["asc", "desc"] = "asc",
535+
) -> Sequence[Span]:
443536
"""Return the stored spans for a rollout, optionally scoped to one attempt.
444537
445-
Spans must be returned in ascending `sequence_id` order. Implementations may raise
446-
a `RuntimeError` when spans were evicted or expired.
538+
Supports a handful of filters that cover the most common debugging
539+
scenarios (matching `trace_id`/`span_id`/`parent_id` or substring
540+
matches on the span name). `attempt_id="latest"` acts as a convenience
541+
that resolves the most recent attempt before evaluating filters. When
542+
`attempt_id=None`, spans across every attempt are eligible. By default
543+
results are sorted by `sequence_id` (oldest first). Implementations may
544+
raise a `RuntimeError` when spans were evicted or expired.
447545
448546
Args:
449547
rollout_id: Identifier of the rollout being inspected.
450548
attempt_id: Attempt identifier to filter by. Pass `"latest"` to retrieve only the
451549
most recent attempt, or `None` to return all spans across attempts.
550+
trace_id: Optional trace ID to filter by.
551+
trace_id_contains: Optional substring match for trace IDs.
552+
span_id: Optional span ID to filter by.
553+
span_id_contains: Optional substring match for span IDs.
554+
parent_id: Optional parent span ID to filter by.
555+
parent_id_contains: Optional substring match for parent span IDs.
556+
name: Optional span name to filter by.
557+
name_contains: Optional substring match for span names.
558+
filter_logic: Logical operator to combine the optional filters above.
559+
The `rollout_id` argument is always applied with AND semantics.
560+
limit: Limit on the number of results. `-1` for unlimited.
561+
offset: Offset into the results.
562+
sort_by: Field to sort by. Must be a numeric or string field of
563+
[`Span`][agentlightning.Span].
564+
sort_order: Order to sort by.
452565
453566
Returns:
454567
An ordered list of spans (possibly empty).
568+
The return value is not guaranteed to be a list.
455569
456570
Raises:
457571
NotImplementedError: Subclasses must implement the query.
@@ -578,11 +692,29 @@ async def update_attempt(
578692

579693
async def query_workers(
580694
self,
581-
) -> List[Worker]:
695+
*,
696+
status_in: Optional[Sequence[WorkerStatus]] = None,
697+
worker_id_contains: Optional[str] = None,
698+
filter_logic: Literal["and", "or"] = "and",
699+
sort_by: Optional[str] = None,
700+
sort_order: Literal["asc", "desc"] = "asc",
701+
limit: int = -1,
702+
offset: int = 0,
703+
) -> Sequence[Worker]:
582704
"""Query all workers in the system.
583705
706+
Args:
707+
status_in: Optional whitelist of [`WorkerStatus`][agentlightning.WorkerStatus] values.
708+
worker_id_contains: Optional substring match for worker identifiers.
709+
filter_logic: Logical operator to combine the optional filters above.
710+
sort_by: Field to sort by. Must be a numeric or string field of [`Worker`][agentlightning.Worker].
711+
sort_order: Order to sort by.
712+
limit: Limit on the number of results. `-1` for unlimited.
713+
offset: Offset into the results.
714+
584715
Returns:
585-
A list of all workers.
716+
Sequence of Workers. Returns an empty sequence when none exist.
717+
The return value is not guaranteed to be a list.
586718
"""
587719
raise NotImplementedError()
588720

0 commit comments

Comments
 (0)