{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torchvision.transforms as transforms\n", "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", "from PIL import Image\n", "from torch.utils.data import ConcatDataset, DataLoader, Dataset\n", "from torchvision.datasets import DatasetFolder\n", "from tqdm import tqdm\n", "from PIL import ImageFile\n", "\n", "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", "random.seed(1234)\n", "np.random.seed(1234)\n", "torch.manual_seed(1234)\n", "\n", "folder = \"./datasets\"\n", "NUM_CLASSES = 14\n", "\n", "\n", "train_tfm = transforms.Compose(\n", " [\n", " # Resize the image to a fixed shape (height = width = 256)\n", " transforms.Resize((224, 224)),\n", " transforms.Lambda(lambda x: x.convert(\"RGB\")),\n", " # Random horizontal flip to increase robustness to object orientation\n", " transforms.RandomHorizontalFlip(),\n", " # Random rotation, a common transformation to handle rotated images\n", " transforms.RandomRotation(\n", " 20\n", " ), # Rotate image by a random angle between -20 and 20 degrees\n", " # Random cropping to simulate random scene zoom\n", " transforms.RandomResizedCrop(\n", " 224, scale=(0.8, 1.0)\n", " ), # Crop and resize to 224x224\n", " # Random color jitter to make the model robust to lighting changes\n", " # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),\n", " # Random affine transformation (translation, scaling, rotation)\n", " # transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2)),\n", " # Convert to tensor\n", " transforms.ToTensor(),\n", " # Normalize the image with mean and standard deviation for better convergence\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "test_tfm = transforms.Compose(\n", " [\n", " # Resize to the fixed size\n", " transforms.Resize((224, 224)),\n", " transforms.Lambda(lambda x: x.convert(\"RGB\")),\n", " # Convert to tensor\n", " transforms.ToTensor(),\n", " # Normalize the image with mean and standard deviation (same as in training)\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "\n", "def get_dataset():\n", " train_set = DatasetFolder(\n", " folder + \"/train/labeled\",\n", " loader=lambda x: Image.open(x),\n", " extensions=\"jpg\",\n", " # transform=train_tfm,\n", " )\n", " valid_set = DatasetFolder(\n", " folder + \"/val\",\n", " loader=lambda x: Image.open(x),\n", " extensions=\"jpg\",\n", " transform=test_tfm,\n", " )\n", " unlabeled_set = DatasetFolder(\n", " folder + \"/train/unlabeled\",\n", " loader=lambda x: Image.open(x),\n", " extensions=\"jpg\",\n", " # transform=train_tfm,\n", " )\n", " test_set = DatasetFolder(\n", " folder + \"/test\",\n", " loader=lambda x: Image.open(x),\n", " extensions=\"jpg\",\n", " transform=test_tfm,\n", " )\n", " return train_set, valid_set, unlabeled_set, test_set\n", "\n", "\n", "def train_collate_fn(batch):\n", " data, labels = zip(*batch)\n", " # data = torch.stack(data)\n", " labels = torch.tensor(labels)\n", " return data, labels\n", "\n", "\n", "def test_collate_fn(batch):\n", " data, labels = zip(*batch)\n", " data = torch.stack(data)\n", " labels = torch.tensor(labels)\n", " return data, labels\n", "\n", "\n", "from utils import CustomDataset\n", "\n", "\n", "def update_dataset(\n", " train_set, unlabeled_set, model, threshold, batch_size=128, num_workers=8\n", ") -> Dataset:\n", " \"\"\"\n", " This is the core function to generate pseudo-labels dataets using the given model.\n", " inputs:\n", " - train_set: The labeled training set\n", " - unlabeled_set: The unlabeled dataset to be pseudo-labeled\n", " - model: The trained model to generate pseudo-labels\n", " - threshold: Confidence threshold for pseudo-labeling\n", " - batch_size: Batch size for DataLoader\n", " - num_workers: Number of workers for DataLoader\n", " outputs:\n", " - new_set: The updated dataset with pseudo-labels\n", " \"\"\"\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", " # Make sure the model is in eval mode.\n", " model.eval()\n", " # Define softmax function.\n", " softmax = nn.Softmax(dim=-1)\n", "\n", " # Create a dataloader for the unlabeled data\n", " unlabeled_loader = DataLoader(\n", " unlabeled_set,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " num_workers=num_workers,\n", " pin_memory=True,\n", " collate_fn=train_collate_fn,\n", " )\n", "\n", " # List to store the most confident predictions\n", " confident_samples = []\n", " confident_labels = []\n", "\n", " with torch.no_grad():\n", " for batch_idx, (images, _) in enumerate(\n", " tqdm(unlabeled_loader, desc=\"Generating pseudo-labels\")\n", " ):\n", " # Apply test transform to each image in the batch\n", " new_images = torch.stack([test_tfm(img) for img in images])\n", "\n", " # Forward pass through the model\n", " outputs = model(new_images.to(device))\n", "\n", " # Apply softmax to get probabilities\n", " probabilities = softmax(outputs)\n", "\n", " # Get the maximum probability and corresponding class for each sample\n", " max_probs, pseudo_labels = torch.max(probabilities, dim=1)\n", "\n", " # For each sample in the batch, check confidence threshold\n", " for i, (prob, label) in enumerate(zip(max_probs, pseudo_labels)):\n", " # If the prediction is confident enough, add to confident set\n", " if prob.item() > threshold:\n", " # Get the actual index in the unlabeled_set\n", " idx = batch_idx * batch_size + i\n", " if idx < len(unlabeled_set):\n", " img, _ = unlabeled_set[idx]\n", " confident_samples.append(img)\n", " confident_labels.append(int(label.cpu()))\n", "\n", " # Create new dataset from the confident predictions\n", " if confident_samples:\n", " pseudo_set = CustomDataset(images=confident_samples, labels=confident_labels)\n", "\n", " # Combine with existing labeled data\n", " new_set = ConcatDataset([train_set, pseudo_set])\n", "\n", " print(f\"Added {len(confident_samples)} pseudo-labeled samples to training set\")\n", " else:\n", " print(\"No confident pseudo-labels found.\")\n", " new_set = train_set\n", " return new_set" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from models import * # noqa: F403\n", "from torch import optim\n", "\n", "# \"cuda\" only when GPUs are available.\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "# Initialize a model, and put it on the device specified.\n", "# model = Classifier().to(device)\n", "# model = VGG16Classifier(num_classes=NUM_CLASSES).to(device)\n", "# model = ResNet50Classifier(num_classes=NUM_CLASSES).to(device)\n", "model = ResNet101Classifier(num_classes=NUM_CLASSES).to(device)\n", "# model = InceptionV3Classifier(num_classes=NUM_CLASSES).to(device)\n", "# model = ViTLargeClassifier(num_classes=NUM_CLASSES).to(device)\n", "\n", "# !!! ONLY USE THIS FOR RESUME TRAINING !!!\n", "model.load_state_dict(torch.load(\"best_model.pth\"))\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)\n", "\n", "scheduler = ReduceLROnPlateau(\n", " optimizer,\n", " mode=\"min\",\n", " factor=0.1,\n", " patience=2,\n", " threshold=0.001,\n", " threshold_mode=\"rel\",\n", " cooldown=0,\n", " min_lr=0,\n", " eps=1e-08,\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# TRAINING\n", "do_semi = True\n", "batch_size = 128\n", "num_workers = 8\n", "train_set, valid_set, unlabeled_set, test_set = get_dataset()\n", "\n", "train_loader = DataLoader(\n", " dataset=train_set,\n", " batch_size=batch_size,\n", " shuffle=True,\n", " num_workers=num_workers,\n", " collate_fn=train_collate_fn,\n", ")\n", "valid_loader = DataLoader(\n", " dataset=valid_set,\n", " batch_size=batch_size,\n", " shuffle=True,\n", " num_workers=num_workers,\n", " collate_fn=test_collate_fn,\n", ")\n", "test_loader = DataLoader(\n", " dataset=test_set,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " num_workers=num_workers,\n", " collate_fn=test_collate_fn,\n", ")\n", "best_valid_loss = float(\"inf\")\n", "start_epoch = 100\n", "epochs = 100\n", "threshold = 0.8\n", "early_stop = False\n", "for epoch in range(start_epoch, epochs):\n", " if do_semi and epoch > epochs // 4 and epoch % 2 == 0:\n", " new_set = update_dataset(\n", " train_set=train_set,\n", " unlabeled_set=unlabeled_set,\n", " model=model,\n", " threshold=threshold,\n", " batch_size=batch_size,\n", " num_workers=num_workers,\n", " )\n", " train_loader = DataLoader(\n", " dataset=new_set,\n", " batch_size=batch_size,\n", " shuffle=True,\n", " num_workers=num_workers,\n", " pin_memory=True,\n", " collate_fn=train_collate_fn,\n", " )\n", "\n", " # ---------- Training ----------\n", " model.train()\n", "\n", " train_loss = []\n", " train_accs = []\n", "\n", " for batch in tqdm(train_loader):\n", " imgs, labels = batch\n", "\n", " new_images = torch.stack([train_tfm(img) for img in imgs])\n", "\n", " logits = model(new_images.to(device))\n", "\n", " loss = criterion(logits, labels.to(device))\n", "\n", " optimizer.zero_grad()\n", "\n", " loss.backward()\n", "\n", " optimizer.step()\n", "\n", " acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()\n", "\n", " train_loss.append(loss.item())\n", " train_accs.append(acc)\n", "\n", " train_loss = sum(train_loss) / len(train_loss)\n", " train_acc = sum(train_accs) / len(train_accs)\n", "\n", " print(\n", " f\"[ Train | {epoch + 1:03d}/{epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}\"\n", " )\n", " # ---------- Validation ----------\n", " model.eval()\n", "\n", " valid_loss = []\n", " valid_accs = []\n", "\n", " for batch in tqdm(valid_loader):\n", " imgs, labels = batch\n", "\n", " with torch.no_grad():\n", " logits = model(imgs.to(device))\n", "\n", " loss = criterion(logits, labels.to(device))\n", "\n", " acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()\n", "\n", " valid_loss.append(loss.item())\n", " valid_accs.append(acc)\n", "\n", " valid_loss = sum(valid_loss) / len(valid_loss)\n", " valid_acc = sum(valid_accs) / len(valid_accs)\n", "\n", " print(\n", " f\"[ Valid | {epoch + 1:03d}/{epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}\"\n", " )\n", "\n", " scheduler.step(metrics=valid_loss)\n", "\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(model.state_dict(), \"best_model.pth\")\n", " print(f\"Model saved with loss: {valid_loss:.5f}, acc: {valid_acc:.5f}\")\n", " elif early_stop:\n", " if epoch > epochs // 2 and valid_loss > best_valid_loss * 1.2:\n", " print(\"Early stopping\")\n", " break\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# torch.save(model.state_dict(), \"best_model.pth\")\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ResNet101Classifier(\n", " (resnet): ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (3): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (4): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (5): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (6): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (7): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (8): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (9): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (10): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (11): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (12): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (13): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (14): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (15): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (16): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (17): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (18): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (19): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (20): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (21): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (22): Bottleneck(\n", " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): Bottleneck(\n", " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " (2): Bottleneck(\n", " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=2048, out_features=14, bias=True)\n", " )\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# LOAD BEST MODEL\n", "\n", "model.load_state_dict(torch.load(\"best_model.pth\"))\n", "model.eval()\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 33/33 [00:07<00:00, 4.14it/s]\n" ] } ], "source": [ "# TEST\n", "\n", "model.eval()\n", "\n", "# Initialize a list to store the predictions.\n", "predictions = []\n", "\n", "\n", "# Iterate the testing set by batches.\n", "for batch in tqdm(test_loader):\n", " imgs, labels = batch\n", "\n", " with torch.no_grad():\n", " logits = model(imgs.to(device))\n", "\n", " predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predictions saved to predict.csv\n" ] } ], "source": [ "# Save predictions into the file.\n", "with open(\"predict.csv\", \"w\") as f:\n", " # The first row must be \"Id, Category\"\n", " f.write(\"Id,Category\\n\")\n", "\n", " # For the rest of the rows, each image id corresponds to a predicted class.\n", " for i, pred in enumerate(predictions):\n", " f.write(f\"{i},{pred}\\n\")\n", "print(\"Predictions saved to predict.csv\")\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv (3.12.9)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 2 }