import torch import torch.nn as nn import torchvision from torchvision import transforms # Custom transformation to handle palette images def convert_to_rgba(image): # Check if the image mode is 'P' (palette mode) if image.mode == 'P': image = image.convert('RGBA') return image def create_model(num_classes: int = 120, seed: int = 42): # 1. Download the default weights weights = torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1 # 2. Setup transforms default_transforms = weights.transforms() custom_transforms = transforms.Compose([ # transforms.RandomHorizontalFlip(p=0.5), # Randomly flip images horizontally # transforms.Lambda(convert_to_rgba), # Apply RGBA conversion if necessary # transforms.RandomRotation(degrees=10), # Randomly rotate images by up to 10 degrees # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Color jitter ]) # 3. Combine custom and ViT's default transforms combined_transforms = transforms.Compose([ custom_transforms, # First, apply your custom augmentations transforms.Resize((224, 224)), # Resize to ConvNext's input size if needed (ConvNext expects 224x224) transforms.ToTensor(), # Convert image to Tensor default_transforms, # Apply default normalization (mean, std) ]) # 4. Create a model and apply the default weights model = torchvision.models.convnext_tiny(weights=weights) # 5. Freeze the base layers in the model (this will stop all layers from training) for parameters in model.parameters(): parameters.requires_grad = False # 6. Set seeds for reproducibility torch.manual_seed(seed) # 7. Modify the number of output layers (add a dropout layer for regularization) model.classifier = nn.Sequential( nn.LayerNorm([768, 1, 1], eps=1e-06, elementwise_affine=True), # Apply LayerNorm on the channel dimension (768) nn.Flatten(start_dim=1), # Flatten the tensor from dimension 1 onwards (batch size remains intact) nn.Linear(in_features=768, out_features=num_classes, bias=True) ) return model, combined_transforms