Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
class ResClassifier(nn.Module): | |
def __init__(self, class_num=14): | |
super(ResClassifier, self).__init__() | |
self.fc1 = nn.Sequential( | |
nn.Linear(128, 64), | |
nn.BatchNorm1d(64, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Dropout() | |
) | |
self.fc2 = nn.Sequential( | |
nn.Linear(64, 64), | |
nn.BatchNorm1d(64, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Dropout() | |
) | |
self.fc3 = nn.Linear(64, class_num) | |
def forward(self, x): | |
fc1_emb = self.fc1(x) | |
fc2_emb = self.fc2(fc1_emb) | |
logit = self.fc3(fc2_emb) | |
return logit | |
class CC_model(nn.Module): | |
def __init__(self, num_classes=14): | |
super(CC_model, self).__init__() | |
self.num_classes = num_classes | |
self.model_resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT') | |
# Modify final layers | |
num_ftrs = self.model_resnet.fc.in_features | |
self.model_resnet.fc = nn.Identity() # Remove ResNet's default final layer | |
self.classification_fc = nn.Linear(num_ftrs, num_classes) | |
self.dr = nn.Linear(num_ftrs, 128) # Feature reduction (for embeddings) | |
self.fc1 = ResClassifier(num_classes) | |
self.fc2 = ResClassifier(num_classes) | |
def forward(self, x): | |
feature = self.model_resnet(x) | |
class_logits = self.classification_fc(feature) # Correct classification output | |
return class_logits # Ensure output shape is [batch_size, 14] | |