Shakir60 commited on
Commit
7406e96
·
verified ·
1 Parent(s): fd94e5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -27
app.py CHANGED
@@ -18,6 +18,32 @@ import matplotlib.pyplot as plt
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class RAGSystem:
22
  def __init__(self):
23
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
@@ -27,7 +53,6 @@ class RAGSystem:
27
 
28
  def load_knowledge_base(self) -> List[Dict]:
29
  """Load and preprocess knowledge base"""
30
- # Using a simplified version of your knowledge base
31
  kb = {
32
  "spalling": [
33
  {
@@ -75,7 +100,6 @@ class RAGSystem:
75
  D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
76
  context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
77
 
78
- # Log query
79
  self.query_history.append({
80
  "timestamp": datetime.now().isoformat(),
81
  "query": query
@@ -88,26 +112,37 @@ class RAGSystem:
88
 
89
  class ImageAnalyzer:
90
  def __init__(self):
91
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
93
- self.model = self._initialize_model()
94
- self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
95
  self.history = []
96
 
97
- def _initialize_model(self):
98
- model = ViTForImageClassification.from_pretrained(
99
- "google/vit-base-patch16-224",
100
- num_labels=len(self.defect_classes),
101
- ignore_mismatched_sizes=True
102
- )
103
- return model.to(self.device)
 
 
 
 
 
 
104
 
105
  def analyze_image(self, image: Image.Image) -> Dict:
106
  """Analyze image for defects"""
107
  try:
 
 
 
108
  # Preprocess image
109
- inputs = self.processor(images=image, return_tensors="pt").to(self.device)
110
-
 
 
111
  # Get model predictions
112
  with torch.no_grad():
113
  outputs = self.model(**inputs)
@@ -123,7 +158,7 @@ class ImageAnalyzer:
123
  img_array = np.array(image)
124
  stats = {
125
  "mean_brightness": float(np.mean(img_array)),
126
- "image_size": image.size
127
  }
128
 
129
  result = {
@@ -182,29 +217,51 @@ def main():
182
  st.session_state.rag_system = RAGSystem()
183
  if 'image_analyzer' not in st.session_state:
184
  st.session_state.image_analyzer = ImageAnalyzer()
 
 
185
 
186
  # Create two columns
187
  col1, col2 = st.columns([1, 1])
188
 
189
  with col1:
 
190
  uploaded_file = st.file_uploader(
191
  "Upload a construction image",
192
- type=['jpg', 'jpeg', 'png']
 
193
  )
194
 
195
- if uploaded_file:
196
- image = Image.open(uploaded_file)
197
- st.image(image, caption="Uploaded Image", use_column_width=True)
198
-
199
- with st.spinner("Analyzing image..."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  results = st.session_state.image_analyzer.analyze_image(image)
 
201
 
202
  if results:
 
 
 
203
  st.subheader("Detected Defects")
204
 
205
  # Create bar chart
206
  defect_probs = results["defect_probabilities"]
207
- fig, ax = plt.subplots()
208
  defects = list(defect_probs.keys())
209
  probs = list(defect_probs.values())
210
  ax.barh(defects, probs)
@@ -213,11 +270,18 @@ def main():
213
  st.pyplot(fig)
214
 
215
  # Show image statistics
216
- if st.checkbox("Show Image Details"):
217
  st.json(results["image_statistics"])
 
 
 
 
 
 
 
218
 
219
  with col2:
220
- st.subheader("Ask About Defects")
221
  user_query = st.text_input(
222
  "Enter your question about construction defects:",
223
  help="Example: What are the repair methods for severe spalling?"
@@ -228,11 +292,10 @@ def main():
228
  context = st.session_state.rag_system.get_relevant_context(user_query)
229
  response = get_groq_response(user_query, context)
230
 
231
- st.write("AI Response:")
232
  st.write(response)
233
 
234
- if st.checkbox("Show Retrieved Context"):
235
- st.write("Context Used:")
236
  st.text(context)
237
 
238
  # Sidebar for history
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
+ # Ensure CUDA is available or use CPU
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ @st.cache_resource
25
+ def load_vit_model():
26
+ """Load and cache the ViT model"""
27
+ try:
28
+ model = ViTForImageClassification.from_pretrained(
29
+ "google/vit-base-patch16-224",
30
+ num_labels=3, # Number of defect classes
31
+ ignore_mismatched_sizes=True
32
+ )
33
+ return model.to(DEVICE)
34
+ except Exception as e:
35
+ logger.error(f"Error loading ViT model: {e}")
36
+ return None
37
+
38
+ @st.cache_resource
39
+ def load_vit_processor():
40
+ """Load and cache the ViT processor"""
41
+ try:
42
+ return ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
43
+ except Exception as e:
44
+ logger.error(f"Error loading ViT processor: {e}")
45
+ return None
46
+
47
  class RAGSystem:
48
  def __init__(self):
49
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
 
53
 
54
  def load_knowledge_base(self) -> List[Dict]:
55
  """Load and preprocess knowledge base"""
 
56
  kb = {
57
  "spalling": [
58
  {
 
100
  D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
101
  context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
102
 
 
103
  self.query_history.append({
104
  "timestamp": datetime.now().isoformat(),
105
  "query": query
 
112
 
113
  class ImageAnalyzer:
114
  def __init__(self):
115
+ self.device = DEVICE
116
  self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
117
+ self.model = load_vit_model()
118
+ self.processor = load_vit_processor()
119
  self.history = []
120
 
121
+ def preprocess_image(self, image: Image.Image) -> torch.Tensor:
122
+ """Preprocess image for model input"""
123
+ try:
124
+ # Ensure image is RGB
125
+ if image.mode != 'RGB':
126
+ image = image.convert('RGB')
127
+
128
+ # Process image using ViT processor
129
+ inputs = self.processor(images=image, return_tensors="pt")
130
+ return inputs.to(self.device)
131
+ except Exception as e:
132
+ logger.error(f"Image preprocessing error: {e}")
133
+ return None
134
 
135
  def analyze_image(self, image: Image.Image) -> Dict:
136
  """Analyze image for defects"""
137
  try:
138
+ if self.model is None or self.processor is None:
139
+ raise ValueError("Model or processor not properly initialized")
140
+
141
  # Preprocess image
142
+ inputs = self.preprocess_image(image)
143
+ if inputs is None:
144
+ raise ValueError("Image preprocessing failed")
145
+
146
  # Get model predictions
147
  with torch.no_grad():
148
  outputs = self.model(**inputs)
 
158
  img_array = np.array(image)
159
  stats = {
160
  "mean_brightness": float(np.mean(img_array)),
161
+ "image_size": f"{image.size[0]}x{image.size[1]}"
162
  }
163
 
164
  result = {
 
217
  st.session_state.rag_system = RAGSystem()
218
  if 'image_analyzer' not in st.session_state:
219
  st.session_state.image_analyzer = ImageAnalyzer()
220
+ if 'processed_images' not in st.session_state:
221
+ st.session_state.processed_images = {}
222
 
223
  # Create two columns
224
  col1, col2 = st.columns([1, 1])
225
 
226
  with col1:
227
+ st.header("Image Analysis")
228
  uploaded_file = st.file_uploader(
229
  "Upload a construction image",
230
+ type=['jpg', 'jpeg', 'png'],
231
+ key="image_uploader"
232
  )
233
 
234
+ if uploaded_file is not None:
235
+ try:
236
+ # Display upload progress
237
+ progress_bar = st.progress(0)
238
+ status_text = st.empty()
239
+
240
+ # Update progress
241
+ status_text.text("Loading image...")
242
+ progress_bar.progress(25)
243
+
244
+ # Load and display image
245
+ image = Image.open(uploaded_file)
246
+ st.image(image, caption="Uploaded Image", use_column_width=True)
247
+
248
+ # Update progress
249
+ status_text.text("Analyzing image...")
250
+ progress_bar.progress(50)
251
+
252
+ # Analyze image
253
  results = st.session_state.image_analyzer.analyze_image(image)
254
+ progress_bar.progress(75)
255
 
256
  if results:
257
+ status_text.text("Analysis complete!")
258
+ progress_bar.progress(100)
259
+
260
  st.subheader("Detected Defects")
261
 
262
  # Create bar chart
263
  defect_probs = results["defect_probabilities"]
264
+ fig, ax = plt.subplots(figsize=(8, 4))
265
  defects = list(defect_probs.keys())
266
  probs = list(defect_probs.values())
267
  ax.barh(defects, probs)
 
270
  st.pyplot(fig)
271
 
272
  # Show image statistics
273
+ with st.expander("Image Details"):
274
  st.json(results["image_statistics"])
275
+ else:
276
+ status_text.text("Analysis failed. Please try again.")
277
+ progress_bar.empty()
278
+
279
+ except Exception as e:
280
+ st.error(f"Error processing image: {str(e)}")
281
+ logger.error(f"Image processing error: {e}")
282
 
283
  with col2:
284
+ st.header("Ask About Defects")
285
  user_query = st.text_input(
286
  "Enter your question about construction defects:",
287
  help="Example: What are the repair methods for severe spalling?"
 
292
  context = st.session_state.rag_system.get_relevant_context(user_query)
293
  response = get_groq_response(user_query, context)
294
 
295
+ st.subheader("AI Response")
296
  st.write(response)
297
 
298
+ with st.expander("View Context Used"):
 
299
  st.text(context)
300
 
301
  # Sidebar for history