Skip to content

Commit fc83be5

Browse files
committed
add VibeVoice-Realtime
1 parent e81395c commit fc83be5

39 files changed

+8190
-6
lines changed

Figures/VibeVoice_Realtime.png

121 KB
Loading

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<div align="center">
22

3-
## 🎙️ VibeVoice: A Frontier Open-Source Voice AI
3+
## 🎙️ VibeVoice: Frontier Open-Source Voice AI
44
[![Project Page](https://img.shields.io/badge/Project-Page-blue?logo=microsoft)](https://microsoft.github.io/VibeVoice)
55
[![Hugging Face](https://img.shields.io/badge/HuggingFace-Collection-orange?logo=huggingface)](https://huggingface.co/collections/microsoft/vibevoice-68a2ef24a875c44be47b034f)
66
[![Technical Report](https://img.shields.io/badge/Technical-Report-red?logo=adobeacrobatreader)](https://arxiv.org/pdf/2508.19205)
@@ -23,11 +23,12 @@
2323
<img src="https://img.shields.io/badge/Status-New-brightgreen?style=flat" alt="New" />
2424
<img src="https://img.shields.io/badge/Feature-Realtime_TTS-blue?style=flat&logo=soundcharts" alt="Realtime TTS" />
2525

26-
<strong>2025-12-03: 📣 We open-sourced <strong>VibeVoice‑Realtime‑0.5B</strong>, a real‑time text‑to‑speech model that supports streaming text input.</strong>
26+
<strong>2025-12-03: 📣 We open-sourced <a href="docs/vibevoice-realtime-0.5b.md"><strong>VibeVoice‑Realtime‑0.5B</strong></a>, a real‑time text‑to‑speech model that supports streaming text input and robust long-form speech generation.</strong>
2727
<br>
28-
<a href="https://github.com/user-attachments/assets/c4fb9be1-e721-41c7-9260-5890b49c1a19" target="_blank">▶️ Watch demo video</a>
29-
&nbsp;&nbsp;
30-
<a href="https://github.com/user-attachments/assets/9aa8ab3c-681d-4a02-b9ea-3f54ffd180b2" target="_blank">🎧 Listen to generated example</a>
28+
29+
https://github.com/user-attachments/assets/0901d274-f6ae-46ef-a0fd-3c4fba4f76dc
30+
31+
> (Launch your own realtime demo via the websocket example in [Usage](docs/vibevoice-realtime-0.5b.md#usage-1-launch-real-time-websocket-demo)).
3132
3233
</div>
3334

@@ -41,7 +42,7 @@ VibeVoice is a novel framework designed for generating **expressive**, **long-fo
4142
VibeVoice currently includes two model variants:
4243

4344
- **Long-form multi-speaker model**: Synthesizes conversational/single-speaker speech up to **90 minutes** with up to **4 distinct speakers**, surpassing the typical 1–2 speaker limits of many prior models.
44-
- **Realtime streaming TTS model**: Produces initial audible speech in ~**300 ms** and supports **streaming text input** for single-speaker **realtime** speech generation; designed for low-latency generation.
45+
- **[Realtime streaming TTS model](docs/vibevoice-realtime-0.5b.md)**: Produces initial audible speech in ~**300 ms** and supports **streaming text input** for single-speaker **real-time** speech generation; designed for low-latency generation.
4546

4647
A core innovation of VibeVoice is its use of continuous speech tokenizers (Acoustic and Semantic) operating at an ultra-low frame rate of 7.5 Hz. These tokenizers efficiently preserve audio fidelity while significantly boosting computational efficiency for processing long sequences. VibeVoice employs a [next-token diffusion](https://arxiv.org/abs/2412.08635) framework, leveraging a Large Language Model (LLM) to understand textual context and dialogue flow, and a diffusion head to generate high-fidelity acoustic details.
4748

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
import argparse
2+
import os
3+
import re
4+
import traceback
5+
from typing import List, Tuple, Union, Dict, Any
6+
import time
7+
import torch
8+
import copy
9+
10+
from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference
11+
from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor
12+
from transformers.utils import logging
13+
14+
logging.set_verbosity_info()
15+
logger = logging.get_logger(__name__)
16+
17+
18+
class VoiceMapper:
19+
"""Maps speaker names to voice file paths"""
20+
21+
def __init__(self):
22+
self.setup_voice_presets()
23+
24+
# change name according to our preset voice file
25+
new_dict = {}
26+
for name, path in self.voice_presets.items():
27+
28+
if '_' in name:
29+
name = name.split('_')[0]
30+
31+
if '-' in name:
32+
name = name.split('-')[-1]
33+
34+
new_dict[name] = path
35+
self.voice_presets.update(new_dict)
36+
# print(list(self.voice_presets.keys()))
37+
38+
def setup_voice_presets(self):
39+
"""Setup voice presets by scanning the voices directory."""
40+
voices_dir = os.path.join(os.path.dirname(__file__), "voices/streaming_model")
41+
42+
# Check if voices directory exists
43+
if not os.path.exists(voices_dir):
44+
print(f"Warning: Voices directory not found at {voices_dir}")
45+
self.voice_presets = {}
46+
self.available_voices = {}
47+
return
48+
49+
# Scan for all VOICE files in the voices directory
50+
self.voice_presets = {}
51+
52+
# Get all .pt files in the voices directory
53+
pt_files = [f for f in os.listdir(voices_dir)
54+
if f.lower().endswith('.pt') and os.path.isfile(os.path.join(voices_dir, f))]
55+
56+
# Create dictionary with filename (without extension) as key
57+
for pt_file in pt_files:
58+
# Remove .pt extension to get the name
59+
name = os.path.splitext(pt_file)[0]
60+
# Create full path
61+
full_path = os.path.join(voices_dir, pt_file)
62+
self.voice_presets[name] = full_path
63+
64+
# Sort the voice presets alphabetically by name for better UI
65+
self.voice_presets = dict(sorted(self.voice_presets.items()))
66+
67+
# Filter out voices that don't exist (this is now redundant but kept for safety)
68+
self.available_voices = {
69+
name: path for name, path in self.voice_presets.items()
70+
if os.path.exists(path)
71+
}
72+
73+
print(f"Found {len(self.available_voices)} voice files in {voices_dir}")
74+
print(f"Available voices: {', '.join(self.available_voices.keys())}")
75+
76+
def get_voice_path(self, speaker_name: str) -> str:
77+
"""Get voice file path for a given speaker name"""
78+
# First try exact match
79+
if speaker_name in self.voice_presets:
80+
return self.voice_presets[speaker_name]
81+
82+
# Try partial matching (case insensitive)
83+
speaker_lower = speaker_name.lower()
84+
for preset_name, path in self.voice_presets.items():
85+
if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
86+
return path
87+
88+
# Default to first voice if no match found
89+
default_voice = list(self.voice_presets.values())[0]
90+
print(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}")
91+
return default_voice
92+
93+
94+
def parse_args():
95+
parser = argparse.ArgumentParser(description="VibeVoiceStreaming Processor TXT Input Test")
96+
parser.add_argument(
97+
"--model_path",
98+
type=str,
99+
default="microsoft/VibeVoice-Realtime-0.5B",
100+
help="Path to the HuggingFace model directory",
101+
)
102+
parser.add_argument(
103+
"--txt_path",
104+
type=str,
105+
default="demo/text_examples/1p_vibevoice.txt",
106+
help="Path to the txt file containing the script",
107+
)
108+
parser.add_argument(
109+
"--speaker_name",
110+
type=str,
111+
default="Wayne",
112+
help="Single speaker name (e.g., --speaker_name Wayne)",
113+
)
114+
parser.add_argument(
115+
"--output_dir",
116+
type=str,
117+
default="./outputs",
118+
help="Directory to save output audio files",
119+
)
120+
parser.add_argument(
121+
"--device",
122+
type=str,
123+
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
124+
help="Device for inference: cuda | mps | cpu",
125+
)
126+
parser.add_argument(
127+
"--cfg_scale",
128+
type=float,
129+
default=1.5,
130+
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)",
131+
)
132+
133+
return parser.parse_args()
134+
135+
def main():
136+
args = parse_args()
137+
138+
# Normalize potential 'mpx' typo to 'mps'
139+
if args.device.lower() == "mpx":
140+
print("Note: device 'mpx' detected, treating it as 'mps'.")
141+
args.device = "mps"
142+
143+
# Validate mps availability if requested
144+
if args.device == "mps" and not torch.backends.mps.is_available():
145+
print("Warning: MPS not available. Falling back to CPU.")
146+
args.device = "cpu"
147+
148+
print(f"Using device: {args.device}")
149+
150+
# Initialize voice mapper
151+
voice_mapper = VoiceMapper()
152+
153+
# Check if txt file exists
154+
if not os.path.exists(args.txt_path):
155+
print(f"Error: txt file not found: {args.txt_path}")
156+
return
157+
158+
# Read and parse txt file
159+
print(f"Reading script from: {args.txt_path}")
160+
with open(args.txt_path, 'r', encoding='utf-8') as f:
161+
scripts = f.read().strip()
162+
163+
if not scripts:
164+
print("Error: No valid scripts found in the txt file")
165+
return
166+
167+
full_script = scripts.replace("’", "'").replace('“', '"').replace('”', '"')
168+
169+
print(f"Loading processor & model from {args.model_path}")
170+
processor = VibeVoiceStreamingProcessor.from_pretrained(args.model_path)
171+
172+
# Decide dtype & attention implementation
173+
if args.device == "mps":
174+
load_dtype = torch.float32 # MPS requires float32
175+
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
176+
elif args.device == "cuda":
177+
load_dtype = torch.bfloat16
178+
attn_impl_primary = "flash_attention_2"
179+
else: # cpu
180+
load_dtype = torch.float32
181+
attn_impl_primary = "sdpa"
182+
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
183+
# Load model with device-specific logic
184+
try:
185+
if args.device == "mps":
186+
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
187+
args.model_path,
188+
torch_dtype=load_dtype,
189+
attn_implementation=attn_impl_primary,
190+
device_map=None, # load then move
191+
)
192+
model.to("mps")
193+
elif args.device == "cuda":
194+
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
195+
args.model_path,
196+
torch_dtype=load_dtype,
197+
device_map="cuda",
198+
attn_implementation=attn_impl_primary,
199+
)
200+
else: # cpu
201+
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
202+
args.model_path,
203+
torch_dtype=load_dtype,
204+
device_map="cpu",
205+
attn_implementation=attn_impl_primary,
206+
)
207+
except Exception as e:
208+
if attn_impl_primary == 'flash_attention_2':
209+
print(f"[ERROR] : {type(e).__name__}: {e}")
210+
print(traceback.format_exc())
211+
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
212+
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
213+
args.model_path,
214+
torch_dtype=load_dtype,
215+
device_map=(args.device if args.device in ("cuda", "cpu") else None),
216+
attn_implementation='sdpa'
217+
)
218+
if args.device == "mps":
219+
model.to("mps")
220+
else:
221+
raise e
222+
223+
224+
model.eval()
225+
model.set_ddpm_inference_steps(num_steps=5)
226+
227+
if hasattr(model.model, 'language_model'):
228+
print(f"Language model attention: {model.model.language_model.config._attn_implementation}")
229+
230+
target_device = args.device if args.device != "cpu" else "cpu"
231+
voice_sample = voice_mapper.get_voice_path(args.speaker_name)
232+
all_prefilled_outputs = torch.load(voice_sample, map_location=target_device, weights_only=False)
233+
234+
# Prepare inputs for the model
235+
inputs = processor.process_input_with_cached_prompt(
236+
text=full_script,
237+
cached_prompt=all_prefilled_outputs,
238+
padding=True,
239+
return_tensors="pt",
240+
return_attention_mask=True,
241+
)
242+
243+
# Move tensors to target device
244+
for k, v in inputs.items():
245+
if torch.is_tensor(v):
246+
inputs[k] = v.to(target_device)
247+
248+
print(f"Starting generation with cfg_scale: {args.cfg_scale}")
249+
250+
# Generate audio
251+
start_time = time.time()
252+
outputs = model.generate(
253+
**inputs,
254+
max_new_tokens=None,
255+
cfg_scale=args.cfg_scale,
256+
tokenizer=processor.tokenizer,
257+
generation_config={'do_sample': False},
258+
verbose=True,
259+
all_prefilled_outputs=copy.deepcopy(all_prefilled_outputs) if all_prefilled_outputs is not None else None,
260+
)
261+
generation_time = time.time() - start_time
262+
print(f"Generation time: {generation_time:.2f} seconds")
263+
264+
# Calculate audio duration and additional metrics
265+
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
266+
# Assuming 24kHz sample rate (common for speech synthesis)
267+
sample_rate = 24000
268+
audio_samples = outputs.speech_outputs[0].shape[-1] if len(outputs.speech_outputs[0].shape) > 0 else len(outputs.speech_outputs[0])
269+
audio_duration = audio_samples / sample_rate
270+
rtf = generation_time / audio_duration if audio_duration > 0 else float('inf')
271+
272+
print(f"Generated audio duration: {audio_duration:.2f} seconds")
273+
print(f"RTF (Real Time Factor): {rtf:.2f}x")
274+
else:
275+
print("No audio output generated")
276+
277+
# Calculate token metrics
278+
input_tokens = inputs['tts_text_ids'].shape[1] # Number of input tokens
279+
output_tokens = outputs.sequences.shape[1] # Total tokens (input + generated)
280+
generated_tokens = output_tokens - input_tokens - all_prefilled_outputs['tts_lm']['last_hidden_state'].size(1)
281+
282+
print(f"Prefilling text tokens: {input_tokens}")
283+
print(f"Generated speech tokens: {generated_tokens}")
284+
print(f"Total tokens: {output_tokens}")
285+
286+
# Save output (processor handles device internally)
287+
txt_filename = os.path.splitext(os.path.basename(args.txt_path))[0]
288+
output_path = os.path.join(args.output_dir, f"{txt_filename}_generated.wav")
289+
os.makedirs(args.output_dir, exist_ok=True)
290+
291+
processor.save_audio(
292+
outputs.speech_outputs[0], # First (and only) batch item
293+
output_path=output_path,
294+
)
295+
print(f"Saved output to {output_path}")
296+
297+
# Print summary
298+
print("\n" + "="*50)
299+
print("GENERATION SUMMARY")
300+
print("="*50)
301+
print(f"Input file: {args.txt_path}")
302+
print(f"Output file: {output_path}")
303+
print(f"Speaker names: {args.speaker_name}")
304+
print(f"Prefilling text tokens: {input_tokens}")
305+
print(f"Generated speech tokens: {generated_tokens}")
306+
print(f"Total tokens: {output_tokens}")
307+
print(f"Generation time: {generation_time:.2f} seconds")
308+
print(f"Audio duration: {audio_duration:.2f} seconds")
309+
print(f"RTF (Real Time Factor): {rtf:.2f}x")
310+
311+
print("="*50)
312+
313+
if __name__ == "__main__":
314+
main()

demo/text_examples/1p_abs.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Generating long-form, multi-speaker conversational audio like podcasts poses significant challenges for traditional Text-to-Speech (TTS) systems, particularly in scalability, speaker consistency, and natural turn-taking. This report presents VibeVoice, a novel model designed to synthesize long-form speech with multiple speakers by employing the next-token diffusion framework, a unified method for modeling continuous data by autoregressively generating latent vectors via diffusion.
2+
A core component of our approach is the continuous speech tokenizers operating at an ultra-low frame rate of 7.5. This tokenizer effectively preserves audio fidelity while significantly boosting computational efficiency for processing long sequences. This enables VibeVoice to synthesize long-form speech for up to 90 minutes (in a 64K context window length) with up to 4 speakers, capturing the authentic conversational "vibe" and surpassing all known open-source and closed-source dialogue models (for example, Gemini 2.5 Pro Preview TTS). Code and checkpoint are available now.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
VibeVoice is a novel framework designed for generating expressive, long-form, multi-speaker conversational audio, such as podcasts, from text. It addresses significant challenges in traditional Text-to-Speech (TTS) systems, particularly in scalability, speaker consistency, and natural turn-taking. A core innovation of VibeVoice is its use of continuous speech tokenizers operating at an ultra-low frame rate of 7.5 Hz. These tokenizers efficiently preserve audio fidelity while significantly boosting computational efficiency for processing long sequences. VibeVoice employs a next-token diffusion framework, leveraging a Large Language Model to understand textual context and dialogue flow, and a diffusion head to generate high-fidelity acoustic details. The model can synthesize speech up to 90 minutes long with up to 4 distinct speakers, surpassing the typical 1-2 speaker limits of many prior models.

0 commit comments

Comments
 (0)