mirudev commited on
Commit
f8de7c6
·
verified ·
1 Parent(s): f050d30

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ dataset.zip
13
+ dataset/
14
+ *.png
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Git Config -global Credential.helper Store
3
- emoji: 🏢
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.13.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: git_config_-global_credential.helper_store
3
+ app_file: client.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.13.0
 
 
6
  ---
 
 
client.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+
6
+
7
+ class AnimeCNN(nn.Module):
8
+ def __init__(self, num_classes=4):
9
+ super().__init__()
10
+ self.features = nn.Sequential(
11
+ nn.Conv2d(3, 32, 3, padding=1),
12
+ nn.BatchNorm2d(32),
13
+ nn.ReLU(),
14
+ nn.MaxPool2d(2, 2),
15
+ nn.Dropout(0.25),
16
+
17
+ nn.Conv2d(32, 64, 3, padding=1),
18
+ nn.BatchNorm2d(64),
19
+ nn.ReLU(),
20
+ nn.MaxPool2d(2, 2),
21
+ nn.Dropout(0.25)
22
+ )
23
+
24
+ self.classifier = nn.Sequential(
25
+ nn.Linear(64*16*16, 256),
26
+ nn.BatchNorm1d(256),
27
+ nn.ReLU(),
28
+ nn.Dropout(0.5),
29
+ nn.Linear(256, num_classes)
30
+ )
31
+
32
+ def forward(self, x):
33
+ x = self.features(x)
34
+ x = x.view(x.size(0), -1)
35
+ x = self.classifier(x)
36
+ return x
37
+
38
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
39
+ model = AnimeCNN()
40
+ model.load_state_dict(torch.load('model.pth', map_location=device, weights_only=True))
41
+ model.eval()
42
+
43
+ classes = ["usada_pekora", "aisaka_taiga", "megumin", "minato_aqua"]
44
+
45
+ transform = transforms.Compose([
46
+ transforms.Resize((64, 64)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
49
+ ])
50
+
51
+ def predict(image):
52
+ image = transform(image).unsqueeze(0)
53
+
54
+ with torch.no_grad():
55
+ outputs = model(image)
56
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
57
+
58
+ confidences = {classes[i]: float(probabilities[i]) for i in range(4)}
59
+ return confidences
60
+
61
+ interface = gr.Interface(
62
+ fn=predict,
63
+ inputs=gr.Image(type="pil", label="入力画像"),
64
+ outputs=gr.Label(num_top_classes=4, label="予測結果"),
65
+ title="アニメキャラクター分類器",
66
+ description="うさだぺこら・逢坂大河・めぐみん・湊あくあの画像を分類します。画像をアップロードしてください。",
67
+ examples=[
68
+ ["examples/usada_pekora.jpg"],
69
+ ["examples/aisaka_taiga.jpg"],
70
+ ["examples/megumin.jpg"],
71
+ ["examples/minato_aqua.jpg"]
72
+ ],
73
+ )
74
+
75
+ interface.launch(server_name="0.0.0.0", server_port=7860, share=True)
examples/aisaka_taiga.jpg ADDED
examples/megumin.jpg ADDED
examples/minato_aqua.jpg ADDED
examples/usada_pekora.jpg ADDED
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7ad84f9ff21d15b7761b69237efa3ce31f0324779a17680b9ebf8901dedbb14
3
+ size 16871835
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "cnn-anime-classification"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "gradio>=5.13.0",
9
+ "matplotlib>=3.10.0",
10
+ "numpy>=2.2.2",
11
+ "pillow>=11.1.0",
12
+ "scikit-learn>=1.6.1",
13
+ ]
requirements.txt ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.8.0
5
+ certifi==2024.12.14
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ contourpy==1.3.1
9
+ cycler==0.12.1
10
+ fastapi==0.115.7
11
+ ffmpy==0.5.0
12
+ filelock==3.17.0
13
+ fonttools==4.55.4
14
+ fsspec==2024.12.0
15
+ gradio==5.13.0
16
+ gradio-client==1.6.0
17
+ h11==0.14.0
18
+ httpcore==1.0.7
19
+ httpx==0.28.1
20
+ huggingface-hub==0.27.1
21
+ idna==3.10
22
+ jinja2==3.1.5
23
+ joblib==1.4.2
24
+ kiwisolver==1.4.8
25
+ markdown-it-py==3.0.0
26
+ markupsafe==2.1.5
27
+ matplotlib==3.10.0
28
+ mdurl==0.1.2
29
+ mpmath==1.3.0
30
+ networkx==3.2.1
31
+ numpy==2.2.2
32
+ orjson==3.10.15
33
+ packaging==24.2
34
+ pandas==2.2.3
35
+ pillow==11.1.0
36
+ pydantic==2.10.5
37
+ pydantic-core==2.27.2
38
+ pydub==0.25.1
39
+ pygments==2.19.1
40
+ pyparsing==3.2.1
41
+ python-dateutil==2.9.0.post0
42
+ python-multipart==0.0.20
43
+ pytz==2024.2
44
+ pyyaml==6.0.2
45
+ requests==2.32.3
46
+ rich==13.9.4
47
+ ruff==0.9.2
48
+ safehttpx==0.1.6
49
+ scikit-learn==1.6.1
50
+ scipy==1.15.1
51
+ semantic-version==2.10.0
52
+ setuptools==70.0.0
53
+ shellingham==1.5.4
54
+ six==1.17.0
55
+ sniffio==1.3.1
56
+ starlette==0.45.2
57
+ sympy==1.13.1
58
+ threadpoolctl==3.5.0
59
+ tomlkit==0.13.2
60
+ torch==2.5.1
61
+ torchvision==0.20.1
62
+ tqdm==4.67.1
63
+ typer==0.15.1
64
+ typing-extensions==4.12.2
65
+ tzdata==2025.1
66
+ urllib3==2.3.0
67
+ uvicorn==0.34.0
68
+ websockets==14.2
69
+
train.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
9
+ from sklearn.model_selection import train_test_split
10
+
11
+ def load_dataset(folder_path, max_images_per_class=60, allowed_classes=None):
12
+ dataset = {}
13
+
14
+ class_names = [
15
+ name for name in os.listdir(folder_path)
16
+ if os.path.isdir(os.path.join(folder_path, name)) and
17
+ (allowed_classes is None or name in allowed_classes)
18
+ ]
19
+
20
+ if allowed_classes:
21
+ class_names = [cls for cls in allowed_classes if cls in class_names]
22
+
23
+ for class_name in class_names:
24
+ class_path = os.path.join(folder_path, class_name)
25
+ images = []
26
+
27
+ for file_name in os.listdir(class_path):
28
+ if len(images) >= max_images_per_class:
29
+ break
30
+ if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
31
+ img_path = os.path.join(class_path, file_name)
32
+ img = Image.open(img_path).convert('RGB')
33
+ images.append(np.array(img))
34
+
35
+ dataset[class_name] = images
36
+
37
+ return dataset
38
+
39
+ class AnimeDataset(Dataset):
40
+ def __init__(self, images, transform=None, classes=None):
41
+ self.images = []
42
+ self.labels = []
43
+ self.transform = transform
44
+ self.classes = classes or list(images.keys())
45
+
46
+ for label, class_name in enumerate(self.classes):
47
+ class_images = images.get(class_name, [])
48
+ self.images.extend(class_images)
49
+ self.labels.extend([label] * len(class_images))
50
+
51
+ def __len__(self):
52
+ return len(self.images)
53
+
54
+ def __getitem__(self, idx):
55
+ image = Image.fromarray(self.images[idx])
56
+ label = self.labels[idx]
57
+
58
+ if self.transform:
59
+ image = self.transform(image)
60
+
61
+ return image, label
62
+
63
+ class AnimeCNN(nn.Module):
64
+ def __init__(self, num_classes=4):
65
+ super().__init__()
66
+ self.features = nn.Sequential(
67
+ nn.Conv2d(3, 32, 3, padding=1),
68
+ nn.BatchNorm2d(32),
69
+ nn.ReLU(),
70
+ nn.MaxPool2d(2, 2),
71
+ nn.Dropout(0.25),
72
+
73
+ nn.Conv2d(32, 64, 3, padding=1),
74
+ nn.BatchNorm2d(64),
75
+ nn.ReLU(),
76
+ nn.MaxPool2d(2, 2),
77
+ nn.Dropout(0.25)
78
+ )
79
+
80
+ self.classifier = nn.Sequential(
81
+ nn.Linear(64*16*16, 256),
82
+ nn.BatchNorm1d(256),
83
+ nn.ReLU(),
84
+ nn.Dropout(0.5),
85
+ nn.Linear(256, num_classes)
86
+ )
87
+
88
+ def forward(self, x):
89
+ x = self.features(x)
90
+ x = x.view(x.size(0), -1)
91
+ x = self.classifier(x)
92
+ return x
93
+
94
+ def main():
95
+ SEED = 42
96
+ CLASSES = ["usada_pekora", "aisaka_taiga", "megumin", "minato_aqua"]
97
+ IMG_SIZE = 64
98
+ BATCH_SIZE = 16
99
+ NUM_EPOCHS = 15
100
+
101
+ torch.manual_seed(SEED)
102
+ np.random.seed(SEED)
103
+
104
+ dataset = load_dataset("dataset", allowed_classes=CLASSES)
105
+
106
+ transform = transforms.Compose([
107
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
108
+ transforms.ToTensor(),
109
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
110
+ ])
111
+
112
+ anime_dataset = AnimeDataset(dataset, transform=transform, classes=CLASSES)
113
+
114
+ indices = list(range(len(anime_dataset)))
115
+ train_indices, val_indices = train_test_split(
116
+ indices,
117
+ test_size=0.2,
118
+ random_state=SEED,
119
+ stratify=anime_dataset.labels
120
+ )
121
+
122
+ train_loader = DataLoader(
123
+ anime_dataset,
124
+ batch_size=BATCH_SIZE,
125
+ sampler=SubsetRandomSampler(train_indices),
126
+ pin_memory=True
127
+ )
128
+
129
+ val_loader = DataLoader(
130
+ anime_dataset,
131
+ batch_size=40,
132
+ sampler=SubsetRandomSampler(val_indices),
133
+ pin_memory=True
134
+ )
135
+
136
+ model = AnimeCNN(num_classes=len(CLASSES))
137
+
138
+ optimizer = optim.Adam(
139
+ model.parameters(),
140
+ lr=0.001,
141
+ weight_decay=1e-4
142
+ )
143
+
144
+ criterion = nn.CrossEntropyLoss()
145
+
146
+ for epoch in range(NUM_EPOCHS):
147
+ model.train()
148
+ train_loss = 0.0
149
+
150
+ for inputs, labels in train_loader:
151
+ optimizer.zero_grad()
152
+ outputs = model(inputs)
153
+ loss = criterion(outputs, labels)
154
+ loss.backward()
155
+ optimizer.step()
156
+ train_loss += loss.item()
157
+
158
+ model.eval()
159
+ val_loss = 0.0
160
+ correct = 0
161
+ total = 0
162
+
163
+ with torch.no_grad():
164
+ for inputs, labels in val_loader:
165
+ outputs = model(inputs)
166
+ loss = criterion(outputs, labels)
167
+ val_loss += loss.item()
168
+
169
+ _, predicted = torch.max(outputs, 1)
170
+ total += labels.size(0)
171
+ correct += (predicted == labels).sum().item()
172
+
173
+ train_loss /= len(train_loader)
174
+ val_loss /= len(val_loader)
175
+ val_acc = 100 * correct / total
176
+
177
+ print(f"Epoch {epoch+1:02d} | "
178
+ f"Train Loss: {train_loss:.4f} | "
179
+ f"Val Loss: {val_loss:.4f} | "
180
+ f"Accuracy: {val_acc:.2f}%")
181
+
182
+ print("Model saved as model.pth")
183
+ torch.save(model.state_dict(), "model.pth")
184
+
185
+ if __name__ == "__main__":
186
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff