Shakir60 commited on
Commit
3db0aec
Β·
verified Β·
1 Parent(s): b13a7d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -247
app.py CHANGED
@@ -1,273 +1,325 @@
 
1
  import streamlit as st
2
- from transformers import (
3
- AutoModelForImageClassification,
4
- AutoImageProcessor,
5
- ViTForImageClassification,
6
- ResNetForImageClassification
7
- )
8
  import torch
9
- import numpy as np
10
- from PIL import Image, ImageDraw
11
- import cv2
12
- from langchain import FAISS
13
- from langchain.embeddings import HuggingFaceEmbeddings
14
- from langchain.chains import RetrievalQA
15
- from langchain.llms import HuggingFacePipeline
16
- import json
17
- import os
18
- from concurrent.futures import ThreadPoolExecutor
19
- import pandas as pd
20
 
21
- class DefectMeasurement:
22
- """Handle defect measurements and severity estimation"""
23
-
24
- @staticmethod
25
- def measure_defect(image, defect_type):
26
- """Measure defect dimensions using computer vision"""
27
- img_array = np.array(image)
28
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
29
-
30
- if defect_type == "Crack":
31
- # Crack width measurement
32
- blur = cv2.GaussianBlur(gray, (3,3), 0)
33
- edges = cv2.Canny(blur, 100, 200)
34
- lines = cv2.HoughLinesP(edges, 1, np.pi/180, 50, minLineLength=100, maxLineGap=10)
35
-
36
- if lines is not None:
37
- max_length = 0
38
- for line in lines:
39
- x1, y1, x2, y2 = line[0]
40
- length = np.sqrt((x2-x1)**2 + (y2-y1)**2)
41
- max_length = max(max_length, length)
42
- return {"length": max_length, "unit": "pixels"}
43
-
44
- elif defect_type in ["Spalling", "Exposed_Bars"]:
45
- # Area measurement
46
- thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)[1]
47
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
48
-
49
- if contours:
50
- max_area = max(cv2.contourArea(cnt) for cnt in contours)
51
- return {"area": max_area, "unit": "square pixels"}
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return None
54
 
55
- class MultiModelAnalyzer:
56
- """Handle multiple pre-trained models for defect detection"""
57
-
58
- def __init__(self):
59
- self.models = {
60
- "CODEBRIM-ViT": "chanwooong/codebrim-vit-base",
61
- "Concrete-Defect-ResNet": "nlp-waseda/concrete-defect-resnet",
62
- "Bridge-Damage-ViT": "microsoft/bridge-damage-vit-base"
63
- }
64
- self.loaded_models = {}
65
- self.loaded_processors = {}
66
-
67
- @st.cache_resource
68
- def load_model(self, model_name):
69
- """Load specific model and processor"""
70
- try:
71
- if "vit" in model_name.lower():
72
- model = ViTForImageClassification.from_pretrained(self.models[model_name])
73
- else:
74
- model = ResNetForImageClassification.from_pretrained(self.models[model_name])
75
- processor = AutoImageProcessor.from_pretrained(self.models[model_name])
76
- return model, processor
77
- except Exception as e:
78
- st.error(f"Error loading {model_name}: {str(e)}")
79
- return None, None
80
-
81
- def analyze_with_all_models(self, image):
82
- """Run analysis with all available models"""
83
- results = {}
84
- for model_name in self.models.keys():
85
- if model_name not in self.loaded_models:
86
- self.loaded_models[model_name], self.loaded_processors[model_name] = self.load_model(model_name)
87
-
88
- if self.loaded_models[model_name] is not None:
89
- try:
90
- inputs = self.loaded_processors[model_name](images=image, return_tensors="pt")
91
- outputs = self.loaded_models[model_name](**inputs)
92
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
93
- results[model_name] = probs
94
- except Exception as e:
95
- st.error(f"Error analyzing with {model_name}: {str(e)}")
96
-
97
- return results
98
 
99
- class EnhancedRAGSystem:
100
- """Enhanced RAG system with comprehensive construction knowledge"""
101
-
102
- def __init__(self):
103
- self.knowledge_sources = {
104
- "ACI_318": "concrete_design_requirements.json",
105
- "ASTM": "testing_standards.json",
106
- "repair_guidelines": "repair_methods.json",
107
- "case_studies": "defect_cases.json"
108
- }
109
- self.embeddings = None
110
- self.vectorstore = None
111
- self.qa_chain = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- def load_knowledge_base(self):
114
- """Load and combine multiple knowledge sources"""
115
- combined_knowledge = []
116
- for source, filename in self.knowledge_sources.items():
117
- try:
118
- with open(f"knowledge_base/{filename}", 'r') as f:
119
- knowledge = json.load(f)
120
- for item in knowledge:
121
- item['source'] = source
122
- combined_knowledge.extend(knowledge)
123
- except Exception as e:
124
- st.warning(f"Could not load {source}: {str(e)}")
125
 
126
- return combined_knowledge
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- def init_rag(self):
129
- """Initialize enhanced RAG system"""
130
- try:
131
- self.embeddings = HuggingFaceEmbeddings(
132
- model_name="sentence-transformers/all-mpnet-base-v2"
133
- )
134
-
135
- knowledge_base = self.load_knowledge_base()
136
- texts = [
137
- f"{item['defect_type']} ({item['source']})\n" +
138
- f"Description: {item['description']}\n" +
139
- f"Repair: {item['repair_methods']}\n" +
140
- f"Standards: {item['applicable_standards']}\n" +
141
- f"Cases: {item['related_cases']}"
142
- for item in knowledge_base
143
- ]
144
-
145
- self.vectorstore = FAISS.from_texts(texts, self.embeddings)
146
-
147
- self.qa_chain = RetrievalQA.from_chain_type(
148
- llm=HuggingFacePipeline.from_model_id(
149
- model_id="google/flan-t5-large",
150
- task="text2text-generation",
151
- model_kwargs={"temperature": 0.7}
152
- ),
153
- chain_type="stuff",
154
- retriever=self.vectorstore.as_retriever(
155
- search_kwargs={"k": 5}
156
- )
157
  )
158
-
159
- return True
160
- except Exception as e:
161
- st.error(f"Error initializing RAG system: {str(e)}")
162
- return False
 
 
 
 
163
 
164
- class ConstructionDefectAnalyzer:
165
- """Main application class"""
 
 
166
 
167
- def __init__(self):
168
- self.multi_model = MultiModelAnalyzer()
169
- self.rag_system = EnhancedRAGSystem()
170
- self.defect_measurement = DefectMeasurement()
171
-
172
- def analyze_multiple_images(self, images):
173
- """Analyze multiple images in parallel"""
174
- results = []
175
- with ThreadPoolExecutor() as executor:
176
- futures = []
177
- for img in images:
178
- future = executor.submit(self.analyze_single_image, img)
179
- futures.append(future)
180
 
181
- for future in futures:
182
- result = future.result()
183
- results.append(result)
184
-
185
- return results
186
-
187
- def analyze_single_image(self, image):
188
- """Analyze a single image with all features"""
189
- model_results = self.multi_model.analyze_with_all_models(image)
190
- measurements = {}
191
- recommendations = {}
192
-
193
- # Get measurements for detected defects
194
- for model_name, predictions in model_results.items():
195
- for idx, prob in enumerate(predictions):
196
- if prob > 0.15: # Confidence threshold
197
- defect_type = self.get_defect_type(model_name, idx)
198
- measurements[defect_type] = self.defect_measurement.measure_defect(image, defect_type)
199
-
200
- # Get RAG recommendations
201
- if self.rag_system.qa_chain:
202
- query = self.generate_rag_query(defect_type, measurements.get(defect_type))
203
- recommendations[defect_type] = self.rag_system.qa_chain.run(query)
204
-
205
- return {
206
- "model_results": model_results,
207
- "measurements": measurements,
208
- "recommendations": recommendations
209
- }
 
 
 
 
 
 
 
 
 
210
 
211
- @staticmethod
212
- def generate_rag_query(defect_type, measurement):
213
- """Generate detailed query for RAG system"""
214
- query = f"What are the recommended repairs, safety measures, and applicable standards for {defect_type}"
215
- if measurement:
216
- if "length" in measurement:
217
- query += f" with length {measurement['length']} {measurement['unit']}"
218
- elif "area" in measurement:
219
- query += f" with affected area {measurement['area']} {measurement['unit']}"
220
- return query + "?"
221
 
222
  def main():
223
- st.set_page_config(page_title="Advanced Construction Defect Analyzer", layout="wide")
 
 
 
 
 
 
 
 
 
224
 
225
- analyzer = ConstructionDefectAnalyzer()
 
 
 
 
 
 
 
226
 
227
- st.title("πŸ—οΈ Advanced Construction Defect Analyzer")
 
 
 
 
 
 
 
 
 
228
 
229
- # Multiple image upload
230
- uploaded_files = st.file_uploader(
231
- "Upload construction images for analysis",
 
 
232
  type=['jpg', 'jpeg', 'png'],
233
- accept_multiple_files=True
234
  )
235
-
236
- if uploaded_files:
237
- images = [Image.open(file).convert('RGB') for file in uploaded_files]
238
-
239
- with st.spinner("Analyzing images..."):
240
- results = analyzer.analyze_multiple_images(images)
241
-
242
- for idx, (image, result) in enumerate(zip(images, results)):
243
- st.markdown(f"### Analysis Results - Image {idx + 1}")
244
-
245
- col1, col2 = st.columns([1, 2])
246
 
247
- with col1:
248
- st.image(image, caption=f"Image {idx + 1}", use_column_width=True)
 
249
 
250
- with col2:
251
- # Display model comparison
252
- st.markdown("#### Model Predictions")
253
- for model_name, predictions in result['model_results'].items():
254
- st.markdown(f"**{model_name}:**")
255
- for i, prob in enumerate(predictions):
256
- if prob > 0.15:
257
- defect_type = analyzer.get_defect_type(model_name, i)
258
- st.progress(float(prob))
259
- st.markdown(f"{defect_type}: {float(prob)*100:.1f}%")
260
-
261
- # Display measurements
262
- if defect_type in result['measurements']:
263
- st.markdown("**Measurements:**")
264
- st.json(result['measurements'][defect_type])
265
-
266
- # Display recommendations
267
- if defect_type in result['recommendations']:
268
- with st.expander("πŸ“‹ Detailed Analysis"):
269
- st.markdown(result['recommendations'][defect_type])
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  if __name__ == "__main__":
272
  main()
273
-
 
