File size: 5,599 Bytes
15526fe
3db0aec
 
9814e5f
3db0aec
 
d9a6878
3db0aec
 
15526fe
d9a6878
 
 
3db0aec
 
 
 
 
 
 
 
 
d9a6878
3db0aec
 
 
 
 
 
d9a6878
3db0aec
 
 
 
 
7030681
 
d9a6878
3db0aec
 
 
 
7030681
d9a6878
7030681
d9a6878
3db0aec
 
d9a6878
 
3db0aec
 
d9a6878
 
3db0aec
 
d9a6878
3db0aec
 
d9a6878
 
 
 
3db0aec
d9a6878
3db0aec
 
d9a6878
3db0aec
d9a6878
3db0aec
d9a6878
3db0aec
 
 
 
 
 
 
 
 
d9a6878
 
dbd2162
60dab82
d9a6878
3db0aec
 
7030681
3db0aec
 
 
7030681
3db0aec
 
 
d9a6878
 
 
 
3db0aec
15526fe
d9a6878
3db0aec
 
 
 
dbd2162
3db0aec
 
 
d9a6878
3db0aec
 
 
d9a6878
3db0aec
 
 
 
d9a6878
 
15526fe
d9a6878
9814e5f
3db0aec
 
 
d9a6878
3db0aec
d9a6878
3db0aec
 
 
 
d9a6878
3db0aec
 
 
d9a6878
 
dbd2162
3db0aec
 
d9a6878
 
 
3db0aec
d9a6878
 
3db0aec
d9a6878
 
 
 
 
 
 
 
 
 
 
9814e5f
 
d9a6878
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import time
import gc
import logging
from knowledge_base import KNOWLEDGE_BASE, DAMAGE_TYPES
from rag_utils import RAGSystem

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
MAX_FILE_SIZE = 5 * 1024 * 1024  # 5MB
MAX_IMAGE_SIZE = 1024  # Maximum dimension for images

# Cache the model and RAG system globally
MODEL = None
PROCESSOR = None
RAG_SYSTEM = None

# Cleanup function for memory
def cleanup_memory():
    """Clean up memory and GPU cache"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Session state initialization
@st.cache_resource(show_spinner="Loading AI model...")
def load_model():
    """Load and cache the model and processor"""
    try:
        model_name = "google/vit-base-patch16-224"
        processor = ViTImageProcessor.from_pretrained(model_name)
        device = "cuda" if torch.cuda.is_available() else "cpu"

        model = ViTForImageClassification.from_pretrained(
            model_name,
            num_labels=len(DAMAGE_TYPES),
            ignore_mismatched_sizes=True,
        ).to(device)

        model.eval()
        logging.info("Model loaded successfully.")
        return model, processor
    except Exception as e:
        logging.error(f"Failed to load model: {str(e)}")
        st.error("Error loading model. Please restart the app.")
        return None, None

# Initialize RAG system
@st.cache_resource
def init_rag_system():
    global RAG_SYSTEM
    try:
        RAG_SYSTEM = RAGSystem()
        RAG_SYSTEM.initialize_knowledge_base(KNOWLEDGE_BASE)
        logging.info("RAG system initialized successfully.")
    except Exception as e:
        logging.error(f"Failed to initialize RAG system: {str(e)}")
        st.error("Error initializing knowledge base.")

# Image validation
def validate_image(image):
    if image.size[0] * image.size[1] > 1024 * 1024:
        st.warning("Large image detected. Resizing for better performance.")
    if image.format not in ['JPEG', 'PNG']:
        st.warning("Non-optimal image format. Use JPEG or PNG.")

# Image preprocessing
def preprocess_image(uploaded_file):
    try:
        image = Image.open(uploaded_file)
        if max(image.size) > MAX_IMAGE_SIZE:
            ratio = MAX_IMAGE_SIZE / max(image.size)
            new_size = tuple([int(dim * ratio) for dim in image.size])
            image = image.resize(new_size, Image.Resampling.LANCZOS)
        return image
    except Exception as e:
        logging.error(f"Error processing image: {str(e)}")
        st.error("Image processing error.")
        return None

# Damage analysis
def analyze_damage(image, model, processor):
    try:
        device = next(model.parameters()).device
        with torch.no_grad():
            image = image.convert('RGB')
            inputs = processor(images=image, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
            cleanup_memory()
            return probs.cpu()
    except Exception as e:
        logging.error(f"Error analyzing image: {str(e)}")
        st.error("Image analysis failed.")
        return None

# Display enhanced analysis
def display_enhanced_analysis(damage_type, confidence):
    try:
        enhanced_info = RAG_SYSTEM.get_enhanced_analysis(damage_type, confidence)
        st.markdown("### πŸ” Enhanced Analysis")
        
        with st.expander("πŸ“š Technical Details", expanded=True):
            for detail in enhanced_info["technical_details"]:
                st.markdown(detail)

        with st.expander("⚠️ Safety Considerations"):
            for safety in enhanced_info["safety_considerations"]:
                st.warning(safety)

        with st.expander("πŸ‘· Expert Recommendations"):
            for rec in enhanced_info["expert_recommendations"]:
                st.info(rec)
    except Exception as e:
        logging.error(f"Failed to generate enhanced analysis: {str(e)}")
        st.error("Error generating enhanced analysis.")

# Main function
def main():
    st.set_page_config(
        page_title="Structural Damage Analyzer Pro",
        page_icon="πŸ—οΈ",
        layout="wide"
    )
    st.title("πŸ—οΈ Structural Damage Analyzer Pro")

    # Load model and initialize RAG system
    global MODEL, PROCESSOR
    if MODEL is None or PROCESSOR is None:
        MODEL, PROCESSOR = load_model()
    init_rag_system()

    uploaded_file = st.file_uploader(
        "Upload an image for analysis (JPG, PNG)",
        type=['jpg', 'jpeg', 'png']
    )

    if uploaded_file:
        if uploaded_file.size > MAX_FILE_SIZE:
            st.error("File too large. Limit: 5MB.")
            return

        image = preprocess_image(uploaded_file)
        validate_image(image)

        st.image(image, caption="Uploaded Image", use_column_width=True)

        with st.spinner("Analyzing damage..."):
            start_time = time.time()
            predictions = analyze_damage(image, MODEL, PROCESSOR)
            analysis_time = time.time() - start_time

            if predictions is not None:
                st.markdown(f"*Analysis completed in {analysis_time:.2f} seconds*")
                confidence = float(predictions[0]) * 100
                display_enhanced_analysis(DAMAGE_TYPES[0]['name'], confidence)

if __name__ == "__main__":
    main()