|
import os |
|
import tempfile |
|
from subprocess import Popen, PIPE |
|
import torch |
|
import gradio as gr |
|
from pydub import AudioSegment |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
from transformers.pipelines.audio_utils import ffmpeg_read |
|
from sentence_transformers import SentenceTransformer, util |
|
import spacy |
|
import spacy.cli |
|
spacy.cli.download("en_core_web_sm") |
|
|
|
|
|
MODEL_NAME = "openai/whisper-large-v3-turbo" |
|
BATCH_SIZE = 8 |
|
FILE_LIMIT_MB = 1000 |
|
device = 0 if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
whisper_pipeline = pipeline( |
|
task="automatic-speech-recognition", |
|
model=MODEL_NAME, |
|
chunk_length_s=30, |
|
device=device, |
|
) |
|
|
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
|
|
|
summarizer_model_name = "Mahalingam/DistilBart-Med-Summary" |
|
tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name) |
|
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_model_name) |
|
summarizer = pipeline("summarization", model=summarizer_model, tokenizer=tokenizer) |
|
|
|
|
|
soap_prompts = { |
|
"subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.", |
|
"objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.", |
|
"assessment": "Clinical assessments, expertise-based opinions on conditions, and significance of medical interventions. Focused on medical evaluations or patient condition summaries.", |
|
"plan": "Future steps, recommendations for treatment, follow-up instructions, and healthcare management plans." |
|
} |
|
soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()} |
|
|
|
|
|
def convert_mp4_to_mp3(mp4_path, mp3_path): |
|
try: |
|
audio = AudioSegment.from_file(mp4_path, format="mp4") |
|
audio.export(mp3_path, format="mp3") |
|
except Exception as e: |
|
raise RuntimeError(f"Error converting MP4 to MP3: {e}") |
|
|
|
|
|
def transcribe_audio(audio_path): |
|
try: |
|
if not os.path.exists(audio_path): |
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
|
|
inputs = ffmpeg_read(audio_path, whisper_pipeline.feature_extractor.sampling_rate) |
|
inputs = {"array": inputs, "sampling_rate": whisper_pipeline.feature_extractor.sampling_rate} |
|
|
|
|
|
result = whisper_pipeline(inputs, batch_size=BATCH_SIZE, return_timestamps=False) |
|
return result["text"] |
|
except Exception as e: |
|
return f"Error during transcription: {e}" |
|
|
|
|
|
|
|
def classify_sentence(sentence): |
|
similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()} |
|
return max(similarities, key=similarities.get) |
|
|
|
|
|
def summarize_section(section_text): |
|
if len(section_text.split()) < 50: |
|
return section_text |
|
target_length = int(len(section_text.split()) * 0.50) |
|
inputs = tokenizer.encode(section_text, return_tensors="pt", truncation=True, max_length=1024) |
|
summary_ids = summarizer_model.generate( |
|
inputs, |
|
max_length=target_length, |
|
min_length=int(target_length * 0.45), |
|
length_penalty=1.0, |
|
num_beams=4 |
|
) |
|
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
def soap_analysis(text): |
|
doc = nlp(text) |
|
soap_note = {section: "" for section in soap_prompts.keys()} |
|
|
|
for sentence in doc.sents: |
|
section = classify_sentence(sentence.text) |
|
soap_note[section] += sentence.text + " " |
|
|
|
|
|
for section in soap_note: |
|
soap_note[section] = summarize_section(soap_note[section].strip()) |
|
|
|
return format_soap_output(soap_note) |
|
|
|
|
|
def format_soap_output(soap_note): |
|
return ( |
|
f"Subjective:\n{soap_note['subjective']}\n\n" |
|
f"Objective:\n{soap_note['objective']}\n\n" |
|
f"Assessment:\n{soap_note['assessment']}\n\n" |
|
f"Plan:\n{soap_note['plan']}\n" |
|
) |
|
|
|
|
|
def process_file(file, user_prompt): |
|
|
|
if file.name.endswith(".mp4"): |
|
temp_mp3_path = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name |
|
try: |
|
convert_mp4_to_mp3(file.name, temp_mp3_path) |
|
audio_path = temp_mp3_path |
|
except Exception as e: |
|
return f"Error during MP4 to MP3 conversion: {e}", "", "" |
|
else: |
|
audio_path = file.name |
|
|
|
|
|
transcription = transcribe_audio(audio_path) |
|
print("Transcribed Text: ", transcription) |
|
|
|
|
|
soap_note = soap_analysis(transcription) |
|
print("SOAP Notes: ", soap_note) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if file.name.endswith(".mp4"): |
|
os.remove(temp_mp3_path) |
|
|
|
return soap_note |
|
|
|
|
|
def process_text(text, user_prompt): |
|
soap_note = soap_analysis(text) |
|
print(soap_note) |
|
|
|
|
|
|
|
|
|
|
|
return soap_note |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def launch_gradio(): |
|
with gr.Blocks(theme=gr.themes.Default()) as demo: |
|
gr.Markdown("# Enhanced Video to SOAP Note Generator") |
|
|
|
with gr.Tab("Audio/Video File to SOAP"): |
|
gr.Interface( |
|
fn=process_file, |
|
inputs=[gr.File(label="Upload Audio/Video File"), gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6)], |
|
outputs=[ |
|
gr.Textbox(label="SOAP Note"), |
|
|
|
|
|
], |
|
) |
|
|
|
with gr.Tab("Text Input to SOAP"): |
|
gr.Interface( |
|
fn=process_text, |
|
inputs=[gr.Textbox(label="Enter Text", placeholder="Enter medical notes...", lines=6), gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6)], |
|
outputs=[ |
|
gr.Textbox(label="SOAP Note"), |
|
|
|
|
|
], |
|
) |
|
|
|
demo.launch(share=True, debug=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
launch_gradio() |
|
|