Skip to content

Commit 2b046a3

Browse files
authored
fix(amd): reset timer for late stt transcript (#5637)
1 parent 706814e commit 2b046a3

3 files changed

Lines changed: 285 additions & 0 deletions

File tree

livekit-agents/livekit/agents/voice/amd/classifier.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __init__(
126126
self._classify_task: asyncio.Task[None] | None = None
127127
self._no_speech_timer: asyncio.TimerHandle | None = None
128128
self._silence_timer: asyncio.TimerHandle | None = None
129+
self._silence_timer_trigger: Literal["short_speech", "long_speech"] | None = None
129130
self._detection_timeout_timer: asyncio.TimerHandle | None = None
130131

131132
self._verdict_result: AMDResult | None = None
@@ -174,6 +175,7 @@ def on_user_speech_started(self) -> None:
174175
if self._silence_timer is not None:
175176
self._silence_timer.cancel()
176177
self._silence_timer = None
178+
self._silence_timer_trigger = None
177179
if self._no_speech_timer is not None:
178180
self._no_speech_timer.cancel()
179181
self._no_speech_timer = None
@@ -193,6 +195,7 @@ def on_user_speech_ended(self, silence_duration: float) -> None:
193195
if self._silence_timer is not None:
194196
self._silence_timer.cancel()
195197
self._silence_timer = None
198+
self._silence_timer_trigger = None
196199
if not self._transcript:
197200
self._silence_timer = asyncio.get_running_loop().call_later(
198201
max(0, self._human_silence_threshold - silence_duration),
@@ -203,6 +206,7 @@ def on_user_speech_ended(self, silence_duration: float) -> None:
203206
speech_duration=speech_duration,
204207
),
205208
)
209+
self._silence_timer_trigger = "short_speech"
206210
else:
207211
self._silence_timer = asyncio.get_running_loop().call_later(
208212
max(0, self._machine_silence_threshold - silence_duration),
@@ -211,6 +215,7 @@ def on_user_speech_ended(self, silence_duration: float) -> None:
211215
speech_duration=speech_duration,
212216
),
213217
)
218+
self._silence_timer_trigger = "long_speech"
214219
return
215220

216221
if self._classify_task is None:
@@ -219,10 +224,12 @@ def on_user_speech_ended(self, silence_duration: float) -> None:
219224
if self._silence_timer is not None:
220225
self._silence_timer.cancel()
221226
self._silence_timer = None
227+
self._silence_timer_trigger = None
222228
self._silence_timer = asyncio.get_running_loop().call_later(
223229
max(0, self._machine_silence_threshold - silence_duration),
224230
functools.partial(self._silence_timer_callback, speech_duration=speech_duration),
225231
)
232+
self._silence_timer_trigger = "long_speech"
226233

