import torch from torch import nn KERNEL_SIZE = (3,3) class VGG19(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.features = nn.Sequential( nn.Conv2d(3, 64, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(64, 64, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(128, 128, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(128, 256, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(256, 256, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(256, 256, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(256, 256, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(256, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(512, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(512, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(512, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(512, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(512, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(512, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.Conv2d(512, 512, KERNEL_SIZE, 1, 1), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier = nn.Sequential( nn.Linear(49*512, 4096), nn.ReLU(), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(), nn.Linear(4096, 1000), ) def forward(self, x:torch.Tensor): x = self.features(x) return self.classifier(x) def embeddings(self, x:torch.Tensor): return self.features(x).flatten().detach().numpy() __call__ = embeddings MODEL_19 = VGG19() MODEL_19.load_state_dict(torch.load("models/vgg19-dcbb9e9d.pth"), strict=True) if __name__ == "__main__": print(MODEL_19.state_dict().keys())