AliArshad commited on
Commit
08ce85d
·
verified ·
1 Parent(s): 686c3b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -71
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
3
  from PIL import Image
4
  import torch
5
  from typing import Tuple, Optional, Dict, Any
6
  from dataclasses import dataclass
7
  import random
8
- from datetime import datetime, timedelta
9
 
10
  @dataclass
11
  class PatientMetadata:
@@ -21,7 +20,6 @@ class PatientMetadata:
21
  class AnalysisResult:
22
  has_tumor: bool
23
  tumor_size: str
24
- confidence: float
25
  metadata: PatientMetadata
26
 
27
  class BreastSinogramAnalyzer:
@@ -51,18 +49,15 @@ class BreastSinogramAnalyzer:
51
  )
52
 
53
  def _init_llm(self) -> None:
54
- """Initialize the language model for report generation."""
55
- print("Loading language model pipeline...")
56
- self.pipe = pipeline(
57
- "text-generation",
58
- model="Qwen/QwQ-32B-Preview",
59
- torch_dtype=torch.float16,
60
- device_map="auto",
61
- model_kwargs={
62
- "load_in_4bit": False,
63
- "bnb_4bit_compute_dtype": torch.float16,
64
- }
65
  )
 
66
 
67
  def _generate_synthetic_metadata(self) -> PatientMetadata:
68
  """Generate realistic patient metadata for breast cancer screening."""
@@ -96,7 +91,6 @@ class BreastSinogramAnalyzer:
96
  @torch.no_grad()
97
  def _analyze_image(self, image: Image.Image) -> AnalysisResult:
98
  """Perform abnormality detection and size measurement."""
99
- # Generate metadata
100
  metadata = self._generate_synthetic_metadata()
101
 
102
  # Detect abnormality
@@ -104,61 +98,73 @@ class BreastSinogramAnalyzer:
104
  tumor_outputs = self.tumor_detector(**tumor_inputs)
105
  tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu()
106
  has_tumor = tumor_probs[1] > tumor_probs[0]
107
- confidence = float(tumor_probs[1] if has_tumor else tumor_probs[0])
108
 
109
- # Measure size
110
  size_inputs = self.size_processor(image, return_tensors="pt").to(self.device)
111
  size_outputs = self.size_detector(**size_inputs)
112
  size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
113
  sizes = ["no-tumor", "0.5", "1.0", "1.5"]
114
  tumor_size = sizes[size_pred.argmax().item()]
115
 
116
- return AnalysisResult(has_tumor, tumor_size, confidence, metadata)
117
 
118
  def _generate_medical_report(self, analysis: AnalysisResult) -> str:
119
- """Generate a simplified medical report."""
120
- prompt = f"""<|system|>You are a radiologist providing clear and concise medical reports.</s>
121
- <|user|>Generate a brief medical report for this microwave breast imaging scan:
122
-
123
- Findings:
124
- - {'Abnormal' if analysis.has_tumor else 'Normal'} dielectric properties
125
- - Size: {analysis.tumor_size} cm
126
- - Confidence: {analysis.confidence:.2%}
127
- - Patient age: {analysis.metadata.age}
 
 
 
 
 
 
 
 
128
  - Risk factors: {', '.join([
129
- 'family history' if analysis.metadata.family_history else '',
130
- analysis.metadata.smoking_status.lower(),
131
- 'hormone therapy' if analysis.metadata.hormone_therapy else ''
132
  ]).strip(', ')}
133
 
134
- Provide:
135
- 1. One sentence interpreting the findings
136
- 2. One clear management recommendation</s>
137
- <|assistant|>"""
138
-
139
- try:
140
- response = self.pipe(
141
- prompt,
 
 
 
 
 
 
 
 
142
  max_new_tokens=128,
143
  temperature=0.3,
144
  top_p=0.9,
145
  repetition_penalty=1.1,
146
- do_sample=True,
147
- num_return_sequences=1
148
- )[0]["generated_text"]
149
-
150
- # Extract assistant's response
151
- if "<|assistant|>" in response:
152
- report = response.split("<|assistant|>")[-1].strip()
153
- else:
154
- report = response[len(prompt):].strip()
155
-
156
- # Simple validation
157
- if len(report.split()) >= 10:
158
- return f"""INTERPRETATION AND RECOMMENDATION:
159
- {report}"""
160
 
161
- print("Report too short, using fallback")
162
  return self._generate_fallback_report(analysis)
163
 
164
  except Exception as e:
