|
|
|
import random |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from PIL import Image |
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset |
|
from torchvision.datasets import DatasetFolder |
|
from tqdm import tqdm |
|
from PIL import ImageFile |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
random.seed(1234) |
|
np.random.seed(1234) |
|
torch.manual_seed(1234) |
|
|
|
folder = "./datasets" |
|
NUM_CLASSES = 14 |
|
|
|
|
|
train_tfm = transforms.Compose( |
|
[ |
|
|
|
transforms.Resize((224, 224)), |
|
transforms.Lambda(lambda x: x.convert("RGB")), |
|
|
|
transforms.RandomHorizontalFlip(), |
|
|
|
transforms.RandomRotation( |
|
20 |
|
), |
|
|
|
transforms.RandomResizedCrop( |
|
224, scale=(0.8, 1.0) |
|
), |
|
|
|
|
|
|
|
|
|
|
|
transforms.ToTensor(), |
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
test_tfm = transforms.Compose( |
|
[ |
|
|
|
transforms.Resize((224, 224)), |
|
transforms.Lambda(lambda x: x.convert("RGB")), |
|
|
|
transforms.ToTensor(), |
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
|
|
def get_dataset(): |
|
train_set = DatasetFolder( |
|
folder + "/train/labeled", |
|
loader=lambda x: Image.open(x), |
|
extensions="jpg", |
|
|
|
) |
|
valid_set = DatasetFolder( |
|
folder + "/val", |
|
loader=lambda x: Image.open(x), |
|
extensions="jpg", |
|
transform=test_tfm, |
|
) |
|
unlabeled_set = DatasetFolder( |
|
folder + "/train/unlabeled", |
|
loader=lambda x: Image.open(x), |
|
extensions="jpg", |
|
|
|
) |
|
test_set = DatasetFolder( |
|
folder + "/test", |
|
loader=lambda x: Image.open(x), |
|
extensions="jpg", |
|
transform=test_tfm, |
|
) |
|
return train_set, valid_set, unlabeled_set, test_set |
|
|
|
|
|
def train_collate_fn(batch): |
|
data, labels = zip(*batch) |
|
|
|
labels = torch.tensor(labels) |
|
return data, labels |
|
|
|
|
|
def test_collate_fn(batch): |
|
data, labels = zip(*batch) |
|
data = torch.stack(data) |
|
labels = torch.tensor(labels) |
|
return data, labels |
|
|
|
|
|
from utils import CustomDataset |
|
|
|
|
|
def update_dataset( |
|
train_set, unlabeled_set, model, threshold, batch_size=128, num_workers=8 |
|
) -> Dataset: |
|
""" |
|
This is the core function to generate pseudo-labels dataets using the given model. |
|
inputs: |
|
- train_set: The labeled training set |
|
- unlabeled_set: The unlabeled dataset to be pseudo-labeled |
|
- model: The trained model to generate pseudo-labels |
|
- threshold: Confidence threshold for pseudo-labeling |
|
- batch_size: Batch size for DataLoader |
|
- num_workers: Number of workers for DataLoader |
|
outputs: |
|
- new_set: The updated dataset with pseudo-labels |
|
""" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model.eval() |
|
|
|
softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
unlabeled_loader = DataLoader( |
|
unlabeled_set, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
collate_fn=train_collate_fn, |
|
) |
|
|
|
|
|
confident_samples = [] |
|
confident_labels = [] |
|
|
|
with torch.no_grad(): |
|
for batch_idx, (images, _) in enumerate( |
|
tqdm(unlabeled_loader, desc="Generating pseudo-labels") |
|
): |
|
|
|
new_images = torch.stack([test_tfm(img) for img in images]) |
|
|
|
|
|
outputs = model(new_images.to(device)) |
|
|
|
|
|
probabilities = softmax(outputs) |
|
|
|
|
|
max_probs, pseudo_labels = torch.max(probabilities, dim=1) |
|
|
|
|
|
for i, (prob, label) in enumerate(zip(max_probs, pseudo_labels)): |
|
|
|
if prob.item() > threshold: |
|
|
|
idx = batch_idx * batch_size + i |
|
if idx < len(unlabeled_set): |
|
img, _ = unlabeled_set[idx] |
|
confident_samples.append(img) |
|
confident_labels.append(int(label.cpu())) |
|
|
|
|
|
if confident_samples: |
|
pseudo_set = CustomDataset(images=confident_samples, labels=confident_labels) |
|
|
|
|
|
new_set = ConcatDataset([train_set, pseudo_set]) |
|
|
|
print(f"Added {len(confident_samples)} pseudo-labeled samples to training set") |
|
else: |
|
print("No confident pseudo-labels found.") |
|
new_set = train_set |
|
return new_set |
|
|
|
|
|
from models import * |
|
from torch import optim |
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
model = ResNet101Classifier(num_classes=NUM_CLASSES).to(device) |
|
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load("best_model.pth")) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5) |
|
|
|
scheduler = ReduceLROnPlateau( |
|
optimizer, |
|
mode="min", |
|
factor=0.1, |
|
patience=2, |
|
threshold=0.001, |
|
threshold_mode="rel", |
|
cooldown=0, |
|
min_lr=0, |
|
eps=1e-08, |
|
) |
|
|
|
|
|
|
|
do_semi = True |
|
batch_size = 128 |
|
num_workers = 8 |
|
train_set, valid_set, unlabeled_set, test_set = get_dataset() |
|
|
|
train_loader = DataLoader( |
|
dataset=train_set, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=num_workers, |
|
collate_fn=train_collate_fn, |
|
) |
|
valid_loader = DataLoader( |
|
dataset=valid_set, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=num_workers, |
|
collate_fn=test_collate_fn, |
|
) |
|
test_loader = DataLoader( |
|
dataset=test_set, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
collate_fn=test_collate_fn, |
|
) |
|
best_valid_loss = float("inf") |
|
start_epoch = 100 |
|
epochs = 100 |
|
threshold = 0.8 |
|
early_stop = False |
|
for epoch in range(start_epoch, epochs): |
|
if do_semi and epoch > epochs // 4 and epoch % 2 == 0: |
|
new_set = update_dataset( |
|
train_set=train_set, |
|
unlabeled_set=unlabeled_set, |
|
model=model, |
|
threshold=threshold, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
) |
|
train_loader = DataLoader( |
|
dataset=new_set, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
collate_fn=train_collate_fn, |
|
) |
|
|
|
|
|
model.train() |
|
|
|
train_loss = [] |
|
train_accs = [] |
|
|
|
for batch in tqdm(train_loader): |
|
imgs, labels = batch |
|
|
|
new_images = torch.stack([train_tfm(img) for img in imgs]) |
|
|
|
logits = model(new_images.to(device)) |
|
|
|
loss = criterion(logits, labels.to(device)) |
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean() |
|
|
|
train_loss.append(loss.item()) |
|
train_accs.append(acc) |
|
|
|
train_loss = sum(train_loss) / len(train_loss) |
|
train_acc = sum(train_accs) / len(train_accs) |
|
|
|
print( |
|
f"[ Train | {epoch + 1:03d}/{epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}" |
|
) |
|
|
|
model.eval() |
|
|
|
valid_loss = [] |
|
valid_accs = [] |
|
|
|
for batch in tqdm(valid_loader): |
|
imgs, labels = batch |
|
|
|
with torch.no_grad(): |
|
logits = model(imgs.to(device)) |
|
|
|
loss = criterion(logits, labels.to(device)) |
|
|
|
acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean() |
|
|
|
valid_loss.append(loss.item()) |
|
valid_accs.append(acc) |
|
|
|
valid_loss = sum(valid_loss) / len(valid_loss) |
|
valid_acc = sum(valid_accs) / len(valid_accs) |
|
|
|
print( |
|
f"[ Valid | {epoch + 1:03d}/{epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}" |
|
) |
|
|
|
scheduler.step(metrics=valid_loss) |
|
|
|
if valid_loss < best_valid_loss: |
|
best_valid_loss = valid_loss |
|
torch.save(model.state_dict(), "best_model.pth") |
|
print(f"Model saved with loss: {valid_loss:.5f}, acc: {valid_acc:.5f}") |
|
elif early_stop: |
|
if epoch > epochs // 2 and valid_loss > best_valid_loss * 1.2: |
|
print("Early stopping") |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load("best_model.pth")) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
predictions = [] |
|
|
|
|
|
|
|
for batch in tqdm(test_loader): |
|
imgs, labels = batch |
|
|
|
with torch.no_grad(): |
|
logits = model(imgs.to(device)) |
|
|
|
predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist()) |
|
|
|
|
|
|
|
|
|
with open("predict.csv", "w") as f: |
|
|
|
f.write("Id,Category\n") |
|
|
|
|
|
for i, pred in enumerate(predictions): |
|
f.write(f"{i},{pred}\n") |
|
print("Predictions saved to predict.csv") |
|
|
|
|
|
|
|
|