File size: 8,682 Bytes
410fd66 a6dea81 410fd66 573cc21 fcdc0cf bc64286 fcdc0cf bc64286 68390a5 573cc21 fcdc0cf 573cc21 fcdc0cf 45a579f 573cc21 fcdc0cf 573cc21 bc64286 a6dea81 573cc21 fcdc0cf a6dea81 fcdc0cf 573cc21 fcdc0cf 573cc21 bc64286 fcdc0cf d2ad93f 573cc21 432d77e dd19451 fcdc0cf a6dea81 573cc21 a6dea81 573cc21 fcdc0cf a6dea81 432d77e fcdc0cf a6dea81 573cc21 bc64286 a6dea81 573cc21 410fd66 432d77e a6dea81 410fd66 a6dea81 410fd66 a6dea81 410fd66 bc64286 a6dea81 410fd66 a6dea81 410fd66 a6dea81 410fd66 a6dea81 410fd66 a6dea81 410fd66 a937006 432d77e a937006 410fd66 bc64286 410fd66 a6dea81 410fd66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import gradio as gr
import librosa
import numpy as np
import os
import hashlib
from datetime import datetime
from transformers import pipeline
import soundfile as sf
import torch
from tenacity import retry, stop_after_attempt, wait_fixed
# Initialize local models with retry logic
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def load_whisper_model():
try:
model = pipeline(
"automatic-speech-recognition",
model="openai/whisper-tiny.en",
device=-1, # CPU; use device=0 for GPU if available
model_kwargs={"use_safetensors": True}
)
print("Whisper model loaded successfully.")
return model
except Exception as e:
print(f"Failed to load Whisper model: {str(e)}")
raise
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def load_symptom_model():
try:
model = pipeline(
"text-classification",
model="abhirajeshbhai/symptom-2-disease-net",
device=-1, # CPU
model_kwargs={"use_safetensors": True}
)
print("Symptom-2-Disease model loaded successfully.")
return model
except Exception as e:
print(f"Failed to load Symptom-2-Disease model: {str(e)}")
# Fallback to a generic model
try:
model = pipeline(
"text-classification",
model="distilbert-base-uncased",
device=-1
)
print("Fallback to distilbert-base-uncased model.")
return model
except Exception as fallback_e:
print(f"Fallback model failed: {str(fallback_e)}")
raise
whisper = None
symptom_classifier = None
is_fallback_model = False
try:
whisper = load_whisper_model()
except Exception as e:
print(f"Whisper model initialization failed after retries: {str(e)}")
try:
symptom_classifier = load_symptom_model()
except Exception as e:
print(f"Symptom model initialization failed after retries: {str(e)}")
symptom_classifier = None
is_fallback_model = True # Track if fallback model is used
def compute_file_hash(file_path):
"""Compute MD5 hash of a file to check uniqueness."""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def transcribe_audio(audio_file):
"""Transcribe audio using local Whisper model."""
if not whisper:
return "Error: Whisper model not loaded. Check logs for details or ensure sufficient compute resources."
try:
# Load and validate audio
audio, sr = librosa.load(audio_file, sr=16000)
if len(audio) < 1600: # Less than 0.1s
return "Error: Audio too short. Please provide audio of at least 1 second."
if np.max(np.abs(audio)) < 1e-4: # Too quiet
return "Error: Audio too quiet. Please provide clear audio describing symptoms in English."
# Save as WAV for Whisper
temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav"
sf.write(temp_wav, audio, sr)
# Transcribe with beam search
with torch.no_grad():
result = whisper(temp_wav, generate_kwargs={"num_beams": 5})
transcription = result.get("text", "").strip()
print(f"Transcription: {transcription}")
# Clean up temp file
try:
os.remove(temp_wav)
except Exception:
pass
if not transcription:
return "Transcription empty. Please provide clear audio describing symptoms in English."
# Check for repetitive transcription
words = transcription.split()
if len(words) > 5 and len(set(words)) < len(words) / 2:
return "Error: Transcription appears repetitive. Please provide clear, non-repetitive audio describing symptoms."
return transcription
except Exception as e:
return f"Error transcribing audio: {str(e)}"
def analyze_symptoms(text):
"""Analyze symptoms using local Symptom-2-Disease model."""
if not symptom_classifier:
return "Error: Symptom-2-Disease model not loaded. Check logs for details or ensure sufficient compute resources.", 0.0
try:
if not text or "Error transcribing" in text:
return "No valid transcription for analysis.", 0.0
with torch.no_grad():
result = symptom_classifier(text)
if result and isinstance(result, list) and len(result) > 0:
prediction = result[0]["label"]
score = result[0]["score"]
if is_fallback_model:
print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.")
prediction = f"{prediction} (using fallback model)"
print(f"Health Prediction: {prediction}, Score: {score:.4f}")
return prediction, score
return "No health condition predicted", 0.0
except Exception as e:
return f"Error analyzing symptoms: {str(e)}", 0.0
def analyze_voice(audio_file):
"""Analyze voice for health indicators."""
try:
# Ensure unique file name to avoid Gradio reuse
unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}"
os.rename(audio_file, unique_path)
audio_file = unique_path
# Log audio file info
file_hash = compute_file_hash(audio_file)
print(f"Processing audio file: {audio_file}, Hash: {file_hash}")
# Load audio to verify format
audio, sr = librosa.load(audio_file, sr=16000)
print(f"Audio shape: {audio.shape}, Sampling rate: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}")
# Transcribe audio
transcription = transcribe_audio(audio_file)
if "Error transcribing" in transcription:
return transcription
# Check for medication-related queries
if "medicine" in transcription.lower() or "treatment" in transcription.lower():
feedback = "Error: This tool does not provide medication or treatment advice. Please describe symptoms only (e.g., 'I have a fever')."
feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}"
feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice."
return feedback
# Analyze symptoms
prediction, score = analyze_symptoms(transcription)
if "Error analyzing" in prediction:
return prediction
# Generate feedback
if prediction == "No health condition predicted":
feedback = "No significant health indicators detected."
else:
feedback = f"Possible health condition: {prediction} (confidence: {score:.4f}). Consult a doctor."
feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', Prediction = {prediction}, Confidence = {score:.4f}, File Hash = {file_hash}"
feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice."
# Clean up temporary audio file
try:
os.remove(audio_file)
print(f"Deleted temporary audio file: {audio_file}")
except Exception as e:
print(f"Failed to delete audio file: {str(e)}")
return feedback
except Exception as e:
return f"Error processing audio: {str(e)}"
def test_with_sample_audio():
"""Test the app with sample audio files."""
samples = ["audio_samples/sample.wav", "audio_samples/common_voice_en.wav"]
results = []
for sample in samples:
if os.path.exists(sample):
results.append(analyze_voice(sample))
else:
results.append(f"Sample not found: {sample}")
return "\n".join(results)
# Gradio interface
iface = gr.Interface(
fn=analyze_voice,
inputs=gr.Audio(type="filepath", label="Record or Upload Voice"),
outputs=gr.Textbox(label="Health Assessment Feedback"),
title="Health Voice Analyzer",
description="Record or upload a voice sample describing symptoms (e.g., 'I have a fever') for preliminary health assessment. Supports English only. Use clear audio (WAV, 16kHz). Do not ask for medication or treatment advice."
)
if __name__ == "__main__":
print(test_with_sample_audio())
iface.launch(server_name="0.0.0.0", server_port=7860) |