Shakir60 commited on
Commit
d604335
Β·
verified Β·
1 Parent(s): 3b1af89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -99
app.py CHANGED
@@ -4,162 +4,270 @@ from PIL import Image
4
  import torch
5
  import time
6
  import gc
7
- import logging
8
  from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES
9
  from rag_utils import RAGSystem
10
-
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
 
14
  # Constants
15
  MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
16
  MAX_IMAGE_SIZE = 1024 # Maximum dimension for images
 
 
 
 
 
17
 
18
- # Cache the model and RAG system globally
19
- MODEL = None
20
- PROCESSOR = None
21
- RAG_SYSTEM = None
 
 
 
22
 
23
- # Cleanup function for memory
24
  def cleanup_memory():
25
  """Clean up memory and GPU cache"""
26
  gc.collect()
27
  if torch.cuda.is_available():
28
  torch.cuda.empty_cache()
29
 
30
- # Session state initialization
31
  @st.cache_resource(show_spinner="Loading AI model...")
32
  def load_model():
33
- """Load and cache the model and processor"""
34
  try:
35
- model_name = "google/vit-base-patch16-224"
36
- processor = ViTImageProcessor.from_pretrained(model_name)
37
- device = "cuda" if torch.cuda.is_available() else "cpu"
38
-
 
 
 
 
 
 
 
39
  model = ViTForImageClassification.from_pretrained(
40
- model_name,
41
  num_labels=len(DAMAGE_TYPES),
42
  ignore_mismatched_sizes=True,
 
 
43
  ).to(device)
44
-
45
- model.eval()
46
- logging.info("Model loaded successfully.")
47
  return model, processor
48
  except Exception as e:
49
- logging.error(f"Failed to load model: {str(e)}")
50
- st.error("Error loading model. Please restart the app.")
 
51
  return None, None
52
 
53
- # Initialize RAG system
54
- @st.cache_resource
55
  def init_rag_system():
56
- global RAG_SYSTEM
57
- try:
58
- RAG_SYSTEM = RAGSystem()
59
- RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE)
60
- logging.info("RAG system initialized successfully.")
61
- except Exception as e:
62
- logging.error(f"Failed to initialize RAG system: {str(e)}")
63
- st.error("Error initializing knowledge base.")
64
 
65
- # Image validation
66
- def validate_image(image):
67
- if image.size[0] * image.size[1] > 1024 * 1024:
68
- st.warning("Large image detected. Resizing for better performance.")
69
- if image.format not in ['JPEG', 'PNG']:
70
- st.warning("Non-optimal image format. Use JPEG or PNG.")
71
-
72
- # Image preprocessing
73
- def preprocess_image(uploaded_file):
74
  try:
75
- image = Image.open(uploaded_file)
 
 
 
 
76
  if max(image.size) > MAX_IMAGE_SIZE:
77
  ratio = MAX_IMAGE_SIZE / max(image.size)
78
  new_size = tuple([int(dim * ratio) for dim in image.size])
79
  image = image.resize(new_size, Image.Resampling.LANCZOS)
 
80
  return image
81
  except Exception as e:
82
- logging.error(f"Error processing image: {str(e)}")
83
- st.error("Image processing error.")
84
  return None
85
 
86
- # Damage analysis
87
  def analyze_damage(image, model, processor):
 
88
  try:
89
  device = next(model.parameters()).device
90
  with torch.no_grad():
91
- image = image.convert('RGB')
92
  inputs = processor(images=image, return_tensors="pt")
93
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
94
  outputs = model(**inputs)
95
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
 
 
96
  cleanup_memory()
97
  return probs.cpu()
 
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
- logging.error(f"Error analyzing image: {str(e)}")
100
- st.error("Image analysis failed.")
101
  return None
102
 
103
- # Display enhanced analysis
104
- def display_enhanced_analysis(damage_type, confidence):
105
  try:
106
- enhanced_info = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence)
107
- st.markdown("### πŸ” Enhanced Analysis")
108
 
