File size: 2,719 Bytes
1d0281d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import torch
import torchvision
from torch import nn
from torchvision import transforms
from PIL import Image

class CFG:
    DEVICE = 'cpu'
    NUM_DEVICES = torch.cuda.device_count()
    NUM_WORKERS = os.cpu_count()
    NUM_CLASSES = 4
    EPOCHS = 16
    BATCH_SIZE = 32
    LR = 0.001
    APPLY_SHUFFLE = True
    SEED = 768
    HEIGHT = 224
    WIDTH = 224
    CHANNELS = 3
    IMAGE_SIZE = (224, 224, 3)
    

class VisionTransformerModel(nn.Module):
    def __init__(self, backbone_model, name='vision-transformer',
                 num_classes=CFG.NUM_CLASSES, device=CFG.DEVICE):
        super(VisionTransformerModel, self).__init__()

        self.backbone_model = backbone_model
        self.device = device
        self.num_classes = num_classes
        self.name = name

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(in_features=1000, out_features=256, bias=True),
            nn.GELU(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(in_features=256, out_features=num_classes, bias=False)
        ).to(device)

    def forward(self, image):
        vit_output = self.backbone_model(image)
        return self.classifier(vit_output)
    

def get_vit_b32_model(
    device: torch.device=CFG.NUM_CLASSES) -> nn.Module:
    # Set the manual seeds
    torch.manual_seed(CFG.SEED)
    torch.cuda.manual_seed(CFG.SEED)

    # Get model weights
    model_weights = (
        torchvision
        .models
        .ViT_L_32_Weights
        .DEFAULT
    )

    # Get model and push to device
    model = (
        torchvision.models.vit_l_32(
            weights=model_weights
        )
    ).to(device)

    # Freeze Model Parameters
    for param in model.parameters():
        param.requires_grad = False

    return model

# Get ViT model
vit_backbone = get_vit_b32_model(CFG.DEVICE)

vit_params = {
    'backbone_model'    : vit_backbone,
    'name'              : 'ViT-L-B32',
    'device'            : CFG.DEVICE
}

# Generate Model
vit_model = VisionTransformerModel(**vit_params)

vit_model.load_state_dict(
    torch.load('vit_model.pth', map_location=torch.device('cpu'))
)


# Define the image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


def predict(image_path):
    image = Image.open(image_path)
    input_tensor = transform(image)
    input_batch = input_tensor.unsqueeze(0).to(CFG.DEVICE)  # Add batch dimension

    # Perform inference
    with torch.no_grad():
        output = vit_model(input_batch).to(CFG.DEVICE)

    # You can now use the 'output' tensor as needed (e.g., get predictions)
    return output