import streamlit as st from PIL import Image import torch import torchvision.transforms as transforms import torchvision.models as models # Save the model (this should be run only once, so it is placed here for completeness) def save_model(): model = models.resnet18(pretrained=True) torch.save(model.state_dict(), 'resnet18.pth') # Call save_model to save the model save_model() # Load the model def load_model(): model = models.resnet18() model.load_state_dict(torch.load('resnet18.pth')) model.eval() return model def main(): st.title("Image Classification with ResNet18") # Upload an image uploaded_file = st.file_uploader("Choose an image...", type="jpg") if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image.', use_column_width=True) st.write("") st.write("Classifying...") # Load the model model = load_model() # Preprocess the image preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0) # Ensure the input is on the same device as the model if torch.cuda.is_available(): input_batch = input_batch.to('cuda') model.to('cuda') with torch.no_grad(): output = model(input_batch) # The output has unnormalized scores. To get probabilities, you can run a softmax on it. probabilities = torch.nn.functional.softmax(output[0], dim=0) # Print top 5 categories with open("imagenet_classes.txt") as f: categories = [line.strip() for line in f.readlines()] top5_prob, top5_catid = torch.topk(probabilities, 5) for i in range(top5_prob.size(0)): st.write(categories[top5_catid[i]], top5_prob[i].item()) if __name__ == "__main__": main()