import torch | |
import torch.nn as nn | |
import torchvision | |
from torchvision import transforms | |
# Custom transformation to handle palette images | |
def convert_to_rgba(image): | |
# Check if the image mode is 'P' (palette mode) | |
if image.mode == 'P': | |
image = image.convert('RGBA') | |
return image | |
def create_model(num_classes: int = 120, seed: int = 42): | |
# 1. Download the default weights | |
weights = torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1 | |
# 2. Setup transforms | |
default_transforms = weights.transforms() | |
custom_transforms = transforms.Compose([ | |
# transforms.RandomHorizontalFlip(p=0.5), # Randomly flip images horizontally | |
# transforms.Lambda(convert_to_rgba), # Apply RGBA conversion if necessary | |
# transforms.RandomRotation(degrees=10), # Randomly rotate images by up to 10 degrees | |
# transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Color jitter | |
]) | |
# 3. Combine custom and ViT's default transforms | |
combined_transforms = transforms.Compose([ | |
custom_transforms, # First, apply your custom augmentations | |
transforms.Resize((224, 224)), # Resize to ConvNext's input size if needed (ConvNext expects 224x224) | |
transforms.ToTensor(), # Convert image to Tensor | |
default_transforms, # Apply default normalization (mean, std) | |
]) | |
# 4. Create a model and apply the default weights | |
model = torchvision.models.convnext_tiny(weights=weights) | |
# 5. Freeze the base layers in the model (this will stop all layers from training) | |
for parameters in model.parameters(): | |
parameters.requires_grad = False | |
# 6. Set seeds for reproducibility | |
torch.manual_seed(seed) | |
# 7. Modify the number of output layers (add a dropout layer for regularization) | |
model.classifier = nn.Sequential( | |
nn.LayerNorm([768, 1, 1], eps=1e-06, elementwise_affine=True), # Apply LayerNorm on the channel dimension (768) | |
nn.Flatten(start_dim=1), # Flatten the tensor from dimension 1 onwards (batch size remains intact) | |
nn.Linear(in_features=768, out_features=num_classes, bias=True) | |
) | |
return model, combined_transforms | |