Shakir60 commited on
Commit
2ceb6a1
·
verified ·
1 Parent(s): 18cc344

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -17
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
- from transformers import ViTForImageClassification, ViTImageProcessor
6
  from sentence_transformers import SentenceTransformer
7
  import matplotlib.pyplot as plt
8
  import logging
@@ -100,24 +100,48 @@ class ImageAnalyzer:
100
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
101
 
102
  try:
 
 
103
  self.model = ViTForImageClassification.from_pretrained(
104
  "google/vit-base-patch16-224",
105
- num_labels=len(self.defect_classes)
 
106
  ).to(self.device)
107
- self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
 
 
 
108
  except Exception as e:
109
  logger.error(f"Model initialization error: {e}")
110
  self.model = None
111
- self.processor = None
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def analyze_image(self, image):
 
114
  try:
115
- # Ensure image is RGB
116
- if image.mode != 'RGB':
117
- image = image.convert('RGB')
118
 
119
- # Process image
120
- inputs = self.processor(images=image, return_tensors="pt")
 
 
 
121
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
122
 
123
  # Get predictions
@@ -126,10 +150,24 @@ class ImageAnalyzer:
126
 
127
  # Get probabilities
128
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
129
- return {self.defect_classes[i]: float(probs[i]) for i in range(len(self.defect_classes))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  except Exception as e:
132
- logger.error(f"Analysis error: {e}")
133
  return None
134
 
135
  def get_groq_response(query: str, context: str) -> str:
@@ -170,11 +208,15 @@ def main():
170
 
171
  st.title("🏗️ Construction Defect Analyzer")
172
 
173
- # Initialize systems
174
- if 'analyzer' not in st.session_state:
175
- st.session_state.analyzer = ImageAnalyzer()
176
- if 'rag_system' not in st.session_state:
177
- st.session_state.rag_system = RAGSystem()
 
 
 
 
178
 
179
  # Create two columns
180
  col1, col2 = st.columns([1, 1])
@@ -241,7 +283,7 @@ def main():
241
  with st.expander("View retrieved information"):
242
  st.text(context)
243
 
244
- # Sidebar for information
245
  with st.sidebar:
246
  st.header("About")
247
  st.write("""
@@ -259,6 +301,12 @@ def main():
259
  st.success("Groq API: Connected")
260
  else:
261
  st.error("Groq API: Not configured")
 
 
 
 
 
 
262
 
263
  if __name__ == "__main__":
264
  main()
 
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
6
  from sentence_transformers import SentenceTransformer
7
  import matplotlib.pyplot as plt
8
  import logging
 
100
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
101
 
102
  try:
103
+ # Use feature extractor instead of processor
104
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
105
  self.model = ViTForImageClassification.from_pretrained(
106
  "google/vit-base-patch16-224",
107
+ num_labels=len(self.defect_classes),
108
+ ignore_mismatched_sizes=True
109
  ).to(self.device)
110
+
111
+ # Initialize the model weights for our specific classes
112
+ with torch.no_grad():
113
+ self.model.classifier = torch.nn.Linear(
114
+ in_features=self.model.classifier.in_features,
115
+ out_features=len(self.defect_classes)
116
+ )
117
+
118
  except Exception as e:
119
  logger.error(f"Model initialization error: {e}")
120
  self.model = None
121
+ self.feature_extractor = None
122
+
123
+ def preprocess_image(self, image):
124
+ """Preprocess image for model input"""
125
+ if image.mode != 'RGB':
126
+ image = image.convert('RGB')
127
+
128
+ # Resize image to expected size
129
+ width, height = 224, 224
130
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
131
+
132
+ return image
133
 
134
  def analyze_image(self, image):
135
+ """Analyze image for defects"""
136
  try:
137
+ # Preprocess image
138
+ processed_image = self.preprocess_image(image)
 
139
 
140
+ # Extract features
141
+ inputs = self.feature_extractor(
142
+ images=processed_image,
143
+ return_tensors="pt"
144
+ )
145
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
146
 
147
  # Get predictions
 
150
 
151
  # Get probabilities
152
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
153
+
154
+ # Add confidence threshold
155
+ confidence_threshold = 0.3
156
+ results = {
157
+ self.defect_classes[i]: float(probs[i])
158
+ for i in range(len(self.defect_classes))
159
+ if float(probs[i]) > confidence_threshold
160
+ }
161
+
162
+ # If no defects meet threshold, return the highest probability one
163
+ if not results:
164
+ max_idx = torch.argmax(probs)
165
+ results = {self.defect_classes[int(max_idx)]: float(probs[max_idx])}
166
+
167
+ return results
168
 
169
  except Exception as e:
170
+ logger.error(f"Analysis error: {str(e)}")
171
  return None
172
 
173
  def get_groq_response(query: str, context: str) -> str:
 
208
 
209
  st.title("🏗️ Construction Defect Analyzer")
210
 
211
+ # Initialize systems with error handling
212
+ try:
213
+ if 'analyzer' not in st.session_state:
214
+ st.session_state.analyzer = ImageAnalyzer()
215
+ if 'rag_system' not in st.session_state:
216
+ st.session_state.rag_system = RAGSystem()
217
+ except Exception as e:
218
+ st.error(f"Error initializing systems: {str(e)}")
219
+ return
220
 
221
  # Create two columns
222
  col1, col2 = st.columns([1, 1])
 
283
  with st.expander("View retrieved information"):
284
  st.text(context)
285
 
286
+ # Sidebar for information and settings
287
  with st.sidebar:
288
  st.header("About")
289
  st.write("""
 
301
  st.success("Groq API: Connected")
302
  else:
303
  st.error("Groq API: Not configured")
304
+
305
+ # Add settings section
306
+ st.subheader("Settings")
307
+ if st.button("Clear Session"):
308
+ st.session_state.clear()
309
+ st.success("Session cleared!")
310
 
311
  if __name__ == "__main__":
312
  main()