Shakir60's picture
Update app.py
e7c0374 verified
raw
history blame
14.1 kB
import streamlit as st
import os
from groq import Groq
from transformers import ViTForImageClassification, ViTImageProcessor
from sentence_transformers import SentenceTransformer
from PIL import Image
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional, Any
import faiss
import json
import torchvision.transforms.functional as TF
from torchvision import transforms
import cv2
import pandas as pd
from datetime import datetime
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ConfigManager:
"""Manages configuration settings for the application"""
DEFAULT_CONFIG = {
"model_settings": {
"vit_model": "google/vit-base-patch16-224",
"sentence_transformer": "all-MiniLM-L6-v2",
"groq_model": "llama-3.3-70b-versatile"
},
"analysis_settings": {
"confidence_threshold": 0.5,
"max_defects": 3,
"heatmap_intensity": 0.7
},
"rag_settings": {
"num_relevant_docs": 3,
"similarity_threshold": 0.75
}
}
@staticmethod
def load_config():
"""Load configuration with fallback to defaults"""
try:
if os.path.exists('config.json'):
with open('config.json', 'r') as f:
config = json.load(f)
return {**ConfigManager.DEFAULT_CONFIG, **config}
except Exception as e:
logger.warning(f"Error loading config: {e}")
return ConfigManager.DEFAULT_CONFIG
config = ConfigManager.load_config()
class ImageAnalyzer:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = config["model_settings"]
self.analysis_config = config["analysis_settings"]
self.defect_classes = [
"spalling", "reinforcement_corrosion", "structural_cracks",
"water_damage", "surface_deterioration", "alkali_silica_reaction",
"concrete_delamination", "honeycomb", "scaling",
"efflorescence", "joint_deterioration", "carbonation"
]
self.initialize_models()
self.history = []
def initialize_models(self):
"""Initialize all required models"""
try:
# Initialize ViT model
self.model = ViTForImageClassification.from_pretrained(
self.config["vit_model"],
num_labels=len(self.defect_classes),
ignore_mismatched_sizes=True
).to(self.device)
# Initialize image processor
self.processor = ViTImageProcessor.from_pretrained(self.config["vit_model"])
# Initialize transformations pipeline
self.transforms = self._setup_transforms()
return True
except Exception as e:
logger.error(f"Model initialization error: {e}")
return False
def _setup_transforms(self):
"""Setup image transformation pipeline"""
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.RandomAdjustSharpness(2),
transforms.ColorJitter(brightness=0.2, contrast=0.2)
])
def preprocess_image(self, image: Image.Image) -> Dict[str, Any]:
"""Enhanced image preprocessing with multiple analyses"""
try:
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Basic image statistics
img_array = np.array(image)
stats = {
"mean_brightness": np.mean(img_array),
"std_brightness": np.std(img_array),
"size": image.size,
"aspect_ratio": image.size[0] / image.size[1]
}
# Edge detection for crack analysis
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 100, 200)
stats["edge_density"] = np.mean(edges > 0)
# Color analysis for rust detection
hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV)
rust_mask = cv2.inRange(hsv, np.array([0, 50, 50]), np.array([30, 255, 255]))
stats["rust_percentage"] = np.mean(rust_mask > 0)
# Transform for model
model_input = self.transforms(image).unsqueeze(0).to(self.device)
return {
"model_input": model_input,
"stats": stats,
"edges": edges,
"rust_mask": rust_mask
}
except Exception as e:
logger.error(f"Preprocessing error: {e}")
return None
def detect_defects(self, image: Image.Image) -> Dict[str, Any]:
"""Enhanced defect detection with multiple analysis methods"""
try:
# Preprocess image
proc_data = self.preprocess_image(image)
if proc_data is None:
logger.error("Image preprocessing failed.")
return None # Early return if preprocessing failed
# Model prediction
with torch.no_grad():
outputs = self.model(proc_data["model_input"])
# Get probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
# Convert to dictionary
defect_probs = {
self.defect_classes[i]: float(probabilities[0][i])
for i in range(len(self.defect_classes))
}
# Generate attention heatmap
attention_weights = outputs.attentions[-1].mean(dim=1)[0] if hasattr(outputs, 'attentions') else None
heatmap = self.generate_heatmap(attention_weights, image.size) if attention_weights is not None else None
# Additional analysis based on image statistics
additional_analysis = self.analyze_image_statistics(proc_data["stats"])
# Combine all results
result = {
"defect_probabilities": defect_probs,
"heatmap": heatmap,
"image_statistics": proc_data["stats"],
"additional_analysis": additional_analysis,
"edge_detection": proc_data["edges"],
"rust_detection": proc_data["rust_mask"],
"timestamp": datetime.now().isoformat()
}
# Save to history
self.history.append(result)
return result
except Exception as e:
logger.error(f"Defect detection error: {e}")
return None
def analyze_image_statistics(self, stats: Dict) -> Dict[str, Any]:
"""Analyze image statistics for additional insights"""
analysis = {}
# Brightness analysis
if stats["mean_brightness"] < 50:
analysis["lighting_condition"] = "Poor lighting - may affect accuracy"
elif stats["mean_brightness"] > 200:
analysis["lighting_condition"] = "Overexposed - may affect accuracy"
# Edge density analysis
if stats["edge_density"] > 0.1:
analysis["crack_likelihood"] = "High crack probability based on edge detection"
# Rust analysis
if stats["rust_percentage"] > 0.05:
analysis["corrosion_indicator"] = "Possible corrosion detected"
return analysis
def generate_heatmap(self, attention_weights: torch.Tensor, image_size: Tuple[int, int]) -> np.ndarray:
"""Generate enhanced attention heatmap"""
try:
if attention_weights is None:
return None
# Process attention weights
heatmap = attention_weights.cpu().numpy()
heatmap = cv2.resize(heatmap, image_size)
# Enhanced normalization
heatmap = np.maximum(heatmap, 0)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
# Apply gamma correction
gamma = self.analysis_config["heatmap_intensity"]
heatmap = np.power(heatmap, gamma)
# Apply colormap
heatmap = (heatmap * 255).astype(np.uint8)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
return heatmap
except Exception as e:
logger.error(f"Heatmap generation error: {e}")
return None
class RAGSystem:
"""Basic RAG System for storing and retrieving documents."""
def __init__(self):
self.embedding_model = SentenceTransformer(config["model_settings"]["sentence_transformer"])
self.vector_store = faiss.IndexFlatL2(384) # 384-dim for MiniLM embeddings
self.knowledge_base = []
def add_documents(self, docs: List[str]):
"""Add documents to the vector store."""
embeddings = self.embedding_model.encode(docs)
self.vector_store.add(np.array(embeddings).astype('float32'))
for doc in docs:
self.knowledge_base.append({"text": doc})
def search(self, query: str, k: int = 3):
"""Retrieve similar documents for the query."""
query_embedding = self.embedding_model.encode([query])
D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
return [self.knowledge_base[i]["text"] for i in I[0]]
class EnhancedRAGSystem(RAGSystem):
"""Enhanced RAG system with additional features"""
def __init__(self):
super().__init__()
self.config = config["rag_settings"]
self.query_history = []
def get_relevant_context(self, query: str, k: int = None) -> str:
"""Enhanced context retrieval with debugging info"""
if k is None:
k = self.config["num_relevant_docs"]
# Log query
self.query_history.append({
"timestamp": datetime.now().isoformat(),
"query": query
})
# Generate query embedding
query_embedding = self.embedding_model.encode([query])
# Search for similar documents
D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
# Filter by similarity threshold
relevant_docs = [
self.knowledge_base[i]["text"]
for i, dist in zip(I[0], D[0])
if dist < self.config["similarity_threshold"]
]
return "\n\n".join(relevant_docs)
def main():
st.set_page_config(
page_title="Enhanced Construction Defect Analyzer",
page_icon="🏗️",
layout="wide"
)
st.title("🏗️ Advanced Construction Defect Analysis System")
# Initialize systems
if 'rag_system' not in st.session_state:
st.session_state.rag_system = EnhancedRAGSystem()
if 'image_analyzer' not in st.session_state:
st.session_state.image_analyzer = ImageAnalyzer()
# Sidebar for settings and history
with st.sidebar:
st.header("Settings & History")
show_debug = st.checkbox("Show Debug Information")
confidence_threshold = st.slider(
"Confidence Threshold",
min_value=0.0,
max_value=1.0,
value=config["analysis_settings"]["confidence_threshold"]
)
if st.button("View Analysis History"):
st.write("Recent Analyses:", st.session_state.image_analyzer.history[-5:])
# Main interface
col1, col2 = st.columns([2, 3])
with col1:
uploaded_file = st.file_uploader(
"Upload a construction image",
type=['jpg', 'jpeg', 'png']
)
user_query = st.text_input(
"Ask a question about construction defects:",
help="Enter your question about specific defects or general construction issues"
)
with col2:
if uploaded_file:
image = Image.open(uploaded_file)
# Create tabs for different views
tabs = st.tabs(["Original", "Analysis", "Details"])
with tabs[0]:
st.image(image, caption="Uploaded Image")
with tabs[1]:
with st.spinner("Analyzing image..."):
results = st.session_state.image_analyzer.detect_defects(image)
if results:
# Show defect probabilities
defect_probs = results["defect_probabilities"]
significant_defects = {
k: v for k, v in defect_probs.items()
if v > confidence_threshold
}
if significant_defects:
st.subheader("Detected Defects")
fig = plt.figure(figsize=(10, 6))
plt.barh(list(significant_defects.keys()),
list(significant_defects.values()))
st.pyplot(fig)
# Show heatmap
if results["heatmap"] is not None:
st.image(results["heatmap"], caption="Defect Attention Map")
with tabs[2]:
if results:
st.json(results["additional_analysis"])
if show_debug:
st.json(results["image_statistics"])
if user_query:
with st.spinner("Processing query..."):
context = st.session_state.rag_system.get_relevant_context(user_query)
response = get_groq_response(user_query, context)
st.subheader("AI Assistant Response")
st.write(response)
if show_debug:
st.subheader("Retrieved Context")
st.text(context)
if __name__ == "__main__":
main()