Devakumar868 commited on
Commit
501663c
·
verified ·
1 Parent(s): 59edf66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -56
app.py CHANGED
@@ -3,84 +3,127 @@ import gradio as gr
3
  import torch
4
  import numpy as np
5
  from transformers import pipeline
6
- from diffusers import DiffusionPipeline
7
  from pyannote.audio import Pipeline as PyannotePipeline
8
  from dia.model import Dia
9
  from dac.utils import load_model as load_dac_model
10
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11
 
12
- #-- Configuration
13
- HF_TOKEN = os.environ["HF_TOKEN"] # Gated model access[2]
14
- device_map = "auto" # Distribute models on 4×L4 GPUs[3]
15
 
16
- #-- 1. Descript Audio Codec (RVQ)
17
- rvq = load_dac_model(tag="latest", model_type="44khz") # RVQ encoder/decoder[4]
 
 
 
18
  rvq.eval()
19
- if torch.cuda.is_available(): rvq = rvq.to("cuda")
 
20
 
21
- #-- 2. Voice Activity Detection via Pyannote
 
22
  vad_pipe = PyannotePipeline.from_pretrained(
23
  "pyannote/voice-activity-detection",
24
  use_auth_token=HF_TOKEN
25
- ) # Proper gated VAD load[2]
26
 
27
- #-- 3. Ultravox ASR+LLM Pipeline
 
28
  ultravox_pipe = pipeline(
29
  model="fixie-ai/ultravox-v0_4",
30
  trust_remote_code=True,
31
  device_map=device_map,
32
  torch_dtype=torch.float16
33
- ) # Custom speech pipeline[2]
34
-
35
- #-- 4. Audio Diffusion Model (Prosody)
36
- diff_pipe = DiffusionPipeline.from_pretrained(
37
- "teticio/audio-diffusion-instrumental-hiphop-256",
38
- torch_dtype=torch.float16
39
- ).to("cuda") # Diffusers-based load[2]
40
-
41
- #-- 5. Dia TTS Model Sharded Across GPUs
42
- dia = Dia.from_pretrained(
43
- "nari-labs/Dia-1.6B",
44
- device_map=device_map,
45
- torch_dtype=torch.float16,
46
- trust_remote_code=True
47
- ) # Auto-sharding in Transformers[2]
48
 
49
- #-- Inference Function
50
- def process_audio(audio):
51
- sr, arr = audio
52
- arr = arr.numpy() if torch.is_tensor(arr) else arr
53
-
54
- # VAD segmentation
55
- _ = vad_pipe({"waveform": torch.tensor(arr).unsqueeze(0), "sample_rate": sr})
56
-
57
- # RVQ encode/decode
58
- x = torch.tensor(arr).unsqueeze(0).to("cuda")
59
- codes = rvq.encode(x)
60
- decoded = rvq.decode(codes).squeeze().cpu().numpy()
61
 
62
- # Ultravox ASR text
63
- ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
64
- text = ultra_out.get("text", "")
65
 
66
- # Diffusion-based prosody enhancement
67
- pros = diff_pipe(raw_audio=decoded)["audios"][0]
68
 
69
- # Dia TTS synthesis
70
- tts = dia.generate(f"[emotion:neutral] {text}")
71
- tts_np = tts.squeeze().cpu().numpy()
72
- tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95 if tts_np.size else tts_np
73
-
74
- return (sr, tts_np), text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- #-- Gradio UI
77
  with gr.Blocks(title="Maya AI 📈") as demo:
