karthi311 commited on
Commit
beb5b30
·
verified ·
1 Parent(s): e041b4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py CHANGED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from subprocess import Popen, PIPE
4
+ import torch
5
+ import gradio as gr
6
+ from pydub import AudioSegment
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
+ from transformers.pipelines.audio_utils import ffmpeg_read
9
+ from sentence_transformers import SentenceTransformer, util
10
+ import spacy
11
+
12
+ # Constants
13
+ MODEL_NAME = "openai/whisper-large-v3-turbo"
14
+ BATCH_SIZE = 8
15
+ FILE_LIMIT_MB = 1000
16
+ device = 0 if torch.cuda.is_available() else "cpu"
17
+
18
+ # Whisper pipeline
19
+ whisper_pipeline = pipeline(
20
+ task="automatic-speech-recognition",
21
+ model=MODEL_NAME,
22
+ chunk_length_s=30,
23
+ device=device,
24
+ )
25
+
26
+ # NLP model and other helpers
27
+ nlp = spacy.load("en_core_web_sm")
28
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
29
+
30
+ # Summarization model
31
+ summarizer_model_name = "Mahalingam/DistilBart-Med-Summary"
32
+ tokenizer = AutoTokenizer.from_pretrained(summarizer_model_name)
33
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_model_name)
34
+ summarizer = pipeline("summarization", model=summarizer_model, tokenizer=tokenizer)
35
+
36
+ # SOAP prompts and embeddings
37
+ soap_prompts = {
38
+ "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.",
39
+ "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.",
40
+ "assessment": "Clinical assessments, expertise-based opinions on conditions, and significance of medical interventions. Focused on medical evaluations or patient condition summaries.",
41
+ "plan": "Future steps, recommendations for treatment, follow-up instructions, and healthcare management plans."
42
+ }
43
+ soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()}
44
+
45
+ # Convert MP4 to MP3
46
+ def convert_mp4_to_mp3(mp4_path, mp3_path):
47
+ try:
48
+ audio = AudioSegment.from_file(mp4_path, format="mp4")
49
+ audio.export(mp3_path, format="mp3")
50
+ except Exception as e:
51
+ raise RuntimeError(f"Error converting MP4 to MP3: {e}")
52
+
53
+ # Transcribe audio
54
+ def transcribe_audio(audio_path):
55
+ try:
56
+ inputs = ffmpeg_read(audio_path, whisper_pipeline.feature_extractor.sampling_rate)
57
+ inputs = {"array": inputs, "sampling_rate": whisper_pipeline.feature_extractor.sampling_rate}
58
+ result = whisper_pipeline(inputs, batch_size=BATCH_SIZE, return_timestamps=False)
59
+ return result["text"]
60
+ except Exception as e:
61
+ return f"Error during transcription: {e}"
62
+
63
+ # Classify the sentence to the correct SOAP section
64
+ def classify_sentence(sentence):
65
+ similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()}
66
+ return max(similarities, key=similarities.get)
67
+
68
+ # Summarize the section if it's too long
69
+ def summarize_section(section_text):
70
+ if len(section_text.split()) < 50:
71
+ return section_text
72
+ target_length = int(len(section_text.split()) * 0.50)
73
+ inputs = tokenizer.encode(section_text, return_tensors="pt", truncation=True, max_length=1024)
74
+ summary_ids = summarizer_model.generate(
75
+ inputs,
76
+ max_length=target_length,
77
+ min_length=int(target_length * 0.45),
78
+ length_penalty=1.0,
79
+ num_beams=4
80
+ )
81
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
82
+
83
+ # Analyze the SOAP content and divide into sections
84
+ def soap_analysis(text):
85
+ doc = nlp(text)
86
+ soap_note = {section: "" for section in soap_prompts.keys()}
87
+
88
+ for sentence in doc.sents:
89
+ section = classify_sentence(sentence.text)
90
+ soap_note[section] += sentence.text + " "
91
+
92
+ # Summarize each section of the SOAP note
93
+ for section in soap_note:
94
+ soap_note[section] = summarize_section(soap_note[section].strip())
95
+
96
+ return format_soap_output(soap_note)
97
+
98
+ # Format the SOAP note output
99
+ def format_soap_output(soap_note):
100
+ return (
101
+ f"Subjective:\n{soap_note['subjective']}\n\n"
102
+ f"Objective:\n{soap_note['objective']}\n\n"
103
+ f"Assessment:\n{soap_note['assessment']}\n\n"
104
+ f"Plan:\n{soap_note['plan']}\n"
105
+ )
106
+
107
+ # Process file function for audio/video to SOAP
108
+ def process_file(file, user_prompt):
109
+ # Determine file type and convert if necessary
110
+ if file.name.endswith(".mp4"):
111
+ temp_mp3_path = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
112
+ try:
113
+ convert_mp4_to_mp3(file.name, temp_mp3_path)
114
+ audio_path = temp_mp3_path
115
+ except Exception as e:
116
+ return f"Error during MP4 to MP3 conversion: {e}", "", ""
117
+ else:
118
+ audio_path = file.name
119
+
120
+ # Transcribe audio
121
+ transcription = transcribe_audio(audio_path)
122
+ print("Transcribed Text: ", transcription)
123
+
124
+ # Perform SOAP analysis
125
+ soap_note = soap_analysis(transcription)
126
+ print("SOAP Notes: ", soap_note)
127
+
128
+ # # Generate template and JSON using LLaMA
129
+ # template_output = llama_query(user_prompt, soap_note)
130
+ # print("Template: ", template_output)
131
+
132
+ # json_output = llama_convert_to_json(template_output)
133
+
134
+ # Clean up temporary files
135
+ if file.name.endswith(".mp4"):
136
+ os.remove(temp_mp3_path)
137
+
138
+ return soap_note#, template_output, json_output
139
+
140
+ # Process text function for text input to SOAP
141
+ def process_text(text, user_prompt):
142
+ soap_note = soap_analysis(text)
143
+ print(soap_note)
144
+
145
+ # template_output = llama_query(user_prompt, soap_note)
146
+ # print(template_output)
147
+ # json_output = llama_convert_to_json(template_output)
148
+
149
+ return soap_note#, template_output, json_output
150
+
151
+ # # Llama query function
152
+ # def llama_query(user_prompt, soap_note, model="llama3.2"):
153
+ # combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}"
154
+ # try:
155
+ # process = Popen(['ollama', 'run', model], stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True, encoding='utf-8')
156
+ # stdout, stderr = process.communicate(input=combined_prompt)
157
+ # if process.returncode != 0:
158
+ # return f"Error: {stderr.strip()}"
159
+ # return stdout.strip()
160
+ # except Exception as e:
161
+ # return f"Unexpected error: {str(e)}"
162
+
163
+ # # Convert the response to JSON format
164
+ # def llama_convert_to_json(template_output, model="llama3.2"):
165
+ # json_prompt = f"Convert the following template into a structured JSON format:\n\n{template_output}"
166
+ # try:
167
+ # process = Popen(['ollama', 'run', model], stdin=PIPE, stdout=PIPE, stderr=PIPE, text=True, encoding='utf-8')
168
+ # stdout, stderr = process.communicate(input=json_prompt)
169
+ # if process.returncode != 0:
170
+ # return f"Error: {stderr.strip()}"
171
+ # return stdout.strip() # Assuming the model outputs a valid JSON string
172
+ # except Exception as e:
173
+ # return f"Unexpected error: {str(e)}"
174
+
175
+ # Gradio interface
176
+ def launch_gradio():
177
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
178
+ gr.Markdown("# Enhanced Video to SOAP Note Generator")
179
+
180
+ with gr.Tab("Audio/Video File to SOAP"):
181
+ gr.Interface(
182
+ fn=process_file,
183
+ inputs=[gr.File(label="Upload Audio/Video File"), gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6)],
184
+ outputs=[
185
+ gr.Textbox(label="SOAP Note"),
186
+ # gr.Textbox(label="Generated Template from LLaMA"),
187
+ # gr.Textbox(label="JSON Output")
188
+ ],
189
+ )
190
+
191
+ with gr.Tab("Text Input to SOAP"):
192
+ gr.Interface(
193
+ fn=process_text,
194
+ inputs=[gr.Textbox(label="Enter Text", placeholder="Enter medical notes...", lines=6), gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6)],
195
+ outputs=[
196
+ gr.Textbox(label="SOAP Note"),
197
+ # gr.Textbox(label="Generated Template from LLaMA"),
198
+ # gr.Textbox(label="JSON Output")
199
+ ],
200
+ )
201
+
202
+ demo.launch(share=True, debug=True)
203
+
204
+ # Run the Gradio app
205
+ if __name__ == "__main__":
206
+ launch_gradio()