Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import spaces | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
from datasets import load_dataset | |
from openvoice.api import ToneColorConverter | |
from openvoice import se_extractor | |
from melo.api import TTS | |
import pyaudio | |
import wave | |
import numpy as np | |
# Load ASR model and processor | |
torch_dtype = torch.float16 | |
asr_model_id = "openai/whisper-large-v3" | |
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) | |
asr_processor = AutoProcessor.from_pretrained(asr_model_id) | |
asr_pipeline = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
tokenizer=asr_processor.tokenizer, | |
feature_extractor=asr_processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=30, | |
batch_size=16, | |
return_timestamps=True, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
# Load text-to-text model and tokenizer | |
text_model_id = "meta-llama/Meta-Llama-3-8B" | |
text_model = AutoModelForSeq2SeqLM.from_pretrained(text_model_id) | |
text_tokenizer = AutoTokenizer.from_pretrained(text_model_id) | |
# Load TTS model and vocoder | |
tts_converter_ckpt = 'checkpoints_v2/converter' | |
tts_output_dir = 'outputs_v2' | |
os.makedirs(tts_output_dir, exist_ok=True) | |
tts_converter = ToneColorConverter(f'{tts_converter_ckpt}/config.json') | |
tts_converter.load_ckpt(f'{tts_converter_ckpt}/checkpoint.pth') | |
reference_speaker = 'resources/example_reference.mp3' # This is the voice you want to clone | |
target_se, _ = se_extractor.get_se(reference_speaker, tts_converter, vad=False) | |
def process_audio(input_audio): | |
# Perform ASR | |
asr_result = asr_pipeline(input_audio)["text"] | |
# Perform text-to-text processing | |
input_ids = text_tokenizer(asr_result, return_tensors="pt").input_ids.to(device) | |
generated_ids = text_model.generate(input_ids, max_length=512) | |
response_text = text_tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
# Perform TTS | |
tts_model = TTS(language='EN', device=device) | |
speaker_id = list(tts_model.hps.data.spk2id.values())[0] | |
tts_model.tts_to_file(response_text, speaker_id, f'{tts_output_dir}/tmp.wav') | |
save_path = f'{tts_output_dir}/output_v2.wav' | |
source_se = torch.load(f'checkpoints_v2/base_speakers/ses/english-american.pth', map_location=device) | |
tts_converter.convert(audio_src_path=f'{tts_output_dir}/tmp.wav', src_se=source_se, tgt_se=target_se, output_path=save_path, message="@MyShell") | |
return save_path | |
# Real-time audio processing | |
def real_time_audio_processing(): | |
p = pyaudio.PyAudio() | |
stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024) | |
frames = [] | |
print("Listening...") | |
while True: | |
data = stream.read(1024) | |
frames.append(data) | |
audio_data = np.frombuffer(data, dtype=np.int16) | |
if np.max(audio_data) > 3000: # Simple VAD threshold | |
wf = wave.open("input_audio.wav", 'wb') | |
wf.setnchannels(1) | |
wf.setsampwidth(p.get_sample_size(pyaudio.paInt16)) | |
wf.setframerate(16000) | |
wf.writeframes(b''.join(frames)) | |
wf.close() | |
return "input_audio.wav" | |
# Gradio Interface | |
def main(): | |
input_audio_path = real_time_audio_processing() | |
if input_audio_path: | |
output_audio_path = process_audio(input_audio_path) | |
return output_audio_path | |
iface = gr.Interface( | |
fn=main, | |
inputs=None, | |
outputs=gr.Audio(type="filepath"), | |
live=True | |
) | |
iface.launch() | |