Skip to content

Commit f790f41

Browse files
authored
feat(elevenlabs): report STT audio duration via RECOGNITION_USAGE events (#4953)
1 parent 63ac13e commit f790f41

File tree

2 files changed

+57
-0
lines changed
  • livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs

2 files changed

+57
-0
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import time
2+
from collections.abc import Callable
3+
from typing import Generic, TypeVar
4+
5+
T = TypeVar("T")
6+
7+
8+
class PeriodicCollector(Generic[T]):
9+
def __init__(self, callback: Callable[[T], None], *, duration: float) -> None:
10+
"""
11+
Create a new periodic collector that accumulates values and calls the callback
12+
after the specified duration if there are values to report.
13+
14+
Args:
15+
duration: Time in seconds between callback invocations
16+
callback: Function to call with accumulated value when duration expires
17+
"""
18+
self._duration = duration
19+
self._callback = callback
20+
self._last_flush_time = time.monotonic()
21+
self._total: T | None = None
22+
23+
def push(self, value: T) -> None:
24+
"""Add a value to the accumulator"""
25+
if self._total is None:
26+
self._total = value
27+
else:
28+
self._total += value # type: ignore
29+
if time.monotonic() - self._last_flush_time >= self._duration:
30+
self.flush()
31+
32+
def flush(self) -> None:
33+
"""Force callback to be called with current total if non-zero"""
34+
if self._total is not None:
35+
self._callback(self._total)
36+
self._total = None
37+
self._last_flush_time = time.monotonic()

livekit-plugins/livekit-plugins-elevenlabs/livekit/plugins/elevenlabs/stt.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from livekit.agents.utils import AudioBuffer, http_context, is_given
4141
from livekit.agents.voice.io import TimedString
4242

43+
from ._utils import PeriodicCollector
4344
from .log import logger
4445
from .models import STTRealtimeSampleRates
4546

@@ -327,6 +328,10 @@ def __init__(
327328
self._session = http_session
328329
self._reconnect_event = asyncio.Event()
329330
self._speaking = False # Track if we're currently in a speech segment
331+
self._audio_duration_collector = PeriodicCollector(
332+
callback=self._on_audio_duration_report,
333+
duration=5.0,
334+
)
330335

331336
def update_options(
332337
self,
@@ -337,6 +342,14 @@ def update_options(
337342
self._opts.server_vad = server_vad
338343
self._reconnect_event.set()
339344

345+
def _on_audio_duration_report(self, duration: float) -> None:
346+
usage_event = stt.SpeechEvent(
347+
type=stt.SpeechEventType.RECOGNITION_USAGE,
348+
alternatives=[],
349+
recognition_usage=stt.RecognitionUsage(audio_duration=duration),
350+
)
351+
self._event_ch.send_nowait(usage_event)
352+
340353
async def _run(self) -> None:
341354
"""Run the streaming transcription session"""
342355
closing_ws = False
@@ -361,15 +374,18 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
361374
samples_per_channel=samples_50ms,
362375
)
363376

377+
has_ended = False
364378
async for data in self._input_ch:
365379
# Write audio bytes to buffer and get 50ms frames
366380
frames: list[rtc.AudioFrame] = []
367381
if isinstance(data, rtc.AudioFrame):
368382
frames.extend(audio_bstream.write(data.data.tobytes()))
369383
elif isinstance(data, self._FlushSentinel):
370384
frames.extend(audio_bstream.flush())
385+
has_ended = True
371386

372387
for frame in frames:
388+
self._audio_duration_collector.push(frame.duration)
373389
audio_b64 = base64.b64encode(frame.data.tobytes()).decode("utf-8")
374390
await ws.send_str(
375391
json.dumps(
@@ -382,6 +398,10 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
382398
)
383399
)
384400

401+
if has_ended:
402+
self._audio_duration_collector.flush()
403+
has_ended = False
404+
385405
closing_ws = True
386406

387407
@utils.log_exceptions(logger=logger)

0 commit comments

Comments
 (0)