benjaminStreltzin commited on
Commit
152bbff
·
verified ·
1 Parent(s): 49babfe

Update vit_model_test.py

Browse files
Files changed (1) hide show
  1. vit_model_test.py +38 -95
vit_model_test.py CHANGED
@@ -1,95 +1,38 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.utils.data import Dataset, DataLoader
5
- from torchvision import transforms
6
- from transformers import ViTForImageClassification
7
- from PIL import Image
8
- import os
9
- import pandas as pd
10
-
11
-
12
-
13
-
14
- class CustomDataset(Dataset):
15
- def __init__(self, dataframe, transform=None):
16
- self.dataframe = dataframe
17
- self.transform = transform
18
-
19
- def __len__(self):
20
- return len(self.dataframe)
21
-
22
- def __getitem__(self, idx):
23
- image_path = self.dataframe.iloc[idx, 0] # Image path is in the first column
24
- image = Image.open(image_path).convert('RGB') # Convert to RGB format
25
-
26
- if self.transform:
27
- image = self.transform(image)
28
-
29
- return image
30
-
31
-
32
- if __name__ == "__main__":
33
- # Check for GPU availability
34
- device = torch.device('cuda')
35
-
36
- # Load the pre-trained ViT model and move it to GPU
37
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
38
-
39
-
40
-
41
- model.classifier = nn.Linear(model.config.hidden_size, 2).to(device)
42
- # Define the image preprocessing pipeline
43
- preprocess = transforms.Compose([
44
- transforms.Resize((224, 224)),
45
- transforms.ToTensor()
46
- ])
47
-
48
-
49
-
50
-
51
-
52
- # Load the test dataset
53
-
54
-
55
- ### need to recive image from gratio/streamlit
56
-
57
- test_set = 'datasets/'
58
-
59
- image_paths = []
60
- for filename in os.listdir(test_set):
61
- image_paths.append(os.path.join(test_set, filename))
62
- dataset = pd.DataFrame({'image_path': image_paths})
63
-
64
-
65
-
66
- test_dataset = CustomDataset(dataset, transform=preprocess)
67
- test_loader = DataLoader(test_dataset, batch_size=32)
68
-
69
- # Load the trained model
70
- model.load_state_dict(torch.load('trained_model.pth'))
71
-
72
- # Evaluate the model
73
- model.eval()
74
- confidences = []
75
- predicted_labels = []
76
-
77
-
78
- with torch.no_grad():
79
- for images in test_loader:
80
- images = images.to(device)
81
- outputs = model(images)
82
- logits = outputs.logits # Extract logits from the output
83
- probabilities = F.softmax(logits, dim=1)
84
- confidences_per_image, predicted = torch.max(probabilities, 1)
85
- predicted_labels.extend(predicted.cpu().numpy())
86
- confidences.extend(confidences_per_image.cpu().numpy())
87
-
88
-
89
- print(predicted_labels)
90
- print(confidences)
91
-
92
- confidence_percentages = [confidence * 100 for confidence in confidences]
93
- for label, confidence in zip(predicted_labels, confidence_percentages):
94
- print(f"Predicted label: {label}, Confidence: {confidence:.2f}%")
95
-
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from transformers import ViTForImageClassification
6
+ from PIL import Image
7
+
8
+ class CustomModel:
9
+ def __init__(self):
10
+ # Check for GPU availability
11
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Load the pre-trained ViT model and move it to GPU
14
+ self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(self.device)
15
+ self.model.classifier = nn.Linear(self.model.config.hidden_size, 2).to(self.device)
16
+ self.model.load_state_dict(torch.load('trained_model.pth'))
17
+ self.model.eval()
18
+
19
+ # Define the image preprocessing pipeline
20
+ self.preprocess = transforms.Compose([
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor()
23
+ ])
24
+
25
+ def predict(self, image: Image.Image):
26
+ # Preprocess the image
27
+ image = self.preprocess(image).unsqueeze(0).to(self.device) # Add batch dimension
28
+
29
+ # Perform inference
30
+ with torch.no_grad():
31
+ outputs = self.model(image)
32
+ logits = outputs.logits
33
+ probabilities = F.softmax(logits, dim=1)
34
+ confidences, predicted = torch.max(probabilities, 1)
35
+ predicted_label = predicted.item()
36
+ confidence = confidences.item() * 100 # Convert to percentage
37
+
38
+ return predicted_label, confidence