```python
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import time
import gc
from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES
from rag_utils import RAGSystem
# Constants
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
MAX_IMAGE_SIZE = 1024 # Maximum dimension for images
# Cache the model and RAG system globally
MODEL = None
PROCESSOR = None
RAG_SYSTEM = None
def cleanup_memory():
"""Clean up memory and GPU cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def init_session_state():
"""Initialize session state variables"""
if 'history' not in st.session_state:
st.session_state.history = []
if 'dark_mode' not in st.session_state:
st.session_state.dark_mode = False
@st.cache_resource(show_spinner="Loading AI model...")
def load_model():
"""Load and cache the model and processor"""
try:
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=len(DAMAGE_TYPES),
ignore_mismatched_sizes=True,
device_map="auto"
)
processor = ViTImageProcessor.from_pretrained(model_name)
return model, processor
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None
def init_rag_system():
"""Initialize the RAG system with knowledge base"""
global RAG_SYSTEM
if RAG_SYSTEM is None:
RAG_SYSTEM = RAGSystem()
RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE)
def validate_image(image):
"""Validate image size and format"""
if image.size[0] * image.size[1] > 1024 * 1024:
st.warning("Large image detected. The image will be resized for better performance.")
if image.format not in ['JPEG', 'PNG']:
st.warning("Image format not optimal. Consider using JPEG or PNG for better performance.")
def preprocess_image(uploaded_file):
"""Preprocess and validate uploaded image"""
try:
image = Image.open(uploaded_file)
# Resize if image is too large
if max(image.size) > MAX_IMAGE_SIZE:
ratio = MAX_IMAGE_SIZE / max(image.size)
new_size = tuple([int(dim * ratio) for dim in image.size])
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
except Exception as e:
st.error(f"Error processing image: {str(e)}")
return None
def analyze_damage(image, model, processor):
"""Analyze structural damage in the image"""
try:
with torch.no_grad():
image = image.convert('RGB')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
cleanup_memory()
return probs
except RuntimeError as e:
if "out of memory" in str(e):
cleanup_memory()
st.error("Out of memory. Please try with a smaller image.")
else:
st.error(f"Error analyzing image: {str(e)}")
return None
def get_custom_css():
"""Return custom CSS styles"""
return """
"""
def display_header():
"""Display application header"""
st.markdown(
"""
🏗️ Structural Damage Analyzer Pro
Advanced AI-powered structural damage assessment tool
""",
unsafe_allow_html=True
)
def display_enhanced_analysis(damage_type, confidence):
"""Display enhanced analysis from RAG system"""
try:
enhanced_info = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence)
st.markdown("### 🔍 Enhanced Analysis")
with st.expander("📚 Technical Details", expanded=True):
for detail in enhanced_info["technical_details"]:
st.markdown(detail)
with st.expander("⚠️ Safety Considerations"):
for safety in enhanced_info["safety_considerations"]:
st.warning(safety)
with st.expander("👷 Expert Recommendations"):
for rec in enhanced_info["expert_recommendations"]:
st.info(rec)
custom_query = st.text_input(
"Ask specific questions about this damage type:",
placeholder="E.g., What are the long-term implications of this damage?"
)
if custom_query:
custom_results = RAG_SYSTEM.get_enhanced_analysis(
damage_type,
confidence,
custom_query=custom_query
)
st.markdown("### 💡 Custom Query Results")
for category, results in custom_results.items():
if results:
st.markdown(f"**{category.replace('_', ' ').title()}:**")
for result in results:
st.markdown(result)
except Exception as e:
st.error(f"Error generating enhanced analysis: {str(e)}")
def display_analysis_results(predictions, analysis_time):
"""Display analysis results with damage details"""
st.markdown("### 📊 Analysis Results")
st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
detected = False
for idx, prob in enumerate(predictions):
confidence = float(prob) * 100
if confidence > 15:
detected = True
damage_type = DAMAGE_TYPES[idx]['name']
cases = KNOWLEDGE_BASE[damage_type]
with st.expander(f"{damage_type.replace('_', ' ').title()} - {confidence:.1f}%", expanded=True):
st.markdown(
f"""
""",
unsafe_allow_html=True
)
st.progress(confidence / 100)
tabs = st.tabs(["📋 Details", "🔧 Repairs", "⚠️ Actions"])
with tabs[0]:
for case in cases:
st.markdown(f"""
- **Severity:** {case['severity']}
- **Description:** {case['description']}
- **Location:** {case['location']}
- **Required Expertise:** {case['required_expertise']}
""")
with tabs[1]:
for step in cases[0]['repair_method']:
st.markdown(f"✓ {step}")
st.info(f"**Estimated Cost:** {cases[0]['estimated_cost']}")
st.info(f"**Timeframe:** {cases[0]['timeframe']}")
with tabs[2]:
st.warning("**Immediate Actions Required:**")
st.markdown(cases[0]['immediate_action'])
st.success("**Prevention Measures:**")
st.markdown(cases[0]['prevention'])
# Display enhanced analysis
display_enhanced_analysis(damage_type, confidence)
if not detected:
st.info("No significant structural damage detected. Regular maintenance recommended.")
def main():
"""Main application function"""
init_session_state()
st.set_page_config(
page_title="Structural Damage Analyzer Pro",
page_icon="🏗️",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown(get_custom_css(), unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown("### ⚙️ Settings")
st.session_state.dark_mode = st.toggle("Dark Mode", st.session_state.dark_mode)
st.markdown("### 📖 Analysis History")
if st.session_state.history:
for item in st.session_state.history[-5:]:
st.markdown(f"- {item}")
display_header()
# Load model and initialize RAG system
global MODEL, PROCESSOR
if MODEL is None or PROCESSOR is None:
with st.spinner("Loading AI model..."):
MODEL, PROCESSOR = load_model()
if MODEL is None:
st.error("Failed to load model. Please refresh the page.")
return
init_rag_system()
# File upload
uploaded_file = st.file_uploader(
"Drag and drop or click to upload an image",
type=['jpg', 'jpeg', 'png'],
help="Supported formats: JPG, JPEG, PNG"
)
if uploaded_file:
try:
if uploaded_file.size > MAX_FILE_SIZE:
st.error("File size too large. Please upload an image smaller than 5MB.")
return
image = preprocess_image(uploaded_file)
if image is None:
return
validate_image(image)
col1, col2 = st.columns([1, 1])
with col1:
st.image(image, caption="Uploaded Structure", use_container_width=True)
with col2:
with st.spinner("🔍 Analyzing damage..."):
start_time = time.time()
predictions = analyze_damage(image, MODEL, PROCESSOR)
analysis_time = time.time() - start_time
if predictions is not None:
display_analysis_results(predictions, analysis_time)
st.session_state.history.append(f"Analyzed image: {uploaded_file.name}")
except Exception as e:
cleanup_memory()
st.error(f"Error processing image: {str(e)}")
st.info("Please try uploading a different image.")
# Footer
st.markdown("---")
st.markdown(
"""
🏗️ Structural Damage Analyzer Pro | Built with Streamlit & Transformers
For professional use only. Always consult with a structural engineer.
""",
unsafe_allow_html=True
)
if __name__ == "__main__":
main()
```