import gradio as gr import torch import spaces from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer from datasets import load_dataset from openvoice.api import ToneColorConverter from openvoice import se_extractor from melo.api import TTS import pyaudio import wave import numpy as np # Load ASR model and processor torch_dtype = torch.float16 asr_model_id = "openai/whisper-large-v3" asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) asr_processor = AutoProcessor.from_pretrained(asr_model_id) asr_pipeline = pipeline( "automatic-speech-recognition", model=asr_model, tokenizer=asr_processor.tokenizer, feature_extractor=asr_processor.feature_extractor, max_new_tokens=128, chunk_length_s=30, batch_size=16, return_timestamps=True, torch_dtype=torch_dtype, device=device, ) # Load text-to-text model and tokenizer text_model_id = "meta-llama/Meta-Llama-3-8B" text_model = AutoModelForSeq2SeqLM.from_pretrained(text_model_id) text_tokenizer = AutoTokenizer.from_pretrained(text_model_id) # Load TTS model and vocoder tts_converter_ckpt = 'checkpoints_v2/converter' tts_output_dir = 'outputs_v2' os.makedirs(tts_output_dir, exist_ok=True) tts_converter = ToneColorConverter(f'{tts_converter_ckpt}/config.json') tts_converter.load_ckpt(f'{tts_converter_ckpt}/checkpoint.pth') reference_speaker = 'resources/example_reference.mp3' # This is the voice you want to clone target_se, _ = se_extractor.get_se(reference_speaker, tts_converter, vad=False) def process_audio(input_audio): # Perform ASR asr_result = asr_pipeline(input_audio)["text"] # Perform text-to-text processing input_ids = text_tokenizer(asr_result, return_tensors="pt").input_ids.to(device) generated_ids = text_model.generate(input_ids, max_length=512) response_text = text_tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Perform TTS tts_model = TTS(language='EN', device=device) speaker_id = list(tts_model.hps.data.spk2id.values())[0] tts_model.tts_to_file(response_text, speaker_id, f'{tts_output_dir}/tmp.wav') save_path = f'{tts_output_dir}/output_v2.wav' source_se = torch.load(f'checkpoints_v2/base_speakers/ses/english-american.pth', map_location=device) tts_converter.convert(audio_src_path=f'{tts_output_dir}/tmp.wav', src_se=source_se, tgt_se=target_se, output_path=save_path, message="@MyShell") return save_path # Real-time audio processing def real_time_audio_processing(): p = pyaudio.PyAudio() stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024) frames = [] print("Listening...") while True: data = stream.read(1024) frames.append(data) audio_data = np.frombuffer(data, dtype=np.int16) if np.max(audio_data) > 3000: # Simple VAD threshold wf = wave.open("input_audio.wav", 'wb') wf.setnchannels(1) wf.setsampwidth(p.get_sample_size(pyaudio.paInt16)) wf.setframerate(16000) wf.writeframes(b''.join(frames)) wf.close() return "input_audio.wav" # Gradio Interface @spaces.GPU(duration=300) def main(): input_audio_path = real_time_audio_processing() if input_audio_path: output_audio_path = process_audio(input_audio_path) return output_audio_path iface = gr.Interface( fn=main, inputs=None, outputs=gr.Audio(type="filepath"), live=True ) iface.launch()