Shakir60 commited on
Commit
f5b37b3
·
verified ·
1 Parent(s): 2c28066

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -93
app.py CHANGED
@@ -11,6 +11,7 @@ from typing import List, Dict
11
  from datetime import datetime
12
  from groq import Groq
13
  import os
 
14
 
15
  # Setup logging
16
  logging.basicConfig(level=logging.INFO)
@@ -18,13 +19,34 @@ logger = logging.getLogger(__name__)
18
 
19
  class RAGSystem:
20
  def __init__(self):
21
- self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
22
- self.knowledge_base = self.load_knowledge_base()
23
- self.vector_store = self.create_vector_store()
24
- self.query_history = []
25
 
26
- def load_knowledge_base(self) -> List[Dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """Load and preprocess knowledge base"""
 
28
  kb = {
29
  "spalling": [
30
  {
@@ -77,18 +99,13 @@ class RAGSystem:
77
  index.add(np.array(embeddings).astype('float32'))
78
  return index
79
 
 
80
  def get_relevant_context(self, query: str, k: int = 2) -> str:
81
  """Retrieve relevant context based on query"""
82
  try:
83
  query_embedding = self.embedding_model.encode([query])
84
  D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
85
  context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
86
-
87
- self.query_history.append({
88
- "timestamp": datetime.now().isoformat(),
89
- "query": query
90
- })
91
-
92
  return context
93
  except Exception as e:
94
  logger.error(f"Error retrieving context: {e}")
@@ -96,62 +113,73 @@ class RAGSystem:
96
 
97
  class ImageAnalyzer:
98
  def __init__(self):
99
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
100
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  try:
103
- # Use feature extractor instead of processor
104
- self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
105
- self.model = ViTForImageClassification.from_pretrained(
106
  "google/vit-base-patch16-224",
107
  num_labels=len(self.defect_classes),
108
  ignore_mismatched_sizes=True
109
  ).to(self.device)
110
 
111
- # Initialize the model weights for our specific classes
112
  with torch.no_grad():
113
- self.model.classifier = torch.nn.Linear(
114
- in_features=self.model.classifier.in_features,
115
  out_features=len(self.defect_classes)
116
  )
117
-
118
  except Exception as e:
119
  logger.error(f"Model initialization error: {e}")
120
- self.model = None
121
- self.feature_extractor = None
122
 
123
- def preprocess_image(self, image):
 
124
  """Preprocess image for model input"""
125
- if image.mode != 'RGB':
126
- image = image.convert('RGB')
127
-
128
- # Resize image to expected size
129
- width, height = 224, 224
130
- image = image.resize((width, height), Image.Resampling.LANCZOS)
131
-
132
- return image
 
 
 
133
 
134
  def analyze_image(self, image):
135
  """Analyze image for defects"""
136
  try:
137
- # Preprocess image
138
- processed_image = self.preprocess_image(image)
139
-
140
- # Extract features
141
  inputs = self.feature_extractor(
142
- images=processed_image,
143
  return_tensors="pt"
144
  )
145
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
146
 
147
- # Get predictions
148
  with torch.no_grad():
149
  outputs = self.model(**inputs)
150
 
151
- # Get probabilities
152
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
153
 
154
- # Add confidence threshold
155
  confidence_threshold = 0.3
156
  results = {
157
  self.defect_classes[i]: float(probs[i])
@@ -159,7 +187,6 @@ class ImageAnalyzer:
159
  if float(probs[i]) > confidence_threshold
160
  }
161
 
162
- # If no defects meet threshold, return the highest probability one
163
  if not results:
164
  max_idx = torch.argmax(probs)
165
  results = {self.defect_classes[int(max_idx)]: float(probs[max_idx])}
@@ -170,9 +197,13 @@ class ImageAnalyzer:
170
  logger.error(f"Analysis error: {str(e)}")
171
  return None
172
 
 
173
  def get_groq_response(query: str, context: str) -> str:
174
- """Get response from Groq LLM"""
175
  try:
 
 
 
176
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
177
 
178
  prompt = f"""Based on the following context about construction defects, answer the question.
@@ -197,7 +228,7 @@ def get_groq_response(query: str, context: str) -> str:
197
  return response.choices[0].message.content
198
  except Exception as e:
199
  logger.error(f"Groq API error: {e}")
200
- return f"Error: Unable to get response from AI model. Please check your API key and try again."
201
 
202
  def main():
203
  st.set_page_config(
@@ -208,56 +239,61 @@ def main():
208
 
209
  st.title("🏗️ Construction Defect Analyzer")
210
 
211
- # Initialize systems with error handling
212
- try:
213
- if 'analyzer' not in st.session_state:
214
- st.session_state.analyzer = ImageAnalyzer()
215
- if 'rag_system' not in st.session_state:
216
- st.session_state.rag_system = RAGSystem()
217
- except Exception as e:
218
- st.error(f"Error initializing systems: {str(e)}")
219
- return
220
 
221
- # Create two columns
222
  col1, col2 = st.columns([1, 1])
223
 
224
  with col1:
225
  st.subheader("Image Analysis")
226
- uploaded_file = st.file_uploader("Upload a construction image for analysis", type=["jpg", "jpeg", "png"])
 
 
 
 
227
 
228
  if uploaded_file is not None:
229
  try:
230
- # Read and display image
231
- image = Image.open(uploaded_file)
232
- st.image(image, caption='Uploaded Image', use_column_width=True)
233
 
234
- # Analyze image
235
- with st.spinner('Analyzing image...'):
236
- results = st.session_state.analyzer.analyze_image(image)
237
-
238
- if results:
239
- st.success('Analysis complete!')
240
 
241
- # Display results
242
- st.subheader("Detected Defects")
 
 
 
243
 
244
- # Create bar chart
245
- fig, ax = plt.subplots(figsize=(8, 4))
246
- defects = list(results.keys())
247
- probs = list(results.values())
248
- ax.barh(defects, probs)
249
- ax.set_xlim(0, 1)
250
- plt.tight_layout()
251
- st.pyplot(fig)
252
-
253
- # Get most likely defect
254
- most_likely_defect = max(results.items(), key=lambda x: x[1])[0]
255
- st.info(f"Most likely defect: {most_likely_defect}")
 
 
 
 
 
256
  else:
257
- st.error("Analysis failed. Please try again.")
258
-
259
  except Exception as e:
260
- st.error(f"Error: {str(e)}")
261
  logger.error(f"Process error: {e}")
262
 
263
  with col2:
@@ -272,18 +308,21 @@ def main():
272
  # Get context from RAG system
273
  context = st.session_state.rag_system.get_relevant_context(user_query)
274
 
275
- # Get response from Groq
276
- response = get_groq_response(user_query, context)
277
-
278
- # Display response
279
- st.write("Answer:")
280
- st.write(response)
281
-
282
- # Option to view context
283
- with st.expander("View retrieved information"):
284
- st.text(context)
 
 
 
 
285
 
286
- # Sidebar for information and settings
287
  with st.sidebar:
288
  st.header("About")
289
  st.write("""
@@ -304,9 +343,9 @@ def main():
304
 
305
  # Add settings section
306
  st.subheader("Settings")
307
- if st.button("Clear Session"):
308
- st.session_state.clear()
309
- st.success("Session cleared!")
310
 
311
  if __name__ == "__main__":
312
  main()
 
11
  from datetime import datetime
12
  from groq import Groq
13
  import os
14
+ from functools import lru_cache
15
 
16
  # Setup logging
17
  logging.basicConfig(level=logging.INFO)
 
19
 
20
  class RAGSystem:
21
  def __init__(self):
22
+ # Load models only when needed
23
+ self._embedding_model = None
24
+ self._vector_store = None
25
+ self._knowledge_base = None
26
 
27
+ @property
28
+ def embedding_model(self):
29
+ if self._embedding_model is None:
30
+ self._embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
31
+ return self._embedding_model
32
+
33
+ @property
34
+ def knowledge_base(self):
35
+ if self._knowledge_base is None:
36
+ self._knowledge_base = self.load_knowledge_base()
37
+ return self._knowledge_base
38
+
39
+ @property
40
+ def vector_store(self):
41
+ if self._vector_store is None:
42
+ self._vector_store = self.create_vector_store()
43
+ return self._vector_store
44
+
45
+ @staticmethod
46
+ @lru_cache(maxsize=1) # Cache the knowledge base
47
+ def load_knowledge_base() -> List[Dict]:
48
  """Load and preprocess knowledge base"""
49
+ # Your existing knowledge base code...
50
  kb = {
51
  "spalling": [
52
  {
 
99
  index.add(np.array(embeddings).astype('float32'))
100
  return index
101
 
102
+ @lru_cache(maxsize=32) # Cache recent query results
103
  def get_relevant_context(self, query: str, k: int = 2) -> str:
104
  """Retrieve relevant context based on query"""
105
  try:
106
  query_embedding = self.embedding_model.encode([query])
107
  D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
108
  context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
 
 
 
 
 
 
109
  return context
110
  except Exception as e:
111
  logger.error(f"Error retrieving context: {e}")
 
113
 
114
  class ImageAnalyzer:
115
  def __init__(self):
116
+ self.device = "cpu" # Force CPU usage for better compatibility
117
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
118
+ self._model = None
119
+ self._feature_extractor = None
120
+
121
+ @property
122
+ def model(self):
123
+ if self._model is None:
124
+ self._model = self._load_model()
125
+ return self._model
126
+
127
+ @property
128
+ def feature_extractor(self):
129
+ if self._feature_extractor is None:
130
+ self._feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
131
+ return self._feature_extractor
132
+
133
+ def _load_model(self):
134
  try:
135
+ model = ViTForImageClassification.from_pretrained(
 
 
136
  "google/vit-base-patch16-224",
137
  num_labels=len(self.defect_classes),
138
  ignore_mismatched_sizes=True
139
  ).to(self.device)
140
 
 
141
  with torch.no_grad():
142
+ model.classifier = torch.nn.Linear(
143
+ in_features=model.classifier.in_features,
144
  out_features=len(self.defect_classes)
145
  )
146
+ return model
147
  except Exception as e:
148
  logger.error(f"Model initialization error: {e}")
149
+ return None
 
150
 
151
+ @st.cache_data # Cache preprocessed images
152
+ def preprocess_image(self, image_bytes):
153
  """Preprocess image for model input"""
154
+ try:
155
+ image = Image.open(image_bytes)
156
+ if image.mode != 'RGB':
157
+ image = image.convert('RGB')
158
+
159
+ width, height = 224, 224
160
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
161
+ return image
162
+ except Exception as e:
163
+ logger.error(f"Image preprocessing error: {e}")
164
+ return None
165
 
166
  def analyze_image(self, image):
167
  """Analyze image for defects"""
168
  try:
169
+ if self.model is None:
170
+ raise ValueError("Model not properly initialized")
171
+
 
172
  inputs = self.feature_extractor(
173
+ images=image,
174
  return_tensors="pt"
175
  )
176
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
177
 
 
178
  with torch.no_grad():
179
  outputs = self.model(**inputs)
180
 
 
181
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
182
 
 
183
  confidence_threshold = 0.3
184
  results = {
185
  self.defect_classes[i]: float(probs[i])
 
187
  if float(probs[i]) > confidence_threshold
188
  }
189
 
 
190
  if not results:
191
  max_idx = torch.argmax(probs)
192
  results = {self.defect_classes[int(max_idx)]: float(probs[max_idx])}
 
197
  logger.error(f"Analysis error: {str(e)}")
198
  return None
199
 
200
+ @st.cache_data
201
  def get_groq_response(query: str, context: str) -> str:
202
+ """Get response from Groq LLM with caching"""
203
  try:
204
+ if not os.getenv("GROQ_API_KEY"):
205
+ return "Error: Groq API key not configured"
206
+
207
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
208
 
209
  prompt = f"""Based on the following context about construction defects, answer the question.
 
228
  return response.choices[0].message.content
229
  except Exception as e:
230
  logger.error(f"Groq API error: {e}")
231
+ return f"Error: Unable to get response from AI model. Please try again."
232
 
233
  def main():
234
  st.set_page_config(
 
239
 
240
  st.title("🏗️ Construction Defect Analyzer")
241
 
242
+ # Initialize systems in session state if not present
243
+ if 'analyzer' not in st.session_state:
244
+ st.session_state.analyzer = ImageAnalyzer()
245
+ if 'rag_system' not in st.session_state:
246
+ st.session_state.rag_system = RAGSystem()
 
 
 
 
247
 
 
248
  col1, col2 = st.columns([1, 1])
249
 
250
  with col1:
251
  st.subheader("Image Analysis")
252
+ uploaded_file = st.file_uploader(
253
+ "Upload a construction image for analysis",
254
+ type=["jpg", "jpeg", "png"],
255
+ key="image_uploader" # Add key for proper state management
256
+ )
257
 
258
  if uploaded_file is not None:
259
  try:
260
+ # Create a placeholder for the image
261
+ image_placeholder = st.empty()
 
262
 
263
+ # Process image with progress indicator
264
+ with st.spinner('Processing image...'):
265
+ processed_image = st.session_state.analyzer.preprocess_image(uploaded_file)
266
+ if processed_image:
267
+ image_placeholder.image(processed_image, caption='Uploaded Image', use_column_width=True)
 
268
 
269
+ # Analyze image with progress bar
270
+ progress_bar = st.progress(0)
271
+ with st.spinner('Analyzing defects...'):
272
+ results = st.session_state.analyzer.analyze_image(processed_image)
273
+ progress_bar.progress(100)
274
 
275
+ if results:
276
+ st.success('Analysis complete!')
277
+
278
+ # Display results
279
+ st.subheader("Detected Defects")
280
+ fig, ax = plt.subplots(figsize=(8, 4))
281
+ defects = list(results.keys())
282
+ probs = list(results.values())
283
+ ax.barh(defects, probs)
284
+ ax.set_xlim(0, 1)
285
+ plt.tight_layout()
286
+ st.pyplot(fig)
287
+
288
+ most_likely_defect = max(results.items(), key=lambda x: x[1])[0]
289
+ st.info(f"Most likely defect: {most_likely_defect}")
290
+ else:
291
+ st.warning("No defects detected or analysis failed. Please try another image.")
292
  else:
293
+ st.error("Failed to process image. Please try another one.")
294
+
295
  except Exception as e:
296
+ st.error(f"Error processing image: {str(e)}")
297
  logger.error(f"Process error: {e}")
298
 
299
  with col2:
 
308
  # Get context from RAG system
309
  context = st.session_state.rag_system.get_relevant_context(user_query)
310
 
311
+ if context:
312
+ # Get response from Groq
313
+ response = get_groq_response(user_query, context)
314
+
315
+ if not response.startswith("Error"):
316
+ st.write("Answer:")
317
+ st.markdown(response)
318
+ else:
319
+ st.error(response)
320
+
321
+ with st.expander("View retrieved information"):
322
+ st.text(context)
323
+ else:
324
+ st.error("Could not find relevant information. Please try rephrasing your question.")
325
 
 
326
  with st.sidebar:
327
  st.header("About")
328
  st.write("""
 
343
 
344
  # Add settings section
345
  st.subheader("Settings")
346
+ if st.button("Clear Cache"):
347
+ st.cache_data.clear()
348
+ st.success("Cache cleared!")
349
 
350
  if __name__ == "__main__":
351
  main()