{ "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ "---\n", "title: 05 Concise Logistic Regression\n", "description: Concise implementation of logistic regression model for binary image classification.\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"Colab\"" ] }, { "cell_type": "markdown", "metadata": { "id": "gC6qMkJooFub" }, "source": [ "## Concise Logistic Regression for Image Classification\n", "\n", "- Shows a concise implementation of logistic regression for image classification\n", "- Uses PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tI49R1p0n-XM" }, "outputs": [], "source": [ "# imports\n", "import torch\n", "import torchvision\n", "import torch.nn as nn\n", "from torchvision import datasets, models, transforms\n", "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "\n", "# use gpu if available\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O92KeM06pJqc", "outputId": "322d8266-f005-4b17-f18e-3d7046cba4b8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2022-04-03 16:17:19-- https://download.pytorch.org/tutorial/hymenoptera_data.zip\n", "Resolving download.pytorch.org (download.pytorch.org)... 13.226.230.76, 13.226.230.24, 13.226.230.114, ...\n", "Connecting to download.pytorch.org (download.pytorch.org)|13.226.230.76|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 47286322 (45M) [application/zip]\n", "Saving to: ‘hymenoptera_data.zip’\n", "\n", "hymenoptera_data.zi 100%[===================>] 45.10M 25.3MB/s in 1.8s \n", "\n", "2022-04-03 16:17:21 (25.3 MB/s) - ‘hymenoptera_data.zip’ saved [47286322/47286322]\n", "\n", "Archive: hymenoptera_data.zip\n", " creating: hymenoptera_data/\n", " creating: hymenoptera_data/train/\n", " creating: hymenoptera_data/train/ants/\n", " inflating: hymenoptera_data/train/ants/0013035.jpg \n", " inflating: hymenoptera_data/train/ants/1030023514_aad5c608f9.jpg \n", " inflating: hymenoptera_data/train/ants/1095476100_3906d8afde.jpg \n", " inflating: hymenoptera_data/train/ants/1099452230_d1949d3250.jpg \n", " inflating: hymenoptera_data/train/ants/116570827_e9c126745d.jpg \n", " inflating: hymenoptera_data/train/ants/1225872729_6f0856588f.jpg \n", " inflating: hymenoptera_data/train/ants/1262877379_64fcada201.jpg \n", " inflating: hymenoptera_data/train/ants/1269756697_0bce92cdab.jpg \n", " inflating: hymenoptera_data/train/ants/1286984635_5119e80de1.jpg \n", " inflating: hymenoptera_data/train/ants/132478121_2a430adea2.jpg \n", " inflating: hymenoptera_data/train/ants/1360291657_dc248c5eea.jpg \n", " inflating: hymenoptera_data/train/ants/1368913450_e146e2fb6d.jpg \n", " inflating: hymenoptera_data/train/ants/1473187633_63ccaacea6.jpg \n", " inflating: hymenoptera_data/train/ants/148715752_302c84f5a4.jpg \n", " inflating: hymenoptera_data/train/ants/1489674356_09d48dde0a.jpg \n", " inflating: hymenoptera_data/train/ants/149244013_c529578289.jpg \n", " inflating: hymenoptera_data/train/ants/150801003_3390b73135.jpg \n", " inflating: hymenoptera_data/train/ants/150801171_cd86f17ed8.jpg \n", " inflating: hymenoptera_data/train/ants/154124431_65460430f2.jpg \n", " inflating: hymenoptera_data/train/ants/162603798_40b51f1654.jpg \n", " inflating: hymenoptera_data/train/ants/1660097129_384bf54490.jpg \n", " inflating: hymenoptera_data/train/ants/167890289_dd5ba923f3.jpg \n", " inflating: hymenoptera_data/train/ants/1693954099_46d4c20605.jpg \n", " inflating: hymenoptera_data/train/ants/175998972.jpg \n", " inflating: hymenoptera_data/train/ants/178538489_bec7649292.jpg \n", " inflating: hymenoptera_data/train/ants/1804095607_0341701e1c.jpg \n", " inflating: hymenoptera_data/train/ants/1808777855_2a895621d7.jpg \n", " inflating: hymenoptera_data/train/ants/188552436_605cc9b36b.jpg \n", " inflating: hymenoptera_data/train/ants/1917341202_d00a7f9af5.jpg \n", " inflating: hymenoptera_data/train/ants/1924473702_daa9aacdbe.jpg \n", " inflating: hymenoptera_data/train/ants/196057951_63bf063b92.jpg \n", " inflating: hymenoptera_data/train/ants/196757565_326437f5fe.jpg \n", " inflating: hymenoptera_data/train/ants/201558278_fe4caecc76.jpg \n", " inflating: hymenoptera_data/train/ants/201790779_527f4c0168.jpg \n", " inflating: hymenoptera_data/train/ants/2019439677_2db655d361.jpg \n", " inflating: hymenoptera_data/train/ants/207947948_3ab29d7207.jpg \n", " inflating: hymenoptera_data/train/ants/20935278_9190345f6b.jpg \n", " inflating: hymenoptera_data/train/ants/224655713_3956f7d39a.jpg \n", " inflating: hymenoptera_data/train/ants/2265824718_2c96f485da.jpg \n", " inflating: hymenoptera_data/train/ants/2265825502_fff99cfd2d.jpg \n", " inflating: hymenoptera_data/train/ants/226951206_d6bf946504.jpg \n", " inflating: hymenoptera_data/train/ants/2278278459_6b99605e50.jpg \n", " inflating: hymenoptera_data/train/ants/2288450226_a6e96e8fdf.jpg \n", " inflating: hymenoptera_data/train/ants/2288481644_83ff7e4572.jpg \n", " inflating: hymenoptera_data/train/ants/2292213964_ca51ce4bef.jpg \n", " inflating: hymenoptera_data/train/ants/24335309_c5ea483bb8.jpg \n", " inflating: hymenoptera_data/train/ants/245647475_9523dfd13e.jpg \n", " inflating: hymenoptera_data/train/ants/255434217_1b2b3fe0a4.jpg \n", " inflating: hymenoptera_data/train/ants/258217966_d9d90d18d3.jpg \n", " inflating: hymenoptera_data/train/ants/275429470_b2d7d9290b.jpg \n", " inflating: hymenoptera_data/train/ants/28847243_e79fe052cd.jpg \n", " inflating: hymenoptera_data/train/ants/318052216_84dff3f98a.jpg \n", " inflating: hymenoptera_data/train/ants/334167043_cbd1adaeb9.jpg \n", " inflating: hymenoptera_data/train/ants/339670531_94b75ae47a.jpg \n", " inflating: hymenoptera_data/train/ants/342438950_a3da61deab.jpg \n", " inflating: hymenoptera_data/train/ants/36439863_0bec9f554f.jpg \n", " inflating: hymenoptera_data/train/ants/374435068_7eee412ec4.jpg \n", " inflating: hymenoptera_data/train/ants/382971067_0bfd33afe0.jpg \n", " inflating: hymenoptera_data/train/ants/384191229_5779cf591b.jpg \n", " inflating: hymenoptera_data/train/ants/386190770_672743c9a7.jpg \n", " inflating: hymenoptera_data/train/ants/392382602_1b7bed32fa.jpg \n", " inflating: hymenoptera_data/train/ants/403746349_71384f5b58.jpg \n", " inflating: hymenoptera_data/train/ants/408393566_b5b694119b.jpg \n", " inflating: hymenoptera_data/train/ants/424119020_6d57481dab.jpg \n", " inflating: hymenoptera_data/train/ants/424873399_47658a91fb.jpg \n", " inflating: hymenoptera_data/train/ants/450057712_771b3bfc91.jpg \n", " inflating: hymenoptera_data/train/ants/45472593_bfd624f8dc.jpg \n", " inflating: hymenoptera_data/train/ants/459694881_ac657d3187.jpg \n", " inflating: hymenoptera_data/train/ants/460372577_f2f6a8c9fc.jpg \n", " inflating: hymenoptera_data/train/ants/460874319_0a45ab4d05.jpg \n", " inflating: hymenoptera_data/train/ants/466430434_4000737de9.jpg \n", " inflating: hymenoptera_data/train/ants/470127037_513711fd21.jpg \n", " inflating: hymenoptera_data/train/ants/474806473_ca6caab245.jpg \n", " inflating: hymenoptera_data/train/ants/475961153_b8c13fd405.jpg \n", " inflating: hymenoptera_data/train/ants/484293231_e53cfc0c89.jpg \n", " inflating: hymenoptera_data/train/ants/49375974_e28ba6f17e.jpg \n", " inflating: hymenoptera_data/train/ants/506249802_207cd979b4.jpg \n", " inflating: hymenoptera_data/train/ants/506249836_717b73f540.jpg \n", " inflating: hymenoptera_data/train/ants/512164029_c0a66b8498.jpg \n", " inflating: hymenoptera_data/train/ants/512863248_43c8ce579b.jpg \n", " inflating: hymenoptera_data/train/ants/518773929_734dbc5ff4.jpg \n", " inflating: hymenoptera_data/train/ants/522163566_fec115ca66.jpg \n", " inflating: hymenoptera_data/train/ants/522415432_2218f34bf8.jpg \n", " inflating: hymenoptera_data/train/ants/531979952_bde12b3bc0.jpg \n", " inflating: hymenoptera_data/train/ants/533848102_70a85ad6dd.jpg \n", " inflating: hymenoptera_data/train/ants/535522953_308353a07c.jpg \n", " inflating: hymenoptera_data/train/ants/540889389_48bb588b21.jpg \n", " inflating: hymenoptera_data/train/ants/541630764_dbd285d63c.jpg \n", " inflating: hymenoptera_data/train/ants/543417860_b14237f569.jpg \n", " inflating: hymenoptera_data/train/ants/560966032_988f4d7bc4.jpg \n", " inflating: hymenoptera_data/train/ants/5650366_e22b7e1065.jpg \n", " inflating: hymenoptera_data/train/ants/6240329_72c01e663e.jpg \n", " inflating: hymenoptera_data/train/ants/6240338_93729615ec.jpg \n", " inflating: hymenoptera_data/train/ants/649026570_e58656104b.jpg \n", " inflating: hymenoptera_data/train/ants/662541407_ff8db781e7.jpg \n", " inflating: hymenoptera_data/train/ants/67270775_e9fdf77e9d.jpg \n", " inflating: hymenoptera_data/train/ants/6743948_2b8c096dda.jpg \n", " inflating: hymenoptera_data/train/ants/684133190_35b62c0c1d.jpg \n", " inflating: hymenoptera_data/train/ants/69639610_95e0de17aa.jpg \n", " inflating: hymenoptera_data/train/ants/707895295_009cf23188.jpg \n", " inflating: hymenoptera_data/train/ants/7759525_1363d24e88.jpg \n", " inflating: hymenoptera_data/train/ants/795000156_a9900a4a71.jpg \n", " inflating: hymenoptera_data/train/ants/822537660_caf4ba5514.jpg \n", " inflating: hymenoptera_data/train/ants/82852639_52b7f7f5e3.jpg \n", " inflating: hymenoptera_data/train/ants/841049277_b28e58ad05.jpg \n", " inflating: hymenoptera_data/train/ants/886401651_f878e888cd.jpg \n", " inflating: hymenoptera_data/train/ants/892108839_f1aad4ca46.jpg \n", " inflating: hymenoptera_data/train/ants/938946700_ca1c669085.jpg \n", " inflating: hymenoptera_data/train/ants/957233405_25c1d1187b.jpg \n", " inflating: hymenoptera_data/train/ants/9715481_b3cb4114ff.jpg \n", " inflating: hymenoptera_data/train/ants/998118368_6ac1d91f81.jpg \n", " inflating: hymenoptera_data/train/ants/ant photos.jpg \n", " inflating: hymenoptera_data/train/ants/Ant_1.jpg \n", " inflating: hymenoptera_data/train/ants/army-ants-red-picture.jpg \n", " inflating: hymenoptera_data/train/ants/formica.jpeg \n", " inflating: hymenoptera_data/train/ants/hormiga_co_por.jpg \n", " inflating: hymenoptera_data/train/ants/imageNotFound.gif \n", " inflating: hymenoptera_data/train/ants/kurokusa.jpg \n", " inflating: hymenoptera_data/train/ants/MehdiabadiAnt2_600.jpg \n", " inflating: hymenoptera_data/train/ants/Nepenthes_rafflesiana_ant.jpg \n", " inflating: hymenoptera_data/train/ants/swiss-army-ant.jpg \n", " inflating: hymenoptera_data/train/ants/termite-vs-ant.jpg \n", " inflating: hymenoptera_data/train/ants/trap-jaw-ant-insect-bg.jpg \n", " inflating: hymenoptera_data/train/ants/VietnameseAntMimicSpider.jpg \n", " creating: hymenoptera_data/train/bees/\n", " inflating: hymenoptera_data/train/bees/1092977343_cb42b38d62.jpg \n", " inflating: hymenoptera_data/train/bees/1093831624_fb5fbe2308.jpg \n", " inflating: hymenoptera_data/train/bees/1097045929_1753d1c765.jpg \n", " inflating: hymenoptera_data/train/bees/1232245714_f862fbe385.jpg \n", " inflating: hymenoptera_data/train/bees/129236073_0985e91c7d.jpg \n", " inflating: hymenoptera_data/train/bees/1295655112_7813f37d21.jpg \n", " inflating: hymenoptera_data/train/bees/132511197_0b86ad0fff.jpg \n", " inflating: hymenoptera_data/train/bees/132826773_dbbcb117b9.jpg \n", " inflating: hymenoptera_data/train/bees/150013791_969d9a968b.jpg \n", " inflating: hymenoptera_data/train/bees/1508176360_2972117c9d.jpg \n", " inflating: hymenoptera_data/train/bees/154600396_53e1252e52.jpg \n", " inflating: hymenoptera_data/train/bees/16838648_415acd9e3f.jpg \n", " inflating: hymenoptera_data/train/bees/1691282715_0addfdf5e8.jpg \n", " inflating: hymenoptera_data/train/bees/17209602_fe5a5a746f.jpg \n", " inflating: hymenoptera_data/train/bees/174142798_e5ad6d76e0.jpg \n", " inflating: hymenoptera_data/train/bees/1799726602_8580867f71.jpg \n", " inflating: hymenoptera_data/train/bees/1807583459_4fe92b3133.jpg \n", " inflating: hymenoptera_data/train/bees/196430254_46bd129ae7.jpg \n", " inflating: hymenoptera_data/train/bees/196658222_3fffd79c67.jpg \n", " inflating: hymenoptera_data/train/bees/198508668_97d818b6c4.jpg \n", " inflating: hymenoptera_data/train/bees/2031225713_50ed499635.jpg \n", " inflating: hymenoptera_data/train/bees/2037437624_2d7bce461f.jpg \n", " inflating: hymenoptera_data/train/bees/2053200300_8911ef438a.jpg \n", " inflating: hymenoptera_data/train/bees/205835650_e6f2614bee.jpg \n", " inflating: hymenoptera_data/train/bees/208702903_42fb4d9748.jpg \n", " inflating: hymenoptera_data/train/bees/21399619_3e61e5bb6f.jpg \n", " inflating: hymenoptera_data/train/bees/2227611847_ec72d40403.jpg \n", " inflating: hymenoptera_data/train/bees/2321139806_d73d899e66.jpg \n", " inflating: hymenoptera_data/train/bees/2330918208_8074770c20.jpg \n", " inflating: hymenoptera_data/train/bees/2345177635_caf07159b3.jpg \n", " inflating: hymenoptera_data/train/bees/2358061370_9daabbd9ac.jpg \n", " inflating: hymenoptera_data/train/bees/2364597044_3c3e3fc391.jpg \n", " inflating: hymenoptera_data/train/bees/2384149906_2cd8b0b699.jpg \n", " inflating: hymenoptera_data/train/bees/2397446847_04ef3cd3e1.jpg \n", " inflating: hymenoptera_data/train/bees/2405441001_b06c36fa72.jpg \n", " inflating: hymenoptera_data/train/bees/2445215254_51698ff797.jpg \n", " inflating: hymenoptera_data/train/bees/2452236943_255bfd9e58.jpg \n", " inflating: hymenoptera_data/train/bees/2467959963_a7831e9ff0.jpg \n", " inflating: hymenoptera_data/train/bees/2470492904_837e97800d.jpg \n", " inflating: hymenoptera_data/train/bees/2477324698_3d4b1b1cab.jpg \n", " inflating: hymenoptera_data/train/bees/2477349551_e75c97cf4d.jpg \n", " inflating: hymenoptera_data/train/bees/2486729079_62df0920be.jpg \n", " inflating: hymenoptera_data/train/bees/2486746709_c43cec0e42.jpg \n", " inflating: hymenoptera_data/train/bees/2493379287_4100e1dacc.jpg \n", " inflating: hymenoptera_data/train/bees/2495722465_879acf9d85.jpg \n", " inflating: hymenoptera_data/train/bees/2528444139_fa728b0f5b.jpg \n", " inflating: hymenoptera_data/train/bees/2538361678_9da84b77e3.jpg \n", " inflating: hymenoptera_data/train/bees/2551813042_8a070aeb2b.jpg \n", " inflating: hymenoptera_data/train/bees/2580598377_a4caecdb54.jpg \n", " inflating: hymenoptera_data/train/bees/2601176055_8464e6aa71.jpg \n", " inflating: hymenoptera_data/train/bees/2610833167_79bf0bcae5.jpg \n", " inflating: hymenoptera_data/train/bees/2610838525_fe8e3cae47.jpg \n", " inflating: hymenoptera_data/train/bees/2617161745_fa3ebe85b4.jpg \n", " inflating: hymenoptera_data/train/bees/2625499656_e3415e374d.jpg \n", " inflating: hymenoptera_data/train/bees/2634617358_f32fd16bea.jpg \n", " inflating: hymenoptera_data/train/bees/2638074627_6b3ae746a0.jpg \n", " inflating: hymenoptera_data/train/bees/2645107662_b73a8595cc.jpg \n", " inflating: hymenoptera_data/train/bees/2651621464_a2fa8722eb.jpg \n", " inflating: hymenoptera_data/train/bees/2652877533_a564830cbf.jpg \n", " inflating: hymenoptera_data/train/bees/266644509_d30bb16a1b.jpg \n", " inflating: hymenoptera_data/train/bees/2683605182_9d2a0c66cf.jpg \n", " inflating: hymenoptera_data/train/bees/2704348794_eb5d5178c2.jpg \n", " inflating: hymenoptera_data/train/bees/2707440199_cd170bd512.jpg \n", " inflating: hymenoptera_data/train/bees/2710368626_cb42882dc8.jpg \n", " inflating: hymenoptera_data/train/bees/2722592222_258d473e17.jpg \n", " inflating: hymenoptera_data/train/bees/2728759455_ce9bb8cd7a.jpg \n", " inflating: hymenoptera_data/train/bees/2756397428_1d82a08807.jpg \n", " inflating: hymenoptera_data/train/bees/2765347790_da6cf6cb40.jpg \n", " inflating: hymenoptera_data/train/bees/2781170484_5d61835d63.jpg \n", " inflating: hymenoptera_data/train/bees/279113587_b4843db199.jpg \n", " inflating: hymenoptera_data/train/bees/2792000093_e8ae0718cf.jpg \n", " inflating: hymenoptera_data/train/bees/2801728106_833798c909.jpg \n", " inflating: hymenoptera_data/train/bees/2822388965_f6dca2a275.jpg \n", " inflating: hymenoptera_data/train/bees/2861002136_52c7c6f708.jpg \n", " inflating: hymenoptera_data/train/bees/2908916142_a7ac8b57a8.jpg \n", " inflating: hymenoptera_data/train/bees/29494643_e3410f0d37.jpg \n", " inflating: hymenoptera_data/train/bees/2959730355_416a18c63c.jpg \n", " inflating: hymenoptera_data/train/bees/2962405283_22718d9617.jpg \n", " inflating: hymenoptera_data/train/bees/3006264892_30e9cced70.jpg \n", " inflating: hymenoptera_data/train/bees/3030189811_01d095b793.jpg \n", " inflating: hymenoptera_data/train/bees/3030772428_8578335616.jpg \n", " inflating: hymenoptera_data/train/bees/3044402684_3853071a87.jpg \n", " inflating: hymenoptera_data/train/bees/3074585407_9854eb3153.jpg \n", " inflating: hymenoptera_data/train/bees/3079610310_ac2d0ae7bc.jpg \n", " inflating: hymenoptera_data/train/bees/3090975720_71f12e6de4.jpg \n", " inflating: hymenoptera_data/train/bees/3100226504_c0d4f1e3f1.jpg \n", " inflating: hymenoptera_data/train/bees/342758693_c56b89b6b6.jpg \n", " inflating: hymenoptera_data/train/bees/354167719_22dca13752.jpg \n", " inflating: hymenoptera_data/train/bees/359928878_b3b418c728.jpg \n", " inflating: hymenoptera_data/train/bees/365759866_b15700c59b.jpg \n", " inflating: hymenoptera_data/train/bees/36900412_92b81831ad.jpg \n", " inflating: hymenoptera_data/train/bees/39672681_1302d204d1.jpg \n", " inflating: hymenoptera_data/train/bees/39747887_42df2855ee.jpg \n", " inflating: hymenoptera_data/train/bees/421515404_e87569fd8b.jpg \n", " inflating: hymenoptera_data/train/bees/444532809_9e931e2279.jpg \n", " inflating: hymenoptera_data/train/bees/446296270_d9e8b93ecf.jpg \n", " inflating: hymenoptera_data/train/bees/452462677_7be43af8ff.jpg \n", " inflating: hymenoptera_data/train/bees/452462695_40a4e5b559.jpg \n", " inflating: hymenoptera_data/train/bees/457457145_5f86eb7e9c.jpg \n", " inflating: hymenoptera_data/train/bees/465133211_80e0c27f60.jpg \n", " inflating: hymenoptera_data/train/bees/469333327_358ba8fe8a.jpg \n", " inflating: hymenoptera_data/train/bees/472288710_2abee16fa0.jpg \n", " inflating: hymenoptera_data/train/bees/473618094_8ffdcab215.jpg \n", " inflating: hymenoptera_data/train/bees/476347960_52edd72b06.jpg \n", " inflating: hymenoptera_data/train/bees/478701318_bbd5e557b8.jpg \n", " inflating: hymenoptera_data/train/bees/507288830_f46e8d4cb2.jpg \n", " inflating: hymenoptera_data/train/bees/509247772_2db2d01374.jpg \n", " inflating: hymenoptera_data/train/bees/513545352_fd3e7c7c5d.jpg \n", " inflating: hymenoptera_data/train/bees/522104315_5d3cb2758e.jpg \n", " inflating: hymenoptera_data/train/bees/537309131_532bfa59ea.jpg \n", " inflating: hymenoptera_data/train/bees/586041248_3032e277a9.jpg \n", " inflating: hymenoptera_data/train/bees/760526046_547e8b381f.jpg \n", " inflating: hymenoptera_data/train/bees/760568592_45a52c847f.jpg \n", " inflating: hymenoptera_data/train/bees/774440991_63a4aa0cbe.jpg \n", " inflating: hymenoptera_data/train/bees/85112639_6e860b0469.jpg \n", " inflating: hymenoptera_data/train/bees/873076652_eb098dab2d.jpg \n", " inflating: hymenoptera_data/train/bees/90179376_abc234e5f4.jpg \n", " inflating: hymenoptera_data/train/bees/92663402_37f379e57a.jpg \n", " inflating: hymenoptera_data/train/bees/95238259_98470c5b10.jpg \n", " inflating: hymenoptera_data/train/bees/969455125_58c797ef17.jpg \n", " inflating: hymenoptera_data/train/bees/98391118_bdb1e80cce.jpg \n", " creating: hymenoptera_data/val/\n", " creating: hymenoptera_data/val/ants/\n", " inflating: hymenoptera_data/val/ants/10308379_1b6c72e180.jpg \n", " inflating: hymenoptera_data/val/ants/1053149811_f62a3410d3.jpg \n", " inflating: hymenoptera_data/val/ants/1073564163_225a64f170.jpg \n", " inflating: hymenoptera_data/val/ants/1119630822_cd325ea21a.jpg \n", " inflating: hymenoptera_data/val/ants/1124525276_816a07c17f.jpg \n", " inflating: hymenoptera_data/val/ants/11381045_b352a47d8c.jpg \n", " inflating: hymenoptera_data/val/ants/119785936_dd428e40c3.jpg \n", " inflating: hymenoptera_data/val/ants/1247887232_edcb61246c.jpg \n", " inflating: hymenoptera_data/val/ants/1262751255_c56c042b7b.jpg \n", " inflating: hymenoptera_data/val/ants/1337725712_2eb53cd742.jpg \n", " inflating: hymenoptera_data/val/ants/1358854066_5ad8015f7f.jpg \n", " inflating: hymenoptera_data/val/ants/1440002809_b268d9a66a.jpg \n", " inflating: hymenoptera_data/val/ants/147542264_79506478c2.jpg \n", " inflating: hymenoptera_data/val/ants/152286280_411648ec27.jpg \n", " inflating: hymenoptera_data/val/ants/153320619_2aeb5fa0ee.jpg \n", " inflating: hymenoptera_data/val/ants/153783656_85f9c3ac70.jpg \n", " inflating: hymenoptera_data/val/ants/157401988_d0564a9d02.jpg \n", " inflating: hymenoptera_data/val/ants/159515240_d5981e20d1.jpg \n", " inflating: hymenoptera_data/val/ants/161076144_124db762d6.jpg \n", " inflating: hymenoptera_data/val/ants/161292361_c16e0bf57a.jpg \n", " inflating: hymenoptera_data/val/ants/170652283_ecdaff5d1a.jpg \n", " inflating: hymenoptera_data/val/ants/17081114_79b9a27724.jpg \n", " inflating: hymenoptera_data/val/ants/172772109_d0a8e15fb0.jpg \n", " inflating: hymenoptera_data/val/ants/1743840368_b5ccda82b7.jpg \n", " inflating: hymenoptera_data/val/ants/181942028_961261ef48.jpg \n", " inflating: hymenoptera_data/val/ants/183260961_64ab754c97.jpg \n", " inflating: hymenoptera_data/val/ants/2039585088_c6f47c592e.jpg \n", " inflating: hymenoptera_data/val/ants/205398178_c395c5e460.jpg \n", " inflating: hymenoptera_data/val/ants/208072188_f293096296.jpg \n", " inflating: hymenoptera_data/val/ants/209615353_eeb38ba204.jpg \n", " inflating: hymenoptera_data/val/ants/2104709400_8831b4fc6f.jpg \n", " inflating: hymenoptera_data/val/ants/212100470_b485e7b7b9.jpg \n", " inflating: hymenoptera_data/val/ants/2127908701_d49dc83c97.jpg \n", " inflating: hymenoptera_data/val/ants/2191997003_379df31291.jpg \n", " inflating: hymenoptera_data/val/ants/2211974567_ee4606b493.jpg \n", " inflating: hymenoptera_data/val/ants/2219621907_47bc7cc6b0.jpg \n", " inflating: hymenoptera_data/val/ants/2238242353_52c82441df.jpg \n", " inflating: hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg \n", " inflating: hymenoptera_data/val/ants/239161491_86ac23b0a3.jpg \n", " inflating: hymenoptera_data/val/ants/263615709_cfb28f6b8e.jpg \n", " inflating: hymenoptera_data/val/ants/308196310_1db5ffa01b.jpg \n", " inflating: hymenoptera_data/val/ants/319494379_648fb5a1c6.jpg \n", " inflating: hymenoptera_data/val/ants/35558229_1fa4608a7a.jpg \n", " inflating: hymenoptera_data/val/ants/412436937_4c2378efc2.jpg \n", " inflating: hymenoptera_data/val/ants/436944325_d4925a38c7.jpg \n", " inflating: hymenoptera_data/val/ants/445356866_6cb3289067.jpg \n", " inflating: hymenoptera_data/val/ants/459442412_412fecf3fe.jpg \n", " inflating: hymenoptera_data/val/ants/470127071_8b8ee2bd74.jpg \n", " inflating: hymenoptera_data/val/ants/477437164_bc3e6e594a.jpg \n", " inflating: hymenoptera_data/val/ants/488272201_c5aa281348.jpg \n", " inflating: hymenoptera_data/val/ants/502717153_3e4865621a.jpg \n", " inflating: hymenoptera_data/val/ants/518746016_bcc28f8b5b.jpg \n", " inflating: hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg \n", " inflating: hymenoptera_data/val/ants/562589509_7e55469b97.jpg \n", " inflating: hymenoptera_data/val/ants/57264437_a19006872f.jpg \n", " inflating: hymenoptera_data/val/ants/573151833_ebbc274b77.jpg \n", " inflating: hymenoptera_data/val/ants/649407494_9b6bc4949f.jpg \n", " inflating: hymenoptera_data/val/ants/751649788_78dd7d16ce.jpg \n", " inflating: hymenoptera_data/val/ants/768870506_8f115d3d37.jpg \n", " inflating: hymenoptera_data/val/ants/800px-Meat_eater_ant_qeen_excavating_hole.jpg \n", " inflating: hymenoptera_data/val/ants/8124241_36b290d372.jpg \n", " inflating: hymenoptera_data/val/ants/8398478_50ef10c47a.jpg \n", " inflating: hymenoptera_data/val/ants/854534770_31f6156383.jpg \n", " inflating: hymenoptera_data/val/ants/892676922_4ab37dce07.jpg \n", " inflating: hymenoptera_data/val/ants/94999827_36895faade.jpg \n", " inflating: hymenoptera_data/val/ants/Ant-1818.jpg \n", " inflating: hymenoptera_data/val/ants/ants-devouring-remains-of-large-dead-insect-on-red-tile-in-Stellenbosch-South-Africa-closeup-1-DHD.jpg \n", " inflating: hymenoptera_data/val/ants/desert_ant.jpg \n", " inflating: hymenoptera_data/val/ants/F.pergan.28(f).jpg \n", " inflating: hymenoptera_data/val/ants/Hormiga.jpg \n", " creating: hymenoptera_data/val/bees/\n", " inflating: hymenoptera_data/val/bees/1032546534_06907fe3b3.jpg \n", " inflating: hymenoptera_data/val/bees/10870992_eebeeb3a12.jpg \n", " inflating: hymenoptera_data/val/bees/1181173278_23c36fac71.jpg \n", " inflating: hymenoptera_data/val/bees/1297972485_33266a18d9.jpg \n", " inflating: hymenoptera_data/val/bees/1328423762_f7a88a8451.jpg \n", " inflating: hymenoptera_data/val/bees/1355974687_1341c1face.jpg \n", " inflating: hymenoptera_data/val/bees/144098310_a4176fd54d.jpg \n", " inflating: hymenoptera_data/val/bees/1486120850_490388f84b.jpg \n", " inflating: hymenoptera_data/val/bees/149973093_da3c446268.jpg \n", " inflating: hymenoptera_data/val/bees/151594775_ee7dc17b60.jpg \n", " inflating: hymenoptera_data/val/bees/151603988_2c6f7d14c7.jpg \n", " inflating: hymenoptera_data/val/bees/1519368889_4270261ee3.jpg \n", " inflating: hymenoptera_data/val/bees/152789693_220b003452.jpg \n", " inflating: hymenoptera_data/val/bees/177677657_a38c97e572.jpg \n", " inflating: hymenoptera_data/val/bees/1799729694_0c40101071.jpg \n", " inflating: hymenoptera_data/val/bees/181171681_c5a1a82ded.jpg \n", " inflating: hymenoptera_data/val/bees/187130242_4593a4c610.jpg \n", " inflating: hymenoptera_data/val/bees/203868383_0fcbb48278.jpg \n", " inflating: hymenoptera_data/val/bees/2060668999_e11edb10d0.jpg \n", " inflating: hymenoptera_data/val/bees/2086294791_6f3789d8a6.jpg \n", " inflating: hymenoptera_data/val/bees/2103637821_8d26ee6b90.jpg \n", " inflating: hymenoptera_data/val/bees/2104135106_a65eede1de.jpg \n", " inflating: hymenoptera_data/val/bees/215512424_687e1e0821.jpg \n", " inflating: hymenoptera_data/val/bees/2173503984_9c6aaaa7e2.jpg \n", " inflating: hymenoptera_data/val/bees/220376539_20567395d8.jpg \n", " inflating: hymenoptera_data/val/bees/224841383_d050f5f510.jpg \n", " inflating: hymenoptera_data/val/bees/2321144482_f3785ba7b2.jpg \n", " inflating: hymenoptera_data/val/bees/238161922_55fa9a76ae.jpg \n", " inflating: hymenoptera_data/val/bees/2407809945_fb525ef54d.jpg \n", " inflating: hymenoptera_data/val/bees/2415414155_1916f03b42.jpg \n", " inflating: hymenoptera_data/val/bees/2438480600_40a1249879.jpg \n", " inflating: hymenoptera_data/val/bees/2444778727_4b781ac424.jpg \n", " inflating: hymenoptera_data/val/bees/2457841282_7867f16639.jpg \n", " inflating: hymenoptera_data/val/bees/2470492902_3572c90f75.jpg \n", " inflating: hymenoptera_data/val/bees/2478216347_535c8fe6d7.jpg \n", " inflating: hymenoptera_data/val/bees/2501530886_e20952b97d.jpg \n", " inflating: hymenoptera_data/val/bees/2506114833_90a41c5267.jpg \n", " inflating: hymenoptera_data/val/bees/2509402554_31821cb0b6.jpg \n", " inflating: hymenoptera_data/val/bees/2525379273_dcb26a516d.jpg \n", " inflating: hymenoptera_data/val/bees/26589803_5ba7000313.jpg \n", " inflating: hymenoptera_data/val/bees/2668391343_45e272cd07.jpg \n", " inflating: hymenoptera_data/val/bees/2670536155_c170f49cd0.jpg \n", " inflating: hymenoptera_data/val/bees/2685605303_9eed79d59d.jpg \n", " inflating: hymenoptera_data/val/bees/2702408468_d9ed795f4f.jpg \n", " inflating: hymenoptera_data/val/bees/2709775832_85b4b50a57.jpg \n", " inflating: hymenoptera_data/val/bees/2717418782_bd83307d9f.jpg \n", " inflating: hymenoptera_data/val/bees/272986700_d4d4bf8c4b.jpg \n", " inflating: hymenoptera_data/val/bees/2741763055_9a7bb00802.jpg \n", " inflating: hymenoptera_data/val/bees/2745389517_250a397f31.jpg \n", " inflating: hymenoptera_data/val/bees/2751836205_6f7b5eff30.jpg \n", " inflating: hymenoptera_data/val/bees/2782079948_8d4e94a826.jpg \n", " inflating: hymenoptera_data/val/bees/2809496124_5f25b5946a.jpg \n", " inflating: hymenoptera_data/val/bees/2815838190_0a9889d995.jpg \n", " inflating: hymenoptera_data/val/bees/2841437312_789699c740.jpg \n", " inflating: hymenoptera_data/val/bees/2883093452_7e3a1eb53f.jpg \n", " inflating: hymenoptera_data/val/bees/290082189_f66cb80bfc.jpg \n", " inflating: hymenoptera_data/val/bees/296565463_d07a7bed96.jpg \n", " inflating: hymenoptera_data/val/bees/3077452620_548c79fda0.jpg \n", " inflating: hymenoptera_data/val/bees/348291597_ee836fbb1a.jpg \n", " inflating: hymenoptera_data/val/bees/350436573_41f4ecb6c8.jpg \n", " inflating: hymenoptera_data/val/bees/353266603_d3eac7e9a0.jpg \n", " inflating: hymenoptera_data/val/bees/372228424_16da1f8884.jpg \n", " inflating: hymenoptera_data/val/bees/400262091_701c00031c.jpg \n", " inflating: hymenoptera_data/val/bees/416144384_961c326481.jpg \n", " inflating: hymenoptera_data/val/bees/44105569_16720a960c.jpg \n", " inflating: hymenoptera_data/val/bees/456097971_860949c4fc.jpg \n", " inflating: hymenoptera_data/val/bees/464594019_1b24a28bb1.jpg \n", " inflating: hymenoptera_data/val/bees/485743562_d8cc6b8f73.jpg \n", " inflating: hymenoptera_data/val/bees/540976476_844950623f.jpg \n", " inflating: hymenoptera_data/val/bees/54736755_c057723f64.jpg \n", " inflating: hymenoptera_data/val/bees/57459255_752774f1b2.jpg \n", " inflating: hymenoptera_data/val/bees/576452297_897023f002.jpg \n", " inflating: hymenoptera_data/val/bees/586474709_ae436da045.jpg \n", " inflating: hymenoptera_data/val/bees/590318879_68cf112861.jpg \n", " inflating: hymenoptera_data/val/bees/59798110_2b6a3c8031.jpg \n", " inflating: hymenoptera_data/val/bees/603709866_a97c7cfc72.jpg \n", " inflating: hymenoptera_data/val/bees/603711658_4c8cd2201e.jpg \n", " inflating: hymenoptera_data/val/bees/65038344_52a45d090d.jpg \n", " inflating: hymenoptera_data/val/bees/6a00d8341c630a53ef00e553d0beb18834-800wi.jpg \n", " inflating: hymenoptera_data/val/bees/72100438_73de9f17af.jpg \n", " inflating: hymenoptera_data/val/bees/759745145_e8bc776ec8.jpg \n", " inflating: hymenoptera_data/val/bees/936182217_c4caa5222d.jpg \n", " inflating: hymenoptera_data/val/bees/abeja.jpg \n" ] } ], "source": [ "# download the data\n", "!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip\n", "!unzip hymenoptera_data.zip" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "var371SKtNyx" }, "outputs": [], "source": [ "# create data loaders\n", "\n", "data_dir = 'hymenoptera_data'\n", "\n", "# custom transformer to flatten the image tensors\n", "class ReshapeTransform:\n", " def __init__(self, new_size):\n", " self.new_size = new_size\n", "\n", " def __call__(self, img):\n", " result = torch.reshape(img, self.new_size)\n", " return result\n", "\n", "# transformations used to standardize and normalize the datasets\n", "data_transforms = {\n", " 'train': transforms.Compose([\n", " transforms.Resize(224),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " ReshapeTransform((-1,)) # flattens the data\n", " ]),\n", " 'val': transforms.Compose([\n", " transforms.Resize(224),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " ReshapeTransform((-1,)) # flattens the data\n", " ]),\n", "}\n", "\n", "# load the correspoding folders\n", "image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n", " data_transforms[x])\n", " for x in ['train', 'val']}\n", "\n", "# load the entire dataset; we are not using minibatches here\n", "train_dataset = torch.utils.data.DataLoader(image_datasets['train'],\n", " batch_size=len(image_datasets['train']),\n", " shuffle=True)\n", "\n", "test_dataset = torch.utils.data.DataLoader(image_datasets['val'],\n", " batch_size=len(image_datasets['val']),\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gc9G-ZTRulDD" }, "outputs": [], "source": [ "# build the LR model\n", "class LR(nn.Module):\n", " def __init__(self, dim):\n", " super(LR, self).__init__()\n", " self.linear = nn.Linear(dim, 1)\n", " nn.init.zeros_(self.linear.weight)\n", " nn.init.zeros_(self.linear.bias)\n", "\n", " def forward(self, x):\n", " x = self.linear(x)\n", " x = torch.sigmoid(x)\n", " return x " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WfSUxBpL6BV1" }, "outputs": [], "source": [ "# predict function\n", "def predict(yhat, y):\n", " yhat = yhat.squeeze()\n", " y = y.unsqueeze(0) \n", " y_prediction = torch.zeros(y.size()[1])\n", " for i in range(yhat.shape[0]):\n", " if yhat[i] <= 0.5:\n", " y_prediction[i] = 0\n", " else:\n", " y_prediction[i] = 1\n", " return 100 - torch.mean(torch.abs(y_prediction - y)) * 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LL5DrdjqxI7m" }, "outputs": [], "source": [ "# model config\n", "dim = train_dataset.dataset[0][0].shape[0]\n", "\n", "lrmodel = LR(dim).to(device)\n", "criterion = nn.BCELoss()\n", "optimizer = torch.optim.SGD(lrmodel.parameters(), lr=0.0001)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "i3s0mxFq6LJ6", "outputId": "66126bae-bd85-46d2-b6e4-74f7e332b469" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cost after iteration 0: 0.6931472420692444 | Train Acc: 50.40983581542969 | Test Acc: 45.75163269042969\n", "Cost after iteration 10: 0.6691471338272095 | Train Acc: 64.3442611694336 | Test Acc: 54.24836730957031\n", "Cost after iteration 20: 0.6513183116912842 | Train Acc: 68.44261932373047 | Test Acc: 54.24836730957031\n", "Cost after iteration 30: 0.6367825269699097 | Train Acc: 68.03278350830078 | Test Acc: 54.24836730957031\n", "Cost after iteration 40: 0.6245337128639221 | Train Acc: 69.67213439941406 | Test Acc: 54.90196228027344\n", "Cost after iteration 50: 0.6139225959777832 | Train Acc: 70.90164184570312 | Test Acc: 56.20914840698242\n", "Cost after iteration 60: 0.6045235991477966 | Train Acc: 72.54098510742188 | Test Acc: 56.86274337768555\n", "Cost after iteration 70: 0.5960512161254883 | Train Acc: 74.18032836914062 | Test Acc: 57.51633834838867\n", "Cost after iteration 80: 0.5883085131645203 | Train Acc: 73.77049255371094 | Test Acc: 57.51633834838867\n", "Cost after iteration 90: 0.5811558365821838 | Train Acc: 74.59016418457031 | Test Acc: 58.1699333190918\n", "Cost after iteration 100: 0.5744911432266235 | Train Acc: 75.0 | Test Acc: 59.47712326049805\n", "Cost after iteration 110: 0.5682383179664612 | Train Acc: 75.40983581542969 | Test Acc: 60.13071823120117\n", "Cost after iteration 120: 0.5623383522033691 | Train Acc: 75.81967163085938 | Test Acc: 60.13071823120117\n", "Cost after iteration 130: 0.5567454099655151 | Train Acc: 75.81967163085938 | Test Acc: 59.47712326049805\n", "Cost after iteration 140: 0.5514224767684937 | Train Acc: 75.81967163085938 | Test Acc: 59.47712326049805\n", "Cost after iteration 150: 0.5463394522666931 | Train Acc: 76.22950744628906 | Test Acc: 58.82352828979492\n", "Cost after iteration 160: 0.5414711833000183 | Train Acc: 76.63934326171875 | Test Acc: 58.82352828979492\n", "Cost after iteration 170: 0.5367969274520874 | Train Acc: 77.04917907714844 | Test Acc: 58.82352828979492\n", "Cost after iteration 180: 0.5322986841201782 | Train Acc: 77.04917907714844 | Test Acc: 58.82352828979492\n", "Cost after iteration 190: 0.5279611349105835 | Train Acc: 77.45901489257812 | Test Acc: 58.82352828979492\n" ] } ], "source": [ "# training the model\n", "costs = []\n", "\n", "for ITER in range(200):\n", " lrmodel.train()\n", " x, y = next(iter(train_dataset))\n", " test_x, test_y = next(iter(test_dataset))\n", "\n", " # forward\n", " yhat = lrmodel.forward(x.to(device))\n", "\n", " cost = criterion(yhat.squeeze(), y.type(torch.FloatTensor))\n", " train_pred = predict(yhat, y)\n", "\n", " # backward\n", " optimizer.zero_grad()\n", " cost.backward()\n", " optimizer.step()\n", " \n", " # evaluate\n", " lrmodel.eval()\n", " with torch.no_grad():\n", " yhat_test = lrmodel.forward(test_x.to(device))\n", " test_pred = predict(yhat_test, test_y)\n", "\n", " if ITER % 10 == 0:\n", " costs.append(cost)\n", "\n", " if ITER % 10 == 0:\n", " print(\"Cost after iteration {}: {} | Train Acc: {} | Test Acc: {}\".format(ITER, \n", " cost, \n", " train_pred,\n", " test_pred))\n", " " ] }, { "cell_type": "markdown", "metadata": { "id": "W0Q8WUq9opWB" }, "source": [ "### References\n", "- [A Logistic Regression Model from Scratch](https://colab.research.google.com/drive/1iBoJ0kngkOthy7SgVaVQA1aHEROt5mra?usp=sharing)" ] } ], "metadata": { "colab": { "name": "Concise Logistic Regression.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 1 }