DogVision / model.py
vapit's picture
add final code
37632e1
raw
history blame
2.21 kB
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
# Custom transformation to handle palette images
def convert_to_rgba(image):
# Check if the image mode is 'P' (palette mode)
if image.mode == 'P':
image = image.convert('RGBA')
return image
def create_model(num_classes: int = 120, seed: int = 42):
# 1. Download the default weights
weights = torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
# 2. Setup transforms
default_transforms = weights.transforms()
custom_transforms = transforms.Compose([
# transforms.RandomHorizontalFlip(p=0.5), # Randomly flip images horizontally
# transforms.Lambda(convert_to_rgba), # Apply RGBA conversion if necessary
# transforms.RandomRotation(degrees=10), # Randomly rotate images by up to 10 degrees
# transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Color jitter
])
# 3. Combine custom and ViT's default transforms
combined_transforms = transforms.Compose([
custom_transforms, # First, apply your custom augmentations
transforms.Resize((224, 224)), # Resize to ConvNext's input size if needed (ConvNext expects 224x224)
transforms.ToTensor(), # Convert image to Tensor
default_transforms, # Apply default normalization (mean, std)
])
# 4. Create a model and apply the default weights
model = torchvision.models.convnext_tiny(weights=weights)
# 5. Freeze the base layers in the model (this will stop all layers from training)
for parameters in model.parameters():
parameters.requires_grad = False
# 6. Set seeds for reproducibility
torch.manual_seed(seed)
# 7. Modify the number of output layers (add a dropout layer for regularization)
model.classifier = nn.Sequential(
nn.LayerNorm([768, 1, 1], eps=1e-06, elementwise_affine=True), # Apply LayerNorm on the channel dimension (768)
nn.Flatten(start_dim=1), # Flatten the tensor from dimension 1 onwards (batch size remains intact)
nn.Linear(in_features=768, out_features=num_classes, bias=True)
)
return model, combined_transforms