AliArshad commited on
Commit
9b86949
·
verified ·
1 Parent(s): 1d0d380

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +371 -0
app.py CHANGED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer
3
+ from PIL import Image
4
+ import torch
5
+ from typing import Tuple, Optional, Dict, Any
6
+ from dataclasses import dataclass
7
+ import random
8
+ import tempfile
9
+ import webbrowser
10
+ import os
11
+ from datetime import datetime
12
+
13
+ @dataclass
14
+ class PatientMetadata:
15
+ age: int
16
+ smoking_status: str
17
+ family_history: bool
18
+ menopause_status: str
19
+ previous_mammogram: bool
20
+ breast_density: str
21
+ hormone_therapy: bool
22
+
23
+ @dataclass
24
+ class AnalysisResult:
25
+ has_tumor: bool
26
+ tumor_size: str
27
+ metadata: PatientMetadata
28
+
29
+ class BreastSinogramAnalyzer:
30
+ def __init__(self):
31
+ """Initialize the analyzer with required models."""
32
+ print("Initializing system...")
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ print(f"Using device: {self.device}")
35
+
36
+ self._init_vision_models()
37
+ self._init_llm()
38
+ print("Initialization complete!")
39
+
40
+ def _init_vision_models(self) -> None:
41
+ """Initialize vision models for abnormality detection and size measurement."""
42
+ print("Loading detection models...")
43
+ self.tumor_detector = AutoModelForImageClassification.from_pretrained(
44
+ "SIATCN/vit_tumor_classifier"
45
+ ).to(self.device).eval()
46
+ self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
47
+
48
+ self.size_detector = AutoModelForImageClassification.from_pretrained(
49
+ "SIATCN/vit_tumor_radius_detection_finetuned"
50
+ ).to(self.device).eval()
51
+ self.size_processor = AutoImageProcessor.from_pretrained(
52
+ "SIATCN/vit_tumor_radius_detection_finetuned"
53
+ )
54
+
55
+ def _init_llm(self) -> None:
56
+ """Initialize the Qwen language model for report generation."""
57
+ print("Loading Qwen language model...")
58
+ self.model_name = "Qwen/QwQ-32B-Preview"
59
+ self.model = AutoModelForCausalLM.from_pretrained(
60
+ self.model_name,
61
+ torch_dtype="auto",
62
+ device_map="auto"
63
+ )
64
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
65
+
66
+ def _generate_synthetic_metadata(self) -> PatientMetadata:
67
+ """Generate realistic patient metadata for breast cancer screening."""
68
+ age = random.randint(40, 75)
69
+ smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"])
70
+ family_history = random.choice([True, False])
71
+ menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal"
72
+ previous_mammogram = random.choice([True, False])
73
+ breast_density = random.choice(["A: Almost entirely fatty",
74
+ "B: Scattered fibroglandular",
75
+ "C: Heterogeneously dense",
76
+ "D: Extremely dense"])
77
+ hormone_therapy = random.choice([True, False])
78
+
79
+ return PatientMetadata(
80
+ age=age,
81
+ smoking_status=smoking_status,
82
+ family_history=family_history,
83
+ menopause_status=menopause_status,
84
+ previous_mammogram=previous_mammogram,
85
+ breast_density=breast_density,
86
+ hormone_therapy=hormone_therapy
87
+ )
88
+
89
+ def _process_image(self, image: Image.Image) -> Image.Image:
90
+ """Process input image for model consumption."""
91
+ if image.mode != 'RGB':
92
+ image = image.convert('RGB')
93
+ return image.resize((224, 224))
94
+
95
+ @torch.no_grad()
96
+ def _analyze_image(self, image: Image.Image) -> AnalysisResult:
97
+ """Perform abnormality detection and size measurement."""
98
+ metadata = self._generate_synthetic_metadata()
99
+
100
+ # Detect abnormality
101
+ tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device)
102
+ tumor_outputs = self.tumor_detector(**tumor_inputs)
103
+ tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu()
104
+ has_tumor = tumor_probs[1] > tumor_probs[0]
105
+
106
+ # Measure size if tumor detected
107
+ size_inputs = self.size_processor(image, return_tensors="pt").to(self.device)
108
+ size_outputs = self.size_detector(**size_inputs)
109
+ size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
110
+ sizes = ["no-tumor", "0.5", "1.0", "1.5"]
111
+ tumor_size = sizes[size_pred.argmax().item()]
112
+
113
+ return AnalysisResult(has_tumor, tumor_size, metadata)
114
+
115
+ def _generate_medical_report(self, analysis: AnalysisResult) -> str:
116
+ """Generate a clear medical report using Qwen."""
117
+ try:
118
+ messages = [
119
+ {
120
+ "role": "system",
121
+ "content": "You are a radiologist providing clear and straightforward medical reports. Focus on clarity and actionable recommendations."
122
+ },
123
+ {
124
+ "role": "user",
125
+ "content": f"""Generate a clear medical report for this breast imaging scan:
126
+
127
+ Scan Results:
128
+ - Finding: {'Abnormal area detected' if analysis.has_tumor else 'No abnormalities detected'}
129
+ {f'- Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''}
130
+
131
+ Patient Information:
132
+ - Age: {analysis.metadata.age} years
133
+ - Risk factors: {', '.join([
134
+ 'family history of breast cancer' if analysis.metadata.family_history else '',
135
+ f'{analysis.metadata.smoking_status.lower()}',
136
+ 'currently on hormone therapy' if analysis.metadata.hormone_therapy else ''
137
+ ]).strip(', ')}
138
+
139
+ Please provide:
140
+ 1. A clear interpretation of the findings
141
+ 2. A specific recommendation for next steps"""
142
+ }
143
+ ]
144
+
145
+ text = self.tokenizer.apply_chat_template(
146
+ messages,
147
+ tokenize=False,
148
+ add_generation_prompt=True
149
+ )
150
+
151
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
152
+
153
+ generated_ids = self.model.generate(
154
+ **model_inputs,
155
+ max_new_tokens=128,
156
+ temperature=0.3,
157
+ top_p=0.9,
158
+ repetition_penalty=1.1,
159
+ do_sample=True
160
+ )
161
+
162
+ generated_ids = [
163
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
164
+ ]
165
+
166
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
167
+
168
+ if len(response.split()) >= 10:
169
+ return f"""FINDINGS AND RECOMMENDATIONS:
170
+ {response}"""
171
+
172
+ return self._generate_fallback_report(analysis)
173
+
174
+ except Exception as e:
175
+ print(f"Error in report generation: {str(e)}")
176
+ return self._generate_fallback_report(analysis)
177
+
178
+ def _generate_fallback_report(self, analysis: AnalysisResult) -> str:
179
+ """Generate a clear fallback report."""
180
+ if analysis.has_tumor:
181
+ return f"""FINDINGS AND RECOMMENDATIONS:
182
+
183
+ Finding: An abnormal area measuring {analysis.tumor_size} cm was detected during the scan.
184
+
185
+ Recommendation: {'An immediate follow-up with conventional mammogram and ultrasound is required.' if analysis.tumor_size in ['1.0', '1.5'] else 'A follow-up scan is recommended in 6 months.'}"""
186
+ else:
187
+ return """FINDINGS AND RECOMMENDATIONS:
188
+
189
+ Finding: No abnormal areas were detected during this scan.
190
+
191
+ Recommendation: Continue with routine screening as per standard guidelines."""
192
+
193
+ def _generate_print_preview(self, analysis_text: str, image: Image.Image) -> str:
194
+ """Generate an HTML print preview."""
195
+ temp_dir = tempfile.gettempdir()
196
+ temp_image_path = os.path.join(temp_dir, 'scan_image.png')
197
+ image.save(temp_image_path)
198
+
199
+ current_date = datetime.now().strftime("%B %d, %Y")
200
+
201
+ html_content = f"""
202
+ <!DOCTYPE html>
203
+ <html>
204
+ <head>
205
+ <title>Medical Imaging Report</title>
206
+ <style>
207
+ @media print {{
208
+ body {{
209
+ font-family: Arial, sans-serif;
210
+ line-height: 1.6;
211
+ padding: 20px;
212
+ max-width: 800px;
213
+ margin: 0 auto;
214
+ }}
215
+ .header {{
216
+ text-align: center;
217
+ margin-bottom: 30px;
218
+ border-bottom: 2px solid #000;
219
+ padding-bottom: 10px;
220
+ }}
221
+ .date {{
222
+ text-align: right;
223
+ margin-bottom: 20px;
224
+ }}
225
+ .content {{
226
+ margin-bottom: 30px;
227
+ }}
228
+ .scan-image {{
229
+ text-align: center;
230
+ margin: 20px 0;
231
+ }}
232
+ .scan-image img {{
233
+ max-width: 500px;
234
+ height: auto;
235
+ }}
236
+ .footer {{
237
+ margin-top: 50px;
238
+ border-top: 1px solid #000;
239
+ padding-top: 20px;
240
+ }}
241
+ @page {{
242
+ size: A4;
243
+ margin: 2cm;
244
+ }}
245
+ .no-print {{
246
+ display: none;
247
+ }}
248
+ }}
249
+ /* Screen-only styles */
250
+ body {{
251
+ font-family: Arial, sans-serif;
252
+ line-height: 1.6;
253
+ padding: 20px;
254
+ max-width: 800px;
255
+ margin: 0 auto;
256
+ }}
257
+ .print-button {{
258
+ background-color: #007bff;
259
+ color: white;
260
+ padding: 10px 20px;
261
+ border: none;
262
+ border-radius: 5px;
263
+ cursor: pointer;
264
+ margin-bottom: 20px;
265
+ }}
266
+ .print-button:hover {{
267
+ background-color: #0056b3;
268
+ }}
269
+ </style>
270
+ </head>
271
+ <body>
272
+ <button onclick="window.print()" class="print-button no-print">Print Report</button>
273
+
274
+ <div class="header">
275
+ <h1>Medical Imaging Report</h1>
276
+ </div>
277
+
278
+ <div class="date">
279
+ Report Date: {current_date}
280
+ </div>
281
+
282
+ <div class="scan-image">
283
+ <img src="file://{temp_image_path}" alt="Scan Image">
284
+ </div>
285
+
286
+ <div class="content">
287
+ <pre style="white-space: pre-wrap; font-family: Arial, sans-serif;">{analysis_text}</pre>
288
+ </div>
289
+
290
+ <div class="footer">
291
+ <p>This report is generated by an automated analysis system and should be reviewed by a qualified healthcare professional.</p>
292
+ </div>
293
+ </body>
294
+ </html>
295
+ """
296
+
297
+ temp_html_path = os.path.join(temp_dir, 'report.html')
298
+ with open(temp_html_path, 'w', encoding='utf-8') as f:
299
+ f.write(html_content)
300
+
301
+ return temp_html_path
302
+
303
+ def analyze(self, image: Image.Image) -> Tuple[str, str]:
304
+ """Main analysis pipeline."""
305
+ try:
306
+ processed_image = self._process_image(image)
307
+ analysis = self._analyze_image(processed_image)
308
+ report = self._generate_medical_report(analysis)
309
+
310
+ analysis_text = f"""SCAN RESULTS:
311
+ {'⚠️ Abnormal area detected' if analysis.has_tumor else '✓ No abnormalities detected'}
312
+ {f'Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''}
313
+
314
+ PATIENT INFORMATION:
315
+ • Age: {analysis.metadata.age} years
316
+ • Risk Factors: {', '.join([
317
+ 'family history of breast cancer' if analysis.metadata.family_history else '',
318
+ analysis.metadata.smoking_status.lower(),
319
+ 'currently on hormone therapy' if analysis.metadata.hormone_therapy else '',
320
+ ]).strip(', ')}
321
+
322
+ {report}"""
323
+
324
+ preview_path = self._generate_print_preview(analysis_text, image)
325
+
326
+ return analysis_text, preview_path
327
+ except Exception as e:
328
+ return f"Error during analysis: {str(e)}", ""
329
+
330
+ def open_print_preview(preview_path: str) -> None:
331
+ """Open the print preview in the default browser."""
332
+ if preview_path:
333
+ webbrowser.open(f'file://{preview_path}')
334
+ return None
335
+
336
+ def create_interface() -> gr.Blocks:
337
+ """Create the Gradio interface."""
338
+ analyzer = BreastSinogramAnalyzer()
339
+
340
+ with gr.Blocks() as interface:
341
+ gr.Markdown("# Breast Imaging Analysis System")
342
+ gr.Markdown("Upload a breast image for analysis and medical assessment.")
343
+
344
+ with gr.Row():
345
+ input_image = gr.Image(type="pil", label="Upload Breast Image for Analysis")
346
+
347
+ with gr.Row():
348
+ analyze_btn = gr.Button("Analyze Image", variant="primary")
349
+ print_btn = gr.Button("Open Print Preview")
350
+
351
+ output_text = gr.Textbox(label="Analysis Results", lines=20)
352
+ preview_path = gr.Textbox(visible=False)
353
+
354
+ analyze_btn.click(
355
+ fn=analyzer.analyze,
356
+ inputs=[input_image],
357
+ outputs=[output_text, preview_path]
358
+ )
359
+
360
+ print_btn.click(
361
+ fn=open_print_preview,
362
+ inputs=[preview_path],
363
+ outputs=None
364
+ )
365
+
366
+ return interface
367
+
368
+ if __name__ == "__main__":
369
+ print("Starting application...")
370
+ interface = create_interface()
371
+ interface.launch(debug=True, share=True)