RealTime / app.py
VanguardAI's picture
Update app.py
ec99632 verified
raw
history blame
3.19 kB
import os
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import LlamaForCausalLM, LlamaTokenizer
from datasets import load_dataset
from openvoice import se_extractor
from openvoice.api import BaseSpeakerTTS, ToneColorConverter
import gradio as gr
import spaces
# Device setup
torch_dtype = torch.float16
# Whisper setup
whisper_model_id = "openai/whisper-large-v3"
whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
whisper_model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True,
)
whisper_processor = AutoProcessor.from_pretrained(whisper_model_id)
whisper_pipe = pipeline(
"automatic-speech-recognition",
model=whisper_model,
tokenizer=whisper_processor.tokenizer,
feature_extractor=whisper_processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=True,
torch_dtype=torch_dtype,
)
# LLaMa3-8B setup
llama_model_id = "meta-llama/Meta-Llama-3-8B"
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_id)
llama_model = LlamaForCausalLM.from_pretrained(llama_model_id, torch_dtype=torch_dtype)
# OpenVoiceV2 setup
ckpt_base = 'checkpoints/base_speakers/EN'
ckpt_converter = 'checkpoints/converter'
output_dir = 'outputs'
base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json',)
base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth')
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json',)
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
os.makedirs(output_dir, exist_ok=True)
source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device)
def process_audio(input_audio):
# ASR with Whisper
whisper_result = whisper_pipe(input_audio)["text"]
# Text generation with LLaMa
inputs = llama_tokenizer(whisper_result, return_tensors="pt").to(device)
outputs = llama_model.generate(**inputs, max_new_tokens=50)
generated_text = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
# TTS with OpenVoiceV2
reference_speaker = 'resources/example_reference.mp3'
target_se, _ = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
save_path = f'{output_dir}/output_en_default.wav'
src_path = f'{output_dir}/tmp.wav'
base_speaker_tts.tts(generated_text, src_path, speaker='default', language='English', speed=1.0)
tone_color_converter.convert(
audio_src_path=src_path,
src_se=source_se,
tgt_se=target_se,
output_path=save_path,
message="@MyShell"
)
return save_path
@spaces.GPU()
def real_time_processing(input_audio):
return process_audio(input_audio)
# Gradio interface
iface = gr.Interface(
fn=real_time_processing,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs=gr.Audio(type="file"),
live=True,
title="ASR + Text-to-Text + TTS with Whisper, LLaMa3-8B, and OpenVoiceV2",
description="Real-time processing using Whisper for ASR, LLaMa3-8B for text generation, and OpenVoiceV2 for TTS."
)
if __name__ == "__main__":
iface.launch()