pytorch / pages /25_Deployment.py
eaglelandsonce's picture
Create 25_Deployment.py
003dc9d verified
raw
history blame
2.19 kB
import streamlit as st
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.models as models
# Save the model (this should be run only once, so it is placed here for completeness)
def save_model():
model = models.resnet18(pretrained=True)
torch.save(model.state_dict(), 'resnet18.pth')
# Call save_model to save the model
save_model()
# Load the model
def load_model():
model = models.resnet18()
model.load_state_dict(torch.load('resnet18.pth'))
model.eval()
return model
def main():
st.title("Image Classification with ResNet18")
# Upload an image
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("")
st.write("Classifying...")
# Load the model
model = load_model()
# Preprocess the image
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
# Ensure the input is on the same device as the model
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
with torch.no_grad():
output = model(input_batch)
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Print top 5 categories
with open("imagenet_classes.txt") as f:
categories = [line.strip() for line in f.readlines()]
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
st.write(categories[top5_catid[i]], top5_prob[i].item())
if __name__ == "__main__":
main()