Spaces:
Build error
Build error
import os | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
from datasets import load_dataset | |
from openvoice import se_extractor | |
from openvoice.api import BaseSpeakerTTS, ToneColorConverter | |
import gradio as gr | |
import spaces | |
# Device setup | |
torch_dtype = torch.float16 | |
# Whisper setup | |
whisper_model_id = "openai/whisper-large-v3" | |
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True, | |
) | |
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id) | |
whisper_pipe = pipeline( | |
"automatic-speech-recognition", | |
model=whisper_model, | |
tokenizer=whisper_processor.tokenizer, | |
feature_extractor=whisper_processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=30, | |
batch_size=16, | |
return_timestamps=True, | |
torch_dtype=torch_dtype, | |
) | |
# LLaMa3-8B setup | |
llama_model_id = "meta-llama/Meta-Llama-3-8B" | |
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_id) | |
llama_model = LlamaForCausalLM.from_pretrained(llama_model_id, torch_dtype=torch_dtype) | |
# OpenVoiceV2 setup | |
ckpt_base = 'checkpoints/base_speakers/EN' | |
ckpt_converter = 'checkpoints/converter' | |
output_dir = 'outputs' | |
base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json',) | |
base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth') | |
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json',) | |
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') | |
os.makedirs(output_dir, exist_ok=True) | |
source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device) | |
def process_audio(input_audio): | |
# ASR with Whisper | |
whisper_result = whisper_pipe(input_audio)["text"] | |
# Text generation with LLaMa | |
inputs = llama_tokenizer(whisper_result, return_tensors="pt").to(device) | |
outputs = llama_model.generate(**inputs, max_new_tokens=50) | |
generated_text = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# TTS with OpenVoiceV2 | |
reference_speaker = 'resources/example_reference.mp3' | |
target_se, _ = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) | |
save_path = f'{output_dir}/output_en_default.wav' | |
src_path = f'{output_dir}/tmp.wav' | |
base_speaker_tts.tts(generated_text, src_path, speaker='default', language='English', speed=1.0) | |
tone_color_converter.convert( | |
audio_src_path=src_path, | |
src_se=source_se, | |
tgt_se=target_se, | |
output_path=save_path, | |
message="@MyShell" | |
) | |
return save_path | |
def real_time_processing(input_audio): | |
return process_audio(input_audio) | |
# Gradio interface | |
iface = gr.Interface( | |
fn=real_time_processing, | |
inputs=gr.Audio(source="microphone", type="filepath"), | |
outputs=gr.Audio(type="file"), | |
live=True, | |
title="ASR + Text-to-Text + TTS with Whisper, LLaMa3-8B, and OpenVoiceV2", | |
description="Real-time processing using Whisper for ASR, LLaMa3-8B for text generation, and OpenVoiceV2 for TTS." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |