KristofGaming39 commited on
Commit
eebb924
·
verified ·
1 Parent(s): 2004b17

Create codeforpredicting.py

Browse files
Files changed (1) hide show
  1. codeforpredicting.py +27 -0
codeforpredicting.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision.transforms import transforms
4
+ from PIL import Image
5
+
6
+ model = torchvision.models.resnet50(pretrained=False)
7
+ model.fc = torch.nn.Linear(in_features=2048, out_features=1)
8
+ model.load_state_dict(torch.load('/content/zero_shot_classification_model.pth'))
9
+ model.eval()
10
+
11
+ test_transform = transforms.Compose([
12
+ transforms.Resize((224, 224)),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
15
+ ])
16
+
17
+ test_image_path = '/content/test_img.png' # Replace with your own test image path!
18
+ test_image = Image.open(test_image_path).convert('RGB')
19
+ test_image = test_transform(test_image)
20
+ test_image = test_image.unsqueeze(0)
21
+
22
+ with torch.no_grad():
23
+ prediction = model(test_image)
24
+
25
+ probability = torch.sigmoid(prediction).item()
26
+
27
+ print(f"The probability of the image being a Roblox character is: {probability*100}%")