Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitignore +14 -0
- .python-version +1 -0
- README.md +2 -8
- client.py +75 -0
- examples/aisaka_taiga.jpg +0 -0
- examples/megumin.jpg +0 -0
- examples/minato_aqua.jpg +0 -0
- examples/usada_pekora.jpg +0 -0
- model.pth +3 -0
- pyproject.toml +13 -0
- requirements.txt +69 -0
- train.py +186 -0
- uv.lock +0 -0
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python-generated files
|
2 |
+
__pycache__/
|
3 |
+
*.py[oc]
|
4 |
+
build/
|
5 |
+
dist/
|
6 |
+
wheels/
|
7 |
+
*.egg-info
|
8 |
+
|
9 |
+
# Virtual environments
|
10 |
+
.venv
|
11 |
+
|
12 |
+
dataset.zip
|
13 |
+
dataset/
|
14 |
+
*.png
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.12
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.13.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: git_config_-global_credential.helper_store
|
3 |
+
app_file: client.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.13.0
|
|
|
|
|
6 |
---
|
|
|
|
client.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
|
6 |
+
|
7 |
+
class AnimeCNN(nn.Module):
|
8 |
+
def __init__(self, num_classes=4):
|
9 |
+
super().__init__()
|
10 |
+
self.features = nn.Sequential(
|
11 |
+
nn.Conv2d(3, 32, 3, padding=1),
|
12 |
+
nn.BatchNorm2d(32),
|
13 |
+
nn.ReLU(),
|
14 |
+
nn.MaxPool2d(2, 2),
|
15 |
+
nn.Dropout(0.25),
|
16 |
+
|
17 |
+
nn.Conv2d(32, 64, 3, padding=1),
|
18 |
+
nn.BatchNorm2d(64),
|
19 |
+
nn.ReLU(),
|
20 |
+
nn.MaxPool2d(2, 2),
|
21 |
+
nn.Dropout(0.25)
|
22 |
+
)
|
23 |
+
|
24 |
+
self.classifier = nn.Sequential(
|
25 |
+
nn.Linear(64*16*16, 256),
|
26 |
+
nn.BatchNorm1d(256),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.Dropout(0.5),
|
29 |
+
nn.Linear(256, num_classes)
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.features(x)
|
34 |
+
x = x.view(x.size(0), -1)
|
35 |
+
x = self.classifier(x)
|
36 |
+
return x
|
37 |
+
|
38 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
39 |
+
model = AnimeCNN()
|
40 |
+
model.load_state_dict(torch.load('model.pth', map_location=device, weights_only=True))
|
41 |
+
model.eval()
|
42 |
+
|
43 |
+
classes = ["usada_pekora", "aisaka_taiga", "megumin", "minato_aqua"]
|
44 |
+
|
45 |
+
transform = transforms.Compose([
|
46 |
+
transforms.Resize((64, 64)),
|
47 |
+
transforms.ToTensor(),
|
48 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
49 |
+
])
|
50 |
+
|
51 |
+
def predict(image):
|
52 |
+
image = transform(image).unsqueeze(0)
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
outputs = model(image)
|
56 |
+
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
|
57 |
+
|
58 |
+
confidences = {classes[i]: float(probabilities[i]) for i in range(4)}
|
59 |
+
return confidences
|
60 |
+
|
61 |
+
interface = gr.Interface(
|
62 |
+
fn=predict,
|
63 |
+
inputs=gr.Image(type="pil", label="入力画像"),
|
64 |
+
outputs=gr.Label(num_top_classes=4, label="予測結果"),
|
65 |
+
title="アニメキャラクター分類器",
|
66 |
+
description="うさだぺこら・逢坂大河・めぐみん・湊あくあの画像を分類します。画像をアップロードしてください。",
|
67 |
+
examples=[
|
68 |
+
["examples/usada_pekora.jpg"],
|
69 |
+
["examples/aisaka_taiga.jpg"],
|
70 |
+
["examples/megumin.jpg"],
|
71 |
+
["examples/minato_aqua.jpg"]
|
72 |
+
],
|
73 |
+
)
|
74 |
+
|
75 |
+
interface.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
examples/aisaka_taiga.jpg
ADDED
![]() |
examples/megumin.jpg
ADDED
![]() |
examples/minato_aqua.jpg
ADDED
![]() |
examples/usada_pekora.jpg
ADDED
![]() |
model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7ad84f9ff21d15b7761b69237efa3ce31f0324779a17680b9ebf8901dedbb14
|
3 |
+
size 16871835
|
pyproject.toml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "cnn-anime-classification"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.12"
|
7 |
+
dependencies = [
|
8 |
+
"gradio>=5.13.0",
|
9 |
+
"matplotlib>=3.10.0",
|
10 |
+
"numpy>=2.2.2",
|
11 |
+
"pillow>=11.1.0",
|
12 |
+
"scikit-learn>=1.6.1",
|
13 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
aiofiles==23.2.1
|
3 |
+
annotated-types==0.7.0
|
4 |
+
anyio==4.8.0
|
5 |
+
certifi==2024.12.14
|
6 |
+
charset-normalizer==3.4.1
|
7 |
+
click==8.1.8
|
8 |
+
contourpy==1.3.1
|
9 |
+
cycler==0.12.1
|
10 |
+
fastapi==0.115.7
|
11 |
+
ffmpy==0.5.0
|
12 |
+
filelock==3.17.0
|
13 |
+
fonttools==4.55.4
|
14 |
+
fsspec==2024.12.0
|
15 |
+
gradio==5.13.0
|
16 |
+
gradio-client==1.6.0
|
17 |
+
h11==0.14.0
|
18 |
+
httpcore==1.0.7
|
19 |
+
httpx==0.28.1
|
20 |
+
huggingface-hub==0.27.1
|
21 |
+
idna==3.10
|
22 |
+
jinja2==3.1.5
|
23 |
+
joblib==1.4.2
|
24 |
+
kiwisolver==1.4.8
|
25 |
+
markdown-it-py==3.0.0
|
26 |
+
markupsafe==2.1.5
|
27 |
+
matplotlib==3.10.0
|
28 |
+
mdurl==0.1.2
|
29 |
+
mpmath==1.3.0
|
30 |
+
networkx==3.2.1
|
31 |
+
numpy==2.2.2
|
32 |
+
orjson==3.10.15
|
33 |
+
packaging==24.2
|
34 |
+
pandas==2.2.3
|
35 |
+
pillow==11.1.0
|
36 |
+
pydantic==2.10.5
|
37 |
+
pydantic-core==2.27.2
|
38 |
+
pydub==0.25.1
|
39 |
+
pygments==2.19.1
|
40 |
+
pyparsing==3.2.1
|
41 |
+
python-dateutil==2.9.0.post0
|
42 |
+
python-multipart==0.0.20
|
43 |
+
pytz==2024.2
|
44 |
+
pyyaml==6.0.2
|
45 |
+
requests==2.32.3
|
46 |
+
rich==13.9.4
|
47 |
+
ruff==0.9.2
|
48 |
+
safehttpx==0.1.6
|
49 |
+
scikit-learn==1.6.1
|
50 |
+
scipy==1.15.1
|
51 |
+
semantic-version==2.10.0
|
52 |
+
setuptools==70.0.0
|
53 |
+
shellingham==1.5.4
|
54 |
+
six==1.17.0
|
55 |
+
sniffio==1.3.1
|
56 |
+
starlette==0.45.2
|
57 |
+
sympy==1.13.1
|
58 |
+
threadpoolctl==3.5.0
|
59 |
+
tomlkit==0.13.2
|
60 |
+
torch==2.5.1
|
61 |
+
torchvision==0.20.1
|
62 |
+
tqdm==4.67.1
|
63 |
+
typer==0.15.1
|
64 |
+
typing-extensions==4.12.2
|
65 |
+
tzdata==2025.1
|
66 |
+
urllib3==2.3.0
|
67 |
+
uvicorn==0.34.0
|
68 |
+
websockets==14.2
|
69 |
+
|
train.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
|
11 |
+
def load_dataset(folder_path, max_images_per_class=60, allowed_classes=None):
|
12 |
+
dataset = {}
|
13 |
+
|
14 |
+
class_names = [
|
15 |
+
name for name in os.listdir(folder_path)
|
16 |
+
if os.path.isdir(os.path.join(folder_path, name)) and
|
17 |
+
(allowed_classes is None or name in allowed_classes)
|
18 |
+
]
|
19 |
+
|
20 |
+
if allowed_classes:
|
21 |
+
class_names = [cls for cls in allowed_classes if cls in class_names]
|
22 |
+
|
23 |
+
for class_name in class_names:
|
24 |
+
class_path = os.path.join(folder_path, class_name)
|
25 |
+
images = []
|
26 |
+
|
27 |
+
for file_name in os.listdir(class_path):
|
28 |
+
if len(images) >= max_images_per_class:
|
29 |
+
break
|
30 |
+
if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
|
31 |
+
img_path = os.path.join(class_path, file_name)
|
32 |
+
img = Image.open(img_path).convert('RGB')
|
33 |
+
images.append(np.array(img))
|
34 |
+
|
35 |
+
dataset[class_name] = images
|
36 |
+
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
class AnimeDataset(Dataset):
|
40 |
+
def __init__(self, images, transform=None, classes=None):
|
41 |
+
self.images = []
|
42 |
+
self.labels = []
|
43 |
+
self.transform = transform
|
44 |
+
self.classes = classes or list(images.keys())
|
45 |
+
|
46 |
+
for label, class_name in enumerate(self.classes):
|
47 |
+
class_images = images.get(class_name, [])
|
48 |
+
self.images.extend(class_images)
|
49 |
+
self.labels.extend([label] * len(class_images))
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.images)
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
image = Image.fromarray(self.images[idx])
|
56 |
+
label = self.labels[idx]
|
57 |
+
|
58 |
+
if self.transform:
|
59 |
+
image = self.transform(image)
|
60 |
+
|
61 |
+
return image, label
|
62 |
+
|
63 |
+
class AnimeCNN(nn.Module):
|
64 |
+
def __init__(self, num_classes=4):
|
65 |
+
super().__init__()
|
66 |
+
self.features = nn.Sequential(
|
67 |
+
nn.Conv2d(3, 32, 3, padding=1),
|
68 |
+
nn.BatchNorm2d(32),
|
69 |
+
nn.ReLU(),
|
70 |
+
nn.MaxPool2d(2, 2),
|
71 |
+
nn.Dropout(0.25),
|
72 |
+
|
73 |
+
nn.Conv2d(32, 64, 3, padding=1),
|
74 |
+
nn.BatchNorm2d(64),
|
75 |
+
nn.ReLU(),
|
76 |
+
nn.MaxPool2d(2, 2),
|
77 |
+
nn.Dropout(0.25)
|
78 |
+
)
|
79 |
+
|
80 |
+
self.classifier = nn.Sequential(
|
81 |
+
nn.Linear(64*16*16, 256),
|
82 |
+
nn.BatchNorm1d(256),
|
83 |
+
nn.ReLU(),
|
84 |
+
nn.Dropout(0.5),
|
85 |
+
nn.Linear(256, num_classes)
|
86 |
+
)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x = self.features(x)
|
90 |
+
x = x.view(x.size(0), -1)
|
91 |
+
x = self.classifier(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
def main():
|
95 |
+
SEED = 42
|
96 |
+
CLASSES = ["usada_pekora", "aisaka_taiga", "megumin", "minato_aqua"]
|
97 |
+
IMG_SIZE = 64
|
98 |
+
BATCH_SIZE = 16
|
99 |
+
NUM_EPOCHS = 15
|
100 |
+
|
101 |
+
torch.manual_seed(SEED)
|
102 |
+
np.random.seed(SEED)
|
103 |
+
|
104 |
+
dataset = load_dataset("dataset", allowed_classes=CLASSES)
|
105 |
+
|
106 |
+
transform = transforms.Compose([
|
107 |
+
transforms.Resize((IMG_SIZE, IMG_SIZE)),
|
108 |
+
transforms.ToTensor(),
|
109 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
110 |
+
])
|
111 |
+
|
112 |
+
anime_dataset = AnimeDataset(dataset, transform=transform, classes=CLASSES)
|
113 |
+
|
114 |
+
indices = list(range(len(anime_dataset)))
|
115 |
+
train_indices, val_indices = train_test_split(
|
116 |
+
indices,
|
117 |
+
test_size=0.2,
|
118 |
+
random_state=SEED,
|
119 |
+
stratify=anime_dataset.labels
|
120 |
+
)
|
121 |
+
|
122 |
+
train_loader = DataLoader(
|
123 |
+
anime_dataset,
|
124 |
+
batch_size=BATCH_SIZE,
|
125 |
+
sampler=SubsetRandomSampler(train_indices),
|
126 |
+
pin_memory=True
|
127 |
+
)
|
128 |
+
|
129 |
+
val_loader = DataLoader(
|
130 |
+
anime_dataset,
|
131 |
+
batch_size=40,
|
132 |
+
sampler=SubsetRandomSampler(val_indices),
|
133 |
+
pin_memory=True
|
134 |
+
)
|
135 |
+
|
136 |
+
model = AnimeCNN(num_classes=len(CLASSES))
|
137 |
+
|
138 |
+
optimizer = optim.Adam(
|
139 |
+
model.parameters(),
|
140 |
+
lr=0.001,
|
141 |
+
weight_decay=1e-4
|
142 |
+
)
|
143 |
+
|
144 |
+
criterion = nn.CrossEntropyLoss()
|
145 |
+
|
146 |
+
for epoch in range(NUM_EPOCHS):
|
147 |
+
model.train()
|
148 |
+
train_loss = 0.0
|
149 |
+
|
150 |
+
for inputs, labels in train_loader:
|
151 |
+
optimizer.zero_grad()
|
152 |
+
outputs = model(inputs)
|
153 |
+
loss = criterion(outputs, labels)
|
154 |
+
loss.backward()
|
155 |
+
optimizer.step()
|
156 |
+
train_loss += loss.item()
|
157 |
+
|
158 |
+
model.eval()
|
159 |
+
val_loss = 0.0
|
160 |
+
correct = 0
|
161 |
+
total = 0
|
162 |
+
|
163 |
+
with torch.no_grad():
|
164 |
+
for inputs, labels in val_loader:
|
165 |
+
outputs = model(inputs)
|
166 |
+
loss = criterion(outputs, labels)
|
167 |
+
val_loss += loss.item()
|
168 |
+
|
169 |
+
_, predicted = torch.max(outputs, 1)
|
170 |
+
total += labels.size(0)
|
171 |
+
correct += (predicted == labels).sum().item()
|
172 |
+
|
173 |
+
train_loss /= len(train_loader)
|
174 |
+
val_loss /= len(val_loader)
|
175 |
+
val_acc = 100 * correct / total
|
176 |
+
|
177 |
+
print(f"Epoch {epoch+1:02d} | "
|
178 |
+
f"Train Loss: {train_loss:.4f} | "
|
179 |
+
f"Val Loss: {val_loss:.4f} | "
|
180 |
+
f"Accuracy: {val_acc:.2f}%")
|
181 |
+
|
182 |
+
print("Model saved as model.pth")
|
183 |
+
torch.save(model.state_dict(), "model.pth")
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
main()
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|