AliArshad commited on
Commit
d6d7239
·
verified ·
1 Parent(s): 0c2c6fc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import torch
5
+ from typing import Tuple, Optional, Dict, Any
6
+ from dataclasses import dataclass
7
+ import random
8
+ from datetime import datetime, timedelta
9
+ import os
10
+ from qwen_agent.agents import Assistant
11
+
12
+ @dataclass
13
+ class PatientMetadata:
14
+ age: int
15
+ smoking_status: str
16
+ family_history: bool
17
+ menopause_status: str
18
+ previous_mammogram: bool
19
+ breast_density: str
20
+ hormone_therapy: bool
21
+
22
+ @dataclass
23
+ class AnalysisResult:
24
+ has_tumor: bool
25
+ tumor_size: str
26
+ confidence: float
27
+ metadata: PatientMetadata
28
+
29
+ class BreastSinogramAnalyzer:
30
+ def __init__(self):
31
+ """Initialize the analyzer with required models."""
32
+ print("Initializing system...")
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ print(f"Using device: {self.device}")
35
+
36
+ self._init_vision_models()
37
+ self._init_llm()
38
+ print("Initialization complete!")
39
+
40
+ def _init_vision_models(self) -> None:
41
+ """Initialize vision models for abnormality detection and size measurement."""
42
+ print("Loading detection models...")
43
+ self.tumor_detector = AutoModelForImageClassification.from_pretrained(
44
+ "SIATCN/vit_tumor_classifier"
45
+ ).to(self.device).eval()
46
+ self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
47
+
48
+ self.size_detector = AutoModelForImageClassification.from_pretrained(
49
+ "SIATCN/vit_tumor_radius_detection_finetuned"
50
+ ).to(self.device).eval()
51
+ self.size_processor = AutoImageProcessor.from_pretrained(
52
+ "SIATCN/vit_tumor_radius_detection_finetuned"
53
+ )
54
+
55
+ def _init_llm(self) -> None:
56
+ """Initialize the Qwen model for report generation."""
57
+ print("Loading language model...")
58
+ self.agent = Assistant(
59
+ llm={
60
+ 'model': os.environ.get("MODELNAME"),
61
+ 'generate_cfg': {
62
+ 'max_input_tokens': 32768,
63
+ 'max_retries': 10,
64
+ 'temperature': float(os.environ.get("T", 0.001)),
65
+ 'repetition_penalty': float(os.environ.get("R", 1.0)),
66
+ "top_k": int(os.environ.get("K", 20)),
67
+ "top_p": float(os.environ.get("P", 0.8)),
68
+ }
69
+ },
70
+ name='QwQ-32B-preview',
71
+ description='Medical report generation model based on QwQ-32B-Preview',
72
+ system_message='You are an experienced radiologist providing clear and concise medical reports. You should think step-by-step and be precise in your analysis.',
73
+ rag_cfg={'max_ref_token': 32768, 'rag_searchers': []},
74
+ )
75
+
76
+ def _generate_synthetic_metadata(self) -> PatientMetadata:
77
+ """Generate realistic patient metadata for breast cancer screening."""
78
+ age = random.randint(40, 75)
79
+ smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"])
80
+ family_history = random.choice([True, False])
81
+ menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal"
82
+ previous_mammogram = random.choice([True, False])
83
+ breast_density = random.choice(["A: Almost entirely fatty",
84
+ "B: Scattered fibroglandular",
85
+ "C: Heterogeneously dense",
86
+ "D: Extremely dense"])
87
+ hormone_therapy = random.choice([True, False])
88
+
89
+ return PatientMetadata(
90
+ age=age,
91
+ smoking_status=smoking_status,
92
+ family_history=family_history,
93
+ menopause_status=menopause_status,
94
+ previous_mammogram=previous_mammogram,
95
+ breast_density=breast_density,
96
+ hormone_therapy=hormone_therapy
97
+ )
98
+
99
+ def _process_image(self, image: Image.Image) -> Image.Image:
100
+ """Process input image for model consumption."""
101
+ if image.mode != 'RGB':
102
+ image = image.convert('RGB')
103
+ return image.resize((224, 224))
104
+
105
+ @torch.no_grad()
106
+ def _analyze_image(self, image: Image.Image) -> AnalysisResult:
107
+ """Perform abnormality detection and size measurement."""
108
+ # Generate metadata
109
+ metadata = self._generate_synthetic_metadata()
110
+
111
+ # Detect abnormality
112
+ tumor_inputs = self.tumor_processor(image, return_tensors="pt").to(self.device)
113
+ tumor_outputs = self.tumor_detector(**tumor_inputs)
114
+ tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu()
115
+ has_tumor = tumor_probs[1] > tumor_probs[0]
116
+ confidence = float(tumor_probs[1] if has_tumor else tumor_probs[0])
117
+
118
+ # Measure size
119
+ size_inputs = self.size_processor(image, return_tensors="pt").to(self.device)
120
+ size_outputs = self.size_detector(**size_inputs)
121
+ size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
122
+ sizes = ["no-tumor", "0.5", "1.0", "1.5"]
123
+ tumor_size = sizes[size_pred.argmax().item()]
124
+
125
+ return AnalysisResult(has_tumor, tumor_size, confidence, metadata)
126
+
127
+ def _generate_medical_report(self, analysis: AnalysisResult) -> str:
128
+ """Generate a medical report using Qwen model."""
129
+ prompt = f"""Generate a brief medical report for this microwave breast imaging scan:
130
+
131
+ Findings:
132
+ - {'Abnormal' if analysis.has_tumor else 'Normal'} dielectric properties
133
+ - Size: {analysis.tumor_size} cm
134
+ - Confidence: {analysis.confidence:.2%}
135
+ - Patient age: {analysis.metadata.age}
136
+ - Risk factors: {', '.join([
137
+ 'family history' if analysis.metadata.family_history else '',
138
+ analysis.metadata.smoking_status.lower(),
139
+ 'hormone therapy' if analysis.metadata.hormone_therapy else ''
140
+ ]).strip(', ')}
141
+
142
+ Provide:
143
+ 1. One sentence interpreting the findings
144
+ 2. One clear management recommendation"""
145
+
146
+ try:
147
+ response = self.agent.chat(prompt)
148
+ if len(response.split()) >= 10:
149
+ return f"""INTERPRETATION AND RECOMMENDATION:
150
+ {response}"""
151
+
152
+ print("Report too short, using fallback")
153
+ return self._generate_fallback_report(analysis)
154
+
155
+ except Exception as e:
156
+ print(f"Error in report generation: {str(e)}")
157
+ return self._generate_fallback_report(analysis)
158
+
159
+ def _generate_fallback_report(self, analysis: AnalysisResult) -> str:
160
+ """Generate a simple fallback report."""
161
+ if analysis.has_tumor:
162
+ return f"""INTERPRETATION AND RECOMMENDATION:
163
+ Microwave imaging reveals abnormal dielectric properties measuring {analysis.tumor_size} cm with {analysis.confidence:.1%} confidence level.
164
+
165
+ {'Immediate conventional imaging and clinical correlation recommended.' if analysis.tumor_size in ['1.0', '1.5'] else 'Follow-up imaging recommended in 6 months.'}"""
166
+ else:
167
+ return f"""INTERPRETATION AND RECOMMENDATION:
168
+ Microwave imaging shows normal dielectric properties with {analysis.confidence:.1%} confidence level.
169
+
170
+ Routine screening recommended per standard protocol."""
171
+
172
+ def analyze(self, image: Image.Image) -> str:
173
+ """Main analysis pipeline."""
174
+ try:
175
+ processed_image = self._process_image(image)
176
+ analysis = self._analyze_image(processed_image)
177
+ report = self._generate_medical_report(analysis)
178
+
179
+ return f"""MICROWAVE IMAGING ANALYSIS:
180
+ • Detection: {'Positive' if analysis.has_tumor else 'Negative'}
181
+ • Size: {analysis.tumor_size} cm
182
+
183
+
184
+ PATIENT INFO:
185
+ • Age: {analysis.metadata.age} years
186
+ • Risk Factors: {', '.join([
187
+ 'family history' if analysis.metadata.family_history else '',
188
+ analysis.metadata.smoking_status.lower(),
189
+ 'hormone therapy' if analysis.metadata.hormone_therapy else '',
190
+ ]).strip(', ')}
191
+
192
+ REPORT:
193
+ {report}"""
194
+ except Exception as e:
195
+ return f"Error during analysis: {str(e)}"
196
+
197
+ def create_interface() -> gr.Interface:
198
+ """Create the Gradio interface."""
199
+ analyzer = BreastSinogramAnalyzer()
200
+
201
+ interface = gr.Interface(
202
+ fn=analyzer.analyze,
203
+ inputs=[
204
+ gr.Image(type="pil", label="Upload Breast Microwave Image")
205
+ ],
206
+ outputs=[
207
+ gr.Textbox(label="Analysis Results", lines=20)
208
+ ],
209
+ title="Breast Cancer Microwave Imaging Analysis System",
210
+ description="Upload a breast microwave image for comprehensive analysis and medical assessment.",
211
+ )
212
+
213
+ return interface
214
+
215
+ if __name__ == "__main__":
216
+ print("Starting application...")
217
+ interface = create_interface()
218
+ interface.launch(debug=True, share=True)