Shakir60 commited on
Commit
7fa4ec7
·
verified ·
1 Parent(s): 7406e96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -272
app.py CHANGED
@@ -1,313 +1,98 @@
1
  import streamlit as st
2
- import os
3
- from groq import Groq
4
- from transformers import ViTForImageClassification, ViTImageProcessor
5
- from sentence_transformers import SentenceTransformer
6
- from PIL import Image
7
  import torch
 
8
  import numpy as np
9
- from typing import List, Dict, Tuple
10
- import faiss
11
- import json
12
- import cv2
13
- import logging
14
- from datetime import datetime
15
  import matplotlib.pyplot as plt
 
16
 
17
  # Setup logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- # Ensure CUDA is available or use CPU
22
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
-
24
- @st.cache_resource
25
- def load_vit_model():
26
- """Load and cache the ViT model"""
27
- try:
28
- model = ViTForImageClassification.from_pretrained(
29
- "google/vit-base-patch16-224",
30
- num_labels=3, # Number of defect classes
31
- ignore_mismatched_sizes=True
32
- )
33
- return model.to(DEVICE)
34
- except Exception as e:
35
- logger.error(f"Error loading ViT model: {e}")
36
- return None
37
-
38
- @st.cache_resource
39
- def load_vit_processor():
40
- """Load and cache the ViT processor"""
41
- try:
42
- return ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
43
- except Exception as e:
44
- logger.error(f"Error loading ViT processor: {e}")
45
- return None
46
-
47
- class RAGSystem:
48
  def __init__(self):
49
- self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
50
- self.knowledge_base = self.load_knowledge_base()
51
- self.vector_store = self.create_vector_store()
52
- self.query_history = []
53
-
54
- def load_knowledge_base(self) -> List[Dict]:
55
- """Load and preprocess knowledge base"""
56
- kb = {
57
- "spalling": [
58
- {
59
- "severity": "Critical",
60
- "description": "Severe concrete spalling with exposed reinforcement",
61
- "repair_method": "Remove deteriorated concrete, clean reinforcement",
62
- "estimated_cost": "Very High ($15,000+)",
63
- "immediate_action": "Evacuate area, install support"
64
- }
65
- ],
66
- "structural_cracks": [
67
- {
68
- "severity": "High",
69
- "description": "Active structural cracks >5mm width",
70
- "repair_method": "Structural analysis, epoxy injection",
71
- "estimated_cost": "High ($10,000-$20,000)",
72
- "immediate_action": "Install crack monitors"
73
- }
74
- ]
75
- }
76
-
77
- documents = []
78
- for category, items in kb.items():
79
- for item in items:
80
- doc_text = f"Category: {category}\n"
81
- for key, value in item.items():
82
- doc_text += f"{key}: {value}\n"
83
- documents.append({"text": doc_text, "metadata": {"category": category}})
84
 
85
- return documents
86
-
87
- def create_vector_store(self):
88
- """Create FAISS vector store"""
89
- texts = [doc["text"] for doc in self.knowledge_base]
90
- embeddings = self.embedding_model.encode(texts)
91
- dimension = embeddings.shape[1]
92
- index = faiss.IndexFlatL2(dimension)
93
- index.add(np.array(embeddings).astype('float32'))
94
- return index
95
-
96
- def get_relevant_context(self, query: str, k: int = 3) -> str:
97
- """Retrieve relevant context based on query"""
98
  try:
99
- query_embedding = self.embedding_model.encode([query])
100
- D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
101
- context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
102
-
103
- self.query_history.append({
104
- "timestamp": datetime.now().isoformat(),
105
- "query": query
106
- })
107
-
108
- return context
109
  except Exception as e:
110
- logger.error(f"Error retrieving context: {e}")
111
- return ""
112
-
113
- class ImageAnalyzer:
114
- def __init__(self):
115
- self.device = DEVICE
116
- self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
117
- self.model = load_vit_model()
118
- self.processor = load_vit_processor()
119
- self.history = []
120
 
121
- def preprocess_image(self, image: Image.Image) -> torch.Tensor:
122
- """Preprocess image for model input"""
123
  try:
124
  # Ensure image is RGB
125
  if image.mode != 'RGB':
126
  image = image.convert('RGB')
127
 
128
- # Process image using ViT processor
129
  inputs = self.processor(images=image, return_tensors="pt")
