Spaces:
Sleeping
Sleeping
#app.py | |
import streamlit as st | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
from PIL import Image | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import pandas as pd | |
import os | |
from pathlib import Path | |
import json | |
DAMAGE_TYPES = { | |
0: {'name': 'spalling', 'risk': 'High'}, | |
1: {'name': 'reinforcement_corrosion', 'risk': 'Critical'}, | |
2: {'name': 'structural_crack', 'risk': 'High'}, | |
3: {'name': 'dampness', 'risk': 'Medium'}, | |
4: {'name': 'no_damage', 'risk': 'Low'} | |
} | |
def load_models(): | |
vision_model = ViTForImageClassification.from_pretrained( | |
"google/vit-base-patch16-224", | |
num_labels=len(DAMAGE_TYPES), | |
ignore_mismatched_sizes=True | |
) | |
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
return vision_model, processor, embedding_model | |
class DamageKnowledgeBase: | |
def __init__(self, embedding_model): | |
self.embedding_model = embedding_model | |
self.load_knowledge_base() | |
def load_knowledge_base(self): | |
# Load dataset metadata and embeddings | |
knowledge_path = Path("data/knowledge_base.json") | |
if knowledge_path.exists(): | |
with open(knowledge_path, 'r') as f: | |
self.kb_data = json.load(f) | |
# Initialize FAISS index | |
embeddings = torch.load("data/embeddings.pt") | |
self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
self.index.add(embeddings.numpy()) | |
else: | |
self.initialize_knowledge_base() | |
def initialize_knowledge_base(self): | |
# Sample knowledge base structure | |
self.kb_data = { | |
'spalling': [ | |
{ | |
'description': 'Severe concrete spalling on column surface', | |
'severity': 'High', | |
'repair_method': 'Remove damaged concrete, clean reinforcement, apply repair mortar', | |
'estimated_cost': 'High', | |
'timeframe': '2-3 weeks', | |
'similar_cases': ['case_123', 'case_456'] | |
} | |
], | |
# Add more damage types... | |
} | |
# Create embeddings | |
texts = [] | |
for damage_type, cases in self.kb_data.items(): | |
for case in cases: | |
texts.append(f"{damage_type} {case['description']} {case['repair_method']}") | |
embeddings = self.embedding_model.encode(texts) | |
self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
self.index.add(embeddings) | |
# Save for future use | |
os.makedirs("data", exist_ok=True) | |
with open("data/knowledge_base.json", 'w') as f: | |
json.dump(self.kb_data, f) | |
torch.save(torch.tensor(embeddings), "data/embeddings.pt") | |
def query(self, damage_type, confidence): | |
query = f"damage type: {damage_type}" | |
query_embedding = self.embedding_model.encode([query]) | |
D, I = self.index.search(query_embedding, k=3) | |
similar_cases = [] | |
for idx in I[0]: | |
for damage, cases in self.kb_data.items(): | |
for case in cases: | |
case_text = f"{damage} {case['description']} {case['repair_method']}" | |
if len(similar_cases) < 3: | |
similar_cases.append(case) | |
return similar_cases | |
def analyze_damage(image, model, processor): | |
image = image.convert('RGB') | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] | |
return probs | |
def main(): | |
st.title("Advanced Structural Damage Assessment Tool") | |
vision_model, processor, embedding_model = load_models() | |
kb = DamageKnowledgeBase(embedding_model) | |
uploaded_file = st.file_uploader("Upload structural image", type=['jpg', 'jpeg', 'png']) | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Structure", use_column_width=True) | |
with st.spinner("Analyzing..."): | |
predictions = analyze_damage(image, vision_model, processor) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Damage Assessment") | |
detected_damages = [] | |
for idx, prob in enumerate(predictions): | |
confidence = float(prob) * 100 | |
if confidence > 15: | |
damage_type = DAMAGE_TYPES[idx]['name'] | |
detected_damages.append((damage_type, confidence)) | |
st.write(f"**{damage_type.replace('_', ' ').title()}**") | |
st.progress(confidence / 100) | |
st.write(f"Confidence: {confidence:.1f}%") | |
st.write(f"Risk Level: {DAMAGE_TYPES[idx]['risk']}") | |
with col2: | |
st.subheader("Similar Cases & Recommendations") | |
for damage_type, confidence in detected_damages: | |
similar_cases = kb.query(damage_type, confidence) | |
st.write(f"**{damage_type.replace('_', ' ').title()}:**") | |
for case in similar_cases: | |
with st.expander(f"Similar Case - {case['severity']} Severity"): | |
st.write(f"Description: {case['description']}") | |
st.write(f"Repair Method: {case['repair_method']}") | |
st.write(f"Estimated Cost: {case['estimated_cost']}") | |
st.write(f"Timeframe: {case['timeframe']}") | |
if __name__ == "__main__": | |
main() |