Shakir60 commited on
Commit
fd94e5f
·
verified ·
1 Parent(s): f26cc53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -301
app.py CHANGED
@@ -6,308 +6,185 @@ from sentence_transformers import SentenceTransformer
6
  from PIL import Image
7
  import torch
8
  import numpy as np
9
- from typing import List, Dict, Tuple, Optional, Any
10
  import faiss
11
  import json
12
- import torchvision.transforms.functional as TF
13
- from torchvision import transforms
14
  import cv2
15
- import pandas as pd
16
- from datetime import datetime
17
  import logging
 
 
18
 
19
  # Setup logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
- class ConfigManager:
24
- """Manages configuration settings for the application"""
25
- DEFAULT_CONFIG = {
26
- "model_settings": {
27
- "vit_model": "google/vit-base-patch16-224",
28
- "sentence_transformer": "all-MiniLM-L6-v2",
29
- "groq_model": "llama-3.3-70b-versatile"
30
- },
31
- "analysis_settings": {
32
- "confidence_threshold": 0.5,
33
- "max_defects": 3,
34
- "heatmap_intensity": 0.7
35
- },
36
- "rag_settings": {
37
- "num_relevant_docs": 3,
38
- "similarity_threshold": 0.75
39
- }
40
- }
41
 
42
- @staticmethod
43
- def load_config():
44
- """Load configuration with fallback to defaults"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
- if os.path.exists('config.json'):
47
- with open('config.json', 'r') as f:
48
- config = json.load(f)
49
- return {**ConfigManager.DEFAULT_CONFIG, **config}
 
 
 
 
 
 
 
50
  except Exception as e:
51
- logger.warning(f"Error loading config: {e}")
52
- return ConfigManager.DEFAULT_CONFIG
53
-
54
- config = ConfigManager.load_config()
55
 
56
  class ImageAnalyzer:
57
  def __init__(self):
58
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
- self.config = config["model_settings"]
60
- self.analysis_config = config["analysis_settings"]
61
- self.defect_classes = [
62
- "spalling", "reinforcement_corrosion", "structural_cracks",
63
- "water_damage", "surface_deterioration", "alkali_silica_reaction",
64
- "concrete_delamination", "honeycomb", "scaling",
65
- "efflorescence", "joint_deterioration", "carbonation"
66
- ]
67
- self.initialize_models()
68
  self.history = []
69
 
70
- def initialize_models(self):
71
- """Initialize all required models"""
72
- try:
73
- # Initialize ViT model
74
- self.model = ViTForImageClassification.from_pretrained(
75
- self.config["vit_model"],
76
- num_labels=len(self.defect_classes),
77
- ignore_mismatched_sizes=True
78
- ).to(self.device)
79
-
80
- # Initialize image processor
81
- self.processor = ViTImageProcessor.from_pretrained(self.config["vit_model"])
82
-
83
- # Initialize transformations pipeline
84
- self.transforms = self._setup_transforms()
85
-
86
- return True
87
- except Exception as e:
88
- logger.error(f"Model initialization error: {e}")
89
- return False
90
-
91
- def _setup_transforms(self):
92
- """Setup image transformation pipeline"""
93
- return transforms.Compose([
94
- transforms.Resize((224, 224)),
95
- transforms.ToTensor(),
96
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
97
- std=[0.229, 0.224, 0.225]),
98
- transforms.RandomAdjustSharpness(2),
99
- transforms.ColorJitter(brightness=0.2, contrast=0.2)
100
- ])
101
 
102
- def preprocess_image(self, image: Image.Image) -> Dict[str, Any]:
103
- """Enhanced image preprocessing with multiple analyses"""
104
  try:
105
- # Convert to RGB if necessary
106
- if image.mode != 'RGB':
107
- image = image.convert('RGB')
108
- # Basic image statistics
109
- img_array = np.array(image)
110
- stats = {
111
- "mean_brightness": np.mean(img_array),
112
- "std_brightness": np.std(img_array),
113
- "size": image.size,
114
- "aspect_ratio": image.size[0] / image.size[1]
115
- }
116
- # Edge detection for crack analysis
117
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
118
- edges = cv2.Canny(gray, 100, 200)
119
- stats["edge_density"] = np.mean(edges > 0)
120
- # Color analysis for rust detection
121
- hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
122
- rust_mask = cv2.inRange(hsv, np.array([0, 50, 50]), np.array([30, 255, 255]))
123
- stats["rust_percentage"] = np.mean(rust_mask > 0)
124
- # Transform for model
125
- model_input = self.transforms(image).unsqueeze(0).to(self.device)
126
- return {
127
- "model_input": model_input,
128
- "stats": stats,
129
- "edges": edges,
130
- "rust_mask": rust_mask
131
- }
132
- except Exception as e:
133
- logger.error(f"Preprocessing error: {e}")
134
- return None
135
- def detect_defects(self, image: Image.Image) -> Dict[str, Any]:
136
- """Enhanced defect detection with multiple analysis methods"""
137
- try:
138
- # Preprocess image
139
- proc_data = self.preprocess_image(image)
140
- if proc_data is None:
141
- logger.error("Image preprocessing failed.")
142
- return None # Early return if preprocessing failed
143
 
144
- # Model prediction
145
- with torch.no_grad():
146
- outputs = self.model(proc_data["model_input"])
147
-
148
- # Get probabilities
149
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
150
-
151
- # Convert to dictionary
152
- defect_probs = {
153
- self.defect_classes[i]: float(probabilities[0][i])
154
- for i in range(len(self.defect_classes))
155
- }
156
- # Generate attention heatmap
157
- attention_weights = outputs.attentions[-1].mean(dim=1)[0] if hasattr(outputs, 'attentions') else None
158
- heatmap = self.generate_heatmap(attention_weights, image.size) if attention_weights is not None else None
159
 
160
- # Additional analysis based on image statistics
161
- additional_analysis = self.analyze_image_statistics(proc_data["stats"])
 
 
 
 
162
 
163
- # Combine all results
164
- result = {
165
- "defect_probabilities": defect_probs,
166
- "heatmap": heatmap,
167
- "image_statistics": proc_data["stats"],
168
- "additional_analysis": additional_analysis,
169
- "edge_detection": proc_data["edges"],
170
- "rust_detection": proc_data["rust_mask"],
171
- "timestamp": datetime.now().isoformat()
172
- }
173
- # Save to history
174
- self.history.append(result)
175
- return result
176
- except Exception as e:
177
- logger.error(f"Defect detection error: {e}")
178
- return None
179
-
180
- def analyze_image_statistics(self, stats: Dict) -> Dict[str, Any]:
181
- """Analyze image statistics for additional insights"""
182
- analysis = {}
183
-
184
- # Brightness analysis
185
- if stats["mean_brightness"] < 50:
186
- analysis["lighting_condition"] = "Poor lighting - may affect accuracy"
187
- elif stats["mean_brightness"] > 200:
188
- analysis["lighting_condition"] = "Overexposed - may affect accuracy"
189
-
190
- # Edge density analysis
191
- if stats["edge_density"] > 0.1:
192
- analysis["crack_likelihood"] = "High crack probability based on edge detection"
193
-
194
- # Rust analysis
195
- if stats["rust_percentage"] > 0.05:
196
- analysis["corrosion_indicator"] = "Possible corrosion detected"
197
-
198
- return analysis
199
-
200
- def generate_heatmap(self, attention_weights: torch.Tensor, image_size: Tuple[int, int]) -> np.ndarray:
201
- """Generate enhanced attention heatmap"""
202
- try:
203
- if attention_weights is None:
204
- return None
205
-
206
- # Process attention weights
207
- heatmap = attention_weights.cpu().numpy()
208
- heatmap = cv2.resize(heatmap, image_size)
209
-
210
- # Enhanced normalization
211
- heatmap = np.maximum(heatmap, 0)
212
- heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
213
 
214
- # Apply gamma correction
215
- gamma = self.analysis_config["heatmap_intensity"]
216
- heatmap = np.power(heatmap, gamma)
 
 
217
 
218
- # Apply colormap
219
- heatmap = (heatmap * 255).astype(np.uint8)
220
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
221
 
222
- return heatmap
223
  except Exception as e:
224
- logger.error(f"Heatmap generation error: {e}")
225
  return None
226
- class RAGSystem:
227
- """Basic RAG System for storing and retrieving documents."""
228
- def __init__(self):
229
- self.embedding_model = SentenceTransformer(config["model_settings"]["sentence_transformer"])
230
- self.vector_store = faiss.IndexFlatL2(384) # 384-dim for MiniLM embeddings
231
- self.knowledge_base = []
232
 
233
- def add_documents(self, docs: List[str]):
234
- """Add documents to the vector store."""
235
- embeddings = self.embedding_model.encode(docs)
236
- self.vector_store.add(np.array(embeddings).astype('float32'))
237
- for doc in docs:
238
- self.knowledge_base.append({"text": doc})
239
-
240
- def search(self, query: str, k: int = 3):
241
- """Retrieve similar documents for the query."""
242
- query_embedding = self.embedding_model.encode([query])
243
- D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
244
- return [self.knowledge_base[i]["text"] for i in I[0]]
245
-
246
-
247
- class EnhancedRAGSystem(RAGSystem):
248
- """Enhanced RAG system with additional features"""
249
- def __init__(self):
250
- super().__init__()
251
- self.config = config["rag_settings"]
252
- self.query_history = []
253
-
254
- def get_relevant_context(self, query: str, k: int = None) -> str:
255
- """Enhanced context retrieval with debugging info"""
256
- if k is None:
257
- k = self.config["num_relevant_docs"]
258
-
259
- # Log query
260
- self.query_history.append({
261
- "timestamp": datetime.now().isoformat(),
262
- "query": query
263
- })
264
-
265
- # Generate query embedding
266
- query_embedding = self.embedding_model.encode([query])
267
-
268
- # Search for similar documents
269
- D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
270
 
271
- # Filter by similarity threshold
272
- relevant_docs = [
273
- self.knowledge_base[i]["text"]
274
- for i, dist in zip(I[0], D[0])
275
- if dist < self.config["similarity_threshold"]
276
- ]
277
-
278
- return "\n\n".join(relevant_docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  def main():
281
  st.set_page_config(
282
- page_title="Enhanced Construction Defect Analyzer",
283
  page_icon="🏗️",
284
  layout="wide"
285
  )
286
 
287
- st.title("🏗️ Advanced Construction Defect Analysis System")
288
 
289
- # Initialize systems
290
  if 'rag_system' not in st.session_state:
291
- st.session_state.rag_system = EnhancedRAGSystem()
292
  if 'image_analyzer' not in st.session_state:
293
  st.session_state.image_analyzer = ImageAnalyzer()
294
 
295
- # Sidebar for settings and history
296
- with st.sidebar:
297
- st.header("Settings & History")
298
- show_debug = st.checkbox("Show Debug Information")
299
- confidence_threshold = st.slider(
300
- "Confidence Threshold",
301
- min_value=0.0,
302
- max_value=1.0,
303
- value=config["analysis_settings"]["confidence_threshold"]
304
- )
305
-
306
- if st.button("View Analysis History"):
307
- st.write("Recent Analyses:", st.session_state.image_analyzer.history[-5:])
308
-
309
- # Main interface
310
- col1, col2 = st.columns([2, 3])
311
 
312
  with col1:
313
  uploaded_file = st.file_uploader(
@@ -315,61 +192,59 @@ def main():
315
  type=['jpg', 'jpeg', 'png']
316
  )
317
 
318
- user_query = st.text_input(
319
- "Ask a question about construction defects:",
320
- help="Enter your question about specific defects or general construction issues"
321
- )
322
-
323
- with col2:
324
  if uploaded_file:
325
  image = Image.open(uploaded_file)
 
326
 
327
- # Create tabs for different views
328
- tabs = st.tabs(["Original", "Analysis", "Details"])
329
-
330
- with tabs[0]:
331
- st.image(image, caption="Uploaded Image")
332
-
333
- with tabs[1]:
334
- with st.spinner("Analyzing image..."):
335
- results = st.session_state.image_analyzer.detect_defects(image)
336
-
337
- if results:
338
- # Show defect probabilities
339
- defect_probs = results["defect_probabilities"]
340
- significant_defects = {
341
- k: v for k, v in defect_probs.items()
342
- if v > confidence_threshold
343
- }
344
-
345
- if significant_defects:
346
- st.subheader("Detected Defects")
347
- fig = plt.figure(figsize=(10, 6))
348
- plt.barh(list(significant_defects.keys()),
349
- list(significant_defects.values()))
350
- st.pyplot(fig)
351
-
352
- # Show heatmap
353
- if results["heatmap"] is not None:
354
- st.image(results["heatmap"], caption="Defect Attention Map")
355
-
356
- with tabs[2]:
357
  if results:
358
- st.json(results["additional_analysis"])
359
- if show_debug:
 
 
 
 
 
 
 
 
 
 
 
 
360
  st.json(results["image_statistics"])
361
 
 
 
 
 
 
 
 
362
  if user_query:
363
  with st.spinner("Processing query..."):
364
  context = st.session_state.rag_system.get_relevant_context(user_query)
365
  response = get_groq_response(user_query, context)
366
 
367
- st.subheader("AI Assistant Response")
368
  st.write(response)
369
 
370
- if show_debug:
371
- st.subheader("Retrieved Context")
372
  st.text(context)
373
 
 
 
 
 
 
 
 
 
 
 
 
374
  if __name__ == "__main__":
375
  main()
 
6
  from PIL import Image
7
  import torch
8
  import numpy as np
9
+ from typing import List, Dict, Tuple
10
  import faiss
11
  import json
 
 
12
  import cv2
 
 
13
  import logging
14
+ from datetime import datetime
15
+ import matplotlib.pyplot as plt
16
 
17
  # Setup logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
+ class RAGSystem:
22
+ def __init__(self):
23
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
24
+ self.knowledge_base = self.load_knowledge_base()
25
+ self.vector_store = self.create_vector_store()
26
+ self.query_history = []
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def load_knowledge_base(self) -> List[Dict]:
29
+ """Load and preprocess knowledge base"""
30
+ # Using a simplified version of your knowledge base
31
+ kb = {
32
+ "spalling": [
33
+ {
34
+ "severity": "Critical",
35
+ "description": "Severe concrete spalling with exposed reinforcement",
36
+ "repair_method": "Remove deteriorated concrete, clean reinforcement",
37
+ "estimated_cost": "Very High ($15,000+)",
38
+ "immediate_action": "Evacuate area, install support"
39
+ }
40
+ ],
41
+ "structural_cracks": [
42
+ {
43
+ "severity": "High",
44
+ "description": "Active structural cracks >5mm width",
45
+ "repair_method": "Structural analysis, epoxy injection",
46
+ "estimated_cost": "High ($10,000-$20,000)",
47
+ "immediate_action": "Install crack monitors"
48
+ }
49
+ ]
50
+ }
51
+
52
+ documents = []
53
+ for category, items in kb.items():
54
+ for item in items:
55
+ doc_text = f"Category: {category}\n"
56
+ for key, value in item.items():
57
+ doc_text += f"{key}: {value}\n"
58
+ documents.append({"text": doc_text, "metadata": {"category": category}})
59
+
60
+ return documents
61
+
62
+ def create_vector_store(self):
63
+ """Create FAISS vector store"""
64
+ texts = [doc["text"] for doc in self.knowledge_base]
65
+ embeddings = self.embedding_model.encode(texts)
66
+ dimension = embeddings.shape[1]
67
+ index = faiss.IndexFlatL2(dimension)
68
+ index.add(np.array(embeddings).astype('float32'))
69
+ return index
70
+
71
+ def get_relevant_context(self, query: str, k: int = 3) -> str:
72
+ """Retrieve relevant context based on query"""
73
  try:
74
+ query_embedding = self.embedding_model.encode([query])
75
+ D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
76
+ context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
77
+
78
+ # Log query
79
+ self.query_history.append({
80
+ "timestamp": datetime.now().isoformat(),
81
+ "query": query
82
+ })
83
+
84
+ return context
85
  except Exception as e:
86
+ logger.error(f"Error retrieving context: {e}")
87
+ return ""
 
 
88
 
89
  class ImageAnalyzer:
90
  def __init__(self):
91
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
93
+ self.model = self._initialize_model()
94
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
 
 
95
  self.history = []
96
 
97
+ def _initialize_model(self):
98
+ model = ViTForImageClassification.from_pretrained(
99
+ "google/vit-base-patch16-224",
100
+ num_labels=len(self.defect_classes),
101
+ ignore_mismatched_sizes=True
102
+ )
103
+ return model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ def analyze_image(self, image: Image.Image) -> Dict:
106
+ """Analyze image for defects"""
107
  try:
108
+ # Preprocess image
109
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # Get model predictions
112
+ with torch.no_grad():
113
+ outputs = self.model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ # Process results
116
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
117
+ defect_probs = {
118
+ self.defect_classes[i]: float(probabilities[0][i])
119
+ for i in range(len(self.defect_classes))
120
+ }
121
 
122
+ # Basic image statistics
123
+ img_array = np.array(image)
124
+ stats = {
125
+ "mean_brightness": float(np.mean(img_array)),
126
+ "image_size": image.size
127
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ result = {
130
+ "defect_probabilities": defect_probs,
131
+ "image_statistics": stats,
132
+ "timestamp": datetime.now().isoformat()
133
+ }
134
 
135
+ self.history.append(result)
136
+ return result
 
137
 
 
138
  except Exception as e:
139
+ logger.error(f"Image analysis error: {e}")
140
  return None
 
 
 
 
 
 
141
 
142
+ def get_groq_response(query: str, context: str) -> str:
143
+ """Get response from Groq LLM"""
144
+ try:
145
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ prompt = f"""Based on the following context about construction defects, answer the question.
148
+ Context: {context}
149
+ Question: {query}
150
+ Provide a detailed answer based on the context."""
151
+
152
+ response = client.chat.completions.create(
153
+ messages=[
154
+ {
155
+ "role": "system",
156
+ "content": "You are a construction defect analysis expert."
157
+ },
158
+ {
159
+ "role": "user",
160
+ "content": prompt
161
+ }
162
+ ],
163
+ model="llama2-70b-4096",
164
+ temperature=0.7,
165
+ )
166
+ return response.choices[0].message.content
167
+ except Exception as e:
168
+ logger.error(f"Groq API error: {e}")
169
+ return f"Error: Unable to get response from AI model. Please try again later."
170
 
171
  def main():
172
  st.set_page_config(
173
+ page_title="Construction Defect Analyzer",
174
  page_icon="🏗️",
175
  layout="wide"
176
  )
177
 
178
+ st.title("🏗️ Construction Defect Analyzer")
179
 
180
+ # Initialize systems in session state
181
  if 'rag_system' not in st.session_state:
182
+ st.session_state.rag_system = RAGSystem()
183
  if 'image_analyzer' not in st.session_state:
184
  st.session_state.image_analyzer = ImageAnalyzer()
185
 
186
+ # Create two columns
187
+ col1, col2 = st.columns([1, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  with col1:
190
  uploaded_file = st.file_uploader(
 
192
  type=['jpg', 'jpeg', 'png']
193
  )
194
 
 
 
 
 
 
 
195
  if uploaded_file:
196
  image = Image.open(uploaded_file)
197
+ st.image(image, caption="Uploaded Image", use_column_width=True)
198
 
199
+ with st.spinner("Analyzing image..."):
200
+ results = st.session_state.image_analyzer.analyze_image(image)
201
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  if results:
203
+ st.subheader("Detected Defects")
204
+
205
+ # Create bar chart
206
+ defect_probs = results["defect_probabilities"]
207
+ fig, ax = plt.subplots()
208
+ defects = list(defect_probs.keys())
209
+ probs = list(defect_probs.values())
210
+ ax.barh(defects, probs)
211
+ ax.set_xlim(0, 1)
212
+ ax.set_xlabel("Probability")
213
+ st.pyplot(fig)
214
+
215
+ # Show image statistics
216
+ if st.checkbox("Show Image Details"):
217
  st.json(results["image_statistics"])
218
 
219
+ with col2:
220
+ st.subheader("Ask About Defects")
221
+ user_query = st.text_input(
222
+ "Enter your question about construction defects:",
223
+ help="Example: What are the repair methods for severe spalling?"
224
+ )
225
+
226
  if user_query:
227
  with st.spinner("Processing query..."):
228
  context = st.session_state.rag_system.get_relevant_context(user_query)
229
  response = get_groq_response(user_query, context)
230
 
231
+ st.write("AI Response:")
232
  st.write(response)
233
 
234
+ if st.checkbox("Show Retrieved Context"):
235
+ st.write("Context Used:")
236
  st.text(context)
237
 
238
+ # Sidebar for history
239
+ with st.sidebar:
240
+ st.header("Analysis History")
241
+ if st.button("Show Recent Analyses"):
242
+ if st.session_state.image_analyzer.history:
243
+ for analysis in st.session_state.image_analyzer.history[-5:]:
244
+ st.write(f"Analysis from: {analysis['timestamp']}")
245
+ st.json(analysis["defect_probabilities"])
246
+ else:
247
+ st.write("No analyses yet")
248
+
249
  if __name__ == "__main__":
250
  main()