1
+ ```python
2
  import streamlit as st
3
+ from transformers import ViTForImageClassification, ViTImageProcessor
4
+ from PIL import Image
 
 
 
 
5
  import torch
6
+ import time
7
+ import gc
8
+ from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES
9
+ from rag_utils import RAGSystem
 
 
 
 
 
 
 
10
 
11
+ # Constants
12
+ MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
13
+ MAX_IMAGE_SIZE = 1024 # Maximum dimension for images
14
+
15
+ # Cache the model and RAG system globally
16
+ MODEL = None
17
+ PROCESSOR = None
18
+ RAG_SYSTEM = None
19
+
20
+ def cleanup_memory():
21
+ """Clean up memory and GPU cache"""
22
+ gc.collect()
23
+ if torch.cuda.is_available():
24
+ torch.cuda.empty_cache()
25
+
26
+ def init_session_state():
27
+ """Initialize session state variables"""
28
+ if 'history' not in st.session_state:
29
+ st.session_state.history = []
30
+ if 'dark_mode' not in st.session_state:
31
+ st.session_state.dark_mode = False
32
+
33
+ @st.cache_resource(show_spinner="Loading AI model...")
34
+ def load_model():
35
+ """Load and cache the model and processor"""
36
+ try:
37
+ model_name = "google/vit-base-patch16-224"
38
+ model = ViTForImageClassification.from_pretrained(
39
+ model_name,
40
+ num_labels=len(DAMAGE_TYPES),
41
+ ignore_mismatched_sizes=True,
42
+ device_map="auto"
43
+ )
44
+ processor = ViTImageProcessor.from_pretrained(model_name)
45
+ return model, processor
46
+ except Exception as e:
47
+ st.error(f"Error loading model: {str(e)}")
48
+ return None, None
49
+
50
+ def init_rag_system():
51
+ """Initialize the RAG system with knowledge base"""
52
+ global RAG_SYSTEM
53
+ if RAG_SYSTEM is None:
54
+ RAG_SYSTEM = RAGSystem()
55
+ RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE)
56
+
57
+ def validate_image(image):
58
+ """Validate image size and format"""
59
+ if image.size[0] * image.size[1] > 1024 * 1024:
60
+ st.warning("Large image detected. The image will be resized for better performance.")
61
+ if image.format not in ['JPEG', 'PNG']:
62
+ st.warning("Image format not optimal. Consider using JPEG or PNG for better performance.")
63
+
64
+ def preprocess_image(uploaded_file):
65
+ """Preprocess and validate uploaded image"""
66
+ try:
67
+ image = Image.open(uploaded_file)
68
+ # Resize if image is too large
69
+ if max(image.size) > MAX_IMAGE_SIZE:
70
+ ratio = MAX_IMAGE_SIZE / max(image.size)
71
+ new_size = tuple([int(dim * ratio) for dim in image.size])
72
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
73
+ return image
74
+ except Exception as e:
75
+ st.error(f"Error processing image: {str(e)}")
76
  return None
77
 
78
+ def analyze_damage(image, model, processor):
79
+ """Analyze structural damage in the image"""
80
+ try:
81
+ with torch.no_grad():
82
+ image = image.convert('RGB')
83
+ inputs = processor(images=image, return_tensors="pt")
84
+ outputs = model(**inputs)
85
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
86
+ cleanup_memory()
87
+ return probs
88
+ except RuntimeError as e:
89
+ if "out of memory" in str(e):
90
+ cleanup_memory()
91
+ st.error("Out of memory. Please try with a smaller image.")
92
+ else:
93
+ st.error(f"Error analyzing image: {str(e)}")
94
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ def get_custom_css():
97
+ """Return custom CSS styles"""
98
+ return """
99
+ <style>
100
+ .main {
101
+ padding: 2rem;
102
+ }
103
+ .stProgress > div > div > div > div {
104
+ background-image: linear-gradient(to right, var(--progress-color, #ff6b6b), var(--progress-color-end, #f06595));
105
+ }
106
+ .damage-card {
107
+ padding: 1.5rem;
108
+ border-radius: 0.5rem;
109
+ background: var(--card-bg, #f8f9fa);
110
+ margin-bottom: 1rem;
111
+ border: 1px solid var(--border-color, #dee2e6);
112
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
113
+ }
114
+ .damage-header {
115
+ font-size: 1.25rem;
116
+ font-weight: bold;
117
+ margin-bottom: 1rem;
118
+ color: var(--text-color, #212529);
119
+ }
120
+ .dark-mode {
121
+ background-color: #1a1a1a;
122
+ color: #ffffff;
123
+ }
124
+ .dark-mode .damage-card {
125
+ background: #2d2d2d;
126
+ border-color: #404040;
127
+ }
128
+ </style>
129
+ """
130
+
131
+ def display_header():
132
+ """Display application header"""
133
+ st.markdown(
134
+ """
135
+ <div style='text-align: center; padding: 1rem;'>
136
+ <h1>πŸ—οΈ Structural Damage Analyzer Pro</h1>
137
+ <p style='font-size: 1.2rem;'>Advanced AI-powered structural damage assessment tool</p>
138
+ </div>
139
+ """,
140
+ unsafe_allow_html=True
141
+ )
142
+
143
+ def display_enhanced_analysis(damage_type, confidence):
144
+ """Display enhanced analysis from RAG system"""
145
+ try:
146
+ enhanced_info = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence)
147
 
148
+ st.markdown("### πŸ” Enhanced Analysis")
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ with st.expander("πŸ“š Technical Details", expanded=True):
151
+ for detail in enhanced_info["technical_details"]:
152
+ st.markdown(detail)
153
+
154
+ with st.expander("⚠️ Safety Considerations"):
155
+ for safety in enhanced_info["safety_considerations"]:
156
+ st.warning(safety)
157
+
158
+ with st.expander("πŸ‘· Expert Recommendations"):
159
+ for rec in enhanced_info["expert_recommendations"]:
160
+ st.info(rec)
161
+
162
+ custom_query = st.text_input(
163
+ "Ask specific questions about this damage type:",
164
+ placeholder="E.g., What are the long-term implications of this damage?"
165
+ )
166
 
167
+ if custom_query:
168
+ custom_results = RAG_SYSTEM.get_enhanced_analysis(
169
+ damage_type,
170
+ confidence,
171
+ custom_query=custom_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  )
173
+ st.markdown("### πŸ’‘ Custom Query Results")
174
+ for category, results in custom_results.items():
175
+ if results:
176
+ st.markdown(f"**{category.replace('_', ' ').title()}:**")
177
+ for result in results:
178
+ st.markdown(result)
179
+
180
+ except Exception as e:
181
+ st.error(f"Error generating enhanced analysis: {str(e)}")
182
 
183
+ def display_analysis_results(predictions, analysis_time):
184
+ """Display analysis results with damage details"""
185
+ st.markdown("### πŸ“Š Analysis Results")
186
+ st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
187
 
188
+ detected = False
189
+ for idx, prob in enumerate(predictions):
190
+ confidence = float(prob) * 100
191
+ if confidence > 15:
192
+ detected = True
193
+ damage_type = DAMAGE_TYPES[idx]['name']
194
+ cases = KNOWLEDGE_BASE[damage_type]
 
 
 
 
 
 
195
 
196
+ with st.expander(f"{damage_type.replace('_', ' ').title()} - {confidence:.1f}%", expanded=True):
197
+ st.markdown(
198
+ f"""
199
+ <style>
200
+ .stProgress > div > div > div > div {{
201
+ background-color: {DAMAGE_TYPES[idx]['color']} !important;
202
+ }}
203
+ </style>
204
+ """,
205
+ unsafe_allow_html=True
206
+ )
207
+ st.progress(confidence / 100)
208
+
209
+ tabs = st.tabs(["πŸ“‹ Details", "πŸ”§ Repairs", "⚠️ Actions"])
210
+
211
+ with tabs[0]:
212
+ for case in cases:
213
+ st.markdown(f"""
214
+ - **Severity:** {case['severity']}
215
+ - **Description:** {case['description']}
216
+ - **Location:** {case['location']}
217
+ - **Required Expertise:** {case['required_expertise']}
218
+ """)
219
+
220
+ with tabs[1]:
221
+ for step in cases[0]['repair_method']:
222
+ st.markdown(f"βœ“ {step}")
223
+ st.info(f"**Estimated Cost:** {cases[0]['estimated_cost']}")
224
+ st.info(f"**Timeframe:** {cases[0]['timeframe']}")
225
+
226
+ with tabs[2]:
227
+ st.warning("**Immediate Actions Required:**")
228
+ st.markdown(cases[0]['immediate_action'])
229
+ st.success("**Prevention Measures:**")
230
+ st.markdown(cases[0]['prevention'])
231
+
232
+ # Display enhanced analysis
233
+ display_enhanced_analysis(damage_type, confidence)
234
 
235
+ if not detected:
236
+ st.info("No significant structural damage detected. Regular maintenance recommended.")
 
 
 
 
 
 
 
 
237
 
238
  def main():
239
+ """Main application function"""
240
+ init_session_state()
241
+ st.set_page_config(
242
+ page_title="Structural Damage Analyzer Pro",
243
+ page_icon="πŸ—οΈ",
244
+ layout="wide",
245
+ initial_sidebar_state="expanded"
246
+ )
247
+
248
+ st.markdown(get_custom_css(), unsafe_allow_html=True)
249
 
250
+ # Sidebar
251
+ with st.sidebar:
252
+ st.markdown("### βš™οΈ Settings")
253
+ st.session_state.dark_mode = st.toggle("Dark Mode", st.session_state.dark_mode)
254
+ st.markdown("### πŸ“– Analysis History")
255
+ if st.session_state.history:
256
+ for item in st.session_state.history[-5:]:
257
+ st.markdown(f"- {item}")
258
 
259
+ display_header()
260
+
261
+ # Load model and initialize RAG system
262
+ global MODEL, PROCESSOR
263
+ if MODEL is None or PROCESSOR is None:
264
+ with st.spinner("Loading AI model..."):
265
+ MODEL, PROCESSOR = load_model()
266
+ if MODEL is None:
267
+ st.error("Failed to load model. Please refresh the page.")
268
+ return
269
 
270
+ init_rag_system()
271
+
272
+ # File upload
273
+ uploaded_file = st.file_uploader(
274
+ "Drag and drop or click to upload an image",
275
  type=['jpg', 'jpeg', 'png'],
276
+ help="Supported formats: JPG, JPEG, PNG"
277
  )
278
+
279
+ if uploaded_file:
280
+ try:
281
+ if uploaded_file.size > MAX_FILE_SIZE:
282
+ st.error("File size too large. Please upload an image smaller than 5MB.")
283
+ return
 
 
 
 
 
284
 
285
+ image = preprocess_image(uploaded_file)
286
+ if image is None:
287
+ return
288
 
289
+ validate_image(image)
290
+
291
+ col1, col2 = st.columns([1, 1])
292
+
293
+ with col1:
294
+ st.image(image, caption="Uploaded Structure", use_container_width=True)
295
+
296
+ with col2:
297
+ with st.spinner("πŸ” Analyzing damage..."):
298
+ start_time = time.time()
299
+ predictions = analyze_damage(image, MODEL, PROCESSOR)
300
+ analysis_time = time.time() - start_time
301
+
302
+ if predictions is not None:
303
+ display_analysis_results(predictions, analysis_time)
304
+ st.session_state.history.append(f"Analyzed image: {uploaded_file.name}")
305
+
306
+ except Exception as e:
307
+ cleanup_memory()
308
+ st.error(f"Error processing image: {str(e)}")
309
+ st.info("Please try uploading a different image.")
310
+
311
+ # Footer
312
+ st.markdown("---")
313
+ st.markdown(
314
+ """
315
+ <div style='text-align: center'>
316
+ <p>πŸ—οΈ Structural Damage Analyzer Pro | Built with Streamlit & Transformers</p>
317
+ <p style='font-size: 0.8rem;'>For professional use only. Always consult with a structural engineer.</p>
318
+ </div>
319
+ """,
320
+ unsafe_allow_html=True
321
+ )
322
 
323
  if __name__ == "__main__":
324
  main()
325
+ ```