227234
def _set_verdict(self, result: AMDResult) -> None:
228235
self._verdict_result = result
@@ -250,6 +257,11 @@ def _silence_timer_callback(
250257
reason: NotGivenOr[str] = NOT_GIVEN,
251258
speech_duration: float | None = None,
252259
) -> None:
260+
if self._silence_timer:
261+
self._silence_timer.cancel()
262+
self._silence_timer = None
263+
self._silence_timer_trigger = None
264+
253265
if is_given(category) and is_given(reason) and self._verdict_result is None:
254266
self._set_verdict(
255267
AMDResult(
@@ -274,6 +286,23 @@ def push_text(self, text: str, source: str = "stt") -> None:
274286
if source != self._source:
275287
return
276288

289+
if self._silence_timer is not None and self._silence_timer_trigger == "short_speech":
290+
self._silence_timer.cancel()
291+
self._silence_timer = None
292+
self._silence_timer_trigger = None
293+
294+
# invariant: trigger == "short_speech" implies on_user_speech_ended ran
295+
assert self._speech_ended_at is not None
296+
remaining = (self._speech_ended_at + self._machine_silence_threshold) - time.time()
297+
self._silence_timer = asyncio.get_running_loop().call_later(
298+
max(0, remaining),
299+
functools.partial(
300+
self._silence_timer_callback,
301+
speech_duration=self.speech_duration,
302+
),
303+
)
304+
self._silence_timer_trigger = "long_speech"
305+
277306
if self._classify_task is None:
278307
self._classify_task = asyncio.create_task(self._classify_user_speech())
279308
if self._no_speech_timer is not None:
@@ -316,6 +345,9 @@ async def postpone_termination(seconds: float) -> str:
316345
self._extension_count += 1
317346
if self._silence_timer is not None:
318347
self._silence_timer.cancel()
348+
self._silence_timer = None
349+
self._silence_timer_trigger = None
350+
319351
loop = asyncio.get_running_loop()
320352

321353
def _on_postpone_elapsed() -> None:
@@ -332,6 +364,7 @@ def _on_postpone_elapsed() -> None:
332364
self._try_emit_result()
333365

334366
self._silence_timer = loop.call_later(clamped, _on_postpone_elapsed)
367+
self._silence_timer_trigger = "long_speech"
335368
return f"waiting {clamped:.1f}s for more audio"
336369

337370
@log_exceptions(logger=logger)
@@ -378,6 +411,7 @@ async def close(self) -> None:
378411
if self._silence_timer is not None:
379412
self._silence_timer.cancel()
380413
self._silence_timer = None
414+
self._silence_timer_trigger = None
381415
if self._detection_timeout_timer is not None:
382416
self._detection_timeout_timer.cancel()
383417
self._detection_timeout_timer = None

makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ unit-tests:
116116
tests/test_tool_search.py \
117117
tests/test_tool_proxy.py \
118118
tests/test_endpointing.py \
119+
tests/test_amd_classifier.py \
119120
tests/test_session_host.py
120121

121122
# ============================================

tests/test_amd_classifier.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""Tests for the AMD classifier silence-timer state machine.
2+
3+
Focuses on the trigger-tagged silence timer logic: pre-baked HUMAN timers for
4+
short greetings can be cancelled and replaced when a transcript arrives, while
5+
long-speech timers (and postpone-termination timers) are left alone.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import asyncio
11+
import time
12+
13+
from livekit.agents.llm import FunctionToolCall
14+
from livekit.agents.voice.amd.classifier import (
15+
AMDCategory,
16+
AMDResult,
17+
_AMDClassifier,
18+
)
19+
20+
from .fake_llm import FakeLLM, FakeLLMResponse
21+
22+
23+
def _make_classifier(
24+
llm: FakeLLM | None = None,
25+
*,
26+
human_speech_threshold: float = 2.5,
27+
human_silence_threshold: float = 0.1,
28+
machine_silence_threshold: float = 0.3,
29+
no_speech_threshold: float = 10.0,
30+
timeout: float = 10.0,
31+
) -> _AMDClassifier:
32+
return _AMDClassifier(
33+
llm or FakeLLM(),
34+
human_speech_threshold=human_speech_threshold,
35+
human_silence_threshold=human_silence_threshold,
36+
machine_silence_threshold=machine_silence_threshold,
37+
no_speech_threshold=no_speech_threshold,
38+
timeout=timeout,
39+
)
40+
41+
42+
class TestAMDClassifier:
43+
"""Tests for ``_AMDClassifier`` silence-timer behaviour."""
44+
45+
async def test_short_greeting_no_transcript_emits_pre_baked_human(self) -> None:
46+
"""Short utterance + no STT text => HUMAN/short_greeting verdict."""
47+
clf = _make_classifier(human_silence_threshold=0.1)
48+
clf.start()
49+
results: list[AMDResult] = []
50+
clf.on("amd_result", results.append)
51+
52+
clf.on_user_speech_started()
53+
await asyncio.sleep(0.05)
54+
clf.on_user_speech_ended(silence_duration=0.0)
55+
assert clf._silence_timer_trigger == "short_speech"
56+
assert clf._silence_timer is not None
57+
58+
await asyncio.sleep(0.2)
59+
60+
assert len(results) == 1
61+
assert results[0].category == AMDCategory.HUMAN
62+
assert results[0].reason == "short_greeting"
63+
assert clf._silence_timer is None
64+
assert clf._silence_timer_trigger is None
65+
assert clf._machine_silence_reached is True
66+
67+
await clf.close()
68+
69+
async def test_push_text_cancels_pre_baked_human_and_flips_trigger(self) -> None:
70+
"""A transcript arriving during the short_speech window must cancel the
71+
pre-baked HUMAN timer and replace it with a long_speech timer anchored at
72+
speech_ended + machine_silence_threshold."""
73+
clf = _make_classifier(human_silence_threshold=0.1, machine_silence_threshold=0.3)
74+
clf.start()
75+
results: list[AMDResult] = []
76+
clf.on("amd_result", results.append)
77+
78+
clf.on_user_speech_started()
79+
await asyncio.sleep(0.05)
80+
clf.on_user_speech_ended(silence_duration=0.0)
81+
assert clf._silence_timer_trigger == "short_speech"
82+
83+
clf.push_text("hello")
84+
assert clf._silence_timer_trigger == "long_speech"
85+
assert clf._silence_timer is not None
86+
87+
# Past the would-be HUMAN deadline (0.1s), well before machine deadline (0.3s).
88+
await asyncio.sleep(0.18)
89+
assert results == [], "pre-baked HUMAN must not fire after a transcript arrives"
90+
assert clf._machine_silence_reached is False
91+
92+
# Past the machine_silence deadline.
93+
await asyncio.sleep(0.2)
94+
assert clf._machine_silence_reached is True
95+
# No verdict was provided by the (empty) FakeLLM, so nothing emits yet.
96+
assert results == []
97+
98+
await clf.close()
99+
100+
async def test_push_text_replacement_timer_preserves_original_deadline(self) -> None:
101+
"""The replacement timer fires near speech_ended + machine_silence_threshold,
102+
not push_text + machine_silence_threshold."""
103+
clf = _make_classifier(human_silence_threshold=0.05, machine_silence_threshold=0.3)
104+
clf.start()
105+
106+
clf.on_user_speech_started()
107+
await asyncio.sleep(0.05)
108+
clf.on_user_speech_ended(silence_duration=0.0)
109+
t_end = clf._speech_ended_at
110+
assert t_end is not None
111+
112+
push_delay = 0.04 # under human_silence_threshold so trigger is still short_speech
113+
await asyncio.sleep(push_delay)
114+
clf.push_text("hello")
115+
assert clf._silence_timer_trigger == "long_speech"
116+
117+
expected_fire = t_end + 0.3
118+
deadline = expected_fire + 0.3
119+
while not clf._machine_silence_reached and time.time() < deadline:
120+
await asyncio.sleep(0.01)
121+
122+
fired_at = time.time()
123+
assert clf._machine_silence_reached
124+
# Allow generous slack for event-loop jitter; the key assertion is that the
125+
# fire time is ~0.3s after t_end, not ~0.34s (which would mean we
126+
# re-armed for a full machine_silence_threshold from push_text).
127+
assert fired_at - t_end < 0.3 + 0.15, (
128+
f"timer fired at {fired_at - t_end:.3f}s after t_end; "
129+
f"expected ~0.30s, never ~0.34s+ (push_text-anchored)"
130+
)
131+
132+
await clf.close()
133+
134+
async def test_long_speech_push_text_does_not_replace_timer(self) -> None:
135+
"""During the long_speech timer, push_text must leave the existing timer
136+
handle intact so the original 1.5s machine deadline is not extended."""
137+
clf = _make_classifier(
138+
human_speech_threshold=0.1,
139+
machine_silence_threshold=0.3,
140+
)
141+
clf.start()
142+
143+
clf.on_user_speech_started()
144+
await asyncio.sleep(0.15)
145+
clf.on_user_speech_ended(silence_duration=0.0)
146+
assert clf._silence_timer_trigger == "long_speech"
147+
handle_before = clf._silence_timer
148+
assert handle_before is not None
149+
150+
clf.push_text("hello world")
151+
assert clf._silence_timer_trigger == "long_speech"
152+
assert clf._silence_timer is handle_before
153+
154+
await clf.close()
155+
156+
async def test_short_greeting_with_existing_transcript_uses_long_speech_trigger(
157+
self,
158+
) -> None:
159+
"""If a transcript is already present when speech ends (push_text before
160+
on_user_speech_ended), the short branch picks the long_speech trigger."""
161+
clf = _make_classifier(human_silence_threshold=0.1, machine_silence_threshold=0.3)
162+
clf.start()
163+
164+
clf.on_user_speech_started()
165+
await asyncio.sleep(0.05)
166+
clf.push_text("hi")
167+
clf.on_user_speech_ended(silence_duration=0.0)
168+
assert clf._silence_timer_trigger == "long_speech"
169+
handle_before = clf._silence_timer
170+
assert handle_before is not None
171+
172+
# A second transcript while in the long_speech window must not replace the timer.
173+
clf.push_text("there")
174+
assert clf._silence_timer is handle_before
175+
assert clf._silence_timer_trigger == "long_speech"
176+
177+
await clf.close()
178+
179+
async def test_on_user_speech_started_clears_trigger(self) -> None:
180+
"""on_user_speech_started cancels the silence timer and nulls the trigger."""
181+
clf = _make_classifier(human_silence_threshold=1.0)
182+
clf.start()
183+
184+
clf.on_user_speech_started()
185+
await asyncio.sleep(0.05)
186+
clf.on_user_speech_ended(silence_duration=0.0)
187+
assert clf._silence_timer is not None
188+
assert clf._silence_timer_trigger == "short_speech"
189+
190+
clf.on_user_speech_started()
191+
assert clf._silence_timer is None
192+
assert clf._silence_timer_trigger is None
193+
194+
await clf.close()
195+
196+
async def test_silence_callback_clears_trigger_on_fire(self) -> None:
197+
"""When the silence timer fires, both handle and trigger are nulled out."""
198+
clf = _make_classifier(human_silence_threshold=0.05)
199+
clf.start()
200+
201+
clf.on_user_speech_started()
202+
await asyncio.sleep(0.02)
203+
clf.on_user_speech_ended(silence_duration=0.0)
204+
assert clf._silence_timer_trigger == "short_speech"
205+
206+
await asyncio.sleep(0.12)
207+
208+
assert clf._silence_timer is None
209+
assert clf._silence_timer_trigger is None
210+
211+
await clf.close()
212+
213+
async def test_short_greeting_transcript_emits_llm_verdict(self) -> None:
214+
"""End-to-end: short greeting + transcript => LLM verdict emits at the
215+
machine_silence deadline (gated on both verdict and machine_silence_reached)."""
216+
llm = FakeLLM(
217+
fake_responses=[
218+
FakeLLMResponse(
219+
input="hello",
220+
content="",
221+
ttft=0.0,
222+
duration=0.05,
223+
tool_calls=[
224+
FunctionToolCall(
225+
name="save_prediction",
226+
arguments='{"label": "human"}',
227+
call_id="c1",
228+
)
229+
],
230+
)
231+
]
232+
)
233+
clf = _make_classifier(llm=llm, human_silence_threshold=0.1, machine_silence_threshold=0.3)
234+
clf.start()
235+
results: list[AMDResult] = []
236+
clf.on("amd_result", results.append)
237+
238+
clf.on_user_speech_started()
239+
await asyncio.sleep(0.05)
240+
clf.on_user_speech_ended(silence_duration=0.0)
241+
clf.push_text("hello")
242+
243+
await asyncio.wait_for(clf._verdict_ready.wait(), timeout=2.0)
244+
245+
assert len(results) == 1
246+
assert results[0].category == AMDCategory.HUMAN
247+
assert results[0].reason == "llm"
248+
assert results[0].transcript == "hello"
249+
250+
await clf.close()

0 commit comments

Comments
 (0)