78
- gr.Markdown("## Maya-AI: Supernatural Conversational Agent")
79
- audio_in = gr.Audio(source="microphone", type="numpy", label="Your Voice")
80
- send_btn = gr.Button("Send")
81
- audio_out = gr.Audio(label="AI Response")
82
- text_out = gr.Textbox(label="Generated Text")
83
- send_btn.click(process_audio, inputs=audio_in, outputs=[audio_out, text_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
  demo.launch()
 
3
  import torch
4
  import numpy as np
5
  from transformers import pipeline
 
6
  from pyannote.audio import Pipeline as PyannotePipeline
7
  from dia.model import Dia
8
  from dac.utils import load_model as load_dac_model
 
9
 
10
+ # Environment setup
11
+ HF_TOKEN = os.environ["HF_TOKEN"]
12
+ device_map = "auto"
13
 
14
+ print("Loading models...")
15
+
16
+ # 1. Load RVQ Codec
17
+ print("Loading RVQ Codec...")
18
+ rvq = load_dac_model(tag="latest", model_type="44khz")
19
  rvq.eval()
20
+ if torch.cuda.is_available():
21
+ rvq = rvq.to("cuda")
22
 
23
+ # 2. Load VAD Pipeline
24
+ print("Loading VAD...")
25
  vad_pipe = PyannotePipeline.from_pretrained(
26
  "pyannote/voice-activity-detection",
27
  use_auth_token=HF_TOKEN
28
+ )
29
 
30
+ # 3. Load Ultravox Pipeline
31
+ print("Loading Ultravox...")
32
  ultravox_pipe = pipeline(
33
  model="fixie-ai/ultravox-v0_4",
34
  trust_remote_code=True,
35
  device_map=device_map,
36
  torch_dtype=torch.float16
37
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # 4. Skip Audio Diffusion (causing UNet mismatch)
40
+ print("Skipping Audio Diffusion due to compatibility issues...")
41
+ diff_pipe = None
 
 
 
 
 
 
 
 
 
42
 
43
+ # 5. Load Dia TTS (correct method based on current API)
44
+ print("Loading Dia TTS...")
45
+ dia = Dia.from_pretrained("nari-labs/Dia-1.6B")
46
 
47
+ print("All models loaded successfully!")
 
48
 
49
+ def process_audio(audio):
50
+ try:
51
+ if audio is None:
52
+ return None, "No audio input provided"
53
+
54
+ sr, array = audio
55
+
56
+ # Ensure numpy array
57
+ if torch.is_tensor(array):
58
+ array = array.numpy()
59
+
60
+ # VAD processing
61
+ try:
62
+ vad_result = vad_pipe({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
63
+ except Exception as e:
64
+ print(f"VAD processing error: {e}")
65
+
66
+ # RVQ encode/decode
67
+ audio_tensor = torch.tensor(array).unsqueeze(0)
68
+ if torch.cuda.is_available():
69
+ audio_tensor = audio_tensor.to("cuda")
70
+ codes = rvq.encode(audio_tensor)
71
+ decoded = rvq.decode(codes).squeeze().cpu().numpy()
72
+
73
+ # Ultravox ASR + LLM
74
+ ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
75
+ text = ultra_out.get("text", "I understand your audio input.")
76
+
77
+ # Skip diffusion processing due to compatibility issues
78
+ prosody_audio = decoded
79
+
80
+ # Dia TTS generation
81
+ tts_output = dia.generate(f"[emotion:neutral] {text}")
82
+
83
+ # Convert to numpy and normalize
84
+ if torch.is_tensor(tts_output):
85
+ tts_np = tts_output.squeeze().cpu().numpy()
86
+ else:
87
+ tts_np = np.array(tts_output)
88
+
89
+ # Normalize audio output
90
+ if len(tts_np) > 0:
91
+ tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
92
+
93
+ return (sr, tts_np), text
94
+
95
+ except Exception as e:
96
+ print(f"Error in process_audio: {e}")
97
+ return None, f"Processing error: {str(e)}"
98
 
99
+ # Gradio Interface
100
  with gr.Blocks(title="Maya AI 📈") as demo:
101
+ gr.Markdown("# Maya-AI: Supernatural Conversational Agent")
102
+ gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ audio_in = gr.Audio(
107
+ source="microphone",
108
+ type="numpy",
109
+ label="Record Your Voice"
110
+ )
111
+ send_btn = gr.Button("Send", variant="primary")
112
+
113
+ with gr.Column():
114
+ audio_out = gr.Audio(label="AI Response")
115
+ text_out = gr.Textbox(
116
+ label="Generated Text",
117
+ lines=3,
118
+ placeholder="AI response will appear here..."
119
+ )
120
+
121
+ # Event handler
122
+ send_btn.click(
123
+ fn=process_audio,
124
+ inputs=audio_in,
125
+ outputs=[audio_out, text_out]
126
+ )
127
 
128
  if __name__ == "__main__":
129
  demo.launch()