yasserrmd commited on
Commit
ed05827
·
verified ·
1 Parent(s): 5f96e36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -113
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import os
 
 
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoProcessor, Gemma3nForConditionalGeneration, TextIteratorStreamer
5
  from PIL import Image
6
- import threading
7
  import traceback
8
  import spaces
9
 
@@ -63,141 +65,136 @@ model = Gemma3nForConditionalGeneration.from_pretrained(
63
  ).to(DEVICE).eval()
64
  processor = AutoProcessor.from_pretrained(MODEL_ID)
65
 
 
 
66
  # -----------------------------
67
- # Inference (streaming) function
68
  # -----------------------------
69
  @spaces.GPU
70
  def analyze_ecg_stream(image: Image.Image):
71
- """
72
- Streams model output into the Gradio textbox.
73
- Yields incremental text chunks.
74
- """
75
  if image is None:
76
  yield "Please upload an ECG image."
77
  return
78
 
79
- # Build a multimodal chat-style message; rely on the model's chat template to inject image tokens.
80
  messages = [
81
- {
82
- "role": "user",
83
- "content": [
84
- {"type": "text", "text": CLINICAL_PROMPT},
85
- {"type": "image"},
86
- ],
87
- }
88
  ]
89
 
 
90
  try:
91
- # Try with chat template first (recommended for chat-tuned models)
92
  chat_text = processor.apply_chat_template(messages, add_generation_prompt=True)
 
 
 
 
 
93
 
94
- model_inputs = processor(
95
- text=chat_text,
96
- images=image,
97
- return_tensors="pt",
98
- )
99
- model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
100
-
101
- except Exception as e:
102
- # If the template or image-token count fails, fallback to a simple text+image pack.
103
- # This handles errors like:
104
- # "Number of images does not match number of special image tokens..."
105
- fallback_note = (
106
- "\n[Note] Falling back to a simpler prompt packing due to template/image token mismatch."
107
- )
108
- try:
109
- model_inputs = processor(
110
- text=CLINICAL_PROMPT,
111
- images=image,
112
- return_tensors="pt",
113
- )
114
- model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
115
- # Surface a short note at the start of the stream so user knows why
116
- yield fallback_note + "\n"
117
- except Exception as inner_e:
118
- err_msg = f"Input preparation failed:\n{repr(e)}\n{repr(inner_e)}"
119
- yield err_msg
120
- return
121
-
122
- # Prepare streamer
123
  streamer = TextIteratorStreamer(
124
- processor.tokenizer,
125
- skip_prompt=True,
126
- skip_special_tokens=True,
127
  )
128
 
129
- # Launch generation in a background thread
130
- generated_text = []
131
  def _generate():
132
  try:
133
  model.generate(
134
  **model_inputs,
135
  streamer=streamer,
136
- **GEN_KW
137
  )
138
- except Exception as gen_e:
139
- # Put traceback into the stream so the user sees it (useful during debugging)
140
- tb = traceback.format_exc()
141
- streamer.put("\n\n[Generation Error]\n" + str(gen_e) + "\n" + tb)
142
  finally:
143
  streamer.end()
144
 
145
- thread = threading.Thread(target=_generate)
146
- thread.start()
147
 
148
- # Collect incremental tokens and yield buffer
149
- buffer = ""
150
- for token in streamer:
151
- buffer += token
152
- # Stream into Gradio textbox
153
- yield buffer
154
 
155
  def reset():
156
  return None, ""
157
 
158
  # -----------------------------
159
- # Gradio UI
160
  # -----------------------------
161
- with gr.Blocks(css="""
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  .disclaimer {
163
- padding: 12px 16px;
164
- border: 1px solid #b91c1c;
165
- background: #fef2f2;
166
- color: #7f1d1d;
167
- border-radius: 8px;
168
- font-weight: 600;
169
  }
170
- .footer-note {
171
- font-size: 12px;
172
- color: #374151;
173
  }
174
- .gr-button { background-color: #1e3a8a; color: #ffffff; }
175
- """) as demo:
176
- gr.Markdown("## 🩺 ECG Interpretation Assistant — Gemma-ECG-Vision")
177
- gr.HTML("""
178
- <div class="disclaimer">
179
- ⚠️ <strong>Important Medical Disclaimer:</strong> This tool is for <u>education and research</u> purposes only.
180
- It is <u>not</u> a medical device and must not be used for diagnosis or treatment.
181
- Always consult a licensed clinician for interpretation and clinical decisions.
182
- </div>
183
- """)
184
-
185
- with gr.Row():
186
- image_input = gr.Image(type="pil", label="Upload ECG Image", height=320)
187
- output_box = gr.Textbox(
188
- label="Generated ECG Report (Streaming)",
189
- lines=24,
190
- show_copy_button=True,
191
- autoscroll=True,
192
- )
193
 
 
194
  with gr.Row():
