import torch class DinoVisionClassifier(torch.nn.Module): def __init__(self, dinov2, num_classes=5): super(DinoVisionClassifier, self).__init__() self.transformer = dinov2 self.classifier = torch.nn.Sequential( torch.nn.Linear(384, 64), torch.nn.ReLU(), torch.nn.Dropout(0.2), torch.nn.Linear(64, num_classes) ) def forward(self, x): x = self.transformer(x) x = self.transformer.norm(x) x = self.classifier(x) return x