smart / app.py
Shakir60's picture
Update app.py
d9a6878 verified
raw
history blame
5.6 kB
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import time
import gc
import logging
from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES
from rag_utils import RAGSystem
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Constants
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
MAX_IMAGE_SIZE = 1024 # Maximum dimension for images
# Cache the model and RAG system globally
MODEL = None
PROCESSOR = None
RAG_SYSTEM = None
# Cleanup function for memory
def cleanup_memory():
"""Clean up memory and GPU cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Session state initialization
@st.cache_resource(show_spinner="Loading AI model...")
def load_model():
"""Load and cache the model and processor"""
try:
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=len(DAMAGE_TYPES),
ignore_mismatched_sizes=True,
).to(device)
model.eval()
logging.info("Model loaded successfully.")
return model, processor
except Exception as e:
logging.error(f"Failed to load model: {str(e)}")
st.error("Error loading model. Please restart the app.")
return None, None
# Initialize RAG system
@st.cache_resource
def init_rag_system():
global RAG_SYSTEM
try:
RAG_SYSTEM = RAGSystem()
RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE)
logging.info("RAG system initialized successfully.")
except Exception as e:
logging.error(f"Failed to initialize RAG system: {str(e)}")
st.error("Error initializing knowledge base.")
# Image validation
def validate_image(image):
if image.size[0] * image.size[1] > 1024 * 1024:
st.warning("Large image detected. Resizing for better performance.")
if image.format not in ['JPEG', 'PNG']:
st.warning("Non-optimal image format. Use JPEG or PNG.")
# Image preprocessing
def preprocess_image(uploaded_file):
try:
image = Image.open(uploaded_file)
if max(image.size) > MAX_IMAGE_SIZE:
ratio = MAX_IMAGE_SIZE / max(image.size)
new_size = tuple([int(dim * ratio) for dim in image.size])
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
except Exception as e:
logging.error(f"Error processing image: {str(e)}")
st.error("Image processing error.")
return None
# Damage analysis
def analyze_damage(image, model, processor):
try:
device = next(model.parameters()).device
with torch.no_grad():
image = image.convert('RGB')
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
cleanup_memory()
return probs.cpu()
except Exception as e:
logging.error(f"Error analyzing image: {str(e)}")
st.error("Image analysis failed.")
return None
# Display enhanced analysis
def display_enhanced_analysis(damage_type, confidence):
try:
enhanced_info = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence)
st.markdown("### πŸ” Enhanced Analysis")
with st.expander("πŸ“š Technical Details", expanded=True):
for detail in enhanced_info["technical_details"]:
st.markdown(detail)
with st.expander("⚠️ Safety Considerations"):
for safety in enhanced_info["safety_considerations"]:
st.warning(safety)
with st.expander("πŸ‘· Expert Recommendations"):
for rec in enhanced_info["expert_recommendations"]:
st.info(rec)
except Exception as e:
logging.error(f"Failed to generate enhanced analysis: {str(e)}")
st.error("Error generating enhanced analysis.")
# Main function
def main():
st.set_page_config(
page_title="Structural Damage Analyzer Pro",
page_icon="πŸ—οΈ",
layout="wide"
)
st.title("πŸ—οΈ Structural Damage Analyzer Pro")
# Load model and initialize RAG system
global MODEL, PROCESSOR
if MODEL is None or PROCESSOR is None:
MODEL, PROCESSOR = load_model()
init_rag_system()
uploaded_file = st.file_uploader(
"Upload an image for analysis (JPG, PNG)",
type=['jpg', 'jpeg', 'png']
)
if uploaded_file:
if uploaded_file.size > MAX_FILE_SIZE:
st.error("File too large. Limit: 5MB.")
return
image = preprocess_image(uploaded_file)
validate_image(image)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner("Analyzing damage..."):
start_time = time.time()
predictions = analyze_damage(image, MODEL, PROCESSOR)
analysis_time = time.time() - start_time
if predictions is not None:
st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
confidence = float(predictions[0]) * 100
display_enhanced_analysis(DAMAGE_TYPES[0]['name'], confidence)
if __name__ == "__main__":
main()