eaglelandsonce commited on
Commit
003dc9d
·
verified ·
1 Parent(s): cbc79c2

Create 25_Deployment.py

Browse files
Files changed (1) hide show
  1. pages/25_Deployment.py +67 -0
pages/25_Deployment.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ import torchvision.models as models
6
+
7
+ # Save the model (this should be run only once, so it is placed here for completeness)
8
+ def save_model():
9
+ model = models.resnet18(pretrained=True)
10
+ torch.save(model.state_dict(), 'resnet18.pth')
11
+
12
+ # Call save_model to save the model
13
+ save_model()
14
+
15
+ # Load the model
16
+ def load_model():
17
+ model = models.resnet18()
18
+ model.load_state_dict(torch.load('resnet18.pth'))
19
+ model.eval()
20
+ return model
21
+
22
+ def main():
23
+ st.title("Image Classification with ResNet18")
24
+
25
+ # Upload an image
26
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
27
+ if uploaded_file is not None:
28
+ image = Image.open(uploaded_file)
29
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
30
+ st.write("")
31
+ st.write("Classifying...")
32
+
33
+ # Load the model
34
+ model = load_model()
35
+
36
+ # Preprocess the image
37
+ preprocess = transforms.Compose([
38
+ transforms.Resize(256),
39
+ transforms.CenterCrop(224),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
42
+ ])
43
+
44
+ input_tensor = preprocess(image)
45
+ input_batch = input_tensor.unsqueeze(0)
46
+
47
+ # Ensure the input is on the same device as the model
48
+ if torch.cuda.is_available():
49
+ input_batch = input_batch.to('cuda')
50
+ model.to('cuda')
51
+
52
+ with torch.no_grad():
53
+ output = model(input_batch)
54
+
55
+ # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
56
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
57
+
58
+ # Print top 5 categories
59
+ with open("imagenet_classes.txt") as f:
60
+ categories = [line.strip() for line in f.readlines()]
61
+
62
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
63
+ for i in range(top5_prob.size(0)):
64
+ st.write(categories[top5_catid[i]], top5_prob[i].item())
65
+
66
+ if __name__ == "__main__":
67
+ main()