|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
from torch.utils.data import DataLoader |
|
from torchvision.datasets import ImageFolder |
|
|
|
|
|
model_resnet = models.resnet18(weights='IMAGENET1K_V1') |
|
for param in model_resnet.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
num_ftrs = model_resnet.fc.in_features |
|
model_resnet.fc = nn.Linear(num_ftrs, 15) |
|
|
|
|
|
|
|
model = model_resnet |
|
state_dict = torch.load('transfer_learning_resnet_15class.pth', map_location=torch.device('cpu')) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
classes = ('Bean', 'Bitter_Gourd', 'Bottle_Gourd', 'Brinjal', 'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cauliflower', 'Cucumber', 'Papaya', 'Potato', 'Pumpkin', 'Radish', 'Tomato') |
|
|
|
def predict(image): |
|
input_tensor = transform(image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(input_batch) |
|
|
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
max_value, predicted_class = torch.max(probabilities, 0) |
|
return classes[predicted_class.item()], max_value.item() * 100 |
|
|
|
st.title('Vegetable Classification for learning') |
|
st.write('you can upload your image of veggies below') |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file).convert('RGB') |
|
st.image(image, caption='Uploaded Image') |
|
label, confidence = predict(image) |
|
st.write(f'Predicted label: {label}, confidence: {confidence:.2f}%') |