Spaces:
No application file
No application file
import streamlit as st | |
import torch.nn as nn | |
import torch | |
from torchvision import models, transforms | |
from PIL import Image | |
CATEGORIES = ["AIHOLE", "BILLESHWAR_TEMPLE", "CHENNAKESHWARA_TEMPLE", "HAMPI_CHARIOT", "IBRAHIM_ROZA", "JAIN_BASADI", "KAMAL_BASTI", "KEDARESHWARA_TEMPLE", "KESHAVA_TEMPLE", "LOTUS_MAHAL"] | |
IMG_SIZE = 224 | |
# Load the trained model | |
model = models.resnet50(pretrained=False) | |
num_features = model.fc.in_features | |
model.fc = nn.Linear(num_features, len(CATEGORIES)) | |
model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device('cpu'))) | |
model.eval() | |
# Define the image transform | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Define the prediction function | |
def classify_image(image): | |
image = transform(image).unsqueeze(0) | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(image) | |
_, predicted = torch.max(outputs.data, 1) | |
return predicted.item() | |
# Streamlit app | |
def main(): | |
st.title("Temple Image Classification") | |
# File uploader | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Classify image on button click | |
if st.button("Classify"): | |
prediction = classify_image(image) | |
st.write(f"Predicted Category: {CATEGORIES[prediction]}") | |
if __name__ == "__main__": | |
main() | |