|
import streamlit as st |
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
from PIL import Image |
|
import torch |
|
import time |
|
import gc |
|
from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES |
|
from rag_utils import RAGSystem |
|
import os |
|
|
|
|
|
MAX_FILE_SIZE = 5 * 1024 * 1024 |
|
MAX_IMAGE_SIZE = 1024 |
|
MODEL_NAME = "google/vit-base-patch16-224" |
|
CACHE_DIR = "/tmp/model_cache" |
|
|
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
if 'model' not in st.session_state: |
|
st.session_state.model = None |
|
if 'processor' not in st.session_state: |
|
st.session_state.processor = None |
|
if 'rag_system' not in st.session_state: |
|
st.session_state.rag_system = None |
|
|
|
def cleanup_memory(): |
|
"""Clean up memory and GPU cache""" |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
@st.cache_resource(show_spinner="Loading AI model...") |
|
def load_model(): |
|
"""Load and cache the model and processor with error handling""" |
|
try: |
|
|
|
processor = ViTImageProcessor.from_pretrained( |
|
MODEL_NAME, |
|
cache_dir=CACHE_DIR, |
|
local_files_only=False |
|
) |
|
|
|
|
|
device = "cpu" |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained( |
|
MODEL_NAME, |
|
num_labels=len(DAMAGE_TYPES), |
|
ignore_mismatched_sizes=True, |
|
cache_dir=CACHE_DIR, |
|
local_files_only=False |
|
).to(device) |
|
|
|
model.eval() |
|
return model, processor |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.info("Attempting to reload model... Please wait.") |
|
cleanup_memory() |
|
return None, None |
|
|
|
def init_rag_system(): |
|
"""Initialize RAG system with error handling""" |
|
if st.session_state.rag_system is None: |
|
try: |
|
st.session_state.rag_system = RAGSystem() |
|
st.session_state.rag_system.initialize_knowledge_base(KNOWLEDGE_BASE) |
|
except Exception as e: |
|
st.error(f"Error initializing RAG system: {str(e)}") |
|
st.session_state.rag_system = None |
|
|
|
def process_image(image): |
|
"""Process and validate image with enhanced error handling""" |
|
try: |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
if max(image.size) > MAX_IMAGE_SIZE: |
|
ratio = MAX_IMAGE_SIZE / max(image.size) |
|
new_size = tuple([int(dim * ratio) for dim in image.size]) |
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
return image |
|
except Exception as e: |
|
st.error(f"Error processing image: {str(e)}") |
|
return None |
|
|
|
def analyze_damage(image, model, processor): |
|
"""Analyze structural damage with enhanced error handling and memory management""" |
|
try: |
|
device = next(model.parameters()).device |
|
with torch.no_grad(): |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
outputs = model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] |
|
|
|
|
|
cleanup_memory() |
|
return probs.cpu() |
|
except RuntimeError as e: |
|
if "out of memory" in str(e): |
|
cleanup_memory() |
|
st.error("Memory error. Processing with reduced image size...") |
|
|
|
image = image.resize((224, 224), Image.Resampling.LANCZOS) |
|
return analyze_damage(image, model, processor) |
|
else: |
|
st.error(f"Error during analysis: {str(e)}") |
|
return None |
|
except Exception as e: |
|
st.error(f"Unexpected error: {str(e)}") |
|
return None |
|
|
|
def display_analysis_results(predictions, analysis_time): |
|
"""Display analysis results with enhanced visualization and error handling""" |
|
try: |
|
st.markdown("### π Analysis Results") |
|
st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*") |
|
|
|
detected = False |
|
for idx, prob in enumerate(predictions): |
|
confidence = float(prob) * 100 |
|
if confidence > 15: |
|
detected = True |
|
damage_type = DAMAGE_TYPES[idx]['name'] |
|
risk_level = DAMAGE_TYPES[idx]['risk'] |
|
|
|
|
|
with st.expander( |
|
f"π {damage_type.replace('_', ' ').title()} - {confidence:.1f}% ({risk_level})", |
|
expanded=True |
|
): |
|
|
|
st.progress(confidence / 100) |
|
|
|
|
|
details_tab, repair_tab, action_tab = st.tabs([ |
|
"π Details", "π§ Repair Plan", "β οΈ Actions Needed" |
|
]) |
|
|
|
with details_tab: |
|
display_damage_details(damage_type, confidence) |
|
|
|
with repair_tab: |
|
display_repair_plan(damage_type) |
|
|
|
with action_tab: |
|
display_action_items(damage_type) |
|
|
|
|
|
if st.session_state.rag_system: |
|
display_enhanced_analysis(damage_type, confidence) |
|
|
|
if not detected: |
|
st.success("No significant structural damage detected. Regular maintenance recommended.") |
|
|
|
except Exception as e: |
|
st.error(f"Error displaying results: {str(e)}") |
|
|
|
def main(): |
|
"""Main application function with enhanced error handling and UI""" |
|
try: |
|
|
|
st.set_page_config( |
|
page_title="Structural Damage Analyzer Pro", |
|
page_icon="ποΈ", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
st.markdown(get_custom_css(), unsafe_allow_html=True) |
|
|
|
|
|
display_header() |
|
|
|
|
|
if st.session_state.model is None or st.session_state.processor is None: |
|
with st.spinner("Initializing AI model..."): |
|
model, processor = load_model() |
|
if model is None: |
|
st.error("Failed to initialize model. Please refresh the page.") |
|
return |
|
st.session_state.model = model |
|
st.session_state.processor = processor |
|
|
|
init_rag_system() |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
"Upload structural image for analysis", |
|
type=['jpg', 'jpeg', 'png'], |
|
help="Maximum file size: 5MB" |
|
) |
|
|
|
if uploaded_file: |
|
process_uploaded_file(uploaded_file) |
|
|
|
|
|
display_footer() |
|
|
|
except Exception as e: |
|
st.error(f"Application error: {str(e)}") |
|
st.info("Please refresh the page and try again.") |
|
cleanup_memory() |
|
|
|
def process_uploaded_file(uploaded_file): |
|
"""Process uploaded file with comprehensive error handling""" |
|
try: |
|
|
|
if uploaded_file.size > MAX_FILE_SIZE: |
|
st.error("File too large. Please upload an image smaller than 5MB.") |
|
return |
|
|
|
|
|
image = Image.open(uploaded_file) |
|
processed_image = process_image(image) |
|
if processed_image is None: |
|
return |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
st.image(processed_image, caption="Uploaded Structure", use_column_width=True) |
|
|
|
with col2: |
|
with st.spinner("π Analyzing structural damage..."): |
|
start_time = time.time() |
|
predictions = analyze_damage( |
|
processed_image, |
|
st.session_state.model, |
|
st.session_state.processor |
|
) |
|
if predictions is not None: |
|
analysis_time = time.time() - start_time |
|
display_analysis_results(predictions, analysis_time) |
|
|
|
except Exception as e: |
|
st.error(f"Error processing upload: {str(e)}") |
|
cleanup_memory() |
|
|
|
def get_custom_css(): |
|
"""Return custom CSS for enhanced UI""" |
|
return """ |
|
<style> |
|
.main { |
|
padding: 1rem; |
|
} |
|
.stProgress > div > div > div > div { |
|
background-image: linear-gradient(to right, #ff6b6b, #f06595); |
|
} |
|
.damage-card { |
|
padding: 1rem; |
|
border-radius: 0.5rem; |
|
background: var(--background-color, #ffffff); |
|
margin-bottom: 1rem; |
|
border: 1px solid var(--border-color, #e0e0e0); |
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1); |
|
} |
|
</style> |
|
""" |
|
|
|
if __name__ == "__main__": |
|
main() |