130
- return inputs.to(self.device)
131
- except Exception as e:
132
- logger.error(f"Image preprocessing error: {e}")
133
- return None
134
-
135
- def analyze_image(self, image: Image.Image) -> Dict:
136
- """Analyze image for defects"""
137
- try:
138
- if self.model is None or self.processor is None:
139
- raise ValueError("Model or processor not properly initialized")
140
-
141
- # Preprocess image
142
- inputs = self.preprocess_image(image)
143
- if inputs is None:
144
- raise ValueError("Image preprocessing failed")
145
-
146
- # Get model predictions
147
  with torch.no_grad():
148
  outputs = self.model(**inputs)
149
 
150
- # Process results
151
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
152
- defect_probs = {
153
- self.defect_classes[i]: float(probabilities[0][i])
154
- for i in range(len(self.defect_classes))
155
- }
156
-
157
- # Basic image statistics
158
- img_array = np.array(image)
159
- stats = {
160
- "mean_brightness": float(np.mean(img_array)),
161
- "image_size": f"{image.size[0]}x{image.size[1]}"
162
- }
163
-
164
- result = {
165
- "defect_probabilities": defect_probs,
166
- "image_statistics": stats,
167
- "timestamp": datetime.now().isoformat()
168
- }
169
-
170
- self.history.append(result)
171
- return result
172
 
173
  except Exception as e:
174
- logger.error(f"Image analysis error: {e}")
175
  return None
176
 
177
- def get_groq_response(query: str, context: str) -> str:
178
- """Get response from Groq LLM"""
179
- try:
180
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
181
-
182
- prompt = f"""Based on the following context about construction defects, answer the question.
183
- Context: {context}
184
- Question: {query}
185
- Provide a detailed answer based on the context."""
186
-
187
- response = client.chat.completions.create(
188
- messages=[
189
- {
190
- "role": "system",
191
- "content": "You are a construction defect analysis expert."
192
- },
193
- {
194
- "role": "user",
195
- "content": prompt
196
- }
197
- ],
198
- model="llama2-70b-4096",
199
- temperature=0.7,
200
- )
201
- return response.choices[0].message.content
202
- except Exception as e:
203
- logger.error(f"Groq API error: {e}")
204
- return f"Error: Unable to get response from AI model. Please try again later."
205
-
206
  def main():
207
- st.set_page_config(
208
- page_title="Construction Defect Analyzer",
209
- page_icon="🏗️",
210
- layout="wide"
211
- )
212
-
213
- st.title("🏗️ Construction Defect Analyzer")
214
-
215
- # Initialize systems in session state
216
- if 'rag_system' not in st.session_state:
217
- st.session_state.rag_system = RAGSystem()
218
- if 'image_analyzer' not in st.session_state:
219
- st.session_state.image_analyzer = ImageAnalyzer()
220
- if 'processed_images' not in st.session_state:
221
- st.session_state.processed_images = {}
222
-
223
- # Create two columns
224
- col1, col2 = st.columns([1, 1])
225
-
226
- with col1:
227
- st.header("Image Analysis")
228
- uploaded_file = st.file_uploader(
229
- "Upload a construction image",
230
- type=['jpg', 'jpeg', 'png'],
231
- key="image_uploader"
232
- )
233
-
234
- if uploaded_file is not None:
235
- try:
236
- # Display upload progress
237
- progress_bar = st.progress(0)
238
- status_text = st.empty()
239
-
240
- # Update progress
241
- status_text.text("Loading image...")
242
- progress_bar.progress(25)
243
-
244
- # Load and display image
245
- image = Image.open(uploaded_file)
246
- st.image(image, caption="Uploaded Image", use_column_width=True)
247
-
248
- # Update progress
249
- status_text.text("Analyzing image...")
250
- progress_bar.progress(50)
251
-
252
- # Analyze image
253
- results = st.session_state.image_analyzer.analyze_image(image)
254
- progress_bar.progress(75)
255
 
256
  if results:
257
- status_text.text("Analysis complete!")
258
- progress_bar.progress(100)
259
 
260
- st.subheader("Detected Defects")
 
261
 
262
  # Create bar chart
263
- defect_probs = results["defect_probabilities"]
264
- fig, ax = plt.subplots(figsize=(8, 4))
265
- defects = list(defect_probs.keys())
266
- probs = list(defect_probs.values())
267
  ax.barh(defects, probs)
268
  ax.set_xlim(0, 1)
269
- ax.set_xlabel("Probability")
270
  st.pyplot(fig)
