smart / app.py
Shakir60's picture
Update app.py
8b2b9dd verified
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import time
import gc
import logging
from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES, validate_knowledge_base
from rag_utils import RAGSystem
import structlog
from typing import Optional, Dict, Any
from functools import lru_cache
from dynaconf import Dynaconf
# Configure settings
settings = Dynaconf(
settings_files=['settings.yaml', '.secrets.yaml'],
environments=True
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = structlog.get_logger()
# Constants
MAX_FILE_SIZE = settings.get('max_file_size', 5 * 1024 * 1024) # 5MB default
MAX_IMAGE_SIZE = settings.get('max_image_size', 1024) # Maximum dimension default
MODEL = None
PROCESSOR = None
RAG_SYSTEM = None
def handle_exceptions(func):
"""Decorator for exception handling"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
cleanup_memory()
st.error(f"Error in {func.__name__}: {str(e)}")
logger.error(f"Error in {func.__name__}: {str(e)}", exc_info=True)
return None
return wrapper
def cleanup_memory():
"""Clean up memory and GPU cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def init_session_state():
"""Initialize session state variables"""
if 'history' not in st.session_state:
st.session_state.history = []
if 'dark_mode' not in st.session_state:
st.session_state.dark_mode = False
if 'analysis_count' not in st.session_state:
st.session_state.analysis_count = 0
@st.cache_resource(show_spinner="Loading AI model...", ttl=3600*24)
def load_model():
"""Load and cache the model with daily refresh"""
try:
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=len(DAMAGE_TYPES),
ignore_mismatched_sizes=True,
).to(device)
model.eval()
logger.info("Model loaded successfully", device=device)
return model, processor
except Exception as e:
logger.error("Error loading model", error=str(e))
return None, None
def validate_upload(file) -> bool:
"""Validate uploaded file for security"""
if not file:
return False
allowed_extensions = {'jpg', 'jpeg', 'png'}
if not file.name.lower().endswith(tuple(allowed_extensions)):
st.error("Invalid file type. Please upload a JPG or PNG image.")
return False
if file.size > MAX_FILE_SIZE:
st.error(f"File too large. Maximum size is {MAX_FILE_SIZE/1024/1024:.1f}MB.")
return False
if file.type not in ['image/jpeg', 'image/png']:
st.error("Invalid file content type.")
return False
return True
@handle_exceptions
def preprocess_image(uploaded_file) -> Optional[Image.Image]:
"""Preprocess and validate uploaded image"""
image = Image.open(uploaded_file)
if max(image.size) > MAX_IMAGE_SIZE:
ratio = MAX_IMAGE_SIZE / max(image.size)
new_size = tuple([int(dim * ratio) for dim in image.size])
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
@handle_exceptions
def analyze_damage(image: Image.Image, model: ViTForImageClassification,
processor: ViTImageProcessor) -> Optional[torch.Tensor]:
"""Analyze structural damage in the image"""
progress_bar = st.progress(0)
stages = ['Preprocessing', 'Analysis', 'Results Generation']
try:
device = next(model.parameters()).device
for i, stage in enumerate(stages):
progress_bar.progress((i + 1) / len(stages))
st.write(f"Stage {i+1}/{len(stages)}: {stage}")
if i == 0: # Preprocessing
image = image.convert('RGB')
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
elif i == 1: # Analysis
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
elif i == 2: # Results Generation
result = probs.cpu()
return result
except RuntimeError as e:
if "out of memory" in str(e):
cleanup_memory()
st.error("Out of memory. Please try with a smaller image.")
else:
st.error(f"Error analyzing image: {str(e)}")
return None
def generate_downloadable_report(analysis_results: Dict):
"""Generate a downloadable PDF report"""
try:
import io
from reportlab.lib import colors
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=letter)
styles = getSampleStyleSheet()
story = []
# Add title
story.append(Paragraph("Structural Damage Analysis Report", styles['Title']))
story.append(Spacer(1, 12))
# Add analysis results
for damage_type, details in analysis_results.items():
story.append(Paragraph(f"Damage Type: {damage_type}", styles['Heading1']))
story.append(Paragraph(f"Confidence: {details['confidence']}%", styles['Normal']))
story.append(Paragraph("Recommendations:", styles['Heading2']))
for rec in details['recommendations']:
story.append(Paragraph(f"β€’ {rec}", styles['Normal']))
story.append(Spacer(1, 12))
doc.build(story)
pdf = buffer.getvalue()
buffer.close()
return pdf
except Exception as e:
logger.error(f"Error generating report: {str(e)}")
return None
def display_analysis_results(predictions: torch.Tensor, analysis_time: float):
"""Display analysis results with damage details"""
st.markdown("### πŸ“Š Analysis Results")
st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
analysis_results = {}
detected = False
for idx, prob in enumerate(predictions):
confidence = float(prob) * 100
if confidence > 15:
detected = True
damage_type = DAMAGE_TYPES[idx]['name']
with st.expander(f"{damage_type.replace('_', ' ').title()} - {confidence:.1f}%", expanded=True):
st.progress(confidence / 100)
# Get enhanced analysis from RAG system
analysis = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence)
tabs = st.tabs(["πŸ“‹ Details", "πŸ”§ Repairs", "⚠️ Safety"])
with tabs[0]:
for detail in analysis['technical_details']:
st.markdown(detail)
with tabs[1]:
for rec in analysis['expert_recommendations']:
st.markdown(rec)
with tabs[2]:
for safety in analysis['safety_considerations']:
st.warning(safety)
analysis_results[damage_type] = {
'confidence': confidence,
'recommendations': analysis['expert_recommendations']
}
if not detected:
st.info("No significant structural damage detected. Regular maintenance recommended.")
else:
# Generate download button for report
pdf_report = generate_downloadable_report(analysis_results)
if pdf_report:
st.download_button(
label="Download Analysis Report",
data=pdf_report,
file_name="damage_analysis_report.pdf",
mime="application/pdf"
)
def main():
"""Main application function"""
st.set_page_config(
page_title="Structural Damage Analyzer Pro",
page_icon="πŸ—οΈ",
layout="wide",
initial_sidebar_state="expanded"
)
init_session_state()
st.markdown("""
<div style='text-align: center; padding: 1rem;'>
<h1>πŸ—οΈ Structural Damage Analyzer Pro</h1>
<p style='font-size: 1.2rem;'>Advanced AI-powered structural damage assessment tool</p>
</div>
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown("### βš™οΈ Settings")
st.session_state.dark_mode = st.toggle("Dark Mode", st.session_state.dark_mode)
st.markdown("### πŸ“– Analysis History")
if st.session_state.history:
for item in st.session_state.history[-5:]:
st.markdown(f"- {item}")
# Load model and initialize RAG system
global MODEL, PROCESSOR, RAG_SYSTEM
if MODEL is None or PROCESSOR is None:
MODEL, PROCESSOR = load_model()
if MODEL is None:
st.error("Failed to load model. Please refresh the page.")
return
if RAG_SYSTEM is None:
RAG_SYSTEM = RAGSystem()
RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE)
# Validate knowledge base
if not validate_knowledge_base():
st.error("Knowledge base validation failed. Please check the logs.")
return
# File upload
uploaded_file = st.file_uploader(
"Upload an image for analysis",
type=['jpg', 'jpeg', 'png'],
help="Supported formats: JPG, JPEG, PNG"
)
if uploaded_file and validate_upload(uploaded_file):
try:
image = preprocess_image(uploaded_file)
if image is None:
return
col1, col2 = st.columns([1, 1])
with col1:
st.image(image, caption="Uploaded Structure", use_column_width=True)
with col2:
start_time = time.time()
predictions = analyze_damage(image, MODEL, PROCESSOR)
if predictions is not None:
analysis_time = time.time() - start_time
display_analysis_results(predictions, analysis_time)
st.session_state.history.append(f"Analyzed {uploaded_file.name}")
st.session_state.analysis_count += 1
except Exception as e:
logger.error("Error in main processing loop", error=str(e))
cleanup_memory()
st.error("An error occurred during processing. Please try again.")
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>πŸ—οΈ Structural Damage Analyzer Pro | Built with Streamlit & Transformers</p>
<p style='font-size: 0.8rem;'>For professional use only. Always consult with a structural engineer.</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()