Shakir60 commited on
Commit
728272c
·
verified ·
1 Parent(s): 3c13553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -16
app.py CHANGED
@@ -112,9 +112,10 @@ class RAGSystem:
112
  return ""
113
 
114
  class ImageAnalyzer:
115
- def __init__(self):
116
- self.device = "cpu" # Force CPU usage for better compatibility
117
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
 
118
  self._model = None
119
  self._feature_extractor = None
120
 
@@ -127,22 +128,60 @@ class ImageAnalyzer:
127
  @property
128
  def feature_extractor(self):
129
  if self._feature_extractor is None:
130
- self._feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
131
  return self._feature_extractor
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def _load_model(self):
134
  try:
135
- model = ViTForImageClassification.from_pretrained(
136
- "google/vit-base-patch16-224",
137
- num_labels=len(self.defect_classes),
138
- ignore_mismatched_sizes=True
139
- ).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
 
141
  with torch.no_grad():
142
- model.classifier = torch.nn.Linear(
143
- in_features=model.classifier.in_features,
144
- out_features=len(self.defect_classes)
145
- )
 
 
 
146
  return model
147
  except Exception as e:
148
  logger.error(f"Model initialization error: {e}")
@@ -150,7 +189,7 @@ class ImageAnalyzer:
150
 
151
  def preprocess_image(self, image_bytes):
152
  """Preprocess image for model input"""
153
- return _cached_preprocess_image(image_bytes)
154
 
155
  def analyze_image(self, image):
156
  """Analyze image for defects"""
@@ -187,20 +226,26 @@ class ImageAnalyzer:
187
  return None
188
 
189
  @st.cache_data
190
- def _cached_preprocess_image(image_bytes):
191
  """Cached version of image preprocessing"""
192
  try:
193
  image = Image.open(image_bytes)
194
  if image.mode != 'RGB':
195
  image = image.convert('RGB')
196
 
197
- width, height = 224, 224
 
 
 
 
 
198
  image = image.resize((width, height), Image.Resampling.LANCZOS)
199
  return image
200
  except Exception as e:
201
  logger.error(f"Image preprocessing error: {e}")
202
  return None
203
-
 
204
  def get_groq_response(query: str, context: str) -> str:
205
  """Get response from Groq LLM with caching"""
206
  try:
 
112
  return ""
113
 
114
  class ImageAnalyzer:
115
+ def __init__(self, model_name="microsoft/swin-base-patch4-window7-224-in22k"):
116
+ self.device = "cpu"
117
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
118
+ self.model_name = model_name
119
  self._model = None
120
  self._feature_extractor = None
121
 
 
128
  @property
129
  def feature_extractor(self):
130
  if self._feature_extractor is None:
131
+ self._feature_extractor = self._load_feature_extractor()
132
  return self._feature_extractor
133
 
134
+ def _load_feature_extractor(self):
135
+ """Load the appropriate feature extractor based on model type"""
136
+ try:
137
+ if "swin" in self.model_name:
138
+ from transformers import AutoFeatureExtractor
139
+ return AutoFeatureExtractor.from_pretrained(self.model_name)
140
+ elif "convnext" in self.model_name:
141
+ from transformers import ConvNextFeatureExtractor
142
+ return ConvNextFeatureExtractor.from_pretrained(self.model_name)
143
+ else:
144
+ from transformers import ViTFeatureExtractor
145
+ return ViTFeatureExtractor.from_pretrained(self.model_name)
146
+ except Exception as e:
147
+ logger.error(f"Feature extractor initialization error: {e}")
148
+ return None
149
+
150
  def _load_model(self):
151
  try:
152
+ if "swin" in self.model_name:
153
+ from transformers import SwinForImageClassification
154
+ model = SwinForImageClassification.from_pretrained(
155
+ self.model_name,
156
+ num_labels=len(self.defect_classes),
157
+ ignore_mismatched_sizes=True
158
+ )
159
+ elif "convnext" in self.model_name:
160
+ from transformers import ConvNextForImageClassification
161
+ model = ConvNextForImageClassification.from_pretrained(
162
+ self.model_name,
163
+ num_labels=len(self.defect_classes),
164
+ ignore_mismatched_sizes=True
165
+ )
166
+ else:
167
+ from transformers import ViTForImageClassification
168
+ model = ViTForImageClassification.from_pretrained(
169
+ self.model_name,
170
+ num_labels=len(self.defect_classes),
171
+ ignore_mismatched_sizes=True
172
+ )
173
+
174
+ model = model.to(self.device)
175
 
176
+ # Reinitialize the classifier layer
177
  with torch.no_grad():
178
+ if hasattr(model, 'classifier'):
179
+ in_features = model.classifier.in_features
180
+ model.classifier = torch.nn.Linear(in_features, len(self.defect_classes))
181
+ elif hasattr(model, 'head'):
182
+ in_features = model.head.in_features
183
+ model.head = torch.nn.Linear(in_features, len(self.defect_classes))
184
+
185
  return model
186
  except Exception as e:
187
  logger.error(f"Model initialization error: {e}")
 
189
 
190
  def preprocess_image(self, image_bytes):
191
  """Preprocess image for model input"""
192
+ return _cached_preprocess_image(image_bytes, self.model_name)
193
 
194
  def analyze_image(self, image):
195
  """Analyze image for defects"""
 
226
  return None
227
 
228
  @st.cache_data
229
+ def _cached_preprocess_image(image_bytes, model_name):
230
  """Cached version of image preprocessing"""
231
  try:
232
  image = Image.open(image_bytes)
233
  if image.mode != 'RGB':
234
  image = image.convert('RGB')
235
 
236
+ # Adjust size based on model requirements
237
+ if "convnext" in model_name:
238
+ width, height = 384, 384
239
+ else:
240
+ width, height = 224, 224
241
+
242
  image = image.resize((width, height), Image.Resampling.LANCZOS)
243
  return image
244
  except Exception as e:
245
  logger.error(f"Image preprocessing error: {e}")
246
  return None
247
+
248
+ @st.cache_data
249
  def get_groq_response(query: str, context: str) -> str:
250
  """Get response from Groq LLM with caching"""
251
  try: