Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import time
from collections.abc import Callable
from typing import Generic, TypeVar

T = TypeVar("T")


class PeriodicCollector(Generic[T]):
def __init__(self, callback: Callable[[T], None], *, duration: float) -> None:
"""
Create a new periodic collector that accumulates values and calls the callback
after the specified duration if there are values to report.

Args:
duration: Time in seconds between callback invocations
callback: Function to call with accumulated value when duration expires
"""
self._duration = duration
self._callback = callback
self._last_flush_time = time.monotonic()
self._total: T | None = None

def push(self, value: T) -> None:
"""Add a value to the accumulator"""
if self._total is None:
self._total = value
else:
self._total += value # type: ignore
if time.monotonic() - self._last_flush_time >= self._duration:
self.flush()

def flush(self) -> None:
"""Force callback to be called with current total if non-zero"""
if self._total is not None:
self._callback(self._total)
self._total = None
self._last_flush_time = time.monotonic()
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from livekit.agents.utils import AudioBuffer, http_context, is_given
from livekit.agents.voice.io import TimedString

from ._utils import PeriodicCollector
from .log import logger
from .models import STTRealtimeSampleRates

Expand Down Expand Up @@ -313,6 +314,10 @@ def __init__(
self._session = http_session
self._reconnect_event = asyncio.Event()
self._speaking = False # Track if we're currently in a speech segment
self._audio_duration_collector = PeriodicCollector(
callback=self._on_audio_duration_report,
duration=5.0,
)

def update_options(
self,
Expand All @@ -323,6 +328,14 @@ def update_options(
self._opts.server_vad = server_vad
self._reconnect_event.set()

def _on_audio_duration_report(self, duration: float) -> None:
usage_event = stt.SpeechEvent(
type=stt.SpeechEventType.RECOGNITION_USAGE,
alternatives=[],
recognition_usage=stt.RecognitionUsage(audio_duration=duration),
)
self._event_ch.send_nowait(usage_event)

async def _run(self) -> None:
"""Run the streaming transcription session"""
closing_ws = False
Expand All @@ -347,15 +360,18 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
samples_per_channel=samples_50ms,
)

has_ended = False
async for data in self._input_ch:
# Write audio bytes to buffer and get 50ms frames
frames: list[rtc.AudioFrame] = []
if isinstance(data, rtc.AudioFrame):
frames.extend(audio_bstream.write(data.data.tobytes()))
elif isinstance(data, self._FlushSentinel):
frames.extend(audio_bstream.flush())
has_ended = True

for frame in frames:
self._audio_duration_collector.push(frame.duration)
audio_b64 = base64.b64encode(frame.data.tobytes()).decode("utf-8")
await ws.send_str(
json.dumps(
Expand All @@ -368,6 +384,10 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
)
)

if has_ended:
self._audio_duration_collector.flush()
has_ended = False

closing_ws = True

@utils.log_exceptions(logger=logger)
Expand Down