Futuresony commited on
Commit
326e81a
·
verified ·
1 Parent(s): 9f14aa7

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -337
app.py DELETED
@@ -1,337 +0,0 @@
1
- # app.py
2
- import os
3
- import tempfile
4
- import traceback
5
- from dataclasses import dataclass, field
6
- from typing import Any, List, Tuple, Optional
7
-
8
- import gradio as gr
9
- import numpy as np
10
- import soundfile as sf
11
- import torchaudio
12
- import torch
13
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
14
- from gradio_client import Client
15
- from ttsmms import download, TTS
16
- from langdetect import detect
17
-
18
- # ========================
19
- # CONFIG - update as needed
20
- # ========================
21
- # Local ASR model (change to correct HF repo id or local path)
22
- asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
23
-
24
- # Remote LLM Gradio Space
25
- llm_space = "Futuresony/Mr.Events"
26
- llm_api_name = "/chat"
27
-
28
- # TTS languages
29
- sw_lang_code = "swh" # ttsmms language code for Swahili (adjust if needed)
30
- en_lang_code = "eng"
31
-
32
- # ========================
33
- # LOAD MODELS / CLIENTS
34
- # ========================
35
- print("[INIT] Loading ASR processor & model...")
36
- processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
37
- asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
38
- asr_model.eval()
39
-
40
- print("[INIT] Creating Gradio Client for LLM Space...")
41
- llm_client = Client(llm_space)
42
-
43
- print("[INIT] Downloading TTS models (this may take time)")
44
- swahili_dir = download(sw_lang_code, "./data/swahili")
45
- english_dir = download(en_lang_code, "./data/english")
46
- swahili_tts = TTS(swahili_dir)
47
- english_tts = TTS(english_dir)
48
-
49
- # ========================
50
- # APP STATE
51
- # ========================
52
- @dataclass
53
- class AppState:
54
- conversation: List[dict] = field(default_factory=list)
55
- last_transcription: Optional[str] = None
56
- last_reply: Optional[str] = None
57
- last_wav: Optional[str] = None
58
-
59
- # ========================
60
- # UTIL: Safe LLM call
61
- # ========================
62
- def safe_predict(prompt: str, api_name: str = llm_api_name, timeout: int = 30) -> str:
63
- """
64
- Calls gradio_client.Client.predict() but defends against:
65
- - gradio_client JSON schema parsing errors
66
- - endpoints returning bool/list/tuple/dict
67
- - other exceptions
68
- Always returns a string (never bool or non-iterable).
69
- """
70
- try:
71
- result = llm_client.predict(query=prompt, api_name=api_name)
72
- print(f"[LLM] raw result: {repr(result)} (type={type(result)})")
73
- except Exception as e:
74
- # If gradio_client fails (schema issues etc.), catch and return an error message
75
- print("[LLM] predict() raised an exception:")
76
- traceback.print_exc()
77
- return f"Error: could not contact LLM endpoint ({str(e)})"
78
-
79
- # Convert whatever we got into a string safely
80
- if isinstance(result, str):
81
- return result.strip()
82
- if isinstance(result, (list, tuple)):
83
- try:
84
- return " ".join(map(str, result)).strip()
85
- except Exception:
86
- return str(result)
87
- # For bool/dict/None/other -> stringify
88
- try:
89
- return str(result).strip()
90
- except Exception as e:
91
- print("[LLM] Failed to stringify result:", e)
92
- return "Error: LLM returned an unsupported type."
93
-
94
- # ========================
95
- # ASR (Wav2Vec2) helpers
96
- # ========================
97
- def write_temp_wav_from_gr_numpy(audio_tuple: Tuple[np.ndarray, int]) -> str:
98
- """
99
- Gradio audio (type='numpy') yields (np_array, sample_rate).
100
- np_array shape: (n_samples, n_channels) or (n_samples,)
101
- We'll write to a temporary WAV file using soundfile, and return path.
102
- """
103
- array, sr = audio_tuple
104
- if array is None:
105
- raise ValueError("Empty audio")
106
- # If stereo, convert to mono by averaging channels
107
- if array.ndim == 2:
108
- array = np.mean(array, axis=1)
109
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
110
- tmp_name = tmp.name
111
- tmp.close()
112
- sf.write(tmp_name, array, sr)
113
- return tmp_name
114
-
115
- def transcribe_wav_file(wav_path: str) -> str:
116
- """Load with torchaudio (for resampling if needed), then transcribe."""
117
- waveform, sr = torchaudio.load(wav_path) # waveform: (channels, samples)
118
- # convert to mono
119
- if waveform.shape[0] > 1:
120
- waveform = torch.mean(waveform, dim=0, keepdim=True)
121
- waveform = waveform.squeeze(0).numpy()
122
- # resample if necessary
123
- if sr != 16000:
124
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
125
- waveform = resampler(torch.from_numpy(waveform)).numpy()
126
- inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)
127
- with torch.no_grad():
128
- logits = asr_model(inputs.input_values).logits
129
- predicted_ids = torch.argmax(logits, dim=-1)
130
- transcription = processor.batch_decode(predicted_ids)[0]
131
- return transcription
132
-
133
- # ========================
134
- # TTS helper
135
- # ========================
136
- def synthesize_text_to_wav(text: str) -> Optional[str]:
137
- """Detect language and synthesize to ./output.wav (overwrites each call)."""
138
- if not text:
139
- return None
140
- try:
141
- lang = detect(text)
142
- except Exception:
143
- lang = "en"
144
- wav_path = "./output.wav"
145
- try:
146
- if lang and lang.startswith("sw"):
147
- swahili_tts.synthesis(text, wav_path=wav_path)
148
- else:
149
- english_tts.synthesis(text, wav_path=wav_path)
150
- return wav_path
151
- except Exception as e:
152
- print("[TTS] synthesis failed:", e)
153
- traceback.print_exc()
154
- return None
155
-
156
- # ========================
157
- # GRPC/HTTP flow functions (for Gradio event hooks)
158
- # ========================
159
- def process_audio_start(audio: Tuple[np.ndarray, int], state: AppState):
160
- """
161
- Called when recording starts/stops depending on how you wire events.
162
- We'll transcribe the incoming audio and append the user message to conversation.
163
- Returns updated state and the latest transcription (so UI can show it).
164
- """
165
- try:
166
- if audio is None:
167
- return state, ""
168
- wav = write_temp_wav_from_gr_numpy(audio)
169
- transcription = transcribe_wav_file(wav)
170
- print(f"[ASR] transcription: {transcription!r}")
171
- state.last_transcription = transcription
172
- # append user message for context
173
- state.conversation.append({"role": "user", "content": transcription})
174
- # cleanup temp wav
175
- try:
176
- os.remove(wav)
177
- except Exception:
178
- pass
179
- return state, transcription
180
- except Exception as e:
181
- print("[ASR] error:", e)
182
- traceback.print_exc()
183
- return state, f"Error in transcription: {str(e)}"
184
-
185
- def generate_reply_stop(state: AppState):
186
- """
187
- Called after transcription is present in state (i.e. on stop_recording).
188
- Generates a reply with safe_predict, appends to conversation, synthesizes TTS,
189
- and returns updated state, the chat history (for Chatbot), and the output wav path.
190
- """
191
- try:
192
- # Build messages for the LLM from state.conversation
193
- # (prefix with system prompt for diet calorie assistant as earlier)
194
- system_prompt = (
195
- "In conversation with the user, ask questions to estimate and provide (1) total calories, "
196
- "(2) protein, carbs, and fat in grams, (3) fiber and sugar content. Only ask one question at a time. "
197
- "Be conversational and natural."
198
- )
199
- messages = [ {"role": "system", "content": system_prompt} ] + state.conversation
200
-
201
- # Convert messages to a single text prompt for the remote space, if your remote space expects `query` plain text.
202
- # If your remote space accepts structured messages, adapt accordingly.
203
- # We'll join messages into a single friendly prompt (safe fallback).
204
- prompt_text = ""
205
- for m in messages:
206
- role = m.get("role", "user")
207
- content = m.get("content", "")
208
- prompt_text += f"[{role}] {content}\n"
209
-
210
- reply_text = safe_predict(prompt_text, api_name=llm_api_name)
211
- print("[LLM] reply:", reply_text)
212
-
213
- # Add assistant reply to conversation
214
- state.conversation.append({"role": "assistant", "content": reply_text})
215
- state.last_reply = reply_text
216
-
217
- # Synthesize to wav (TTS)
218
- wav_path = synthesize_text_to_wav(reply_text)
219
- state.last_wav = wav_path
220
-
221
- # Build chatbot history for gr.Chatbot (list of tuples (user, bot) or messages)
222
- # gr.Chatbot expects list of (user_msg, bot_msg) pairs; we'll convert conversation
223
- # into that form:
224
- pairs = []
225
- # collapse conversation into pairs
226
- user_msgs = []
227
- bot_msgs = []
228
- # simple converter: walk conversation and pair each user with next assistant
229
- conv = state.conversation
230
- i = 0
231
- while i < len(conv):
232
- if conv[i]["role"] == "user":
233
- user = conv[i]["content"]
234
- # look ahead for assistant
235
- assistant = ""
236
- if i + 1 < len(conv) and conv[i+1]["role"] == "assistant":
237
- assistant = conv[i+1]["content"]
238
- i += 1
239
- pairs.append((user, assistant))
240
- i += 1
241
-
242
- return state, pairs, wav_path
243
- except Exception as e:
244
- print("[LLM/TTS] error:", e)
245
- traceback.print_exc()
246
- return state, [("error", f"Error generating reply: {str(e)}")], None
247
-
248
- # ========================
249
- # CLIENT-SIDE VAD JS (embedded)
250
- # ========================
251
- custom_js = r"""
252
- async function main() {
253
- // Load ONNX runtime and VAD library dynamically
254
- const script1 = document.createElement("script");
255
- script1.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/ort.js";
256
- document.head.appendChild(script1);
257
-
258
- const script2 = document.createElement("script");
259
- script2.onload = async () => {
260
- console.log("VAD loaded");
261
- var record = document.querySelector('.record-button');
262
- if (record) record.textContent = "Just Start Talking!";
263
- // create MicVAD and auto click the record/stop buttons
264
- try {
265
- const myvad = await vad.MicVAD.new({
266
- onSpeechStart: () => {
267
- var record = document.querySelector('.record-button');
268
- var player = document.querySelector('#streaming-out');
269
- if (record && (!player || player.paused)) {
270
- record.click();
271
- }
272
- },
273
- onSpeechEnd: () => {
274
- var stop = document.querySelector('.stop-button');
275
- if (stop) stop.click();
276
- }
277
- });
278
- myvad.start();
279
- } catch (e) {
280
- console.warn("VAD init failed:", e);
281
- }
282
- };
283
- script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js";
284
- document.head.appendChild(script2);
285
- }
286
- main();
287
- """
288
-
289
- # ========================
290
- # BUILD GRADIO UI
291
- # ========================
292
- with gr.Blocks(js=custom_js, title="ASR → LLM → TTS (Safe)") as demo:
293
- gr.Markdown("## Speak: ASR → LLM → TTS (defensive, production-friendly)")
294
-
295
- state = gr.State(AppState())
296
-
297
- with gr.Row():
298
- input_audio = gr.Audio(
299
- label="🎙 Speak (microphone)",
300
- source="microphone", # <-- Added source argument here
301
- type="numpy",
302
- streaming=False,
303
- show_label=True,
304
- )
305
-
306
- with gr.Row():
307
- transcription_out = gr.Textbox(label="Transcription", interactive=False)
308
- with gr.Row():
309
- chatbot = gr.Chatbot(label="Conversation")
310
- with gr.Row():
311
- output_audio = gr.Audio(label="Assistant speech (TTS)", type="filepath")
312
-
313
- # Wire events:
314
- # When recording starts/stops - process transcription and update UI
315
- input_audio.start_recording(
316
- fn=process_audio_start,
317
- inputs=[input_audio, state],
318
- outputs=[state, transcription_out],
319
- )
320
-
321
- # When recording stops - generate reply and update chatbot + audio output
322
- input_audio.stop_recording(
323
- fn=generate_reply_stop,
324
- inputs=[state],
325
- outputs=[state, chatbot, output_audio],
326
- )
327
-
328
- # Manual trigger button to generate reply if needed
329
- gen_btn = gr.Button("Generate reply (manual)")
330
- gen_btn.click(fn=generate_reply_stop, inputs=[state], outputs=[state, chatbot, output_audio])
331
-
332
- # ========================
333
- # LAUNCH
334
- # ========================
335
- if __name__ == "__main__":
336
- demo.launch()
337
-