Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitignore +5 -0
- README.md +34 -8
- app.py +39 -0
- data/.gitkeep +0 -0
- environment.yml +0 -0
- models/model.safetensors +3 -0
- models/snapshots/.gitkeep +0 -0
- requirements.txt +3 -0
- train.ipynb +711 -0
- vgg16.py +36 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/*
|
2 |
+
!data/.gitkeep
|
3 |
+
models/snapshots/*
|
4 |
+
!models/snapshots/.gitkeep
|
5 |
+
wandb
|
README.md
CHANGED
@@ -1,12 +1,38 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 👀
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: blue
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.19.1
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
1 |
---
|
2 |
+
title: bread-or-dog
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.19.0
|
6 |
---
|
7 |
+
# Fine-Tuning VGG16 to detect bread or dog
|
8 |
+
|
9 |
+
See <>.
|
10 |
+
|
11 |
+
## Prerequisites
|
12 |
+
|
13 |
+
```powershell
|
14 |
+
conda env create -f environment.yml
|
15 |
+
conda activate fine-tuning-vgg16-bread-or-dog
|
16 |
+
Invoke-WebRequest -Uri https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json -OutFile ./data/imagenet-simple-labels.json
|
17 |
+
```
|
18 |
+
|
19 |
+
## Demo
|
20 |
+
|
21 |
+
```powershell
|
22 |
+
# Before Fine-Tuning
|
23 |
+
python vgg16.py
|
24 |
+
|
25 |
+
# After Fine-Tunin
|
26 |
+
python app.py
|
27 |
+
```
|
28 |
+
|
29 |
+
## Deploy
|
30 |
+
|
31 |
+
```powershell
|
32 |
+
gradio deploy
|
33 |
+
```
|
34 |
+
|
35 |
+
## References
|
36 |
|
37 |
+
- [PyTorchの学習済みモデルで画像分類(VGG, ResNetなど) | note.nkmk.me](https://note.nkmk.me/python-pytorch-pretrained-models-image-classification/)
|
38 |
+
- [PyTorchによるファインチューニングの実装 - 機械学習ともろもろ](https://venoda.hatenablog.com/entry/2020/10/18/014516)
|
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from safetensors import safe_open
|
5 |
+
from torchvision import models, transforms
|
6 |
+
|
7 |
+
labels = ["bread", "dog"]
|
8 |
+
|
9 |
+
model = models.vgg16(pretrained=True)
|
10 |
+
|
11 |
+
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=2)
|
12 |
+
|
13 |
+
model_save_path = "models/vgg16_epoch10.safetensors"
|
14 |
+
tensors = {}
|
15 |
+
with safe_open(model_save_path, framework="pt", device="cpu") as f:
|
16 |
+
for key in f.keys():
|
17 |
+
tensors[key] = f.get_tensor(key)
|
18 |
+
|
19 |
+
model.load_state_dict(tensors, strict=False)
|
20 |
+
model.eval()
|
21 |
+
|
22 |
+
preprocess = transforms.Compose([
|
23 |
+
transforms.Resize((224, 224)), # Resize all images to 224x224
|
24 |
+
transforms.ToTensor(), # Convert images to PyTorch tensors
|
25 |
+
])
|
26 |
+
|
27 |
+
def classify_image(input_image: Image):
|
28 |
+
img_t = preprocess(input_image)
|
29 |
+
batch_t = torch.unsqueeze(img_t, 0)
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
output = model(batch_t)
|
33 |
+
|
34 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
35 |
+
label_to_prob = {labels[i]: prob for i, prob in enumerate(probabilities[0])}
|
36 |
+
return label_to_prob
|
37 |
+
|
38 |
+
demo = gr.Interface(fn=classify_image, inputs=gr.Image(type='pil'), outputs='label')
|
39 |
+
demo.launch()
|
data/.gitkeep
ADDED
File without changes
|
environment.yml
ADDED
Binary file (20.1 kB). View file
|
|
models/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6cf14bb076755b13494dcb1556be809b5adfe4254748799b8acff762f4336bf
|
3 |
+
size 537077840
|
models/snapshots/.gitkeep
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
safetensors
|
3 |
+
torchvision
|
train.ipynb
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## Collect images"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"Save images from below websites with Firefox.\n",
|
15 |
+
"- https://www.dreamstime.com/photos-images/corgi-butt.html\n",
|
16 |
+
"- https://www.pinterest.com/I_love_Corgi/corgi-butt/"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": 1,
|
22 |
+
"metadata": {},
|
23 |
+
"outputs": [
|
24 |
+
{
|
25 |
+
"name": "stderr",
|
26 |
+
"output_type": "stream",
|
27 |
+
"text": [
|
28 |
+
"c:\\Users\\hiroga\\miniconda3\\envs\\fine-tuning-vgg16-bread-or-dog\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
29 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
30 |
+
]
|
31 |
+
}
|
32 |
+
],
|
33 |
+
"source": [
|
34 |
+
"from datasets import load_dataset\n",
|
35 |
+
"\n",
|
36 |
+
"stanford_dogs_dataset = load_dataset(\"Alanox/stanford-dogs\", split=\"full\", trust_remote_code=True)\n",
|
37 |
+
"# OR !kaggle datasets download -d jessicali9530/stanford-dogs-dataset -p \"data\" -q"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": 2,
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [
|
45 |
+
{
|
46 |
+
"data": {
|
47 |
+
"text/plain": [
|
48 |
+
"Dataset({\n",
|
49 |
+
" features: ['name', 'annotations', 'target', 'image'],\n",
|
50 |
+
" num_rows: 20580\n",
|
51 |
+
"})"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
"execution_count": 2,
|
55 |
+
"metadata": {},
|
56 |
+
"output_type": "execute_result"
|
57 |
+
}
|
58 |
+
],
|
59 |
+
"source": [
|
60 |
+
"stanford_dogs_dataset"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 3,
|
66 |
+
"metadata": {},
|
67 |
+
"outputs": [],
|
68 |
+
"source": [
|
69 |
+
"from datasets import load_dataset\n",
|
70 |
+
"\n",
|
71 |
+
"bread_dataset = load_dataset(\"imagefolder\", data_dir=\"data/images.cv_fg0xp9w733695pvws1a4yh/data\")"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": 4,
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [
|
79 |
+
{
|
80 |
+
"data": {
|
81 |
+
"text/plain": [
|
82 |
+
"DatasetDict({\n",
|
83 |
+
" train: Dataset({\n",
|
84 |
+
" features: ['image', 'label'],\n",
|
85 |
+
" num_rows: 1478\n",
|
86 |
+
" })\n",
|
87 |
+
" validation: Dataset({\n",
|
88 |
+
" features: ['image', 'label'],\n",
|
89 |
+
" num_rows: 240\n",
|
90 |
+
" })\n",
|
91 |
+
" test: Dataset({\n",
|
92 |
+
" features: ['image', 'label'],\n",
|
93 |
+
" num_rows: 738\n",
|
94 |
+
" })\n",
|
95 |
+
"})"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
"execution_count": 4,
|
99 |
+
"metadata": {},
|
100 |
+
"output_type": "execute_result"
|
101 |
+
}
|
102 |
+
],
|
103 |
+
"source": [
|
104 |
+
"bread_dataset"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 5,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [
|
112 |
+
{
|
113 |
+
"name": "stdout",
|
114 |
+
"output_type": "stream",
|
115 |
+
"text": [
|
116 |
+
"['Bedlington Terrier', 'Clumber', 'Bluetick', 'German Short Haired Pointer', 'Labrador Retriever', 'Bernese Mountain Dog', 'Saluki', 'German Shepherd', 'Komondor', 'Kuvasz', 'Weimaraner', 'Great Pyrenees', 'Rottweiler', 'Pekinese', 'Gordon Setter', 'Tibetan Terrier', 'Soft Coated Wheaten Terrier', 'Brittany Spaniel', 'Leonberg', 'English Foxhound', 'Collie', 'Basset', 'Wire Haired Fox Terrier', 'Norwegian Elkhound', 'Chesapeake Bay Retriever', 'Cardigan', 'Borzoi', 'Border Collie', 'Malamute', 'Australian Terrier', 'Silky Terrier', 'Affenpinscher', 'Pomeranian', 'American Staffordshire Terrier', 'Otterhound', 'Staffordshire Bullterrier', 'West Highland White Terrier', 'Boston Bull', 'Redbone', 'Irish Water Spaniel', 'Giant Schnauzer', 'Flat Coated Retriever', 'Norwich Terrier', 'Dhole', 'Airedale', 'Miniature Poodle', 'Malinois', 'Sealyham Terrier', 'Cairn', 'Eskimo Dog', 'Siberian Husky', 'Papillon', 'Greater Swiss Mountain Dog', 'Sussex Spaniel', 'African Hunting Dog', 'Pembroke', 'Dingo', 'Appenzeller', 'Irish Setter', 'Kelpie', 'Brabancon Griffon', 'Groenendael', 'Norfolk Terrier', 'Lakeland Terrier', 'Italian Greyhound', 'Great Dane', 'Yorkshire Terrier', 'Miniature Schnauzer', 'Dandie Dinmont', 'Maltese Dog', 'Border Terrier', 'Rhodesian Ridgeback', 'Blenheim Spaniel', 'Miniature Pinscher', 'Japanese Spaniel', 'Afghan Hound', 'Toy Poodle', 'Old English Sheepdog', 'Doberman', 'Golden Retriever', 'Samoyed', 'Standard Schnauzer', 'Ibizan Hound', 'Mexican Hairless', 'Bouvier Des Flandres', 'Shih Tzu', 'Irish Terrier', 'Standard Poodle', 'Cocker Spaniel', 'Pug', 'Walker Hound', 'Bull Mastiff', 'Toy Terrier', 'Chihuahua', 'Beagle', 'Newfoundland', 'Black And Tan Coonhound', 'Welsh Springer Spaniel', 'Kerry Blue Terrier', 'French Bulldog', 'Tibetan Mastiff', 'English Setter', 'Boxer', 'Curly Coated Retriever', 'Irish Wolfhound', 'Shetland Sheepdog', 'Briard', 'Bloodhound', 'Saint Bernard', 'Whippet', 'Basenji', 'English Springer', 'Scotch Terrier', 'Entlebucher', 'Scottish Deerhound', 'Lhasa', 'Vizsla', 'Keeshond', 'Schipperke', 'Chow']\n"
|
117 |
+
]
|
118 |
+
}
|
119 |
+
],
|
120 |
+
"source": [
|
121 |
+
"unique_targets = stanford_dogs_dataset.unique('target')\n",
|
122 |
+
"print(unique_targets)"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 6,
|
128 |
+
"metadata": {},
|
129 |
+
"outputs": [
|
130 |
+
{
|
131 |
+
"name": "stdout",
|
132 |
+
"output_type": "stream",
|
133 |
+
"text": [
|
134 |
+
"181\n"
|
135 |
+
]
|
136 |
+
}
|
137 |
+
],
|
138 |
+
"source": [
|
139 |
+
"# もしコーギー(Pembroke)だけで数百件あればそれを使い、なければ犬の画像すべてを使う\n",
|
140 |
+
"pembroke_count = sum(target == 'Pembroke' for target in stanford_dogs_dataset['target'])\n",
|
141 |
+
"print(pembroke_count)"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": 7,
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [],
|
149 |
+
"source": [
|
150 |
+
"# Add a new column 'dog_or_bread' to the stanford_dogs_dataset and bread_dataset\n",
|
151 |
+
"bread_dataset = bread_dataset.map(lambda example: {'bread_or_dog': 0})\n",
|
152 |
+
"stanford_dogs_dataset = stanford_dogs_dataset.map(lambda example: {'bread_or_dog': 1})"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": 8,
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [
|
160 |
+
{
|
161 |
+
"data": {
|
162 |
+
"text/plain": [
|
163 |
+
"Dataset({\n",
|
164 |
+
" features: ['name', 'annotations', 'target', 'image', 'bread_or_dog'],\n",
|
165 |
+
" num_rows: 20580\n",
|
166 |
+
"})"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
"execution_count": 8,
|
170 |
+
"metadata": {},
|
171 |
+
"output_type": "execute_result"
|
172 |
+
}
|
173 |
+
],
|
174 |
+
"source": [
|
175 |
+
"stanford_dogs_dataset"
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"cell_type": "code",
|
180 |
+
"execution_count": 9,
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [
|
183 |
+
{
|
184 |
+
"data": {
|
185 |
+
"text/plain": [
|
186 |
+
"DatasetDict({\n",
|
187 |
+
" train: Dataset({\n",
|
188 |
+
" features: ['name', 'annotations', 'target', 'image', 'bread_or_dog'],\n",
|
189 |
+
" num_rows: 16464\n",
|
190 |
+
" })\n",
|
191 |
+
" test: Dataset({\n",
|
192 |
+
" features: ['name', 'annotations', 'target', 'image', 'bread_or_dog'],\n",
|
193 |
+
" num_rows: 2058\n",
|
194 |
+
" })\n",
|
195 |
+
" validation: Dataset({\n",
|
196 |
+
" features: ['name', 'annotations', 'target', 'image', 'bread_or_dog'],\n",
|
197 |
+
" num_rows: 2058\n",
|
198 |
+
" })\n",
|
199 |
+
"})"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
"execution_count": 9,
|
203 |
+
"metadata": {},
|
204 |
+
"output_type": "execute_result"
|
205 |
+
}
|
206 |
+
],
|
207 |
+
"source": [
|
208 |
+
"from datasets import DatasetDict\n",
|
209 |
+
"\n",
|
210 |
+
"train_test_dataset = stanford_dogs_dataset.train_test_split(test_size=0.2)\n",
|
211 |
+
"test_valid_dataset = train_test_dataset[\"test\"].train_test_split(test_size=0.5)\n",
|
212 |
+
"stanford_dogs_dataset_dict = DatasetDict({\n",
|
213 |
+
" \"train\": train_test_dataset[\"train\"],\n",
|
214 |
+
" \"test\": test_valid_dataset[\"train\"],\n",
|
215 |
+
" \"validation\": test_valid_dataset[\"test\"]\n",
|
216 |
+
"})\n",
|
217 |
+
"stanford_dogs_dataset_dict"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "code",
|
222 |
+
"execution_count": 10,
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [
|
225 |
+
{
|
226 |
+
"name": "stdout",
|
227 |
+
"output_type": "stream",
|
228 |
+
"text": [
|
229 |
+
"DatasetDict({\n",
|
230 |
+
" train: Dataset({\n",
|
231 |
+
" features: ['name', 'annotations', 'target', 'image', 'bread_or_dog', 'label'],\n",
|
232 |
+
" num_rows: 17942\n",
|
233 |
+
" })\n",
|
234 |
+
" validation: Dataset({\n",
|
235 |
+
" features: ['name', 'annotations', 'target', 'image', 'bread_or_dog', 'label'],\n",
|
236 |
+
" num_rows: 2298\n",
|
237 |
+
" })\n",
|
238 |
+
" test: Dataset({\n",
|
239 |
+
" features: ['name', 'annotations', 'target', 'image', 'bread_or_dog', 'label'],\n",
|
240 |
+
" num_rows: 2796\n",
|
241 |
+
" })\n",
|
242 |
+
"})\n"
|
243 |
+
]
|
244 |
+
}
|
245 |
+
],
|
246 |
+
"source": [
|
247 |
+
"from datasets import concatenate_datasets\n",
|
248 |
+
"\n",
|
249 |
+
"# Concatenate the datasets for each split\n",
|
250 |
+
"merged_train_dataset = concatenate_datasets([stanford_dogs_dataset_dict['train'], bread_dataset['train']])\n",
|
251 |
+
"merged_validation_dataset = concatenate_datasets([stanford_dogs_dataset_dict['validation'], bread_dataset['validation']])\n",
|
252 |
+
"merged_test_dataset = concatenate_datasets([stanford_dogs_dataset_dict['test'], bread_dataset['test']])\n",
|
253 |
+
"merged_dataset = DatasetDict({\n",
|
254 |
+
" \"train\": merged_train_dataset,\n",
|
255 |
+
" \"validation\": merged_validation_dataset,\n",
|
256 |
+
" \"test\": merged_test_dataset\n",
|
257 |
+
"})\n",
|
258 |
+
"\n",
|
259 |
+
"print(merged_dataset)"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"cell_type": "markdown",
|
264 |
+
"metadata": {},
|
265 |
+
"source": [
|
266 |
+
"## Inspect model"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": 11,
|
272 |
+
"metadata": {},
|
273 |
+
"outputs": [
|
274 |
+
{
|
275 |
+
"name": "stdout",
|
276 |
+
"output_type": "stream",
|
277 |
+
"text": [
|
278 |
+
"VGG(\n",
|
279 |
+
" (features): Sequential(\n",
|
280 |
+
" (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
281 |
+
" (1): ReLU(inplace=True)\n",
|
282 |
+
" (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
283 |
+
" (3): ReLU(inplace=True)\n",
|
284 |
+
" (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
|
285 |
+
" (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
286 |
+
" (6): ReLU(inplace=True)\n",
|
287 |
+
" (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
288 |
+
" (8): ReLU(inplace=True)\n",
|
289 |
+
" (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
|
290 |
+
" (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
291 |
+
" (11): ReLU(inplace=True)\n",
|
292 |
+
" (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
293 |
+
" (13): ReLU(inplace=True)\n",
|
294 |
+
" (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
295 |
+
" (15): ReLU(inplace=True)\n",
|
296 |
+
" (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
|
297 |
+
" (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
298 |
+
" (18): ReLU(inplace=True)\n",
|
299 |
+
" (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
300 |
+
" (20): ReLU(inplace=True)\n",
|
301 |
+
" (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
302 |
+
" (22): ReLU(inplace=True)\n",
|
303 |
+
" (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
|
304 |
+
" (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
305 |
+
" (25): ReLU(inplace=True)\n",
|
306 |
+
" (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
307 |
+
" (27): ReLU(inplace=True)\n",
|
308 |
+
" (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
|
309 |
+
" (29): ReLU(inplace=True)\n",
|
310 |
+
" (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
|
311 |
+
" )\n",
|
312 |
+
" (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
|
313 |
+
" (classifier): Sequential(\n",
|
314 |
+
" (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
|
315 |
+
" (1): ReLU(inplace=True)\n",
|
316 |
+
" (2): Dropout(p=0.5, inplace=False)\n",
|
317 |
+
" (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
|
318 |
+
" (4): ReLU(inplace=True)\n",
|
319 |
+
" (5): Dropout(p=0.5, inplace=False)\n",
|
320 |
+
" (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
|
321 |
+
" )\n",
|
322 |
+
")\n"
|
323 |
+
]
|
324 |
+
}
|
325 |
+
],
|
326 |
+
"source": [
|
327 |
+
"from torchvision import models\n",
|
328 |
+
"\n",
|
329 |
+
"model = models.vgg16()\n",
|
330 |
+
"print(model)"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": 12,
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [
|
338 |
+
{
|
339 |
+
"name": "stdout",
|
340 |
+
"output_type": "stream",
|
341 |
+
"text": [
|
342 |
+
"name='features.0.weight'\n",
|
343 |
+
"name='features.0.bias'\n",
|
344 |
+
"name='features.2.weight'\n",
|
345 |
+
"name='features.2.bias'\n",
|
346 |
+
"name='features.5.weight'\n",
|
347 |
+
"name='features.5.bias'\n",
|
348 |
+
"name='features.7.weight'\n",
|
349 |
+
"name='features.7.bias'\n",
|
350 |
+
"name='features.10.weight'\n",
|
351 |
+
"name='features.10.bias'\n",
|
352 |
+
"name='features.12.weight'\n",
|
353 |
+
"name='features.12.bias'\n",
|
354 |
+
"name='features.14.weight'\n",
|
355 |
+
"name='features.14.bias'\n",
|
356 |
+
"name='features.17.weight'\n",
|
357 |
+
"name='features.17.bias'\n",
|
358 |
+
"name='features.19.weight'\n",
|
359 |
+
"name='features.19.bias'\n",
|
360 |
+
"name='features.21.weight'\n",
|
361 |
+
"name='features.21.bias'\n",
|
362 |
+
"name='features.24.weight'\n",
|
363 |
+
"name='features.24.bias'\n",
|
364 |
+
"name='features.26.weight'\n",
|
365 |
+
"name='features.26.bias'\n",
|
366 |
+
"name='features.28.weight'\n",
|
367 |
+
"name='features.28.bias'\n",
|
368 |
+
"name='classifier.0.weight'\n",
|
369 |
+
"name='classifier.0.bias'\n",
|
370 |
+
"name='classifier.3.weight'\n",
|
371 |
+
"name='classifier.3.bias'\n",
|
372 |
+
"name='classifier.6.weight'\n",
|
373 |
+
"name='classifier.6.bias'\n"
|
374 |
+
]
|
375 |
+
}
|
376 |
+
],
|
377 |
+
"source": [
|
378 |
+
"for name, _param in model.named_parameters():\n",
|
379 |
+
" print(f\"{name=}\")"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "markdown",
|
384 |
+
"metadata": {},
|
385 |
+
"source": [
|
386 |
+
"## Fine Tuning"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": 18,
|
392 |
+
"metadata": {},
|
393 |
+
"outputs": [],
|
394 |
+
"source": [
|
395 |
+
"import torch\n",
|
396 |
+
"import wandb\n",
|
397 |
+
"\n",
|
398 |
+
"def train(model, criterion, optimizer, dataloaders_dict, num_epochs, device):\n",
|
399 |
+
" model.to(device)\n",
|
400 |
+
"\n",
|
401 |
+
" for epoch in range(num_epochs):\n",
|
402 |
+
" print('Epoch {}/{}'.format(epoch + 1, num_epochs))\n",
|
403 |
+
" print('-------------')\n",
|
404 |
+
" \n",
|
405 |
+
" for phase in ['train', 'validation']:\n",
|
406 |
+
" if phase == 'train':\n",
|
407 |
+
" model.train()\n",
|
408 |
+
" else:\n",
|
409 |
+
" model.eval()\n",
|
410 |
+
"\n",
|
411 |
+
" epoch_loss = 0.0\n",
|
412 |
+
" epoch_corrects = 0\n",
|
413 |
+
" \n",
|
414 |
+
" for batch in dataloaders_dict[phase]:\n",
|
415 |
+
" images, labels = batch[\"image\"], batch[\"bread_or_dog\"]\n",
|
416 |
+
" images, labels = images.to(device), labels.to(device)\n",
|
417 |
+
"\n",
|
418 |
+
" optimizer.zero_grad()\n",
|
419 |
+
" \n",
|
420 |
+
" # 学習時のみ勾配を計算させる設定にする\n",
|
421 |
+
" with torch.set_grad_enabled(phase == 'train'):\n",
|
422 |
+
" outputs = model(images)\n",
|
423 |
+
" \n",
|
424 |
+
" # 損失を計算\n",
|
425 |
+
" loss = criterion(outputs, labels)\n",
|
426 |
+
" \n",
|
427 |
+
" # ラベルを予測\n",
|
428 |
+
" _, preds = torch.max(outputs, 1)\n",
|
429 |
+
"\n",
|
430 |
+
" if phase == 'train':\n",
|
431 |
+
" loss.backward()\n",
|
432 |
+
" optimizer.step()\n",
|
433 |
+
"\n",
|
434 |
+
" # イテレーション結果の計算\n",
|
435 |
+
" # lossの合計を更新\n",
|
436 |
+
" # PyTorchの仕様上各バッチ内での平均のlossが計算される。\n",
|
437 |
+
" # データ数を掛けることで平均から合計に変換をしている。\n",
|
438 |
+
" # 損失和は「全データの損失/データ数」で計算されるため、\n",
|
439 |
+
" # 平均のままだと損失和を求めることができないため。\n",
|
440 |
+
" epoch_loss += loss.item() * images.size(0)\n",
|
441 |
+
" \n",
|
442 |
+
" # 正解数の合計を更新\n",
|
443 |
+
" epoch_corrects += torch.sum(preds == labels.data)\n",
|
444 |
+
"\n",
|
445 |
+
" # epochごとのlossと正解率を表示\n",
|
446 |
+
" epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)\n",
|
447 |
+
" epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)\n",
|
448 |
+
"\n",
|
449 |
+
" log = {\n",
|
450 |
+
" \"epoch\": epoch +1,\n",
|
451 |
+
" \"phase\": phase,\n",
|
452 |
+
" f\"{phase}_loss\": epoch_loss,\n",
|
453 |
+
" f\"{phase}_acc\": epoch_acc,\n",
|
454 |
+
" }\n",
|
455 |
+
" print(log)\n",
|
456 |
+
" wandb.log(log)"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"cell_type": "code",
|
461 |
+
"execution_count": 20,
|
462 |
+
"metadata": {},
|
463 |
+
"outputs": [
|
464 |
+
{
|
465 |
+
"data": {
|
466 |
+
"text/html": [
|
467 |
+
"Finishing last run (ID:nw4ysgpu) before initializing another..."
|
468 |
+
],
|
469 |
+
"text/plain": [
|
470 |
+
"<IPython.core.display.HTML object>"
|
471 |
+
]
|
472 |
+
},
|
473 |
+
"metadata": {},
|
474 |
+
"output_type": "display_data"
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"data": {
|
478 |
+
"text/html": [
|
479 |
+
"<style>\n",
|
480 |
+
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
|
481 |
+
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
|
482 |
+
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
|
483 |
+
" </style>\n",
|
484 |
+
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇██</td></tr><tr><td>train_acc</td><td>▁▂▂▂▅▇████</td></tr><tr><td>train_loss</td><td>█▆▅▃▃▂▂▁▁▁</td></tr><tr><td>validation_acc</td><td>▁▁▁▄▇▇▆██▇</td></tr><tr><td>validation_loss</td><td>█▆▅▃▂▂▂▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>10</td></tr><tr><td>phase</td><td>validation</td></tr><tr><td>train_acc</td><td>0.98473</td></tr><tr><td>train_loss</td><td>0.04198</td></tr><tr><td>validation_acc</td><td>0.97998</td></tr><tr><td>validation_loss</td><td>0.05055</td></tr></table><br/></div></div>"
|
485 |
+
],
|
486 |
+
"text/plain": [
|
487 |
+
"<IPython.core.display.HTML object>"
|
488 |
+
]
|
489 |
+
},
|
490 |
+
"metadata": {},
|
491 |
+
"output_type": "display_data"
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"data": {
|
495 |
+
"text/html": [
|
496 |
+
" View run <strong style=\"color:#cdcd00\">resplendent-festival-20</strong> at: <a href='https://wandb.ai/hiroga/bread-or-dog/runs/nw4ysgpu' target=\"_blank\">https://wandb.ai/hiroga/bread-or-dog/runs/nw4ysgpu</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
497 |
+
],
|
498 |
+
"text/plain": [
|
499 |
+
"<IPython.core.display.HTML object>"
|
500 |
+
]
|
501 |
+
},
|
502 |
+
"metadata": {},
|
503 |
+
"output_type": "display_data"
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"data": {
|
507 |
+
"text/html": [
|
508 |
+
"Find logs at: <code>.\\wandb\\run-20240217_062447-nw4ysgpu\\logs</code>"
|
509 |
+
],
|
510 |
+
"text/plain": [
|
511 |
+
"<IPython.core.display.HTML object>"
|
512 |
+
]
|
513 |
+
},
|
514 |
+
"metadata": {},
|
515 |
+
"output_type": "display_data"
|
516 |
+
},
|
517 |
+
{
|
518 |
+
"data": {
|
519 |
+
"text/html": [
|
520 |
+
"Successfully finished last run (ID:nw4ysgpu). Initializing new run:<br/>"
|
521 |
+
],
|
522 |
+
"text/plain": [
|
523 |
+
"<IPython.core.display.HTML object>"
|
524 |
+
]
|
525 |
+
},
|
526 |
+
"metadata": {},
|
527 |
+
"output_type": "display_data"
|
528 |
+
},
|
529 |
+
{
|
530 |
+
"data": {
|
531 |
+
"text/html": [
|
532 |
+
"Tracking run with wandb version 0.16.3"
|
533 |
+
],
|
534 |
+
"text/plain": [
|
535 |
+
"<IPython.core.display.HTML object>"
|
536 |
+
]
|
537 |
+
},
|
538 |
+
"metadata": {},
|
539 |
+
"output_type": "display_data"
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"data": {
|
543 |
+
"text/html": [
|
544 |
+
"Run data is saved locally in <code>c:\\Users\\hiroga\\Documents\\GitHub\\til\\computer-science\\machine-learning\\_src\\fine-tuning-vgg16-bread-or-dog\\wandb\\run-20240217_081017-vy00m3yf</code>"
|
545 |
+
],
|
546 |
+
"text/plain": [
|
547 |
+
"<IPython.core.display.HTML object>"
|
548 |
+
]
|
549 |
+
},
|
550 |
+
"metadata": {},
|
551 |
+
"output_type": "display_data"
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"data": {
|
555 |
+
"text/html": [
|
556 |
+
"Syncing run <strong><a href='https://wandb.ai/hiroga/bread-or-dog/runs/vy00m3yf' target=\"_blank\">abundant-festival-21</a></strong> to <a href='https://wandb.ai/hiroga/bread-or-dog' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
557 |
+
],
|
558 |
+
"text/plain": [
|
559 |
+
"<IPython.core.display.HTML object>"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
"metadata": {},
|
563 |
+
"output_type": "display_data"
|
564 |
+
},
|
565 |
+
{
|
566 |
+
"data": {
|
567 |
+
"text/html": [
|
568 |
+
" View project at <a href='https://wandb.ai/hiroga/bread-or-dog' target=\"_blank\">https://wandb.ai/hiroga/bread-or-dog</a>"
|
569 |
+
],
|
570 |
+
"text/plain": [
|
571 |
+
"<IPython.core.display.HTML object>"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
"metadata": {},
|
575 |
+
"output_type": "display_data"
|
576 |
+
},
|
577 |
+
{
|
578 |
+
"data": {
|
579 |
+
"text/html": [
|
580 |
+
" View run at <a href='https://wandb.ai/hiroga/bread-or-dog/runs/vy00m3yf' target=\"_blank\">https://wandb.ai/hiroga/bread-or-dog/runs/vy00m3yf</a>"
|
581 |
+
],
|
582 |
+
"text/plain": [
|
583 |
+
"<IPython.core.display.HTML object>"
|
584 |
+
]
|
585 |
+
},
|
586 |
+
"metadata": {},
|
587 |
+
"output_type": "display_data"
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"name": "stdout",
|
591 |
+
"output_type": "stream",
|
592 |
+
"text": [
|
593 |
+
"Epoch 1/10\n",
|
594 |
+
"-------------\n",
|
595 |
+
"{'epoch': 1, 'phase': 'train', 'train_loss': 0.07768695316224603, 'train_acc': tensor(0.9642, device='cuda:0', dtype=torch.float64)}\n",
|
596 |
+
"{'epoch': 1, 'phase': 'validation', 'validation_loss': 0.07376273388806584, 'validation_acc': tensor(0.9752, device='cuda:0', dtype=torch.float64)}\n",
|
597 |
+
"Epoch 2/10\n",
|
598 |
+
"-------------\n",
|
599 |
+
"{'epoch': 2, 'phase': 'train', 'train_loss': 0.039798663440396634, 'train_acc': tensor(0.9868, device='cuda:0', dtype=torch.float64)}\n",
|
600 |
+
"{'epoch': 2, 'phase': 'validation', 'validation_loss': 0.07450082683805061, 'validation_acc': tensor(0.9704, device='cuda:0', dtype=torch.float64)}\n",
|
601 |
+
"Epoch 3/10\n",
|
602 |
+
"-------------\n",
|
603 |
+
"{'epoch': 3, 'phase': 'train', 'train_loss': 0.038719125189200225, 'train_acc': tensor(0.9869, device='cuda:0', dtype=torch.float64)}\n",
|
604 |
+
"{'epoch': 3, 'phase': 'validation', 'validation_loss': 0.08719978122943582, 'validation_acc': tensor(0.9608, device='cuda:0', dtype=torch.float64)}\n",
|
605 |
+
"Epoch 4/10\n",
|
606 |
+
"-------------\n"
|
607 |
+
]
|
608 |
+
}
|
609 |
+
],
|
610 |
+
"source": [
|
611 |
+
"import torch\n",
|
612 |
+
"import wandb\n",
|
613 |
+
"from safetensors.torch import save_file\n",
|
614 |
+
"from torchvision import transforms\n",
|
615 |
+
"from torch.utils.data import DataLoader\n",
|
616 |
+
"\n",
|
617 |
+
"model_name = \"vgg16\"\n",
|
618 |
+
"model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=2)\n",
|
619 |
+
"\n",
|
620 |
+
"features = [param for name, param in model.named_parameters() if \"features\" in name]\n",
|
621 |
+
"classifier = [param for name, param in model.named_parameters() if \"classifier.0\" in name or \"classifier.3\" in name]\n",
|
622 |
+
"last_classifier = [param for name, param in model.named_parameters() if \"classifier.6\" in name]\n",
|
623 |
+
"param_groups = [\n",
|
624 |
+
" {'params': features, 'lr': 1e-4},\n",
|
625 |
+
" {'params': classifier, 'lr': 5e-4},\n",
|
626 |
+
" {'params': last_classifier, 'lr': 1e-3},\n",
|
627 |
+
"]\n",
|
628 |
+
"momentum = 0.9\n",
|
629 |
+
"\n",
|
630 |
+
"batch_size = 64\n",
|
631 |
+
"\n",
|
632 |
+
"# torchvision の datasets とは違い、transforms をそのままセットすれば良いわけではないので留意。\n",
|
633 |
+
"composed = transforms.Compose([\n",
|
634 |
+
" transforms.Resize((224, 224)), # Resize all images to 224x224\n",
|
635 |
+
" transforms.ToTensor(), # Convert images to PyTorch tensors\n",
|
636 |
+
"])\n",
|
637 |
+
"def transform(batch):\n",
|
638 |
+
" tensors = [composed(img) for img in batch['image']]\n",
|
639 |
+
" return {\"image\": tensors, \"bread_or_dog\": batch[\"bread_or_dog\"]}\n",
|
640 |
+
"\n",
|
641 |
+
"merged_dataset['train'].set_transform(transform, [\"image\", \"bread_or_dog\"])\n",
|
642 |
+
"merged_dataset['validation'].set_transform(transform, [\"image\", \"bread_or_dog\"])\n",
|
643 |
+
"\n",
|
644 |
+
"# Assuming that the datasets 'train' and 'validation' are available in the dataloaders_dict\n",
|
645 |
+
"train_dataloader = DataLoader(merged_dataset['train'], batch_size=batch_size, shuffle=True)\n",
|
646 |
+
"valid_dataloader = DataLoader(merged_dataset['validation'], batch_size=batch_size, shuffle=False)\n",
|
647 |
+
"dataloaders_dict = {\n",
|
648 |
+
" \"train\": train_dataloader,\n",
|
649 |
+
" \"validation\": valid_dataloader\n",
|
650 |
+
"}\n",
|
651 |
+
"\n",
|
652 |
+
"num_epochs = 10\n",
|
653 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
654 |
+
"\n",
|
655 |
+
"wandb.init(\n",
|
656 |
+
" project=\"bread-or-dog\",\n",
|
657 |
+
" config={\n",
|
658 |
+
" \"model_name\": model_name,\n",
|
659 |
+
" \"architecture\": \"CNN\",\n",
|
660 |
+
" \"dataset\": [\"Alanox/stanford-dogs\", \"images.cv_fg0xp9w733695pvws1a4yh\"],\n",
|
661 |
+
" \"param_groups\": param_groups,\n",
|
662 |
+
" \"num_epoch\": num_epochs,\n",
|
663 |
+
" \"momentum\": momentum,\n",
|
664 |
+
" \"device\": device\n",
|
665 |
+
" }\n",
|
666 |
+
")\n",
|
667 |
+
"\n",
|
668 |
+
"criterion = torch.nn.CrossEntropyLoss()\n",
|
669 |
+
"optimizer = torch.optim.SGD(param_groups, momentum=momentum)\n",
|
670 |
+
"\n",
|
671 |
+
"\n",
|
672 |
+
"train(model, criterion, optimizer, dataloaders_dict, num_epochs=num_epochs, device=device)\n",
|
673 |
+
"\n",
|
674 |
+
"save_file(model.state_dict(), f\"models/snapshots/{model_name}_epoch{num_epochs}.safetensors\")\n",
|
675 |
+
"\n",
|
676 |
+
"wandb.log_artifact(model)\n",
|
677 |
+
"\n",
|
678 |
+
"model.to_onnx()\n",
|
679 |
+
"wandb.save(\"model.onnx\")"
|
680 |
+
]
|
681 |
+
},
|
682 |
+
{
|
683 |
+
"cell_type": "code",
|
684 |
+
"execution_count": null,
|
685 |
+
"metadata": {},
|
686 |
+
"outputs": [],
|
687 |
+
"source": []
|
688 |
+
}
|
689 |
+
],
|
690 |
+
"metadata": {
|
691 |
+
"kernelspec": {
|
692 |
+
"display_name": "fune-tuning-vgg16-bread-or-dog",
|
693 |
+
"language": "python",
|
694 |
+
"name": "python3"
|
695 |
+
},
|
696 |
+
"language_info": {
|
697 |
+
"codemirror_mode": {
|
698 |
+
"name": "ipython",
|
699 |
+
"version": 3
|
700 |
+
},
|
701 |
+
"file_extension": ".py",
|
702 |
+
"mimetype": "text/x-python",
|
703 |
+
"name": "python",
|
704 |
+
"nbconvert_exporter": "python",
|
705 |
+
"pygments_lexer": "ipython3",
|
706 |
+
"version": "3.12.1"
|
707 |
+
}
|
708 |
+
},
|
709 |
+
"nbformat": 4,
|
710 |
+
"nbformat_minor": 2
|
711 |
+
}
|
vgg16.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import models, transforms
|
6 |
+
|
7 |
+
with open("data/imagenet-simple-labels.json") as f:
|
8 |
+
labels = json.load(f)
|
9 |
+
|
10 |
+
model = models.vgg16(pretrained=True)
|
11 |
+
model.eval() # 推論モードに設定
|
12 |
+
|
13 |
+
preprocess = transforms.Compose(
|
14 |
+
[
|
15 |
+
transforms.Resize(256),
|
16 |
+
transforms.CenterCrop(224),
|
17 |
+
transforms.ToTensor(),
|
18 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
19 |
+
]
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def classify_image(input_image: Image):
|
24 |
+
img_t = preprocess(input_image)
|
25 |
+
batch_t = torch.unsqueeze(img_t, 0)
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
output = model(batch_t)
|
29 |
+
|
30 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
31 |
+
label_to_prob = {labels[i]: prob for i, prob in enumerate(probabilities[0])}
|
32 |
+
return label_to_prob
|
33 |
+
|
34 |
+
|
35 |
+
demo = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="label")
|
36 |
+
demo.launch()
|