@@ -166,17 +172,19 @@ Provide:
166
  return self._generate_fallback_report(analysis)
167
 
168
  def _generate_fallback_report(self, analysis: AnalysisResult) -> str:
169
- """Generate a simple fallback report."""
170
  if analysis.has_tumor:
171
- return f"""INTERPRETATION AND RECOMMENDATION:
172
- Microwave imaging reveals abnormal dielectric properties measuring {analysis.tumor_size} cm with {analysis.confidence:.1%} confidence level.
173
 
174
- {'Immediate conventional imaging and clinical correlation recommended.' if analysis.tumor_size in ['1.0', '1.5'] else 'Follow-up imaging recommended in 6 months.'}"""
 
 
175
  else:
176
- return f"""INTERPRETATION AND RECOMMENDATION:
177
- Microwave imaging shows normal dielectric properties with {analysis.confidence:.1%} confidence level.
 
178
 
179
- Routine screening recommended per standard protocol."""
180
 
181
  def analyze(self, image: Image.Image) -> str:
182
  """Main analysis pipeline."""
@@ -185,20 +193,18 @@ Routine screening recommended per standard protocol."""
185
  analysis = self._analyze_image(processed_image)
186
  report = self._generate_medical_report(analysis)
187
 
188
- return f"""MICROWAVE IMAGING ANALYSIS:
189
- Detection: {'Positive' if analysis.has_tumor else 'Negative'}
190
- Size: {analysis.tumor_size} cm
191
-
192
 
193
- PATIENT INFO:
194
  • Age: {analysis.metadata.age} years
195
  • Risk Factors: {', '.join([
196
- 'family history' if analysis.metadata.family_history else '',
197
  analysis.metadata.smoking_status.lower(),
198
- 'hormone therapy' if analysis.metadata.hormone_therapy else '',
199
  ]).strip(', ')}
200
 
