yasserrmd commited on
Commit
3f1277b
·
verified ·
1 Parent(s): 0134b10

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoProcessor, Gemma3nForConditionalGeneration, TextIteratorStreamer
5
+ from PIL import Image
6
+ import threading
7
+ import traceback
8
+ import spaces
9
+
10
+ # -----------------------------
11
+ # Config
12
+ # -----------------------------
13
+ MODEL_ID = "yasserrmd/GemmaECG-Vision"
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # safe CPU dtype
16
+
17
+ # Generation defaults
18
+ GEN_KW = dict(
19
+ max_new_tokens=768,
20
+ do_sample=True,
21
+ temperature=1.0,
22
+ top_p=0.95,
23
+ top_k=64,
24
+ use_cache=True,
25
+ )
26
+
27
+ # Clinical prompt
28
+ CLINICAL_PROMPT = """You are a clinical assistant specialized in ECG interpretation. Given an ECG image, generate a concise, structured, and medically accurate report.
29
+
30
+ Use this exact format:
31
+
32
+ Rhythm:
33
+ PR Interval:
34
+ QRS Duration:
35
+ Axis:
36
+ Bundle Branch Blocks:
37
+ Atrial Abnormalities:
38
+ Ventricular Hypertrophy:
39
+ Q Wave or QS Complexes:
40
+ T Wave Abnormalities:
41
+ ST Segment Changes:
42
+ Final Impression:
43
+
44
+ Guidance:
45
+ - Confirm sinus rhythm only if consistent P waves precede each QRS.
46
+ - Describe PACs only if early, ectopic P waves are visible.
47
+ - Do not diagnose myocardial infarction solely based on QS complexes unless accompanied by other signs (e.g., ST elevation, reciprocal changes, poor R wave progression).
48
+ - Only mention axis deviation if QRS axis is clearly rightward (RAD) or leftward (LAD).
49
+ - Use terms like "suggestive of" or "possible" for uncertain findings.
50
+ - Avoid repetition and keep the report clinically focused.
51
+ - Do not include external references or source citations.
52
+ - Do not diagnose left bundle branch block unless QRS duration is ≥120 ms with typical morphology in leads I, V5, V6.
53
+ - Mark T wave changes in inferior leads as “nonspecific” unless clear ST elevation or reciprocal depression is present.
54
+
55
+ Your goal is to provide a structured ECG summary useful for a cardiologist or internal medicine physician.
56
+ """
57
+
58
+ # -----------------------------
59
+ # Load model & processor
60
+ # -----------------------------
61
+ model = Gemma3nForConditionalGeneration.from_pretrained(
62
+ MODEL_ID, torch_dtype=DTYPE
63
+ ).to(DEVICE).eval()
64
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
65
+
66
+ # -----------------------------
67
+ # Inference (streaming) function
68
+ # -----------------------------
69
+ @spaces.GPU
70
+ def analyze_ecg_stream(image: Image.Image):
71
+ """
72
+ Streams model output into the Gradio textbox.
73
+ Yields incremental text chunks.
74
+ """
75
+ if image is None:
76
+ yield "Please upload an ECG image."
77
+ return
78
+
79
+ # Build a multimodal chat-style message; rely on the model's chat template to inject image tokens.
80
+ messages = [
81
+ {
82
+ "role": "user",
83
+ "content": [
84
+ {"type": "text", "text": CLINICAL_PROMPT},
85
+ {"type": "image"},
86
+ ],
87
+ }
88
+ ]
89
+
90
+ try:
91
+ # Try with chat template first (recommended for chat-tuned models)
92
+ chat_text = processor.apply_chat_template(messages, add_generation_prompt=True)
93
+
94
+ model_inputs = processor(
95
+ text=chat_text,
96
+ images=image,
97
+ return_tensors="pt",
98
+ )
99
+ model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
100
+
101
+ except Exception as e:
102
+ # If the template or image-token count fails, fallback to a simple text+image pack.
103
+ # This handles errors like:
104
+ # "Number of images does not match number of special image tokens..."
105
+ fallback_note = (
106
+ "\n[Note] Falling back to a simpler prompt packing due to template/image token mismatch."
107
+ )
108
+ try:
109
+ model_inputs = processor(
110
+ text=CLINICAL_PROMPT,
111
+ images=image,
112
+ return_tensors="pt",
113
+ )
114
+ model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
115
+ # Surface a short note at the start of the stream so user knows why
116
+ yield fallback_note + "\n"
117
+ except Exception as inner_e:
118
+ err_msg = f"Input preparation failed:\n{repr(e)}\n{repr(inner_e)}"
119
+ yield err_msg
120
+ return
121
+
122
+ # Prepare streamer
123
+ streamer = TextIteratorStreamer(
124
+ processor.tokenizer,
125
+ skip_prompt=True,
126
+ skip_special_tokens=True,
127
+ )
128
+
129
+ # Launch generation in a background thread
130
+ generated_text = []
131
+ def _generate():
132
+ try:
133
+ model.generate(
134
+ **model_inputs,
135
+ streamer=streamer,
136
+ **GEN_KW
137
+ )
138
+ except Exception as gen_e:
139
+ # Put traceback into the stream so the user sees it (useful during debugging)
140
+ tb = traceback.format_exc()
141
+ streamer.put("\n\n[Generation Error]\n" + str(gen_e) + "\n" + tb)
142
+ finally:
143
+ streamer.end()
144
+
145
+ thread = threading.Thread(target=_generate)
146
+ thread.start()
147
+
148
+ # Collect incremental tokens and yield buffer
149
+ buffer = ""
150
+ for token in streamer:
151
+ buffer += token
152
+ # Stream into Gradio textbox
153
+ yield buffer
154
+
155
+ def reset():
156
+ return None, ""
157
+
158
+ # -----------------------------
159
+ # Gradio UI
160
+ # -----------------------------
161
+ with gr.Blocks(css="""
162
+ .disclaimer {
163
+ padding: 12px 16px;
164
+ border: 1px solid #b91c1c;
165
+ background: #fef2f2;
166
+ color: #7f1d1d;
167
+ border-radius: 8px;
168
+ font-weight: 600;
169
+ }
170
+ .footer-note {
171
+ font-size: 12px;
172
+ color: #374151;
173
+ }
174
+ .gr-button { background-color: #1e3a8a; color: #ffffff; }
175
+ """) as demo:
176
+ gr.Markdown("## 🩺 ECG Interpretation Assistant — Gemma-ECG-Vision")
177
+ gr.HTML("""
178
+ <div class="disclaimer">
179
+ ⚠️ <strong>Important Medical Disclaimer:</strong> This tool is for <u>education and research</u> purposes only.
180
+ It is <u>not</u> a medical device and must not be used for diagnosis or treatment.
181
+ Always consult a licensed clinician for interpretation and clinical decisions.
182
+ </div>
183
+ """)
184
+
185
+ with gr.Row():
186
+ image_input = gr.Image(type="pil", label="Upload ECG Image", height=320)
187
+ output_box = gr.Textbox(
188
+ label="Generated ECG Report (Streaming)",
189
+ lines=24,
190
+ show_copy_button=True,
191
+ autoscroll=True,
192
+ )
193
+
194
+ with gr.Row():
195
+ with gr.Column():
196
+ submit_btn = gr.Button("Generate Report", variant="primary")
197
+ with gr.Column():
198
+ reset_btn = gr.Button("Reset")
199
+
200
+ # Wire actions: analyze_ecg_stream yields partial strings for streaming
201
+ submit_btn.click(
202
+ fn=analyze_ecg_stream,
203
+ inputs=image_input,
204
+ outputs=output_box,
205
+ queue=True,
206
+ api_name="analyze_ecg",
207
+ )
208
+ reset_btn.click(fn=reset, outputs=[image_input, output_box])
209
+
210
+ gr.Markdown(
211
+ """
212
+ <div class="footer-note">
213
+ Model: <code>{model_id}</code> | Device: <code>{device}</code><br>
214
+ Tip: Larger images can improve recognition of fine waveform details (P waves, ST segments).
215
+ Ensure lead labels are visible when possible.
216
+ </div>
217
+ """.format(model_id=MODEL_ID, device=DEVICE)
218
+ )
219
+
220
+ # Enable queuing for proper streaming under concurrency
221
+ demo.queue(concurrency_count=2, max_size=16)
222
+ # In hosted notebooks, you can set share=True if needed
223
+ demo.launch(share=False, debug=True)