RathodHarish commited on
Commit
bc64286
·
verified ·
1 Parent(s): 48d8761

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_fixed
13
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
14
  def load_whisper_model():
15
  try:
16
- # Whisper for speech-to-text (English-only)
17
  model = pipeline(
18
  "automatic-speech-recognition",
19
  model="openai/whisper-tiny.en",
@@ -29,7 +28,6 @@ def load_whisper_model():
29
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
30
  def load_symptom_model():
31
  try:
32
- # Symptom-2-Disease for health analysis
33
  model = pipeline(
34
  "text-classification",
35
  model="abhirajeshbhai/symptom-2-disease-net",
@@ -40,10 +38,22 @@ def load_symptom_model():
40
  return model
41
  except Exception as e:
42
  print(f"Failed to load Symptom-2-Disease model: {str(e)}")
43
- raise
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  whisper = None
46
  symptom_classifier = None
 
47
 
48
  try:
49
  whisper = load_whisper_model()
@@ -53,7 +63,9 @@ except Exception as e:
53
  try:
54
  symptom_classifier = load_symptom_model()
55
  except Exception as e:
56
- print(f"Symptom-2-Disease model initialization failed after retries: {str(e)}")
 
 
57
 
58
  def compute_file_hash(file_path):
59
  """Compute MD5 hash of a file to check uniqueness."""
@@ -79,7 +91,7 @@ def transcribe_audio(audio_file):
79
  temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav"
80
  sf.write(temp_wav, audio, sr)
81
 
82
- # Transcribe with beam search for accuracy
83
  with torch.no_grad():
84
  result = whisper(temp_wav, generate_kwargs={"num_beams": 5})
85
  transcription = result.get("text", "").strip()
@@ -113,6 +125,9 @@ def analyze_symptoms(text):
113
  if result and isinstance(result, list) and len(result) > 0:
114
  prediction = result[0]["label"]
115
  score = result[0]["score"]
 
 
 
116
  print(f"Health Prediction: {prediction}, Score: {score:.4f}")
117
  return prediction, score
118
  return "No health condition predicted", 0.0
@@ -140,6 +155,13 @@ def analyze_voice(audio_file):
140
  if "Error transcribing" in transcription:
141
  return transcription
142
 
 
 
 
 
 
 
 
143
  # Analyze symptoms
144
  prediction, score = analyze_symptoms(transcription)
145
  if "Error analyzing" in prediction:
@@ -182,7 +204,7 @@ iface = gr.Interface(
182
  inputs=gr.Audio(type="filepath", label="Record or Upload Voice"),
183
  outputs=gr.Textbox(label="Health Assessment Feedback"),
184
  title="Health Voice Analyzer",
185
- description="Record or upload a voice sample describing symptoms for preliminary health assessment. Supports English (transcription), with symptom analysis in English. Use clear audio (WAV, 16kHz) describing symptoms like 'I have a cough.'"
186
  )
187
 
188
  if __name__ == "__main__":
 
13
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
14
  def load_whisper_model():
15
  try:
 
16
  model = pipeline(
17
  "automatic-speech-recognition",
18
  model="openai/whisper-tiny.en",
 
28
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
29
  def load_symptom_model():
30
  try:
 
31
  model = pipeline(
32
  "text-classification",
33
  model="abhirajeshbhai/symptom-2-disease-net",
 
38
  return model
39
  except Exception as e:
40
  print(f"Failed to load Symptom-2-Disease model: {str(e)}")
41
+ # Fallback to a generic model
42
+ try:
43
+ model = pipeline(
44
+ "text-classification",
45
+ model="distilbert-base-uncased",
46
+ device=-1
47
+ )
48
+ print("Fallback to distilbert-base-uncased model.")
49
+ return model
50
+ except Exception as fallback_e:
51
+ print(f"Fallback model failed: {str(fallback_e)}")
52
+ raise
53
 
54
  whisper = None
55
  symptom_classifier = None
56
+ is_fallback_model = False
57
 
58
  try:
59
  whisper = load_whisper_model()
 
63
  try:
64
  symptom_classifier = load_symptom_model()
65
  except Exception as e:
66
+ print(f"Symptom model initialization failed after retries: {str(e)}")
67
+ symptom_classifier = None
68
+ is_fallback_model = True # Track if fallback model is used
69
 
70
  def compute_file_hash(file_path):
71
  """Compute MD5 hash of a file to check uniqueness."""
 
91
  temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav"
92
  sf.write(temp_wav, audio, sr)
93
 
94
+ # Transcribe with beam search
95
  with torch.no_grad():
96
  result = whisper(temp_wav, generate_kwargs={"num_beams": 5})
97
  transcription = result.get("text", "").strip()
 
125
  if result and isinstance(result, list) and len(result) > 0:
126
  prediction = result[0]["label"]
127
  score = result[0]["score"]
128
+ if is_fallback_model:
129
+ print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.")
130
+ prediction = f"{prediction} (using fallback model)"
131
  print(f"Health Prediction: {prediction}, Score: {score:.4f}")
132
  return prediction, score
133
  return "No health condition predicted", 0.0
 
155
  if "Error transcribing" in transcription:
156
  return transcription
157
 
158
+ # Check for medication-related queries
159
+ if "medicine" in transcription.lower() or "treatment" in transcription.lower():
160
+ feedback = "Error: This tool does not provide medication or treatment advice. Please describe symptoms only (e.g., 'I have a fever')."
161
+ feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}"
162
+ feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice."
163
+ return feedback
164
+
165
  # Analyze symptoms
166
  prediction, score = analyze_symptoms(transcription)
167
  if "Error analyzing" in prediction:
 
204
  inputs=gr.Audio(type="filepath", label="Record or Upload Voice"),
205
  outputs=gr.Textbox(label="Health Assessment Feedback"),
206
  title="Health Voice Analyzer",
207
+ description="Record or upload a voice sample describing symptoms (e.g., 'I have a fever') for preliminary health assessment. Supports English only. Use clear audio (WAV, 16kHz). Do not ask for medication or treatment advice."
208
  )
209
 
210
  if __name__ == "__main__":