195
- with gr.Column():
196
- submit_btn = gr.Button("Generate Report", variant="primary")
197
- with gr.Column():
198
- reset_btn = gr.Button("Reset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # Wire actions: analyze_ecg_stream yields partial strings for streaming
201
  submit_btn.click(
202
  fn=analyze_ecg_stream,
203
  inputs=image_input,
@@ -205,19 +202,20 @@ with gr.Blocks(css="""
205
  queue=True,
206
  api_name="analyze_ecg",
207
  )
208
- reset_btn.click(fn=reset, outputs=[image_input, output_box])
209
-
210
- gr.Markdown(
211
- """
212
- <div class="footer-note">
213
- Model: <code>{model_id}</code> | Device: <code>{device}</code><br>
214
- Tip: Larger images can improve recognition of fine waveform details (P waves, ST segments).
215
- Ensure lead labels are visible when possible.
216
- </div>
217
- """.format(model_id=MODEL_ID, device=DEVICE)
218
- )
219
-
220
- # Enable queuing for proper streaming under concurrency
221
- #demo.queue(concurrency_count=2, max_size=16)
222
- # In hosted notebooks, you can set share=True if needed
223
- demo.launch(share=False, debug=True)
 
 
1
  import os
2
+ import threading
3
+ import traceback
4
  import gradio as gr
5
  import torch
6
  from transformers import AutoProcessor, Gemma3nForConditionalGeneration, TextIteratorStreamer
7
  from PIL import Image
8
+ import inspect
9
  import traceback
10
  import spaces
11
 
 
65
  ).to(DEVICE).eval()
66
  processor = AutoProcessor.from_pretrained(MODEL_ID)
67
 
68
+
69
+
70
  # -----------------------------
71
+ # Streaming generator
72
  # -----------------------------
73
  @spaces.GPU
74
  def analyze_ecg_stream(image: Image.Image):
 
 
 
 
75
  if image is None:
76
  yield "Please upload an ECG image."
77
  return
78
 
 
79
  messages = [
80
+ {"role": "user", "content": [
81
+ {"type": "text", "text": CLINICAL_PROMPT},
82
+ {"type": "image"},
83
+ ]}
 
 
 
84
  ]
85
 
86
+ # Prepare inputs (try chat template; fallback to plain text+image)
87
  try:
 
88
  chat_text = processor.apply_chat_template(messages, add_generation_prompt=True)
89
+ model_inputs = processor(text=chat_text, images=image, return_tensors="pt")
90
+ except Exception:
91
+ # Fallback when the template/image token count mismatches
92
+ model_inputs = processor(text=CLINICAL_PROMPT, images=image, return_tensors="pt")
93
+ yield "[Note] Using fallback prompt packing.\n"
94
 
95
+ model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
96
+
97
+ # Streamer must use the tokenizer (not the processor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  streamer = TextIteratorStreamer(
99
+ processor.tokenizer, skip_prompt=True, skip_special_tokens=True
 
 
100
  )
101
 
 
 
102
  def _generate():
103
  try:
104
  model.generate(
105
  **model_inputs,
106
  streamer=streamer,
107
+ **GEN_KW,
108
  )
109
+ except Exception as e:
110
+ streamer.put("\n\n[Generation Error]\n" + traceback.format_exc())
 
 
111
  finally:
112
  streamer.end()
113
 
114
+ t = threading.Thread(target=_generate, daemon=True)
115
+ t.start()
116
 
117
+ buf = ""
118
+ for piece in streamer:
119
+ buf += piece
120
+ yield buf
 
 
121
 
122
  def reset():
123
  return None, ""
124
 
125
  # -----------------------------
126
+ # UI
127
  # -----------------------------
128
+ theme = gr.themes.Soft(primary_hue="indigo", neutral_hue="slate")
129
+
130
+ custom_css = """
131
+ #app {
132
+ max-width: 1100px;
133
+ margin: 0 auto;
134
+ }
135
+ .header {
136
+ display:flex; align-items:center; justify-content:space-between;
137
+ padding: 16px 14px; border-radius: 14px;
138
+ background: linear-gradient(135deg, #1f2937 0%, #111827 100%);
139
+ color: #fff; box-shadow: 0 6px 20px rgba(0,0,0,0.25);
140
+ }
141
+ .brand { font-size: 18px; font-weight: 700; letter-spacing: 0.3px; }
142
  .disclaimer {
143
+ margin-top: 12px; padding: 12px 14px; border-radius: 12px;
144
+ background: #fef2f2; color:#7f1d1d; border:1px solid #fecaca; font-weight:600;
 
 
 
 
145
  }
146
+ .card {
147
+ background: #ffffff; border: 1px solid #e5e7eb; border-radius: 14px;
148
+ padding: 16px; box-shadow: 0 8px 18px rgba(17,24,39,0.06);
149
  }
150
+ footer {
151
+ font-size: 12px; color:#6b7280; margin-top: 8px;
152
+ }
153
+ .gr-button { background-color:#1e3a8a !important; color:#fff !important; }
154
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ with gr.Blocks(theme=theme, css=custom_css, elem_id="app") as demo:
157
  with gr.Row():
158
+ gr.HTML("""
159
+ <div class="header">
160
+ <div class="brand">🩺 ECG Interpretation Assistant</div>
161
+ <div>Gemma-ECG-Vision</div>
162
+ </div>
163
+ <div class="disclaimer">
164
+ ⚠️ <strong>Education & Research Only:</strong> This tool is not a medical device and must not be used for diagnosis or treatment.
165
+ Always consult a licensed clinician for interpretation and clinical decisions.
166
+ </div>
167
+ """)
168
+
169
+ with gr.Row(equal_height=True):
170
+ with gr.Column(scale=1):
171
+ with gr.Group(elem_classes="card"):
172
+ image_input = gr.Image(
173
+ type="pil", label="Upload ECG Image", height=360, show_label=True
174
+ )
175
+ with gr.Row():
176
+ submit_btn = gr.Button("Generate Report", variant="primary")
177
+ reset_btn = gr.Button("Reset")
178
+
179
+ with gr.Column(scale=2):
180
+ with gr.Group(elem_classes="card"):
181
+ output_box = gr.Textbox(
182
+ label="Generated ECG Report (Streaming)",
183
+ lines=26,
184
+ show_copy_button=True,
185
+ autoscroll=True,
186
+ placeholder="The model's report will appear here…",
187
+ )
188
+ gr.Markdown(
189
+ "Tip: Clear, high-resolution ECGs with visible lead labels improve P wave and ST-segment assessment."
190
+ )
191
+
192
+ gr.HTML(f"""
193
+ <footer>
194
+ Model: <code>{MODEL_ID}</code> | Device: <code>{DEVICE}</code>
195
+ </footer>
196
+ """)
197
 
 
198
  submit_btn.click(
199
  fn=analyze_ecg_stream,
200
  inputs=image_input,
 
202
  queue=True,
203
  api_name="analyze_ecg",
204
  )
205
+ reset_btn.click(reset, outputs=[image_input, output_box])
206
+
207
+ def queue_with_compat(demo, max_size=32, limit=4):
208
+ params = inspect.signature(gr.Blocks.queue).parameters
209
+ if "concurrency_count" in params:
210
+ # Older Gradio 3.x / early 4.x
211
+ return demo.queue(concurrency_count=limit, max_size=max_size)
212
+ elif "default_concurrency_limit" in params:
213
+ # Newer Gradio 4.x
214
+ return demo.queue(default_concurrency_limit=limit, max_size=max_size)
215
+ else:
216
+ # Fallback – no knobs exposed
217
+ return demo.queue()
218
+
219
+ # build your UI as before
220
+ queue_with_compat(demo, max_size=32, limit=4)
221
+ demo.launch(share=False, debug=True)