109
- with st.expander("πŸ“š Technical Details", expanded=True):
110
- for detail in enhanced_info["technical_details"]:
111
- st.markdown(detail)
112
-
113
- with st.expander("⚠️ Safety Considerations"):
114
- for safety in enhanced_info["safety_considerations"]:
115
- st.warning(safety)
116
-
117
- with st.expander("πŸ‘· Expert Recommendations"):
118
- for rec in enhanced_info["expert_recommendations"]:
119
- st.info(rec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
- logging.error(f"Failed to generate enhanced analysis: {str(e)}")
122
- st.error("Error generating enhanced analysis.")
123
 
124
- # Main function
125
  def main():
126
- st.set_page_config(
127
- page_title="Structural Damage Analyzer Pro",
128
- page_icon="πŸ—οΈ",
129
- layout="wide"
130
- )
131
- st.title("πŸ—οΈ Structural Damage Analyzer Pro")
132
-
133
- # Load model and initialize RAG system
134
- global MODEL, PROCESSOR
135
- if MODEL is None or PROCESSOR is None:
136
- MODEL, PROCESSOR = load_model()
137
- init_rag_system()
138
-
139
- uploaded_file = st.file_uploader(
140
- "Upload an image for analysis (JPG, PNG)",
141
- type=['jpg', 'jpeg', 'png']
142
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- if uploaded_file:
 
 
 
145
  if uploaded_file.size > MAX_FILE_SIZE:
146
- st.error("File too large. Limit: 5MB.")
147
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- image = preprocess_image(uploaded_file)
150
- validate_image(image)
151
-
152
- st.image(image, caption="Uploaded Image", use_column_width=True)
153
-
154
- with st.spinner("Analyzing damage..."):
155
- start_time = time.time()
156
- predictions = analyze_damage(image, MODEL, PROCESSOR)
157
- analysis_time = time.time() - start_time
158
-
159
- if predictions is not None:
160
- st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
161
- confidence = float(predictions[0]) * 100
162
- display_enhanced_analysis(DAMAGE_TYPES[0]['name'], confidence)
 
 
 
 
 
 
163
 
164
  if __name__ == "__main__":
165
- main()
 
4
  import torch
5
  import time
6
  import gc
 
7
  from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES
8
  from rag_utils import RAGSystem
9
+ import os
 
 
10
 
11
  # Constants
12
  MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
13
  MAX_IMAGE_SIZE = 1024 # Maximum dimension for images
14
+ MODEL_NAME = "google/vit-base-patch16-224"
15
+ CACHE_DIR = "/tmp/model_cache" # HF Spaces compatible cache directory
16
+
17
+ # Ensure cache directory exists
18
+ os.makedirs(CACHE_DIR, exist_ok=True)
19
 
20
+ # Initialize session state for caching
21
+ if 'model' not in st.session_state:
22
+ st.session_state.model = None
23
+ if 'processor' not in st.session_state:
24
+ st.session_state.processor = None
25
+ if 'rag_system' not in st.session_state:
26
+ st.session_state.rag_system = None
27
 
 
28
  def cleanup_memory():
29
  """Clean up memory and GPU cache"""
30
  gc.collect()
31
  if torch.cuda.is_available():
32
  torch.cuda.empty_cache()
33
 
 
34
  @st.cache_resource(show_spinner="Loading AI model...")
35
  def load_model():
36
+ """Load and cache the model and processor with error handling"""
37
  try:
38
+ # Initialize processor with cache directory
39
+ processor = ViTImageProcessor.from_pretrained(
40
+ MODEL_NAME,
41
+ cache_dir=CACHE_DIR,
42
+ local_files_only=False
43
+ )
44
+
45
+ # Determine device - prefer CPU on Hugging Face Spaces
46
+ device = "cpu" # Default to CPU for stability
47
+
48
+ # Load model with specific configuration
49
  model = ViTForImageClassification.from_pretrained(
50
+ MODEL_NAME,
51
  num_labels=len(DAMAGE_TYPES),
52
  ignore_mismatched_sizes=True,
53
+ cache_dir=CACHE_DIR,
54
+ local_files_only=False
55
  ).to(device)
56
+
57
+ model.eval() # Set to evaluation mode
 
58
  return model, processor
59
  except Exception as e:
60
+ st.error(f"Error loading model: {str(e)}")
61
+ st.info("Attempting to reload model... Please wait.")
62
+ cleanup_memory()
63
  return None, None
64
 
 
 
65
  def init_rag_system():
66
+ """Initialize RAG system with error handling"""
67
+ if st.session_state.rag_system is None:
68
+ try:
69
+ st.session_state.rag_system = RAGSystem()
70
+ st.session_state.rag_system.initialize_knowledge_base(KNOWLEDGE_BASE)
71
+ except Exception as e:
72
+ st.error(f"Error initializing RAG system: {str(e)}")
73
+ st.session_state.rag_system = None
74
 
75
+ def process_image(image):
76
+ """Process and validate image with enhanced error handling"""
 
 
 
 
 
 
 
77
  try:
78
+ # Convert to RGB if necessary
79
+ if image.mode != 'RGB':
80
+ image = image.convert('RGB')
81
+
82
+ # Resize if needed
83
  if max(image.size) > MAX_IMAGE_SIZE:
84
  ratio = MAX_IMAGE_SIZE / max(image.size)
85
  new_size = tuple([int(dim * ratio) for dim in image.size])
86
  image = image.resize(new_size, Image.Resampling.LANCZOS)
87
+
88
  return image
89
  except Exception as e:
90
+ st.error(f"Error processing image: {str(e)}")
 
91
  return None
92
 
 
93
  def analyze_damage(image, model, processor):
94
+ """Analyze structural damage with enhanced error handling and memory management"""
95
  try:
96
  device = next(model.parameters()).device
97
  with torch.no_grad():
98
+ # Process image
99
  inputs = processor(images=image, return_tensors="pt")
100
  inputs = {k: v.to(device) for k, v in inputs.items()}
101
+
102
+ # Run inference
103
  outputs = model(**inputs)
104
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
105
+
106
+ # Clean up
107
  cleanup_memory()
108
  return probs.cpu()
109
+ except RuntimeError as e:
110
+ if "out of memory" in str(e):
111
+ cleanup_memory()
112
+ st.error("Memory error. Processing with reduced image size...")
113
+ # Retry with smaller image
114
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
115
+ return analyze_damage(image, model, processor)
116
+ else:
117
+ st.error(f"Error during analysis: {str(e)}")
118
+ return None
119
  except Exception as e:
120
+ st.error(f"Unexpected error: {str(e)}")
 
121
  return None
122
 
123
+ def display_analysis_results(predictions, analysis_time):
124
+ """Display analysis results with enhanced visualization and error handling"""
125
  try:
126
+ st.markdown("### πŸ“Š Analysis Results")
127
+ st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
128
 
129
+ detected = False
130
+ for idx, prob in enumerate(predictions):
131
+ confidence = float(prob) * 100
132
+ if confidence > 15: # Threshold for displaying results
133
+ detected = True
134
+ damage_type = DAMAGE_TYPES[idx]['name']
135
+ risk_level = DAMAGE_TYPES[idx]['risk']
136
+
137
+ # Create expander with color-coded header
138
+ with st.expander(
139
+ f"πŸ” {damage_type.replace('_', ' ').title()} - {confidence:.1f}% ({risk_level})",
140
+ expanded=True
141
+ ):
142
+ # Display confidence bar
143
+ st.progress(confidence / 100)
144
+
145
+ # Create tabs for organized information
146
+ details_tab, repair_tab, action_tab = st.tabs([
147
+ "πŸ“‹ Details", "πŸ”§ Repair Plan", "⚠️ Actions Needed"
148
+ ])
149
+
150
+ with details_tab:
151
+ display_damage_details(damage_type, confidence)
152
+
153
+ with repair_tab:
154
+ display_repair_plan(damage_type)
155
+
156
+ with action_tab:
157
+ display_action_items(damage_type)
158
+
159
+ # Display enhanced analysis if RAG system is available
160
+ if st.session_state.rag_system:
161
+ display_enhanced_analysis(damage_type, confidence)
162
+
163
+ if not detected:
164
+ st.success("No significant structural damage detected. Regular maintenance recommended.")
165
+
166
  except Exception as e:
167
+ st.error(f"Error displaying results: {str(e)}")
 
168
 
 
169
  def main():
170
+ """Main application function with enhanced error handling and UI"""
171
+ try:
172
+ # Page configuration
173
+ st.set_page_config(
174
+ page_title="Structural Damage Analyzer Pro",
175
+ page_icon="πŸ—οΈ",
176
+ layout="wide",
177
+ initial_sidebar_state="expanded"
178
+ )
179
+
180
+ # Custom CSS
181
+ st.markdown(get_custom_css(), unsafe_allow_html=True)
182
+
183
+ # Header
184
+ display_header()
185
+
186
+ # Initialize systems
187
+ if st.session_state.model is None or st.session_state.processor is None:
188
+ with st.spinner("Initializing AI model..."):
189
+ model, processor = load_model()
190
+ if model is None:
191
+ st.error("Failed to initialize model. Please refresh the page.")
192
+ return
193
+ st.session_state.model = model
194
+ st.session_state.processor = processor
195
+
196
+ init_rag_system()
197
+
198
+ # File upload section
199
+ uploaded_file = st.file_uploader(
200
+ "Upload structural image for analysis",
201
+ type=['jpg', 'jpeg', 'png'],
202
+ help="Maximum file size: 5MB"
203
+ )
204
+
205
+ if uploaded_file:
206
+ process_uploaded_file(uploaded_file)
207
+
208
+ # Footer
209
+ display_footer()
210
+
211
+ except Exception as e:
212
+ st.error(f"Application error: {str(e)}")
213
+ st.info("Please refresh the page and try again.")
214
+ cleanup_memory()
215
 
216
+ def process_uploaded_file(uploaded_file):
217
+ """Process uploaded file with comprehensive error handling"""
218
+ try:
219
+ # Validate file size
220
  if uploaded_file.size > MAX_FILE_SIZE:
221
+ st.error("File too large. Please upload an image smaller than 5MB.")
222
  return
223
+
224
+ # Process image
225
+ image = Image.open(uploaded_file)
226
+ processed_image = process_image(image)
227
+ if processed_image is None:
228
+ return
229
+
230
+ # Display layout
231
+ col1, col2 = st.columns([1, 1])
232
+ with col1:
233
+ st.image(processed_image, caption="Uploaded Structure", use_column_width=True)
234
+
235
+ with col2:
236
+ with st.spinner("πŸ” Analyzing structural damage..."):
237
+ start_time = time.time()
238
+ predictions = analyze_damage(
239
+ processed_image,
240
+ st.session_state.model,
241
+ st.session_state.processor
242
+ )
243
+ if predictions is not None:
244
+ analysis_time = time.time() - start_time
245
+ display_analysis_results(predictions, analysis_time)
246
+
247
+ except Exception as e:
248
+ st.error(f"Error processing upload: {str(e)}")
249
+ cleanup_memory()
250
 
251
+ def get_custom_css():
252
+ """Return custom CSS for enhanced UI"""
253
+ return """
254
+ <style>
255
+ .main {
256
+ padding: 1rem;
257
+ }
258
+ .stProgress > div > div > div > div {
259
+ background-image: linear-gradient(to right, #ff6b6b, #f06595);
260
+ }
261
+ .damage-card {
262
+ padding: 1rem;
263
+ border-radius: 0.5rem;
264
+ background: var(--background-color, #ffffff);
265
+ margin-bottom: 1rem;
266
+ border: 1px solid var(--border-color, #e0e0e0);
267
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
268
+ }
269
+ </style>
270
+ """
271
 
272
  if __name__ == "__main__":
273
+ main()