271
-
272
- # Show image statistics
273
- with st.expander("Image Details"):
274
- st.json(results["image_statistics"])
275
  else:
276
- status_text.text("Analysis failed. Please try again.")
277
- progress_bar.empty()
278
 
279
- except Exception as e:
280
- st.error(f"Error processing image: {str(e)}")
281
- logger.error(f"Image processing error: {e}")
282
-
283
- with col2:
284
- st.header("Ask About Defects")
285
- user_query = st.text_input(
286
- "Enter your question about construction defects:",
287
- help="Example: What are the repair methods for severe spalling?"
288
- )
289
-
290
- if user_query:
291
- with st.spinner("Processing query..."):
292
- context = st.session_state.rag_system.get_relevant_context(user_query)
293
- response = get_groq_response(user_query, context)
294
-
295
- st.subheader("AI Response")
296
- st.write(response)
297
-
298
- with st.expander("View Context Used"):
299
- st.text(context)
300
-
301
- # Sidebar for history
302
- with st.sidebar:
303
- st.header("Analysis History")
304
- if st.button("Show Recent Analyses"):
305
- if st.session_state.image_analyzer.history:
306
- for analysis in st.session_state.image_analyzer.history[-5:]:
307
- st.write(f"Analysis from: {analysis['timestamp']}")
308
- st.json(analysis["defect_probabilities"])
309
- else:
310
- st.write("No analyses yet")
311
 
312
  if __name__ == "__main__":
313
  main()
 
1
  import streamlit as st
 
 
 
 
 
2
  import torch
3
+ from PIL import Image
4
  import numpy as np
5
+ from transformers import ViTForImageClassification, ViTImageProcessor
 
 
 
 
 
6
  import matplotlib.pyplot as plt
7
+ import logging
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ class ImageAnalyzer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def __init__(self):
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Load model and processor
 
 
 
 
 
 
 
 
 
 
 
 
19
  try:
20
+ self.model = ViTForImageClassification.from_pretrained(
21
+ "google/vit-base-patch16-224",
22
+ num_labels=len(self.defect_classes)
23
+ ).to(self.device)
24
+ self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
 
 
 
 
 
25
  except Exception as e:
26
+ logger.error(f"Model initialization error: {e}")
27
+ self.model = None
28
+ self.processor = None
 
 
 
 
 
 
 
29
 
30
+ def analyze_image(self, image):
 
31
  try:
32
  # Ensure image is RGB
33
  if image.mode != 'RGB':
34
  image = image.convert('RGB')
35
 
36
+ # Process image
37
  inputs = self.processor(images=image, return_tensors="pt")
38
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
39
+
40
+ # Get predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  with torch.no_grad():
42
  outputs = self.model(**inputs)
43
 
44
+ # Get probabilities
45
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
46
+ return {self.defect_classes[i]: float(probs[i]) for i in range(len(self.defect_classes))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  except Exception as e:
49
+ logger.error(f"Analysis error: {e}")
50
  return None
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def main():
53
+ st.title("Construction Defect Analyzer")
54
+
55
+ # Initialize analyzer
56
+ if 'analyzer' not in st.session_state:
57
+ st.session_state.analyzer = ImageAnalyzer()
58
+
59
+ # File uploader
60
+ st.write("Upload a construction image for analysis")
61
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
62
+
63
+ if uploaded_file is not None:
64
+ try:
65
+ # Display confirmation
66
+ st.write("Image received. Processing...")
67
+
68
+ # Read and display image
69
+ image = Image.open(uploaded_file)
70
+ st.image(image, caption='Uploaded Image', use_column_width=True)
71
+
72
+ # Analyze image
73
+ with st.spinner('Analyzing image...'):
74
+ results = st.session_state.analyzer.analyze_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  if results:
77
+ st.success('Analysis complete!')
 
78
 
79
+ # Display results
80
+ st.subheader("Defect Probabilities")
81
 
82
  # Create bar chart
83
+ fig, ax = plt.subplots()
84
+ defects = list(results.keys())
85
+ probs = list(results.values())
 
86
  ax.barh(defects, probs)
87
  ax.set_xlim(0, 1)
88
+ plt.tight_layout()
89
  st.pyplot(fig)
 
 
 
 
90
  else:
91
+ st.error("Analysis failed. Please try again.")
 
92
 
93
+ except Exception as e:
94
+ st.error(f"Error: {str(e)}")
95
+ logger.error(f"Process error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  main()