import os import torch import torch.nn as nn import torch.nn.init as init from .vit import ViT class ViTLargeClassifier(nn.Module): def __init__(self, num_classes: int = 14, image_size=224, patch_size=16): super(ViTLargeClassifier, self).__init__() # 初始化ViT模型 self.vit = ViT( image_size=image_size, patch_size=patch_size, num_classes=num_classes, dim=1024, depth=24, heads=16, mlp_dim=4096, dropout=0.1, emb_dropout=0.1 ) # 初始化权重 if not self.load(): for m in self.modules(): if isinstance(m, nn.Linear): init.xavier_normal_(m.weight) if m.bias is not None: init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): init.ones_(m.weight) init.zeros_(m.bias) def forward(self, x): return self.vit(x) def load(self, filename: str = None) -> bool: if filename is None: current_work_dir = os.path.dirname(__file__) filename = os.path.join(current_work_dir, "best_pth", "ViTLargeClassifier.pth") if not os.path.exists(filename): print("Model file does not exist.") return False self.load_state_dict(torch.load(filename)) print("Model loaded successfully.") return True