VanguardAI commited on
Commit
cab275d
1 Parent(s): b29b41c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -110
app.py CHANGED
@@ -1,121 +1,106 @@
1
- import torch
2
- import torchaudio
3
  import gradio as gr
4
- import soundfile as sf
 
 
 
 
 
 
 
5
  import wave
6
  import numpy as np
7
- from transformers import WhisperForCTC, WhisperProcessor, AutoModelForSeq2SeqLM, AutoTokenizer
8
- from transformers import OpenVoiceV2Processor, OpenVoiceV2
9
 
10
  # Load ASR model and processor
11
- processor_asr = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
12
- model_asr = WhisperForCTC.from_pretrained("openai/whisper-large-v3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Load text-to-text model and tokenizer
15
- text_model = AutoModelForSeq2SeqLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
16
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
17
-
18
- # Load TTS model
19
- tts_processor = OpenVoiceV2Processor.from_pretrained("myshell-ai/OpenVoiceV2")
20
- tts_model = OpenVoiceV2.from_pretrained("myshell-ai/OpenVoiceV2")
21
-
22
- @spaces.GPU()
23
- # ASR function
24
- def transcribe(audio):
25
- waveform, sample_rate = torchaudio.load(audio)
26
- inputs = processor_asr(waveform, sampling_rate=sample_rate, return_tensors="pt", padding=True)
27
- with torch.no_grad():
28
- logits = model_asr(inputs.input_values).logits
29
- predicted_ids = torch.argmax(logits, dim=-1)
30
- transcription = processor_asr.batch_decode(predicted_ids)
31
- return transcription[0]
32
-
33
- @spaces.GPU()
34
- # Text-to-text function
35
- def generate_response(text):
36
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
37
- outputs = text_model.generate(**inputs)
38
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
- return response
 
 
 
 
40
 
41
- @spaces.GPU()
42
- # TTS function
43
- def synthesize_speech(text):
44
- inputs = tts_processor(text, return_tensors="pt")
45
- with torch.no_grad():
46
- mel_outputs, mel_outputs_postnet, _, alignments = tts_model.inference(inputs.input_ids)
47
- audio = tts_model.infer(mel_outputs_postnet)
48
- return audio
 
 
49
 
50
- @spaces.GPU()
51
- # Real-time processing function
52
- def real_time_pipeline():
53
- # Adjust this part to handle live recording using soundfile and play back using simpleaudio
54
- import simpleaudio as sa
55
- import tempfile
56
- import time
57
-
58
- wake_word = "hello mate"
59
- wake_word_detected = False
60
-
61
- print("Listening for wake word...")
62
-
63
- with tempfile.NamedTemporaryFile(delete=False) as tmp_wav_file:
64
- tmp_wav_path = tmp_wav_file.name
65
-
66
- try:
67
- while True:
68
- # Capture audio here (this is a simplified example, you need actual audio capture logic)
69
- time.sleep(2) # Simulate 2 seconds of audio capture
70
-
71
- # Save the captured audio to the temp file for ASR
72
- data, sample_rate = sf.read(tmp_wav_path)
73
- sf.write(tmp_wav_path, data, sample_rate)
74
-
75
- # Step 1: Transcribe audio to text
76
- transcription = transcribe(tmp_wav_path).lower()
77
-
78
- if wake_word in transcription:
79
- wake_word_detected = True
80
- print("Wake word detected. Processing audio...")
81
-
82
- while wake_word_detected:
83
- # Capture audio here (this is a simplified example, you need actual audio capture logic)
84
- time.sleep(2) # Simulate 2 seconds of audio capture
85
-
86
- # Save the captured audio to the temp file for ASR
87
- data, sample_rate = sf.read(tmp_wav_path)
88
- sf.write(tmp_wav_path, data, sample_rate)
89
-
90
- # Step 1: Transcribe audio to text
91
- transcription = transcribe(tmp_wav_path)
92
-
93
- # Step 2: Generate response using text-to-text model
94
- response = generate_response(transcription)
95
-
96
- # Step 3: Synthesize speech from text
97
- synthesized_audio = synthesize_speech(response)
98
-
99
- # Save the synthesized audio to a temporary file
100
- output_path = "output.wav"
101
- torchaudio.save(output_path, synthesized_audio.squeeze(1), 22050)
102
-
103
- # Play the synthesized audio using simpleaudio
104
- wave_obj = sa.WaveObject.from_wave_file(output_path)
105
- play_obj = wave_obj.play()
106
- play_obj.wait_done()
107
- except KeyboardInterrupt:
108
- print("Stopping...")
109
-
110
- # Gradio interface
111
- gr_interface = gr.Interface(
112
- fn=real_time_pipeline,
113
- inputs=None,
114
- outputs=None,
115
- live=True,
116
- title="Real-Time Audio-to-Audio Model",
117
- description="ASR + Text-to-Text Model + TTS with Human-like Voice and Emotions"
118
  )
119
 
120
-
121
- iface.launch(inline=False)
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import spaces
4
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
5
+ from datasets import load_dataset
6
+ from openvoice.api import ToneColorConverter
7
+ from openvoice import se_extractor
8
+ from melo.api import TTS
9
+ import pyaudio
10
  import wave
11
  import numpy as np
 
 
12
 
13
  # Load ASR model and processor
14
+ torch_dtype = torch.float16
15
+
16
+ asr_model_id = "openai/whisper-large-v3"
17
+ asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(asr_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
18
+ asr_processor = AutoProcessor.from_pretrained(asr_model_id)
19
+
20
+ asr_pipeline = pipeline(
21
+ "automatic-speech-recognition",
22
+ model=asr_model,
23
+ tokenizer=asr_processor.tokenizer,
24
+ feature_extractor=asr_processor.feature_extractor,
25
+ max_new_tokens=128,
26
+ chunk_length_s=30,
27
+ batch_size=16,
28
+ return_timestamps=True,
29
+ torch_dtype=torch_dtype,
30
+ device=device,
31
+ )
32
 
33
  # Load text-to-text model and tokenizer
34
+ text_model_id = "meta-llama/Meta-Llama-3-8B"
35
+ text_model = AutoModelForSeq2SeqLM.from_pretrained(text_model_id)
36
+ text_tokenizer = AutoTokenizer.from_pretrained(text_model_id)
37
+
38
+ # Load TTS model and vocoder
39
+ tts_converter_ckpt = 'checkpoints_v2/converter'
40
+ tts_output_dir = 'outputs_v2'
41
+ os.makedirs(tts_output_dir, exist_ok=True)
42
+
43
+ tts_converter = ToneColorConverter(f'{tts_converter_ckpt}/config.json')
44
+ tts_converter.load_ckpt(f'{tts_converter_ckpt}/checkpoint.pth')
45
+
46
+ reference_speaker = 'resources/example_reference.mp3' # This is the voice you want to clone
47
+ target_se, _ = se_extractor.get_se(reference_speaker, tts_converter, vad=False)
48
+
49
+ def process_audio(input_audio):
50
+ # Perform ASR
51
+ asr_result = asr_pipeline(input_audio)["text"]
52
+
53
+ # Perform text-to-text processing
54
+ input_ids = text_tokenizer(asr_result, return_tensors="pt").input_ids.to(device)
55
+ generated_ids = text_model.generate(input_ids, max_length=512)
56
+ response_text = text_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
57
+
58
+ # Perform TTS
59
+ tts_model = TTS(language='EN', device=device)
60
+ speaker_id = list(tts_model.hps.data.spk2id.values())[0]
61
+ tts_model.tts_to_file(response_text, speaker_id, f'{tts_output_dir}/tmp.wav')
62
+ save_path = f'{tts_output_dir}/output_v2.wav'
63
 
64
+ source_se = torch.load(f'checkpoints_v2/base_speakers/ses/english-american.pth', map_location=device)
65
+ tts_converter.convert(audio_src_path=f'{tts_output_dir}/tmp.wav', src_se=source_se, tgt_se=target_se, output_path=save_path, message="@MyShell")
66
+
67
+ return save_path
68
+
69
+ # Real-time audio processing
70
+
71
+ def real_time_audio_processing():
72
+ p = pyaudio.PyAudio()
73
+ stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=1024)
74
 
75
+ frames = []
76
+ print("Listening...")
77
+
78
+ while True:
79
+ data = stream.read(1024)
80
+ frames.append(data)
81
+ audio_data = np.frombuffer(data, dtype=np.int16)
82
+ if np.max(audio_data) > 3000: # Simple VAD threshold
83
+ wf = wave.open("input_audio.wav", 'wb')
84
+ wf.setnchannels(1)
85
+ wf.setsampwidth(p.get_sample_size(pyaudio.paInt16))
86
+ wf.setframerate(16000)
87
+ wf.writeframes(b''.join(frames))
88
+ wf.close()
89
+ return "input_audio.wav"
90
+
91
+ # Gradio Interface
92
+ @spaces.GPU(duration=300)
93
+ def main():
94
+ input_audio_path = real_time_audio_processing()
95
+ if input_audio_path:
96
+ output_audio_path = process_audio(input_audio_path)
97
+ return output_audio_path
98
+
99
+ iface = gr.Interface(
100
+ fn=main,
101
+ inputs=None,
102
+ outputs=gr.Audio(type="filepath"),
103
+ live=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  )
105
 
106
+ iface.launch()