|
import streamlit as st |
|
import os |
|
from groq import Groq |
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
from sentence_transformers import SentenceTransformer |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
from typing import List, Dict, Tuple, Optional, Any |
|
import faiss |
|
import json |
|
import torchvision.transforms.functional as TF |
|
from torchvision import transforms |
|
import cv2 |
|
import pandas as pd |
|
from datetime import datetime |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ConfigManager: |
|
"""Manages configuration settings for the application""" |
|
DEFAULT_CONFIG = { |
|
"model_settings": { |
|
"vit_model": "google/vit-base-patch16-224", |
|
"sentence_transformer": "all-MiniLM-L6-v2", |
|
"groq_model": "llama-3.3-70b-versatile" |
|
}, |
|
"analysis_settings": { |
|
"confidence_threshold": 0.5, |
|
"max_defects": 3, |
|
"heatmap_intensity": 0.7 |
|
}, |
|
"rag_settings": { |
|
"num_relevant_docs": 3, |
|
"similarity_threshold": 0.75 |
|
} |
|
} |
|
|
|
@staticmethod |
|
def load_config(): |
|
"""Load configuration with fallback to defaults""" |
|
try: |
|
if os.path.exists('config.json'): |
|
with open('config.json', 'r') as f: |
|
config = json.load(f) |
|
return {**ConfigManager.DEFAULT_CONFIG, **config} |
|
except Exception as e: |
|
logger.warning(f"Error loading config: {e}") |
|
return ConfigManager.DEFAULT_CONFIG |
|
|
|
config = ConfigManager.load_config() |
|
|
|
class ImageAnalyzer: |
|
def __init__(self): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.config = config["model_settings"] |
|
self.analysis_config = config["analysis_settings"] |
|
self.defect_classes = [ |
|
"spalling", "reinforcement_corrosion", "structural_cracks", |
|
"water_damage", "surface_deterioration", "alkali_silica_reaction", |
|
"concrete_delamination", "honeycomb", "scaling", |
|
"efflorescence", "joint_deterioration", "carbonation" |
|
] |
|
self.initialize_models() |
|
self.history = [] |
|
|
|
def initialize_models(self): |
|
"""Initialize all required models""" |
|
try: |
|
|
|
self.model = ViTForImageClassification.from_pretrained( |
|
self.config["vit_model"], |
|
num_labels=len(self.defect_classes), |
|
ignore_mismatched_sizes=True |
|
).to(self.device) |
|
|
|
|
|
self.processor = ViTImageProcessor.from_pretrained(self.config["vit_model"]) |
|
|
|
|
|
self.transforms = self._setup_transforms() |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"Model initialization error: {e}") |
|
return False |
|
|
|
def _setup_transforms(self): |
|
"""Setup image transformation pipeline""" |
|
return transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
transforms.RandomAdjustSharpness(2), |
|
transforms.ColorJitter(brightness=0.2, contrast=0.2) |
|
]) |
|
|
|
def preprocess_image(self, image: Image.Image) -> Dict[str, Any]: |
|
"""Enhanced image preprocessing with multiple analyses""" |
|
try: |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
img_array = np.array(image) |
|
stats = { |
|
"mean_brightness": np.mean(img_array), |
|
"std_brightness": np.std(img_array), |
|
"size": image.size, |
|
"aspect_ratio": image.size[0] / image.size[1] |
|
} |
|
|
|
|
|
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) |
|
edges = cv2.Canny(gray, 100, 200) |
|
stats["edge_density"] = np.mean(edges > 0) |
|
|
|
|
|
hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV) |
|
rust_mask = cv2.inRange(hsv, np.array([0, 50, 50]), np.array([30, 255, 255])) |
|
stats["rust_percentage"] = np.mean(rust_mask > 0) |
|
|
|
|
|
model_input = self.transforms(image).unsqueeze(0).to(self.device) |
|
|
|
return { |
|
"model_input": model_input, |
|
"stats": stats, |
|
"edges": edges, |
|
"rust_mask": rust_mask |
|
} |
|
except Exception as e: |
|
logger.error(f"Preprocessing error: {e}") |
|
return None |
|
|
|
|
|
def detect_defects(self, image: Image.Image) -> Dict[str, Any]: |
|
"""Enhanced defect detection with multiple analysis methods""" |
|
try: |
|
|
|
proc_data = self.preprocess_image(image) |
|
if proc_data is None: |
|
logger.error("Image preprocessing failed.") |
|
return None |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(proc_data["model_input"]) |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) |
|
|
|
|
|
defect_probs = { |
|
self.defect_classes[i]: float(probabilities[0][i]) |
|
for i in range(len(self.defect_classes)) |
|
} |
|
|
|
|
|
attention_weights = outputs.attentions[-1].mean(dim=1)[0] if hasattr(outputs, 'attentions') else None |
|
heatmap = self.generate_heatmap(attention_weights, image.size) if attention_weights is not None else None |
|
|
|
|
|
additional_analysis = self.analyze_image_statistics(proc_data["stats"]) |
|
|
|
|
|
result = { |
|
"defect_probabilities": defect_probs, |
|
"heatmap": heatmap, |
|
"image_statistics": proc_data["stats"], |
|
"additional_analysis": additional_analysis, |
|
"edge_detection": proc_data["edges"], |
|
"rust_detection": proc_data["rust_mask"], |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
|
|
self.history.append(result) |
|
|
|
return result |
|
except Exception as e: |
|
logger.error(f"Defect detection error: {e}") |
|
return None |
|
|
|
|
|
def analyze_image_statistics(self, stats: Dict) -> Dict[str, Any]: |
|
"""Analyze image statistics for additional insights""" |
|
analysis = {} |
|
|
|
|
|
if stats["mean_brightness"] < 50: |
|
analysis["lighting_condition"] = "Poor lighting - may affect accuracy" |
|
elif stats["mean_brightness"] > 200: |
|
analysis["lighting_condition"] = "Overexposed - may affect accuracy" |
|
|
|
|
|
if stats["edge_density"] > 0.1: |
|
analysis["crack_likelihood"] = "High crack probability based on edge detection" |
|
|
|
|
|
if stats["rust_percentage"] > 0.05: |
|
analysis["corrosion_indicator"] = "Possible corrosion detected" |
|
|
|
return analysis |
|
|
|
def generate_heatmap(self, attention_weights: torch.Tensor, image_size: Tuple[int, int]) -> np.ndarray: |
|
"""Generate enhanced attention heatmap""" |
|
try: |
|
if attention_weights is None: |
|
return None |
|
|
|
|
|
heatmap = attention_weights.cpu().numpy() |
|
heatmap = cv2.resize(heatmap, image_size) |
|
|
|
|
|
heatmap = np.maximum(heatmap, 0) |
|
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) |
|
|
|
|
|
gamma = self.analysis_config["heatmap_intensity"] |
|
heatmap = np.power(heatmap, gamma) |
|
|
|
|
|
heatmap = (heatmap * 255).astype(np.uint8) |
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
|
return heatmap |
|
except Exception as e: |
|
logger.error(f"Heatmap generation error: {e}") |
|
return None |
|
class RAGSystem: |
|
"""Basic RAG System for storing and retrieving documents.""" |
|
def __init__(self): |
|
self.embedding_model = SentenceTransformer(config["model_settings"]["sentence_transformer"]) |
|
self.vector_store = faiss.IndexFlatL2(384) |
|
self.knowledge_base = [] |
|
|
|
def add_documents(self, docs: List[str]): |
|
"""Add documents to the vector store.""" |
|
embeddings = self.embedding_model.encode(docs) |
|
self.vector_store.add(np.array(embeddings).astype('float32')) |
|
for doc in docs: |
|
self.knowledge_base.append({"text": doc}) |
|
|
|
def search(self, query: str, k: int = 3): |
|
"""Retrieve similar documents for the query.""" |
|
query_embedding = self.embedding_model.encode([query]) |
|
D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k) |
|
return [self.knowledge_base[i]["text"] for i in I[0]] |
|
|
|
|
|
class EnhancedRAGSystem(RAGSystem): |
|
"""Enhanced RAG system with additional features""" |
|
def __init__(self): |
|
super().__init__() |
|
self.config = config["rag_settings"] |
|
self.query_history = [] |
|
|
|
def get_relevant_context(self, query: str, k: int = None) -> str: |
|
"""Enhanced context retrieval with debugging info""" |
|
if k is None: |
|
k = self.config["num_relevant_docs"] |
|
|
|
|
|
self.query_history.append({ |
|
"timestamp": datetime.now().isoformat(), |
|
"query": query |
|
}) |
|
|
|
|
|
query_embedding = self.embedding_model.encode([query]) |
|
|
|
|
|
D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k) |
|
|
|
|
|
relevant_docs = [ |
|
self.knowledge_base[i]["text"] |
|
for i, dist in zip(I[0], D[0]) |
|
if dist < self.config["similarity_threshold"] |
|
] |
|
|
|
return "\n\n".join(relevant_docs) |
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Enhanced Construction Defect Analyzer", |
|
page_icon="🏗️", |
|
layout="wide" |
|
) |
|
|
|
st.title("🏗️ Advanced Construction Defect Analysis System") |
|
|
|
|
|
if 'rag_system' not in st.session_state: |
|
st.session_state.rag_system = EnhancedRAGSystem() |
|
if 'image_analyzer' not in st.session_state: |
|
st.session_state.image_analyzer = ImageAnalyzer() |
|
|
|
|
|
with st.sidebar: |
|
st.header("Settings & History") |
|
show_debug = st.checkbox("Show Debug Information") |
|
confidence_threshold = st.slider( |
|
"Confidence Threshold", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=config["analysis_settings"]["confidence_threshold"] |
|
) |
|
|
|
if st.button("View Analysis History"): |
|
st.write("Recent Analyses:", st.session_state.image_analyzer.history[-5:]) |
|
|
|
|
|
col1, col2 = st.columns([2, 3]) |
|
|
|
with col1: |
|
uploaded_file = st.file_uploader( |
|
"Upload a construction image", |
|
type=['jpg', 'jpeg', 'png'] |
|
) |
|
|
|
user_query = st.text_input( |
|
"Ask a question about construction defects:", |
|
help="Enter your question about specific defects or general construction issues" |
|
) |
|
|
|
with col2: |
|
if uploaded_file: |
|
image = Image.open(uploaded_file) |
|
|
|
|
|
tabs = st.tabs(["Original", "Analysis", "Details"]) |
|
|
|
with tabs[0]: |
|
st.image(image, caption="Uploaded Image") |
|
|
|
with tabs[1]: |
|
with st.spinner("Analyzing image..."): |
|
results = st.session_state.image_analyzer.detect_defects(image) |
|
|
|
if results: |
|
|
|
defect_probs = results["defect_probabilities"] |
|
significant_defects = { |
|
k: v for k, v in defect_probs.items() |
|
if v > confidence_threshold |
|
} |
|
|
|
if significant_defects: |
|
st.subheader("Detected Defects") |
|
fig = plt.figure(figsize=(10, 6)) |
|
plt.barh(list(significant_defects.keys()), |
|
list(significant_defects.values())) |
|
st.pyplot(fig) |
|
|
|
|
|
if results["heatmap"] is not None: |
|
st.image(results["heatmap"], caption="Defect Attention Map") |
|
|
|
with tabs[2]: |
|
if results: |
|
st.json(results["additional_analysis"]) |
|
if show_debug: |
|
st.json(results["image_statistics"]) |
|
|
|
if user_query: |
|
with st.spinner("Processing query..."): |
|
context = st.session_state.rag_system.get_relevant_context(user_query) |
|
response = get_groq_response(user_query, context) |
|
|
|
st.subheader("AI Assistant Response") |
|
st.write(response) |
|
|
|
if show_debug: |
|
st.subheader("Retrieved Context") |
|
st.text(context) |
|
|
|
if __name__ == "__main__": |
|
main() |