import os import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from transformers import LlamaForCausalLM, LlamaTokenizer from datasets import load_dataset from openvoice import se_extractor from openvoice.api import BaseSpeakerTTS, ToneColorConverter import gradio as gr import spaces # Device setup torch_dtype = torch.float16 # Whisper setup whisper_model_id = "openai/whisper-large-v3" whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained( whisper_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) whisper_processor = AutoProcessor.from_pretrained(whisper_model_id) whisper_pipe = pipeline( "automatic-speech-recognition", model=whisper_model, tokenizer=whisper_processor.tokenizer, feature_extractor=whisper_processor.feature_extractor, max_new_tokens=128, chunk_length_s=30, batch_size=16, return_timestamps=True, torch_dtype=torch_dtype, device=device, ) # LLaMa3-8B setup llama_model_id = "meta-llama/Meta-Llama-3-8B" llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_id) llama_model = LlamaForCausalLM.from_pretrained(llama_model_id, torch_dtype=torch_dtype) # OpenVoiceV2 setup ckpt_base = 'checkpoints/base_speakers/EN' ckpt_converter = 'checkpoints/converter' output_dir = 'outputs' base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base}/config.json',) base_speaker_tts.load_ckpt(f'{ckpt_base}/checkpoint.pth') tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json',) tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') os.makedirs(output_dir, exist_ok=True) source_se = torch.load(f'{ckpt_base}/en_default_se.pth').to(device) def process_audio(input_audio): # ASR with Whisper whisper_result = whisper_pipe(input_audio)["text"] # Text generation with LLaMa inputs = llama_tokenizer(whisper_result, return_tensors="pt").to(device) outputs = llama_model.generate(**inputs, max_new_tokens=50) generated_text = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) # TTS with OpenVoiceV2 reference_speaker = 'resources/example_reference.mp3' target_se, _ = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) save_path = f'{output_dir}/output_en_default.wav' src_path = f'{output_dir}/tmp.wav' base_speaker_tts.tts(generated_text, src_path, speaker='default', language='English', speed=1.0) tone_color_converter.convert( audio_src_path=src_path, src_se=source_se, tgt_se=target_se, output_path=save_path, message="@MyShell" ) return save_path @spaces.GPU() def real_time_processing(input_audio): return process_audio(input_audio) # Gradio interface iface = gr.Interface( fn=real_time_processing, inputs=gr.Audio(source="microphone", type="filepath"), outputs=gr.Audio(type="file"), live=True, title="ASR + Text-to-Text + TTS with Whisper, LLaMa3-8B, and OpenVoiceV2", description="Real-time processing using Whisper for ASR, LLaMa3-8B for text generation, and OpenVoiceV2 for TTS." ) if __name__ == "__main__": iface.launch()