Spaces:
Build error
Build error
File size: 5,354 Bytes
14cda64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import torchaudio
import gradio as gr
import pyaudio
import wave
import numpy as np
from transformers import WhisperForCTC, WhisperProcessor, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import OpenVoiceV2Processor, OpenVoiceV2
# Load ASR model and processor
processor_asr = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model_asr = WhisperForCTC.from_pretrained("openai/whisper-large-v3")
# Load text-to-text model and tokenizer
text_model = AutoModelForSeq2SeqLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
# Load TTS model
tts_processor = OpenVoiceV2Processor.from_pretrained("myshell-ai/OpenVoiceV2")
tts_model = OpenVoiceV2.from_pretrained("myshell-ai/OpenVoiceV2")
@spaces.GPU()
# ASR function
def transcribe(audio):
waveform, sample_rate = torchaudio.load(audio)
inputs = processor_asr(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model_asr(inputs.input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor_asr.batch_decode(predicted_ids)
return transcription[0]
@spaces.GPU(duration=300)
# Text-to-text function
def generate_response(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = text_model.generate(**inputs)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
@spaces.GPU(duration=300)
# TTS function
def synthesize_speech(text):
inputs = tts_processor(text, return_tensors="pt")
with torch.no_grad():
mel_outputs, mel_outputs_postnet, _, alignments = tts_model.inference(inputs.input_ids)
audio = tts_model.infer(mel_outputs_postnet)
return audio
@spaces.GPU(duration=300)
# Real-time processing function
def real_time_pipeline():
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024)
wake_word = "hello mate"
wake_word_detected = False
print("Listening for wake word...")
try:
while True:
frames = []
for _ in range(0, int(16000 / 1024 * 2)): # 2 seconds of audio
data = stream.read(1024)
frames.append(data)
audio_data = np.frombuffer(b''.join(frames), dtype=np.int16)
# Save the audio to a temporary file for ASR
wf = wave.open("temp.wav", 'wb')
wf.setnchannels(1)
wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
wf.setframerate(16000)
wf.writeframes(b''.join(frames))
wf.close()
# Step 1: Transcribe audio to text
transcription = transcribe("temp.wav").lower()
if wake_word in transcription:
wake_word_detected = True
print("Wake word detected. Processing audio...")
while wake_word_detected:
frames = []
for _ in range(0, int(16000 / 1024 * 2)): # 2 seconds of audio
data = stream.read(1024)
frames.append(data)
audio_data = np.frombuffer(b''.join(frames), dtype=np.int16)
# Save the audio to a temporary file for ASR
wf = wave.open("temp.wav", 'wb')
wf.setnchannels(1)
wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
wf.setframerate(16000)
wf.writeframes(b''.join(frames))
wf.close()
# Step 1: Transcribe audio to text
transcription = transcribe("temp.wav")
# Step 2: Generate response using text-to-text model
response = generate_response(transcription)
# Step 3: Synthesize speech from text
synthesized_audio = synthesize_speech(response)
# Save the synthesized audio to a temporary file
output_path = "output.wav"
torchaudio.save(output_path, synthesized_audio.squeeze(1), 22050)
# Play the synthesized audio
wf = wave.open(output_path, 'rb')
stream_out = p.open(format=p.get_format_from_width(wf.getsampwidth()),
channels=wf.getnchannels(),
rate=wf.getframerate(),
output=True)
data = wf.readframes(1024)
while data:
stream_out.write(data)
data = wf.readframes(1024)
stream_out.stop_stream()
stream_out.close()
wf.close()
except KeyboardInterrupt:
print("Stopping...")
finally:
stream.stop_stream()
stream.close()
p.terminate()
# Gradio interface
gr_interface = gr.Interface(
fn=real_time_pipeline,
inputs=None,
outputs=None,
live=True,
title="Real-Time Audio-to-Audio Model",
description="ASR + Text-to-Text Model + TTS with Human-like Voice and Emotions"
)
iface.launch(inline=False)
|