|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
from .vit import ViT |
|
|
|
class ViTLargeClassifier(nn.Module): |
|
def __init__(self, num_classes: int = 14, image_size=224, patch_size=16): |
|
super(ViTLargeClassifier, self).__init__() |
|
|
|
|
|
self.vit = ViT( |
|
image_size=image_size, |
|
patch_size=patch_size, |
|
num_classes=num_classes, |
|
dim=1024, |
|
depth=24, |
|
heads=16, |
|
mlp_dim=4096, |
|
dropout=0.1, |
|
emb_dropout=0.1 |
|
) |
|
|
|
|
|
if not self.load(): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight) |
|
if m.bias is not None: |
|
init.zeros_(m.bias) |
|
elif isinstance(m, nn.LayerNorm): |
|
init.ones_(m.weight) |
|
init.zeros_(m.bias) |
|
|
|
def forward(self, x): |
|
return self.vit(x) |
|
|
|
def load(self, filename: str = None) -> bool: |
|
if filename is None: |
|
current_work_dir = os.path.dirname(__file__) |
|
filename = os.path.join(current_work_dir, "best_pth", "ViTLargeClassifier.pth") |
|
if not os.path.exists(filename): |
|
print("Model file does not exist.") |
|
return False |
|
self.load_state_dict(torch.load(filename)) |
|
print("Model loaded successfully.") |
|
return True |
|
|