import streamlit as st import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image from resnet import Resnet50Flower102 import pandas as pd from dataloader import transform st.title("Flower Image Classification") device = 'cuda' if torch.cuda.is_available() else 'cpu' model = Resnet50Flower102(device) flowers_data = pd.read_csv("flowerdata.csv") uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"]) model.load_state_dict(torch.load("model.pth", map_location=torch.device(device))) transform_val = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if uploaded_file is not None: image = Image.open(uploaded_file) img = transform_val(image) img = img.type(torch.FloatTensor).to(device) print(img.shape) img = img.unsqueeze(0) print(img.shape) with torch.no_grad(): model.eval() flower = model(img) _, flower = flower.max(1) flower = flower[0].detach().cpu().numpy() flower_name = flowers_data["Name"][flower] st.header("Input Image") st.image(image=image, use_column_width=True) st.write("##", flower_name)