|
import streamlit as st |
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
from PIL import Image |
|
import torch |
|
import time |
|
import gc |
|
import logging |
|
from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES |
|
from rag_utils import RAGSystem |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
MAX_FILE_SIZE = 5 * 1024 * 1024 |
|
MAX_IMAGE_SIZE = 1024 |
|
|
|
|
|
MODEL = None |
|
PROCESSOR = None |
|
RAG_SYSTEM = None |
|
|
|
|
|
def cleanup_memory(): |
|
"""Clean up memory and GPU cache""" |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading AI model...") |
|
def load_model(): |
|
"""Load and cache the model and processor""" |
|
try: |
|
model_name = "google/vit-base-patch16-224" |
|
processor = ViTImageProcessor.from_pretrained(model_name) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = ViTForImageClassification.from_pretrained( |
|
model_name, |
|
num_labels=len(DAMAGE_TYPES), |
|
ignore_mismatched_sizes=True, |
|
).to(device) |
|
|
|
model.eval() |
|
logging.info("Model loaded successfully.") |
|
return model, processor |
|
except Exception as e: |
|
logging.error(f"Failed to load model: {str(e)}") |
|
st.error("Error loading model. Please restart the app.") |
|
return None, None |
|
|
|
|
|
@st.cache_resource |
|
def init_rag_system(): |
|
global RAG_SYSTEM |
|
try: |
|
RAG_SYSTEM = RAGSystem() |
|
RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE) |
|
logging.info("RAG system initialized successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to initialize RAG system: {str(e)}") |
|
st.error("Error initializing knowledge base.") |
|
|
|
|
|
def validate_image(image): |
|
if image.size[0] * image.size[1] > 1024 * 1024: |
|
st.warning("Large image detected. Resizing for better performance.") |
|
if image.format not in ['JPEG', 'PNG']: |
|
st.warning("Non-optimal image format. Use JPEG or PNG.") |
|
|
|
|
|
def preprocess_image(uploaded_file): |
|
try: |
|
image = Image.open(uploaded_file) |
|
if max(image.size) > MAX_IMAGE_SIZE: |
|
ratio = MAX_IMAGE_SIZE / max(image.size) |
|
new_size = tuple([int(dim * ratio) for dim in image.size]) |
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
return image |
|
except Exception as e: |
|
logging.error(f"Error processing image: {str(e)}") |
|
st.error("Image processing error.") |
|
return None |
|
|
|
|
|
def analyze_damage(image, model, processor): |
|
try: |
|
device = next(model.parameters()).device |
|
with torch.no_grad(): |
|
image = image.convert('RGB') |
|
inputs = processor(images=image, return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
outputs = model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] |
|
cleanup_memory() |
|
return probs.cpu() |
|
except Exception as e: |
|
logging.error(f"Error analyzing image: {str(e)}") |
|
st.error("Image analysis failed.") |
|
return None |
|
|
|
|
|
def display_enhanced_analysis(damage_type, confidence): |
|
try: |
|
enhanced_info = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence) |
|
st.markdown("### π Enhanced Analysis") |
|
|
|
with st.expander("π Technical Details", expanded=True): |
|
for detail in enhanced_info["technical_details"]: |
|
st.markdown(detail) |
|
|
|
with st.expander("β οΈ Safety Considerations"): |
|
for safety in enhanced_info["safety_considerations"]: |
|
st.warning(safety) |
|
|
|
with st.expander("π· Expert Recommendations"): |
|
for rec in enhanced_info["expert_recommendations"]: |
|
st.info(rec) |
|
except Exception as e: |
|
logging.error(f"Failed to generate enhanced analysis: {str(e)}") |
|
st.error("Error generating enhanced analysis.") |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Structural Damage Analyzer Pro", |
|
page_icon="ποΈ", |
|
layout="wide" |
|
) |
|
st.title("ποΈ Structural Damage Analyzer Pro") |
|
|
|
|
|
global MODEL, PROCESSOR |
|
if MODEL is None or PROCESSOR is None: |
|
MODEL, PROCESSOR = load_model() |
|
init_rag_system() |
|
|
|
uploaded_file = st.file_uploader( |
|
"Upload an image for analysis (JPG, PNG)", |
|
type=['jpg', 'jpeg', 'png'] |
|
) |
|
|
|
if uploaded_file: |
|
if uploaded_file.size > MAX_FILE_SIZE: |
|
st.error("File too large. Limit: 5MB.") |
|
return |
|
|
|
image = preprocess_image(uploaded_file) |
|
validate_image(image) |
|
|
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
with st.spinner("Analyzing damage..."): |
|
start_time = time.time() |
|
predictions = analyze_damage(image, MODEL, PROCESSOR) |
|
analysis_time = time.time() - start_time |
|
|
|
if predictions is not None: |
|
st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*") |
|
confidence = float(predictions[0]) * 100 |
|
display_enhanced_analysis(DAMAGE_TYPES[0]['name'], confidence) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|