walaa2022 commited on
Commit
123160d
Β·
verified Β·
1 Parent(s): 0a567ba

Upload qwen_vl_medical_ui.py

Browse files
Files changed (1) hide show
  1. qwen_vl_medical_ui.py +472 -0
qwen_vl_medical_ui.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Qwen-VL Medical Image Analysis - Gradio Interface for Hugging Face Spaces
4
+ Optimized for Hugging Face deployment with efficient resource usage
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
10
+ from qwen_vl_utils import process_vision_info
11
+ from PIL import Image
12
+ import json
13
+ import time
14
+ import os
15
+ from typing import Dict, List, Optional, Tuple
16
+ import warnings
17
+ warnings.filterwarnings("ignore")
18
+
19
+ # Global variables for model caching
20
+ MODEL = None
21
+ PROCESSOR = None
22
+ DEVICE = None
23
+
24
+ def get_device():
25
+ """Determine the best available device"""
26
+ if torch.cuda.is_available():
27
+ return "cuda"
28
+ else:
29
+ return "cpu"
30
+
31
+ def load_model_cached(model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"):
32
+ """Load and cache model - optimized for HF Spaces"""
33
+ global MODEL, PROCESSOR, DEVICE
34
+
35
+ if MODEL is None:
36
+ print(f"Loading {model_name}...")
37
+ DEVICE = get_device()
38
+ print(f"Using device: {DEVICE}")
39
+
40
+ try:
41
+ # Load with memory optimization for HF Spaces
42
+ MODEL = Qwen2VLForConditionalGeneration.from_pretrained(
43
+ model_name,
44
+ torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
45
+ device_map="auto" if DEVICE == "cuda" else None,
46
+ trust_remote_code=True,
47
+ attn_implementation="eager", # More stable for HF Spaces
48
+ low_cpu_mem_usage=True
49
+ )
50
+
51
+ PROCESSOR = AutoProcessor.from_pretrained(
52
+ model_name,
53
+ trust_remote_code=True
54
+ )
55
+
56
+ if DEVICE == "cpu":
57
+ MODEL = MODEL.to(DEVICE)
58
+
59
+ print("Model loaded successfully!")
60
+ return True
61
+
62
+ except Exception as e:
63
+ print(f"Error loading model: {str(e)}")
64
+ return False
65
+
66
+ return True
67
+
68
+ def create_medical_prompt(clinical_data: Dict, analysis_type: str, focus_areas: str) -> str:
69
+ """Create optimized medical analysis prompt"""
70
+
71
+ base_prompt = """You are an expert medical AI assistant specializing in medical image analysis with training across radiology, pathology, dermatology, and clinical medicine.
72
+
73
+ **ANALYSIS INSTRUCTIONS:**
74
+ Provide a systematic, professional medical image analysis following these guidelines:
75
+ - Use clear medical terminology with explanations
76
+ - Structure your response with clear sections
77
+ - Be thorough but concise
78
+ - Always mention limitations and need for professional consultation
79
+ """
80
+
81
+ # Add clinical context if provided
82
+ clinical_context = ""
83
+ if clinical_data and any(v.strip() for v in clinical_data.values() if v):
84
+ clinical_context = "\n**CLINICAL CONTEXT:**\n"
85
+ context_items = []
86
+ if clinical_data.get("age"): context_items.append(f"Age: {clinical_data['age']}")
87
+ if clinical_data.get("gender"): context_items.append(f"Gender: {clinical_data['gender']}")
88
+ if clinical_data.get("symptoms"): context_items.append(f"Symptoms: {clinical_data['symptoms']}")
89
+ if clinical_data.get("history"): context_items.append(f"History: {clinical_data['history']}")
90
+ if clinical_data.get("medications"): context_items.append(f"Medications: {clinical_data['medications']}")
91
+
92
+ clinical_context += "\n".join(f"β€’ {item}" for item in context_items) + "\n"
93
+
94
+ # Analysis type specific instructions
95
+ analysis_instructions = {
96
+ "Comprehensive": """
97
+ **PROVIDE COMPREHENSIVE ANALYSIS:**
98
+ 1. **Image Description**: Type, quality, anatomical structures visible
99
+ 2. **Clinical Findings**: Normal and abnormal observations
100
+ 3. **Interpretation**: Clinical significance and differential diagnosis
101
+ 4. **Recommendations**: Next steps and follow-up suggestions
102
+ 5. **Limitations**: What cannot be determined from this image alone
103
+ """,
104
+ "Quick Assessment": """
105
+ **PROVIDE QUICK ASSESSMENT:**
106
+ 1. **Key Findings**: Most important observations
107
+ 2. **Clinical Impression**: Primary diagnostic considerations
108
+ 3. **Urgent Concerns**: Any findings requiring immediate attention
109
+ 4. **Next Steps**: Essential recommendations
110
+ """,
111
+ "Educational": """
112
+ **PROVIDE EDUCATIONAL ANALYSIS:**
113
+ 1. **Learning Points**: Key educational aspects of this case
114
+ 2. **Normal vs Abnormal**: Clear explanation of findings
115
+ 3. **Clinical Correlation**: How image relates to symptoms/history
116
+ 4. **Teaching Insights**: Important concepts demonstrated
117
+ """
118
+ }
119
+
120
+ focus_instruction = ""
121
+ if focus_areas and focus_areas.strip():
122
+ focus_instruction = f"\n**SPECIAL FOCUS**: Pay particular attention to: {focus_areas}\n"
123
+
124
+ disclaimer = """
125
+ **MEDICAL DISCLAIMER**: This AI analysis is for educational purposes only. Always consult qualified healthcare professionals for medical diagnosis and treatment decisions.
126
+ """
127
+
128
+ return base_prompt + clinical_context + analysis_instructions.get(analysis_type, analysis_instructions["Comprehensive"]) + focus_instruction + disclaimer
129
+
130
+ def analyze_medical_image(
131
+ image: Image.Image,
132
+ age: str,
133
+ gender: str,
134
+ symptoms: str,
135
+ history: str,
136
+ medications: str,
137
+ analysis_type: str,
138
+ focus_areas: str,
139
+ progress=gr.Progress()
140
+ ) -> Tuple[str, str]:
141
+ """Main analysis function for Gradio interface"""
142
+
143
+ if image is None:
144
+ return "❌ Please upload an image first.", ""
145
+
146
+ # Load model if needed
147
+ progress(0.1, desc="Loading model...")
148
+ if not load_model_cached():
149
+ return "❌ Failed to load model. Please try again.", ""
150
+
151
+ try:
152
+ progress(0.3, desc="Preparing analysis...")
153
+
154
+ # Prepare clinical data
155
+ clinical_data = {
156
+ "age": age.strip(),
157
+ "gender": gender,
158
+ "symptoms": symptoms.strip(),
159
+ "history": history.strip(),
160
+ "medications": medications.strip()
161
+ }
162
+
163
+ # Create prompt
164
+ prompt = create_medical_prompt(clinical_data, analysis_type, focus_areas)
165
+
166
+ progress(0.5, desc="Processing image...")
167
+
168
+ # Prepare messages for model
169
+ messages = [
170
+ {
171
+ "role": "user",
172
+ "content": [
173
+ {"type": "image", "image": image},
174
+ {"type": "text", "text": prompt}
175
+ ]
176
+ }
177
+ ]
178
+
179
+ # Process inputs
180
+ text = PROCESSOR.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
+ image_inputs, video_inputs = process_vision_info(messages)
182
+
183
+ inputs = PROCESSOR(
184
+ text=[text],
185
+ images=image_inputs,
186
+ videos=video_inputs,
187
+ padding=True,
188
+ return_tensors="pt",
189
+ )
190
+
191
+ inputs = inputs.to(DEVICE)
192
+
193
+ progress(0.7, desc="Generating analysis...")
194
+
195
+ # Generate response with optimized parameters for HF Spaces
196
+ with torch.no_grad():
197
+ generated_ids = MODEL.generate(
198
+ **inputs,
199
+ max_new_tokens=1024, # Reduced for faster processing
200
+ do_sample=True,
201
+ temperature=0.3,
202
+ top_p=0.8,
203
+ repetition_penalty=1.1,
204
+ pad_token_id=PROCESSOR.tokenizer.eos_token_id,
205
+ eos_token_id=PROCESSOR.tokenizer.eos_token_id,
206
+ )
207
+
208
+ generated_ids_trimmed = [
209
+ out_ids[len(in_ids):]
210
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
211
+ ]
212
+
213
+ response = PROCESSOR.batch_decode(
214
+ generated_ids_trimmed,
215
+ skip_special_tokens=True,
216
+ clean_up_tokenization_spaces=False
217
+ )[0]
218
+
219
+ progress(1.0, desc="Analysis complete!")
220
+
221
+ # Create download content
222
+ report_data = {
223
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S UTC"),
224
+ "model": "Qwen2.5-VL-3B-Instruct",
225
+ "analysis_type": analysis_type,
226
+ "clinical_data": clinical_data,
227
+ "focus_areas": focus_areas,
228
+ "analysis": response
229
+ }
230
+
231
+ download_content = json.dumps(report_data, indent=2)
232
+
233
+ return response, download_content
234
+
235
+ except Exception as e:
236
+ return f"❌ Analysis failed: {str(e)}\n\nPlease try again or contact support.", ""
237
+
238
+ def create_interface():
239
+ """Create the Gradio interface"""
240
+
241
+ # Custom CSS for medical theme
242
+ css = """
243
+ .gradio-container {
244
+ max-width: 1200px !important;
245
+ }
246
+ .medical-header {
247
+ text-align: center;
248
+ color: #2c5aa0;
249
+ margin-bottom: 20px;
250
+ }
251
+ .clinical-section {
252
+ background-color: #f8f9fa;
253
+ padding: 15px;
254
+ border-radius: 8px;
255
+ margin: 10px 0;
256
+ }
257
+ """
258
+
259
+ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Medical AI Analyzer") as interface:
260
+
261
+ # Header
262
+ gr.HTML("""
263
+ <div class="medical-header">
264
+ <h1>πŸ₯ Medical Image AI Analyzer</h1>
265
+ <h3>Advanced Medical Image Analysis using Qwen-VL</h3>
266
+ <p><em>Upload medical images and provide clinical context for AI-powered analysis</em></p>
267
+ </div>
268
+ """)
269
+
270
+ with gr.Row():
271
+ # Left column - Inputs
272
+ with gr.Column(scale=1):
273
+ gr.Markdown("## πŸ“€ Upload Medical Image")
274
+
275
+ image_input = gr.Image(
276
+ type="pil",
277
+ label="Medical Image",
278
+ height=300,
279
+ sources=["upload", "clipboard"]
280
+ )
281
+
282
+ gr.Markdown("*Supported: X-rays, CT, MRI, photographs, microscopy, etc.*")
283
+
284
+ gr.Markdown("## πŸ“‹ Clinical Information")
285
+
286
+ with gr.Group():
287
+ with gr.Row():
288
+ age_input = gr.Textbox(
289
+ label="Patient Age",
290
+ placeholder="e.g., 45 years",
291
+ max_lines=1
292
+ )
293
+ gender_input = gr.Dropdown(
294
+ choices=["", "Male", "Female", "Other"],
295
+ label="Gender",
296
+ value=""
297
+ )
298
+
299
+ symptoms_input = gr.Textbox(
300
+ label="Chief Complaint / Symptoms",
301
+ placeholder="e.g., Chest pain, shortness of breath for 3 days",
302
+ lines=2
303
+ )
304
+
305
+ history_input = gr.Textbox(
306
+ label="Medical History",
307
+ placeholder="e.g., Hypertension, diabetes, previous surgeries",
308
+ lines=2
309
+ )
310
+
311
+ medications_input = gr.Textbox(
312
+ label="Current Medications",
313
+ placeholder="e.g., Metformin, Lisinopril, Aspirin",
314
+ lines=2
315
+ )
316
+
317
+ gr.Markdown("## βš™οΈ Analysis Settings")
318
+
319
+ analysis_type = gr.Radio(
320
+ choices=["Comprehensive", "Quick Assessment", "Educational"],
321
+ label="Analysis Type",
322
+ value="Comprehensive",
323
+ info="Choose the depth and focus of analysis"
324
+ )
325
+
326
+ focus_areas = gr.Textbox(
327
+ label="Focus Areas (Optional)",
328
+ placeholder="e.g., cardiac, pulmonary, neurological",
329
+ info="Specific areas to emphasize in analysis"
330
+ )
331
+
332
+ analyze_btn = gr.Button(
333
+ "πŸ”¬ Analyze Medical Image",
334
+ variant="primary",
335
+ size="lg"
336
+ )
337
+
338
+ # Right column - Results
339
+ with gr.Column(scale=1):
340
+ gr.Markdown("## πŸ€– AI Analysis Results")
341
+
342
+ analysis_output = gr.Textbox(
343
+ label="Medical Analysis",
344
+ lines=20,
345
+ max_lines=30,
346
+ show_copy_button=True,
347
+ placeholder="Analysis results will appear here after processing..."
348
+ )
349
+
350
+ download_file = gr.File(
351
+ label="πŸ“₯ Download Analysis Report",
352
+ visible=False
353
+ )
354
+
355
+ # Hidden component to store download content
356
+ download_content = gr.Textbox(visible=False)
357
+
358
+ # Example section
359
+ with gr.Accordion("πŸ’‘ Example Use Cases & Tips", open=False):
360
+ gr.Markdown("""
361
+ ### πŸ” **Supported Medical Images:**
362
+ - **Radiology**: X-rays, CT scans, MRI images, Ultrasound
363
+ - **Pathology**: Histological slides, Cytology specimens
364
+ - **Dermatology**: Skin lesions, Rashes, Clinical photos
365
+ - **Ophthalmology**: Fundus photos, OCT images
366
+ - **Clinical Photography**: Wound assessment, Physical findings
367
+
368
+ ### πŸ“ **Tips for Better Analysis:**
369
+ - **Provide clinical context**: Age, symptoms, and history improve accuracy
370
+ - **Use specific focus areas**: "cardiac silhouette, lung fields" vs just "chest"
371
+ - **Choose appropriate analysis type**: Comprehensive for complex cases, Quick for screening
372
+ - **High-quality images**: Clear, well-lit images produce better results
373
+
374
+ ### ⚠️ **Important Limitations:**
375
+ - This AI tool is for educational and research purposes only
376
+ - Always consult qualified healthcare professionals for medical decisions
377
+ - AI cannot replace clinical expertise and physical examination
378
+ - Results should be validated by medical professionals
379
+ """)
380
+
381
+ # Footer
382
+ gr.HTML("""
383
+ <div style="text-align: center; margin-top: 20px; padding: 15px; background-color: #fff3cd; border-radius: 8px;">
384
+ <strong>⚠️ Medical Disclaimer:</strong> This AI tool is for educational purposes only.
385
+ It should never replace professional medical diagnosis or treatment.
386
+ Always consult qualified healthcare providers for medical decisions.
387
+ </div>
388
+ """)
389
+
390
+ # Event handlers
391
+ def create_download_file(content):
392
+ if content:
393
+ filename = f"medical_analysis_{int(time.time())}.json"
394
+ with open(filename, "w") as f:
395
+ f.write(content)
396
+ return gr.File(value=filename, visible=True)
397
+ return gr.File(visible=False)
398
+
399
+ analyze_btn.click(
400
+ fn=analyze_medical_image,
401
+ inputs=[
402
+ image_input, age_input, gender_input, symptoms_input,
403
+ history_input, medications_input, analysis_type, focus_areas
404
+ ],
405
+ outputs=[analysis_output, download_content]
406
+ ).then(
407
+ fn=create_download_file,
408
+ inputs=[download_content],
409
+ outputs=[download_file]
410
+ )
411
+
412
+ return interface
413
+
414
+ if __name__ == "__main__":
415
+ # Pre-load model to reduce first-run latency
416
+ print("Initializing Medical AI Analyzer...")
417
+ load_model_cached()
418
+
419
+ # Create and launch interface
420
+ interface = create_interface()
421
+
422
+ # Launch with settings optimized for HF Spaces
423
+ interface.launch(
424
+ server_name="0.0.0.0",
425
+ server_port=7860,
426
+ share=False,
427
+ show_error=True,
428
+ quiet=False
429
+ )
430
+
431
+ """
432
+ HUGGING FACE SPACE CONFIGURATION:
433
+
434
+ 1. Create requirements.txt:
435
+ torch
436
+ torchvision
437
+ transformers>=4.37.0
438
+ accelerate
439
+ pillow
440
+ gradio
441
+ qwen-vl-utils
442
+ spaces
443
+
444
+ 2. Create app.py (this file)
445
+
446
+ 3. Create README.md:
447
+ ---
448
+ title: Medical Image AI Analyzer
449
+ emoji: πŸ₯
450
+ colorFrom: blue
451
+ colorTo: green
452
+ sdk: gradio
453
+ sdk_version: 4.44.0
454
+ app_file: app.py
455
+ pinned: false
456
+ license: apache-2.0
457
+ ---
458
+
459
+ 4. Optional - Create .gitattributes:
460
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
461
+ *.bin filter=lfs diff=lfs merge=lfs -text
462
+
463
+ DEPLOYMENT FEATURES:
464
+ βœ… Optimized for Hugging Face Spaces
465
+ βœ… Efficient memory usage
466
+ βœ… Progress indicators
467
+ βœ… Professional medical interface
468
+ βœ… Download analysis reports
469
+ βœ… Mobile-responsive design
470
+ βœ… Error handling and validation
471
+ βœ… Medical disclaimer compliance
472
+ """