{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyPSMRz+42399VfS0zVAPJAw"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"N9Zst3gplupx","executionInfo":{"status":"ok","timestamp":1738938644930,"user_tz":-60,"elapsed":16632,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}}},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torchvision.transforms as transforms\n","import torchvision.datasets as datasets\n","from torch.utils.data import DataLoader, Subset\n","from torchvision import models"]},{"cell_type":"code","source":["# Define class labels to include\n","selected_classes = {'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7}\n","num_classes = len(selected_classes)\n","\n","class_mapping = {2: 0, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5} # Remap to 0-5 range\n","\n","def get_filtered_dataset(train=True):\n"," transform = transforms.Compose([\n"," transforms.ToTensor(),\n"," transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n"," ])\n"," dataset = datasets.CIFAR10(root='./data', train=train, download=True, transform=transform)\n","\n"," filtered_data = [(img, class_mapping[label]) for img, label in dataset if label in class_mapping]\n"," return torch.utils.data.TensorDataset(\n"," torch.stack([img for img, _ in filtered_data]),\n"," torch.tensor([label for _, label in filtered_data])\n"," )"],"metadata":{"id":"GuK20ZQ0lx_g","executionInfo":{"status":"ok","timestamp":1738938644947,"user_tz":-60,"elapsed":12,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}}},"execution_count":2,"outputs":[]},{"cell_type":"code","source":["# Load filtered datasets\n","train_dataset = get_filtered_dataset(train=True)\n","val_dataset = get_filtered_dataset(train=False)\n","\n","# Define DataLoaders\n","batch_size = 2048\n","train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n","val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Af80BS3Wl8-u","executionInfo":{"status":"ok","timestamp":1738940526346,"user_tz":-60,"elapsed":14789,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}},"outputId":"13f4df16-f351-4ff2-fb8f-5e484aecec17"},"execution_count":33,"outputs":[{"output_type":"stream","name":"stdout","text":["Files already downloaded and verified\n","Files already downloaded and verified\n"]}]},{"cell_type":"code","source":["train_labels = [label for _, label in train_dataset]\n","val_labels = [label for _, label in val_dataset]\n","\n","print(\"Unique labels in train dataset:\", max(train_labels))\n","print(\"Unique labels in validation dataset:\", max(val_labels))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"41bcqkvnpG3k","executionInfo":{"status":"ok","timestamp":1738938672787,"user_tz":-60,"elapsed":729,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}},"outputId":"4e9a441d-f4e5-4b4e-f04d-d5dd0f77e4e9"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["Unique labels in train dataset: tensor(5)\n","Unique labels in validation dataset: tensor(5)\n"]}]},{"cell_type":"code","source":["# Load pre-trained ResNet model and modify classifier\n","model = models.resnet152(pretrained=True)\n","for param in model.parameters():\n"," param.requires_grad = False # Freeze feature extractor\n","\n","# Modify the classifier for 6 classes with an additional hidden layer\n","model.fc = nn.Sequential(\n"," nn.Linear(model.fc.in_features, 512),\n"," nn.ReLU(),\n"," nn.Linear(512, num_classes)\n",")\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","model.to(device)\n","\n","# Loss and optimizer\n","criterion = nn.CrossEntropyLoss()\n","optimizer = optim.Adam(model.fc.parameters(), lr=0.005)"],"metadata":{"id":"SLqUZIGol-xL","executionInfo":{"status":"ok","timestamp":1738940528157,"user_tz":-60,"elapsed":1045,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}}},"execution_count":34,"outputs":[]},{"cell_type":"code","source":["# Training function\n","def train_model(model, train_loader, val_loader, epochs=10):\n"," for epoch in range(epochs):\n"," model.train()\n"," running_loss = 0.0\n"," for images, labels in train_loader:\n"," images, labels = images.to(device), labels.to(device)\n"," optimizer.zero_grad()\n"," outputs = model(images)\n"," loss = criterion(outputs, labels)\n"," loss.backward()\n"," optimizer.step()\n"," running_loss += loss.item()\n","\n"," print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")\n"," print(\"Training complete!\")\n","\n","# Train the model\n","train_model(model, train_loader, val_loader, epochs=50)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dhDUoGsBmKBS","executionInfo":{"status":"ok","timestamp":1738940926940,"user_tz":-60,"elapsed":396512,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}},"outputId":"71058a31-fda9-459d-a1b0-87b3f65d7c64"},"execution_count":35,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch [1/50], Loss: 3.5701\n","Epoch [2/50], Loss: 1.5634\n","Epoch [3/50], Loss: 1.3792\n","Epoch [4/50], Loss: 1.2970\n","Epoch [5/50], Loss: 1.2437\n","Epoch [6/50], Loss: 1.2093\n","Epoch [7/50], Loss: 1.1940\n","Epoch [8/50], Loss: 1.1638\n","Epoch [9/50], Loss: 1.1560\n","Epoch [10/50], Loss: 1.1454\n","Epoch [11/50], Loss: 1.1354\n","Epoch [12/50], Loss: 1.1121\n","Epoch [13/50], Loss: 1.1029\n","Epoch [14/50], Loss: 1.0843\n","Epoch [15/50], Loss: 1.0740\n","Epoch [16/50], Loss: 1.0754\n","Epoch [17/50], Loss: 1.0622\n","Epoch [18/50], Loss: 1.0484\n","Epoch [19/50], Loss: 1.0432\n","Epoch [20/50], Loss: 1.0280\n","Epoch [21/50], Loss: 1.0292\n","Epoch [22/50], Loss: 1.0211\n","Epoch [23/50], Loss: 1.0173\n","Epoch [24/50], Loss: 1.0112\n","Epoch [25/50], Loss: 1.0140\n","Epoch [26/50], Loss: 0.9901\n","Epoch [27/50], Loss: 0.9816\n","Epoch [28/50], Loss: 0.9871\n","Epoch [29/50], Loss: 0.9813\n","Epoch [30/50], Loss: 0.9792\n","Epoch [31/50], Loss: 0.9712\n","Epoch [32/50], Loss: 0.9599\n","Epoch [33/50], Loss: 0.9558\n","Epoch [34/50], Loss: 0.9523\n","Epoch [35/50], Loss: 0.9444\n","Epoch [36/50], Loss: 0.9516\n","Epoch [37/50], Loss: 0.9425\n","Epoch [38/50], Loss: 0.9394\n","Epoch [39/50], Loss: 0.9208\n","Epoch [40/50], Loss: 0.9137\n","Epoch [41/50], Loss: 0.9099\n","Epoch [42/50], Loss: 0.9193\n","Epoch [43/50], Loss: 0.9100\n","Epoch [44/50], Loss: 0.9046\n","Epoch [45/50], Loss: 0.9024\n","Epoch [46/50], Loss: 0.8992\n","Epoch [47/50], Loss: 0.8957\n","Epoch [48/50], Loss: 0.8859\n","Epoch [49/50], Loss: 0.9190\n","Epoch [50/50], Loss: 0.9079\n","Training complete!\n"]}]},{"cell_type":"code","source":["# Save the model as pth file\n","torch.save(model.state_dict(), 'model.pth')"],"metadata":{"id":"I7_Da4L561-Y","executionInfo":{"status":"ok","timestamp":1738940367122,"user_tz":-60,"elapsed":523,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}}},"execution_count":29,"outputs":[]},{"cell_type":"code","source":["def test_model(model, test_loader):\n"," model.eval() # Set to evaluation mode\n"," correct = 0\n"," total = 0\n"," test_loss = 0.0\n"," with torch.no_grad(): # No need to track gradients\n"," for images, labels in test_loader:\n"," images, labels = images.to(device), labels.to(device)\n"," outputs = model(images)\n"," loss = criterion(outputs, labels)\n"," test_loss += loss.item()\n","\n"," _, predicted = torch.max(outputs, 1) # Get class with highest probability\n"," correct += (predicted == labels).sum().item()\n"," total += labels.size(0)\n","\n"," avg_loss = test_loss / len(test_loader)\n"," accuracy = 100 * correct / total\n"," print(f\"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%\")\n"," return avg_loss, accuracy\n","\n","# Run the test loop\n","test_model(model, train_loader)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2b2ZPAEnqDU1","executionInfo":{"status":"ok","timestamp":1738941057896,"user_tz":-60,"elapsed":8154,"user":{"displayName":"Darpan Aswal","userId":"16079642321735668134"}},"outputId":"3ead8b74-f164-4add-eb4d-09c31a32ba17"},"execution_count":36,"outputs":[{"output_type":"stream","name":"stdout","text":["Test Loss: 0.8585, Test Accuracy: 67.88%\n"]},{"output_type":"execute_result","data":{"text/plain":["(0.8585072080294291, 67.88)"]},"metadata":{},"execution_count":36}]}]}