Skip to content

Commit f3fd58e

Browse files
authored
Put store init in the right place of tracer (#321)
1 parent b3cb5e1 commit f3fd58e

File tree

8 files changed

+82
-34
lines changed

8 files changed

+82
-34
lines changed

agentlightning/runner/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def init_worker(self, worker_id: int, store: LightningStore, **kwargs: Any) -> N
138138
self._store = store
139139
self.worker_id = worker_id
140140

141-
self._tracer.init_worker(worker_id)
141+
self._tracer.init_worker(worker_id, store)
142142

143143
def teardown(self, *args: Any, **kwargs: Any) -> None:
144144
"""Teardown the runner and clean up all resources.
@@ -469,7 +469,7 @@ async def _step_impl(self, next_rollout: AttemptedRollout, raise_on_exception: b
469469

470470
start_time = time.time()
471471
async with self._tracer.trace_context(
472-
name=rollout_id, store=store, rollout_id=rollout_id, attempt_id=next_rollout.attempt.attempt_id
472+
name=rollout_id, rollout_id=rollout_id, attempt_id=next_rollout.attempt.attempt_id
473473
):
474474
await self._trigger_hooks(
475475
hook_type="on_trace_start", agent=agent, runner=self, tracer=self._tracer, rollout=next_rollout

agentlightning/tracer/agentops.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import logging
66
import os
7+
import warnings
78
from contextlib import asynccontextmanager, contextmanager
89
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterator, List, Optional
910

@@ -114,6 +115,14 @@ async def trace_context(
114115
Yields:
115116
The OpenTelemetry tracer instance to collect spans.
116117
"""
118+
if store is not None:
119+
warnings.warn(
120+
"store is deprecated in favor of init_worker(). It will be removed in the future.",
121+
DeprecationWarning,
122+
stacklevel=3,
123+
)
124+
else:
125+
store = self._store
117126
with self._trace_context_sync(name=name, store=store, rollout_id=rollout_id, attempt_id=attempt_id) as tracer:
118127
yield tracer
119128

@@ -164,9 +173,10 @@ def _agentops_trace_context(self, rollout_id: Optional[str], attempt_id: Optiona
164173
try:
165174
yield
166175
except Exception as e:
167-
# TODO: I'm not sure whether this will catch errors in user code.
176+
# This will catch errors in user code.
168177
status = StatusCode.ERROR # type: ignore
169178
logger.error(f"Trace failed for rollout_id={rollout_id}, attempt_id={attempt_id}: {e}")
179+
raise # should reraise the error here so that runner can handle it
170180
finally:
171181
agentops.end_trace(trace, end_state=status) # type: ignore
172182

agentlightning/tracer/base.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ class Tracer(ParallelWorkerBase):
5252
```
5353
"""
5454

55+
_store: Optional[LightningStore] = None
56+
57+
def init_worker(self, worker_id: int, store: Optional[LightningStore] = None) -> None:
58+
"""Initialize the tracer for a worker.
59+
60+
Args:
61+
worker_id: The ID of the worker.
62+
store: The store to add the spans to. If it's provided, traces will be added to the store when tracing.
63+
"""
64+
super().init_worker(worker_id)
65+
self._store = store
66+
5567
def trace_context(
5668
self,
5769
name: Optional[str] = None,
@@ -68,11 +80,9 @@ def trace_context(
6880
within the `with` block are collected and made available via
6981
[`get_last_trace`][agentlightning.Tracer.get_last_trace].
7082
71-
If a store is provided, the spans will be added to the store when tracing.
72-
7383
Args:
7484
name: The name for the root span of this trace context.
75-
store: The store to add the spans to.
85+
store: The store to add the spans to. Deprecated in favor of passing store to init_worker().
7686
rollout_id: The rollout ID to add the spans to.
7787
attempt_id: The attempt ID to add the spans to.
7888
"""
@@ -82,7 +92,6 @@ def _trace_context_sync(
8292
self,
8393
name: Optional[str] = None,
8494
*,
85-
store: Optional[LightningStore] = None,
8695
rollout_id: Optional[str] = None,
8796
attempt_id: Optional[str] = None,
8897
) -> ContextManager[Any]:
@@ -141,19 +150,22 @@ def get_langchain_handler(self) -> Optional[BaseCallbackHandler]: # type: ignor
141150
return None
142151

143152
@contextmanager
144-
def lifespan(self):
153+
def lifespan(self, store: Optional[LightningStore] = None):
145154
"""A context manager to manage the lifespan of the tracer.
146155
147156
This can be used to set up and tear down any necessary resources
148157
for the tracer, useful for debugging purposes.
158+
159+
Args:
160+
store: The store to add the spans to. If it's provided, traces will be added to the store when tracing.
149161
"""
150162
has_init = False
151163
has_init_worker = False
152164
try:
153165
self.init()
154166
has_init = True
155167

156-
self.init_worker(0)
168+
self.init_worker(0, store)
157169
has_init_worker = True
158170

159171
yield

agentlightning/tracer/http.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
TraceState,
2020
)
2121

22+
from agentlightning.store import LightningStore
23+
2224
from .base import Tracer
2325

2426
logger = logging.getLogger(__name__)
@@ -68,14 +70,15 @@ def __init__(
6870
self.subprocess_mode = subprocess_mode
6971
self.subprocess_timeout = subprocess_timeout
7072

71-
def init_worker(self, worker_id: int) -> None:
73+
def init_worker(self, worker_id: int, store: Optional[LightningStore] = None) -> None:
7274
"""
7375
Initialize the tracer in a worker process.
7476
7577
Args:
7678
worker_id: The ID of the worker process.
79+
store: The store to add the spans to.
7780
"""
78-
super().init_worker(worker_id)
81+
super().init_worker(worker_id, store)
7982
logger.info(f"[Worker {worker_id}] HttpTracer initialized.")
8083

8184
@asynccontextmanager

agentlightning/tracer/otel.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import logging
77
import threading
8+
import warnings
89
from contextlib import asynccontextmanager
910
from typing import Any, AsyncGenerator, Awaitable, List, Optional
1011

@@ -42,8 +43,8 @@ def __init__(self):
4243
self._otlp_span_exporter: Optional[LightningStoreOTLPExporter] = None
4344
self._initialized: bool = False
4445

45-
def init_worker(self, worker_id: int):
46-
super().init_worker(worker_id)
46+
def init_worker(self, worker_id: int, store: Optional[LightningStore] = None):
47+
super().init_worker(worker_id, store)
4748
self._initialize_tracer_provider(worker_id)
4849

4950
def _initialize_tracer_provider(self, worker_id: int):
@@ -92,7 +93,18 @@ async def trace_context(
9293
if not self._lightning_span_processor:
9394
raise RuntimeError("LightningSpanProcessor is not initialized. Call init_worker() first.")
9495

95-
if store is not None and rollout_id is not None and attempt_id is not None:
96+
if store is not None:
97+
warnings.warn(
98+
"store is deprecated in favor of init_worker(). It will be removed in the future.",
99+
DeprecationWarning,
100+
stacklevel=3,
101+
)
102+
else:
103+
store = self._store
104+
105+
if rollout_id is not None and attempt_id is not None:
106+
if store is None:
107+
raise ValueError("store is required to be initialized when rollout_id and attempt_id are provided")
96108
if store.capabilities.get("otlp_traces", False) is True:
97109
logger.debug(f"Tracing to LightningStore rollout_id={rollout_id}, attempt_id={attempt_id}")
98110
self._enable_native_otlp_exporter(store, rollout_id, attempt_id)
@@ -101,12 +113,12 @@ async def trace_context(
101113
ctx = self._lightning_span_processor.with_context(store=store, rollout_id=rollout_id, attempt_id=attempt_id)
102114
with ctx:
103115
yield trace_api.get_tracer(__name__, tracer_provider=self._tracer_provider)
104-
elif store is None and rollout_id is None and attempt_id is None:
116+
elif rollout_id is None and attempt_id is None:
105117
self._disable_native_otlp_exporter()
106118
with self._lightning_span_processor:
107119
yield trace_api.get_tracer(__name__, tracer_provider=self._tracer_provider)
108120
else:
109-
raise ValueError("store, rollout_id, and attempt_id must be either all provided or all None")
121+
raise ValueError("rollout_id and attempt_id must be either all provided or all None")
110122

111123
def get_last_trace(self) -> List[ReadableSpan]:
112124
"""

examples/minimal/write_traces.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def send_traces_via_otel(use_client: bool = False):
3434
store = LightningStoreClient("http://localhost:45993")
3535
rollout = await store.start_rollout(input={"origin": "write_traces_example"})
3636

37-
with tracer.lifespan():
37+
with tracer.lifespan(store):
3838
# Initialize the capture of one single trace for one single rollout
3939
async with tracer.trace_context(
4040
"trace-manual", store=store, rollout_id=rollout.rollout_id, attempt_id=rollout.attempt.attempt_id
@@ -89,10 +89,10 @@ async def send_traces_via_agentops(use_client: bool = False):
8989

9090
# Initialize the tracer lifespan
9191
# One lifespan can contain multiple traces
92-
with tracer.lifespan():
92+
with tracer.lifespan(store):
9393
# Initialize the capture of one single trace for one single rollout
9494
async with tracer.trace_context(
95-
"trace-1", store=store, rollout_id=rollout.rollout_id, attempt_id=rollout.attempt.attempt_id
95+
"trace-1", rollout_id=rollout.rollout_id, attempt_id=rollout.attempt.attempt_id
9696
):
9797
openai_client = AsyncOpenAI()
9898
response = await openai_client.chat.completions.create(

examples/tinker/tests/test_tinker_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def test_tracer():
5656
try:
5757
tracer = AgentOpsTracer()
5858
tracer.init()
59-
tracer.init_worker(0)
59+
tracer.init_worker(worker_id=0, store=store)
6060

6161
# init tracer before llm_proxy to avoid tracer provider being not active.
6262
console.print("Starting LLM proxy...")
@@ -70,7 +70,7 @@ async def test_tracer():
7070
client = openai.OpenAI(base_url="http://localhost:4000/v1", api_key="dummy")
7171

7272
async with tracer.trace_context(
73-
name="test_llm", store=store, rollout_id=rollout.rollout_id, attempt_id=rollout.attempt.attempt_id
73+
name="test_llm", rollout_id=rollout.rollout_id, attempt_id=rollout.attempt.attempt_id
7474
):
7575
response = client.chat.completions.create(
7676
model=model_name,

tests/tracer/test_agentops.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import multiprocessing
4+
import sys
45
from typing import Any, Optional, Union
56

67
import agentops
8+
import pytest
79
from agentops.sdk.core import TraceContext
810
from opentelemetry.trace.status import StatusCode
911

@@ -20,7 +22,8 @@ def _func_without_exception():
2022
pass
2123

2224

23-
def test_trace_error_status_from_instance():
25+
@pytest.mark.parametrize("with_exception", [True, False])
26+
def test_trace_error_status_from_instance(with_exception: bool):
2427
"""
2528
Test that AgentOpsTracer correctly sets trace end state based on execution result.
2629
@@ -30,7 +33,7 @@ def test_trace_error_status_from_instance():
3033
"""
3134

3235
ctx = multiprocessing.get_context("spawn")
33-
proc = ctx.Process(target=_test_trace_error_status_from_instance_imp)
36+
proc = ctx.Process(target=_test_trace_error_status_from_instance_imp, args=(with_exception,))
3437
proc.start()
3538
proc.join(30.0) # On GPU server, the time is around 10 seconds.
3639

@@ -42,13 +45,19 @@ def test_trace_error_status_from_instance():
4245

4346
assert False, "Child process hung. Check test output for details."
4447

45-
assert proc.exitcode == 0, (
46-
f"Child process for test_trace_error_status_from_instance failed with exit code {proc.exitcode}. "
47-
"Check child traceback in test output."
48-
)
48+
if with_exception:
49+
assert proc.exitcode != 0, (
50+
f"Child process for test_trace_error_status_from_instance with exception exited with exit code {proc.exitcode}. "
51+
"Check child traceback in test output."
52+
)
53+
else:
54+
assert proc.exitcode == 0, (
55+
f"Child process for test_trace_error_status_from_instance without exception failed with exit code {proc.exitcode}. "
56+
"Check child traceback in test output."
57+
)
4958

5059

51-
def _test_trace_error_status_from_instance_imp():
60+
def _test_trace_error_status_from_instance_imp(with_exception: bool):
5261
captured_state = {}
5362
old_end_trace = agentops.end_trace
5463

@@ -65,12 +74,14 @@ def custom_end_trace(
6574
tracer.init_worker(0)
6675

6776
try:
68-
tracer.trace_run(_func_with_exception)
69-
assert captured_state["state"] == StatusCode.ERROR
70-
71-
tracer.trace_run(_func_without_exception)
72-
assert captured_state["state"] == StatusCode.OK
73-
77+
if with_exception:
78+
tracer.trace_run(_func_with_exception)
79+
if captured_state["state"] != StatusCode.ERROR:
80+
sys.exit(-1)
81+
else:
82+
tracer.trace_run(_func_without_exception)
83+
if captured_state["state"] != StatusCode.OK:
84+
sys.exit(-1)
7485
finally:
7586
agentops.end_trace = old_end_trace
7687
tracer.teardown_worker(0)

0 commit comments

Comments
 (0)