Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler | |
from sklearn.model_selection import train_test_split | |
def load_dataset(folder_path, max_images_per_class=60, allowed_classes=None): | |
dataset = {} | |
class_names = [ | |
name for name in os.listdir(folder_path) | |
if os.path.isdir(os.path.join(folder_path, name)) and | |
(allowed_classes is None or name in allowed_classes) | |
] | |
if allowed_classes: | |
class_names = [cls for cls in allowed_classes if cls in class_names] | |
for class_name in class_names: | |
class_path = os.path.join(folder_path, class_name) | |
images = [] | |
for file_name in os.listdir(class_path): | |
if len(images) >= max_images_per_class: | |
break | |
if file_name.lower().endswith(('.png', '.jpg', '.jpeg')): | |
img_path = os.path.join(class_path, file_name) | |
img = Image.open(img_path).convert('RGB') | |
images.append(np.array(img)) | |
dataset[class_name] = images | |
return dataset | |
class AnimeDataset(Dataset): | |
def __init__(self, images, transform=None, classes=None): | |
self.images = [] | |
self.labels = [] | |
self.transform = transform | |
self.classes = classes or list(images.keys()) | |
for label, class_name in enumerate(self.classes): | |
class_images = images.get(class_name, []) | |
self.images.extend(class_images) | |
self.labels.extend([label] * len(class_images)) | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
image = Image.fromarray(self.images[idx]) | |
label = self.labels[idx] | |
if self.transform: | |
image = self.transform(image) | |
return image, label | |
class AnimeCNN(nn.Module): | |
def __init__(self, num_classes=4): | |
super().__init__() | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 32, 3, padding=1), | |
nn.BatchNorm2d(32), | |
nn.ReLU(), | |
nn.MaxPool2d(2, 2), | |
nn.Dropout(0.25), | |
nn.Conv2d(32, 64, 3, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.MaxPool2d(2, 2), | |
nn.Dropout(0.25) | |
) | |
self.classifier = nn.Sequential( | |
nn.Linear(64*16*16, 256), | |
nn.BatchNorm1d(256), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(256, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
return x | |
def main(): | |
SEED = 42 | |
CLASSES = ["usada_pekora", "aisaka_taiga", "megumin", "minato_aqua"] | |
IMG_SIZE = 64 | |
BATCH_SIZE = 16 | |
NUM_EPOCHS = 15 | |
torch.manual_seed(SEED) | |
np.random.seed(SEED) | |
dataset = load_dataset("dataset", allowed_classes=CLASSES) | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
anime_dataset = AnimeDataset(dataset, transform=transform, classes=CLASSES) | |
indices = list(range(len(anime_dataset))) | |
train_indices, val_indices = train_test_split( | |
indices, | |
test_size=0.2, | |
random_state=SEED, | |
stratify=anime_dataset.labels | |
) | |
train_loader = DataLoader( | |
anime_dataset, | |
batch_size=BATCH_SIZE, | |
sampler=SubsetRandomSampler(train_indices), | |
pin_memory=True | |
) | |
val_loader = DataLoader( | |
anime_dataset, | |
batch_size=40, | |
sampler=SubsetRandomSampler(val_indices), | |
pin_memory=True | |
) | |
model = AnimeCNN(num_classes=len(CLASSES)) | |
optimizer = optim.Adam( | |
model.parameters(), | |
lr=0.001, | |
weight_decay=1e-4 | |
) | |
criterion = nn.CrossEntropyLoss() | |
for epoch in range(NUM_EPOCHS): | |
model.train() | |
train_loss = 0.0 | |
for inputs, labels in train_loader: | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
model.eval() | |
val_loss = 0.0 | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for inputs, labels in val_loader: | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
val_loss += loss.item() | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
train_loss /= len(train_loader) | |
val_loss /= len(val_loader) | |
val_acc = 100 * correct / total | |
print(f"Epoch {epoch+1:02d} | " | |
f"Train Loss: {train_loss:.4f} | " | |
f"Val Loss: {val_loss:.4f} | " | |
f"Accuracy: {val_acc:.2f}%") | |
print("Model saved as model.pth") | |
torch.save(model.state_dict(), "model.pth") | |
if __name__ == "__main__": | |
main() | |