201
- REPORT:
202
  {report}"""
203
  except Exception as e:
204
  return f"Error during analysis: {str(e)}"
@@ -210,13 +216,13 @@ def create_interface() -> gr.Interface:
210
  interface = gr.Interface(
211
  fn=analyzer.analyze,
212
  inputs=[
213
- gr.Image(type="pil", label="Upload Breast Microwave Image")
214
  ],
215
  outputs=[
216
  gr.Textbox(label="Analysis Results", lines=20)
217
  ],
218
- title="Breast Cancer Microwave Imaging Analysis System",
219
- description="Upload a breast microwave image for comprehensive analysis and medical assessment.",
220
  )
221
 
222
  return interface
 
1
  import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer
3
  from PIL import Image
4
  import torch
5
  from typing import Tuple, Optional, Dict, Any
6
  from dataclasses import dataclass
7
  import random
 
8
 
9
  @dataclass
10
  class PatientMetadata:
 
20
  class AnalysisResult:
21
  has_tumor: bool
22
  tumor_size: str
 
23
  metadata: PatientMetadata
24
 
25
  class BreastSinogramAnalyzer:
 
49
  )
50
 
51
  def _init_llm(self) -> None:
52
+ """Initialize the Qwen language model for report generation."""
53
+ print("Loading Qwen language model...")
54
+ self.model_name = "Qwen/QwQ-32B-Preview"
55
+ self.model = AutoModelForCausalLM.from_pretrained(
56
+ self.model_name,
57
+ torch_dtype="auto",
58
+ device_map="auto"
 
 
 
 
59
  )
60
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
61
 
62
  def _generate_synthetic_metadata(self) -> PatientMetadata:
63
  """Generate realistic patient metadata for breast cancer screening."""
 
91
  @torch.no_grad()
92
  def _analyze_image(self, image: Image.Image) -> AnalysisResult:
93
  """Perform abnormality detection and size measurement."""
 
94
  metadata = self._generate_synthetic_metadata()
95
 
96
  # Detect abnormality
 
98
  tumor_outputs = self.tumor_detector(**tumor_inputs)
99
  tumor_probs = tumor_outputs.logits.softmax(dim=-1)[0].cpu()
100
  has_tumor = tumor_probs[1] > tumor_probs[0]
 
101
 
102
+ # Measure size if tumor detected
103
  size_inputs = self.size_processor(image, return_tensors="pt").to(self.device)
104
  size_outputs = self.size_detector(**size_inputs)
105
  size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
106
  sizes = ["no-tumor", "0.5", "1.0", "1.5"]
107
  tumor_size = sizes[size_pred.argmax().item()]
108
 
109
+ return AnalysisResult(has_tumor, tumor_size, metadata)
110
 
111
  def _generate_medical_report(self, analysis: AnalysisResult) -> str:
112
+ """Generate a clear medical report using Qwen."""
113
+ try:
114
+ messages = [
115
+ {
116
+ "role": "system",
117
+ "content": "You are a radiologist providing clear and straightforward medical reports. Focus on clarity and actionable recommendations."
118
+ },
119
+ {
120
+ "role": "user",
121
+ "content": f"""Generate a clear medical report for this breast imaging scan:
122
+
123
+ Scan Results:
124
+ - Finding: {'Abnormal area detected' if analysis.has_tumor else 'No abnormalities detected'}
125
+ {f'- Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''}
126
+
127
+ Patient Information:
128
+ - Age: {analysis.metadata.age} years
129
  - Risk factors: {', '.join([
130
+ 'family history of breast cancer' if analysis.metadata.family_history else '',
131
+ f'{analysis.metadata.smoking_status.lower()}',
132
+ 'currently on hormone therapy' if analysis.metadata.hormone_therapy else ''
133
  ]).strip(', ')}
134
 
135
+ Please provide:
136
+ 1. A clear interpretation of the findings
137
+ 2. A specific recommendation for next steps"""
138
+ }
139
+ ]
140
+
141
+ text = self.tokenizer.apply_chat_template(
142
+ messages,
143
+ tokenize=False,
144
+ add_generation_prompt=True
145
+ )
146
+
147
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
148
+
149
+ generated_ids = self.model.generate(
150
+ **model_inputs,
151
  max_new_tokens=128,
152
  temperature=0.3,
153
  top_p=0.9,
154
  repetition_penalty=1.1,
155
+ do_sample=True
156
+ )
157
+
158
+ generated_ids = [
159
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
160
+ ]
161
+
162
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
163
+
164
+ if len(response.split()) >= 10:
165
+ return f"""FINDINGS AND RECOMMENDATIONS:
166
+ {response}"""
 
 
167
 
 
168
  return self._generate_fallback_report(analysis)
169
 
170
  except Exception as e:
 
172
  return self._generate_fallback_report(analysis)
173
 
174
  def _generate_fallback_report(self, analysis: AnalysisResult) -> str:
175
+ """Generate a clear fallback report."""
176
  if analysis.has_tumor:
177
+ return f"""FINDINGS AND RECOMMENDATIONS:
 
178
 
179
+ Finding: An abnormal area measuring {analysis.tumor_size} cm was detected during the scan.
180
+
181
+ Recommendation: {'An immediate follow-up with conventional mammogram and ultrasound is required.' if analysis.tumor_size in ['1.0', '1.5'] else 'A follow-up scan is recommended in 6 months.'}"""
182
  else:
183
+ return """FINDINGS AND RECOMMENDATIONS:
184
+
185
+ Finding: No abnormal areas were detected during this scan.
186
 
187
+ Recommendation: Continue with routine screening as per standard guidelines."""
188
 
189
  def analyze(self, image: Image.Image) -> str:
190
  """Main analysis pipeline."""
 
193
  analysis = self._analyze_image(processed_image)
194
  report = self._generate_medical_report(analysis)
195
 
196
+ return f"""SCAN RESULTS:
197
+ {'⚠️ Abnormal area detected' if analysis.has_tumor else '✓ No abnormalities detected'}
198
+ {f'Size of abnormal area: {analysis.tumor_size} cm' if analysis.has_tumor else ''}
 
199
 
200
+ PATIENT INFORMATION:
201
  • Age: {analysis.metadata.age} years
202
  • Risk Factors: {', '.join([
203
+ 'family history of breast cancer' if analysis.metadata.family_history else '',
204
  analysis.metadata.smoking_status.lower(),
205
+ 'currently on hormone therapy' if analysis.metadata.hormone_therapy else '',
206
  ]).strip(', ')}
207
 
 
208
  {report}"""
209
  except Exception as e:
210
  return f"Error during analysis: {str(e)}"
 
216
  interface = gr.Interface(
217
  fn=analyzer.analyze,
218
  inputs=[
219
+ gr.Image(type="pil", label="Upload Breast Image for Analysis")
220
  ],
221
  outputs=[
222
  gr.Textbox(label="Analysis Results", lines=20)
223
  ],
224
+ title="Breast Imaging Analysis System",
225
+ description="Upload a breast image for analysis and medical assessment.",
226
  )
227
 
228
  return interface