smart / app.py
Shakir60's picture
Update app.py
3db0aec verified
raw
history blame
11.4 kB
```python
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import time
import gc
from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES
from rag_utils import RAGSystem
# Constants
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
MAX_IMAGE_SIZE = 1024 # Maximum dimension for images
# Cache the model and RAG system globally
MODEL = None
PROCESSOR = None
RAG_SYSTEM = None
def cleanup_memory():
"""Clean up memory and GPU cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def init_session_state():
"""Initialize session state variables"""
if 'history' not in st.session_state:
st.session_state.history = []
if 'dark_mode' not in st.session_state:
st.session_state.dark_mode = False
@st.cache_resource(show_spinner="Loading AI model...")
def load_model():
"""Load and cache the model and processor"""
try:
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=len(DAMAGE_TYPES),
ignore_mismatched_sizes=True,
device_map="auto"
)
processor = ViTImageProcessor.from_pretrained(model_name)
return model, processor
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None
def init_rag_system():
"""Initialize the RAG system with knowledge base"""
global RAG_SYSTEM
if RAG_SYSTEM is None:
RAG_SYSTEM = RAGSystem()
RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE)
def validate_image(image):
"""Validate image size and format"""
if image.size[0] * image.size[1] > 1024 * 1024:
st.warning("Large image detected. The image will be resized for better performance.")
if image.format not in ['JPEG', 'PNG']:
st.warning("Image format not optimal. Consider using JPEG or PNG for better performance.")
def preprocess_image(uploaded_file):
"""Preprocess and validate uploaded image"""
try:
image = Image.open(uploaded_file)
# Resize if image is too large
if max(image.size) > MAX_IMAGE_SIZE:
ratio = MAX_IMAGE_SIZE / max(image.size)
new_size = tuple([int(dim * ratio) for dim in image.size])
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
except Exception as e:
st.error(f"Error processing image: {str(e)}")
return None
def analyze_damage(image, model, processor):
"""Analyze structural damage in the image"""
try:
with torch.no_grad():
image = image.convert('RGB')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
cleanup_memory()
return probs
except RuntimeError as e:
if "out of memory" in str(e):
cleanup_memory()
st.error("Out of memory. Please try with a smaller image.")
else:
st.error(f"Error analyzing image: {str(e)}")
return None
def get_custom_css():
"""Return custom CSS styles"""
return """
<style>
.main {
padding: 2rem;
}
.stProgress > div > div > div > div {
background-image: linear-gradient(to right, var(--progress-color, #ff6b6b), var(--progress-color-end, #f06595));
}
.damage-card {
padding: 1.5rem;
border-radius: 0.5rem;
background: var(--card-bg, #f8f9fa);
margin-bottom: 1rem;
border: 1px solid var(--border-color, #dee2e6);
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.damage-header {
font-size: 1.25rem;
font-weight: bold;
margin-bottom: 1rem;
color: var(--text-color, #212529);
}
.dark-mode {
background-color: #1a1a1a;
color: #ffffff;
}
.dark-mode .damage-card {
background: #2d2d2d;
border-color: #404040;
}
</style>
"""
def display_header():
"""Display application header"""
st.markdown(
"""
<div style='text-align: center; padding: 1rem;'>
<h1>πŸ—οΈ Structural Damage Analyzer Pro</h1>
<p style='font-size: 1.2rem;'>Advanced AI-powered structural damage assessment tool</p>
</div>
""",
unsafe_allow_html=True
)
def display_enhanced_analysis(damage_type, confidence):
"""Display enhanced analysis from RAG system"""
try:
enhanced_info = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence)
st.markdown("### πŸ” Enhanced Analysis")
with st.expander("πŸ“š Technical Details", expanded=True):
for detail in enhanced_info["technical_details"]:
st.markdown(detail)
with st.expander("⚠️ Safety Considerations"):
for safety in enhanced_info["safety_considerations"]:
st.warning(safety)
with st.expander("πŸ‘· Expert Recommendations"):
for rec in enhanced_info["expert_recommendations"]:
st.info(rec)
custom_query = st.text_input(
"Ask specific questions about this damage type:",
placeholder="E.g., What are the long-term implications of this damage?"
)
if custom_query:
custom_results = RAG_SYSTEM.get_enhanced_analysis(
damage_type,
confidence,
custom_query=custom_query
)
st.markdown("### πŸ’‘ Custom Query Results")
for category, results in custom_results.items():
if results:
st.markdown(f"**{category.replace('_', ' ').title()}:**")
for result in results:
st.markdown(result)
except Exception as e:
st.error(f"Error generating enhanced analysis: {str(e)}")
def display_analysis_results(predictions, analysis_time):
"""Display analysis results with damage details"""
st.markdown("### πŸ“Š Analysis Results")
st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
detected = False
for idx, prob in enumerate(predictions):
confidence = float(prob) * 100
if confidence > 15:
detected = True
damage_type = DAMAGE_TYPES[idx]['name']
cases = KNOWLEDGE_BASE[damage_type]
with st.expander(f"{damage_type.replace('_', ' ').title()} - {confidence:.1f}%", expanded=True):
st.markdown(
f"""
<style>
.stProgress > div > div > div > div {{
background-color: {DAMAGE_TYPES[idx]['color']} !important;
}}
</style>
""",
unsafe_allow_html=True
)
st.progress(confidence / 100)
tabs = st.tabs(["πŸ“‹ Details", "πŸ”§ Repairs", "⚠️ Actions"])
with tabs[0]:
for case in cases:
st.markdown(f"""
- **Severity:** {case['severity']}
- **Description:** {case['description']}
- **Location:** {case['location']}
- **Required Expertise:** {case['required_expertise']}
""")
with tabs[1]:
for step in cases[0]['repair_method']:
st.markdown(f"βœ“ {step}")
st.info(f"**Estimated Cost:** {cases[0]['estimated_cost']}")
st.info(f"**Timeframe:** {cases[0]['timeframe']}")
with tabs[2]:
st.warning("**Immediate Actions Required:**")
st.markdown(cases[0]['immediate_action'])
st.success("**Prevention Measures:**")
st.markdown(cases[0]['prevention'])
# Display enhanced analysis
display_enhanced_analysis(damage_type, confidence)
if not detected:
st.info("No significant structural damage detected. Regular maintenance recommended.")
def main():
"""Main application function"""
init_session_state()
st.set_page_config(
page_title="Structural Damage Analyzer Pro",
page_icon="πŸ—οΈ",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown(get_custom_css(), unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown("### βš™οΈ Settings")
st.session_state.dark_mode = st.toggle("Dark Mode", st.session_state.dark_mode)
st.markdown("### πŸ“– Analysis History")
if st.session_state.history:
for item in st.session_state.history[-5:]:
st.markdown(f"- {item}")
display_header()
# Load model and initialize RAG system
global MODEL, PROCESSOR
if MODEL is None or PROCESSOR is None:
with st.spinner("Loading AI model..."):
MODEL, PROCESSOR = load_model()
if MODEL is None:
st.error("Failed to load model. Please refresh the page.")
return
init_rag_system()
# File upload
uploaded_file = st.file_uploader(
"Drag and drop or click to upload an image",
type=['jpg', 'jpeg', 'png'],
help="Supported formats: JPG, JPEG, PNG"
)
if uploaded_file:
try:
if uploaded_file.size > MAX_FILE_SIZE:
st.error("File size too large. Please upload an image smaller than 5MB.")
return
image = preprocess_image(uploaded_file)
if image is None:
return
validate_image(image)
col1, col2 = st.columns([1, 1])
with col1:
st.image(image, caption="Uploaded Structure", use_container_width=True)
with col2:
with st.spinner("πŸ” Analyzing damage..."):
start_time = time.time()
predictions = analyze_damage(image, MODEL, PROCESSOR)
analysis_time = time.time() - start_time
if predictions is not None:
display_analysis_results(predictions, analysis_time)
st.session_state.history.append(f"Analyzed image: {uploaded_file.name}")
except Exception as e:
cleanup_memory()
st.error(f"Error processing image: {str(e)}")
st.info("Please try uploading a different image.")
# Footer
st.markdown("---")
st.markdown(
"""
<div style='text-align: center'>
<p>πŸ—οΈ Structural Damage Analyzer Pro | Built with Streamlit & Transformers</p>
<p style='font-size: 0.8rem;'>For professional use only. Always consult with a structural engineer.</p>
</div>
""",
unsafe_allow_html=True
)
if __name__ == "__main__":
main()
```