{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#import libraries\n", "import torch \n", "from torchvision import datasets, transforms \n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from torchvision.datasets import ImageFolder\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#define the data transforms\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224,224)),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))\n", " ])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#insert the datasets\n", "\n", "train_dataset = ImageFolder('./data/train', transform=transform)\n", "test_dataset =ImageFolder('./data/test', transform=transform)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# make cnn model\n", "\n", "class CNN(nn.Module):\n", " def __init__(self):\n", " super(CNN, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 6, 5)\n", " self.conv2 = nn.Conv2d(6, 16, 5)\n", " self.pool = nn.MaxPool2d(2, 2)\n", " self.fc1 = nn.Linear(16 * 53 * 53, 120)\n", " self.fc2 = nn.Linear(120, 84)\n", " self.fc3 = nn.Linear(84, 3)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.pool(x)\n", " x = self.conv2(x)\n", " x = self.pool(x)\n", " x = x.view(-1, 16 * 53 * 53)\n", " x = self.fc1(x)\n", " x = self.fc2(x)\n", " x = self.fc3(x)\n", " return x\n", "\n", " \n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "batch_size = 8\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "model = CNN()\n", "loss_function = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [1/10], Step [1/34], Loss: 1.0981\n", "Epoch [2/10], Step [1/34], Loss: 1.2921\n", "Epoch [3/10], Step [1/34], Loss: 0.4883\n", "Epoch [4/10], Step [1/34], Loss: 0.3408\n", "Epoch [5/10], Step [1/34], Loss: 0.1063\n", "Epoch [6/10], Step [1/34], Loss: 0.0406\n", "Epoch [7/10], Step [1/34], Loss: 0.0009\n", "Epoch [8/10], Step [1/34], Loss: 0.0066\n", "Epoch [9/10], Step [1/34], Loss: 0.0009\n", "Epoch [10/10], Step [1/34], Loss: 0.0012\n" ] } ], "source": [ "#Train the model\n", "\n", "for epoch in range(10):\n", " for i, (images, labels) in enumerate(train_loader):\n", "\n", " outputs = model(images)\n", "\n", " loss = loss_function(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if i % 200 == 0:\n", " print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, 10, i + 1, len(train_loader), loss.item()))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "#iterate over the test data \n", "\n", "correct = 0\n", "total = 0\n", "for i, (images, labels) in enumerate(test_loader):\n", " outputs = model(images)\n", " \n", " _, predicted = torch.max(outputs.data, 1)\n", " correct += (predicted == labels).sum().item()\n", " total += labels.size(0)\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 53.333333333333336%\n" ] } ], "source": [ "#calculate the accuracy\n", "accuracy = 100 * correct / total\n", "print('Accuracy: {}%' .format(accuracy))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "model_scripted = torch.jit.script(model)\n", "model_scripted.save('./models/cat_dog_cnn.pt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 2 }