till-onethousand's picture
trained model
7a1b8c6
raw
history blame
1.32 kB
import streamlit as st
from PIL import Image
from transformers import ViTForImageClassification
from config import UNTRAINED, labels, TRAINED
from utils import predict
model_untrained = ViTForImageClassification.from_pretrained(
UNTRAINED,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)},
)
model_trained = ViTForImageClassification.from_pretrained(
TRAINED,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)},
)
st.title("Detect Hurricane Damage")
col1, col2 = st.columns(2)
with col1:
st.markdown("## Pre-Trained Model")
file_name = st.file_uploader("Upload a satellite image")
if file_name is not None:
image = Image.open(file_name)
col1.image(image, use_container_width=True)
label = predict(model_untrained, image)
st.write(f"Predicted label: {label}")
with col2:
st.markdown("## Fine-Tuned Model")
file_name = st.file_uploader("Upload a satellite image")
if file_name is not None:
image = Image.open(file_name)
col2.image(image, use_container_width=True)
label = predict(model_trained, image)
st.write(f"Predicted label: {label}")