Futuresony commited on
Commit
326df18
·
verified ·
1 Parent(s): 727f54d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -28
app.py CHANGED
@@ -2,31 +2,34 @@ import gradio as gr
2
  import torch
3
  import torchaudio
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
- # from huggingface_hub import InferenceClient # Removed
6
  from ttsmms import download, TTS
7
  from langdetect import detect
8
- from gradio_client import Client # Added
9
 
 
10
  # Load ASR Model
 
11
  asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
12
  processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
13
  asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
14
 
15
- # Load Text Generation Model - Using Gradio Client
16
- # client = InferenceClient("unsloth/gemma-3-1b-it") # Removed
17
- llm_client = Client("Futuresony/Mr.Events") # Added
18
-
19
- # def format_prompt(user_input): # Removed
20
- # return f"{user_input}" # Removed
21
 
 
22
  # Load TTS Models
 
23
  swahili_dir = download("swh", "./data/swahili")
24
  english_dir = download("eng", "./data/english")
25
 
26
  swahili_tts = TTS(swahili_dir)
27
  english_tts = TTS(english_dir)
28
 
 
29
  # ASR Function
 
30
  def transcribe(audio_file):
31
  speech_array, sample_rate = torchaudio.load(audio_file)
32
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
@@ -38,39 +41,62 @@ def transcribe(audio_file):
38
  transcription = processor.batch_decode(predicted_ids)[0]
39
  return transcription
40
 
41
- # Text Generation Function - Using Gradio Client
 
 
42
  def generate_text(prompt):
43
- # formatted_prompt = format_prompt(prompt) # Removed
44
- # response = client.text_generation(formatted_prompt, max_new_tokens=250, temperature=0.7, top_p=0.95) # Removed
45
- print(f"Generating text for prompt (type: {type(prompt)}): {prompt}") # Debug print
46
- result = llm_client.predict(query=prompt, api_name="/chat") # Added
47
- print(f"Generated text result (type: {type(result)}): {result}") # Debug print
48
- return result.strip() # Modified to return the result from the Gradio Client
 
 
 
 
 
 
 
 
49
 
 
50
  # TTS Function
 
51
  def text_to_speech(text):
52
- print(f"Converting text to speech (type: {type(text)}): {text}") # Debug print
53
  lang = detect(text)
54
  wav_path = "./output.wav"
55
- if lang == "sw":
56
- swahili_tts.synthesis(text, wav_path=wav_path)
57
- else:
58
- english_tts.synthesis(text, wav_path=wav_path)
59
- print(f"TTS output path (type: {type(wav_path)}): {wav_path}") # Debug print
 
 
 
60
  return wav_path
61
 
 
62
  # Combined Processing Function
 
63
  def process_audio(audio):
64
- print(f"Processing audio file (type: {type(audio)}): {audio}") # Debug print
 
65
  transcription = transcribe(audio)
66
- print(f"Transcription result (type: {type(transcription)}): {transcription}") # Debug print
 
67
  generated_text = generate_text(transcription)
68
- print(f"Generated text after function call (type: {type(generated_text)}): {generated_text}") # Debug print
69
- speech = text_to_speech(generated_text)
70
- print(f"Speech output after function call (type: {type(speech)}): {speech}") # Debug print
71
- return transcription, generated_text, speech
 
 
72
 
 
73
  # Gradio Interface
 
74
  with gr.Blocks() as demo:
75
  gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
76
  gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
@@ -88,4 +114,4 @@ with gr.Blocks() as demo:
88
  )
89
 
90
  if __name__ == "__main__":
91
- demo.launch()
 
2
  import torch
3
  import torchaudio
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
5
  from ttsmms import download, TTS
6
  from langdetect import detect
7
+ from gradio_client import Client
8
 
9
+ # =========================
10
  # Load ASR Model
11
+ # =========================
12
  asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
13
  processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
14
  asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
15
 
16
+ # =========================
17
+ # Load Text Generation Model via Gradio Client
18
+ # =========================
19
+ llm_client = Client("Futuresony/Mr.Events")
 
 
20
 
21
+ # =========================
22
  # Load TTS Models
23
+ # =========================
24
  swahili_dir = download("swh", "./data/swahili")
25
  english_dir = download("eng", "./data/english")
26
 
27
  swahili_tts = TTS(swahili_dir)
28
  english_tts = TTS(english_dir)
29
 
30
+ # =========================
31
  # ASR Function
32
+ # =========================
33
  def transcribe(audio_file):
34
  speech_array, sample_rate = torchaudio.load(audio_file)
35
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
 
41
  transcription = processor.batch_decode(predicted_ids)[0]
42
  return transcription
43
 
44
+ # =========================
45
+ # Text Generation Function (Safe)
46
+ # =========================
47
  def generate_text(prompt):
48
+ print(f"[DEBUG] Generating text for prompt: {prompt} (type: {type(prompt)})")
49
+
50
+ result = llm_client.predict(query=prompt, api_name="/chat")
51
+ print(f"[DEBUG] /chat returned: {result} (type: {type(result)})")
52
+
53
+ # Ensure result is always a string
54
+ if not isinstance(result, str):
55
+ try:
56
+ result = " ".join(map(str, result)) if isinstance(result, (list, tuple)) else str(result)
57
+ except Exception as e:
58
+ print(f"[ERROR] Failed to convert result to string: {e}")
59
+ result = "Error: Unable to generate text."
60
+
61
+ return result.strip()
62
 
63
+ # =========================
64
  # TTS Function
65
+ # =========================
66
  def text_to_speech(text):
67
+ print(f"[DEBUG] Converting text to speech: {text} (type: {type(text)})")
68
  lang = detect(text)
69
  wav_path = "./output.wav"
70
+ try:
71
+ if lang == "sw":
72
+ swahili_tts.synthesis(text, wav_path=wav_path)
73
+ else:
74
+ english_tts.synthesis(text, wav_path=wav_path)
75
+ except Exception as e:
76
+ print(f"[ERROR] TTS synthesis failed: {e}")
77
+ return None
78
  return wav_path
79
 
80
+ # =========================
81
  # Combined Processing Function
82
+ # =========================
83
  def process_audio(audio):
84
+ print(f"[DEBUG] Processing audio: {audio} (type: {type(audio)})")
85
+
86
  transcription = transcribe(audio)
87
+ print(f"[DEBUG] Transcription: {transcription}")
88
+
89
  generated_text = generate_text(transcription)
90
+ print(f"[DEBUG] Generated Text: {generated_text}")
91
+
92
+ speech_path = text_to_speech(generated_text)
93
+ print(f"[DEBUG] Speech Path: {speech_path}")
94
+
95
+ return transcription, generated_text, speech_path
96
 
97
+ # =========================
98
  # Gradio Interface
99
+ # =========================
100
  with gr.Blocks() as demo:
101
  gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
102
  gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
 
114
  )
115
 
116
  if __name__ == "__main__":
117
+ demo.launch()