Test / app.py
Shakir60's picture
Update app.py
83437a6 verified
raw
history blame
5.02 kB
import streamlit as st
import torch
from transformers import ViTForImageClassification, ViTImageProcessor, pipeline
from PIL import Image
import pandas as pd
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceHub
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import os
# Page config
st.set_page_config(
page_title="Building Damage Analysis",
page_icon="πŸ—οΈ",
layout="wide"
)
# Initialize models
@st.cache_resource
def load_models():
# Vision model
damage_model = ViTForImageClassification.from_pretrained("microsoft/vit-base-patch16-224")
processor = ViTImageProcessor.from_pretrained("microsoft/vit-base-patch16-224")
# Text model
llm = HuggingFaceHub(
repo_id="google/flan-t5-large",
model_kwargs={"temperature": 0.7, "max_length": 512}
)
embeddings = HuggingFaceEmbeddings(
model_name='sentence-transformers/all-MiniLM-L6-v2'
)
return damage_model, processor, embeddings, llm
# Sample data - in production, you'd load this from a proper dataset
SAMPLE_DATA = [
{
"repair_description": "Major wall crack requiring structural repair. Steel plate reinforcement needed.",
"repair_cost": 5000,
"damage_type": "Wall Crack"
},
{
"repair_description": "Concrete beam damage with exposed rebar. Requires immediate attention.",
"repair_cost": 7500,
"damage_type": "Beam Damage"
},
{
"repair_description": "Foundation settling causing structural issues. Need underpinning.",
"repair_cost": 15000,
"damage_type": "Foundation Issue"
}
]
def setup_rag(embeddings, llm):
# Create documents from sample data
documents = [
Document(
page_content=f"{item['repair_description']} Cost: ${item['repair_cost']}",
metadata={'cost': item['repair_cost'], 'damage_type': item['damage_type']}
)
for item in SAMPLE_DATA
]
# Create vector store
vectorstore = FAISS.from_documents(documents, embeddings)
# Create prompt template
template = """
Analyze building damage and provide repair recommendations based on this context:
{context}
For damage type: {question}
Provide:
1. Damage assessment
2. Repair steps
3. Safety considerations
4. Estimated cost range
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={'k': 2}),
chain_type_kwargs={"prompt": prompt}
)
return qa_chain
def process_image(image, model, processor):
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
return predictions[0].tolist()
def main():
st.title("πŸ—οΈ Building Damage Detection & Analysis")
st.markdown("""
Upload a photo of building damage for AI analysis and repair recommendations.
""")
# Load models on first run
if 'models_loaded' not in st.session_state:
with st.spinner('Loading AI models...'):
damage_model, processor, embeddings, llm = load_models()
qa_chain = setup_rag(embeddings, llm)
st.session_state['models_loaded'] = True
st.session_state['models'] = (damage_model, processor, qa_chain)
damage_model, processor, qa_chain = st.session_state['models']
# File upload
uploaded_file = st.file_uploader("Upload building damage photo", type=["jpg", "jpeg", "png"])
if uploaded_file:
# Display image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with st.spinner('Analyzing damage...'):
# Process image
predictions = process_image(image, damage_model, processor)
damage_types = ["Wall Crack", "Beam Damage", "Foundation Issue",
"Roof Damage", "Structural Damage"]
# Show results
st.subheader("Detected Damage Types")
for damage_type, prob in zip(damage_types, predictions):
if prob > 0.2:
st.metric(damage_type, f"{prob:.1%}")
with st.spinner(f'Generating analysis for {damage_type}...'):
analysis = qa_chain.invoke(damage_type)
st.markdown(f"### Analysis for {damage_type}")
st.markdown(analysis['result'])
if __name__ == "__main__":
main()