hasnanmr's picture
modify model
654e088
import streamlit as st
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
#define the model architecture
model_resnet = models.resnet18(weights='IMAGENET1K_V1')
for param in model_resnet.parameters():
param.requires_grad = False
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_resnet.fc.in_features
model_resnet.fc = nn.Linear(num_ftrs, 15) #mengganti jumlah classifier sesuai output kelas
# Load the model
model = model_resnet
state_dict = torch.load('transfer_learning_resnet_15class.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
# Define the same transforms that were used during the model training
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
classes = ('Bean', 'Bitter_Gourd', 'Bottle_Gourd', 'Brinjal', 'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cauliflower', 'Cucumber', 'Papaya', 'Potato', 'Pumpkin', 'Radish', 'Tomato')
def predict(image):
input_tensor = transform(image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
max_value, predicted_class = torch.max(probabilities, 0)
return classes[predicted_class.item()], max_value.item() * 100
st.title('Vegetable Classification for learning')
st.write('you can upload your image of veggies below')
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image')
label, confidence = predict(image)
st.write(f'Predicted label: {label}, confidence: {confidence:.2f}%')