news_verification / src /images /Diffusion /diffusion_model_classifier.py
pmkhanh7890's picture
1st
22e1b62
raw
history blame
7.34 kB
import argparse
import logging
import os
import pandas as pd
import pytorch_lightning as pl
import timm
import torch
import torchvision.transforms as transforms
from data_split import *
from dataloader import *
from PIL import Image
from pytorch_lightning.callbacks import (
EarlyStopping,
ModelCheckpoint,
)
from sklearn.metrics import roc_auc_score
from torchmetrics import (
Accuracy,
Recall,
)
from utils_sampling import *
logging.basicConfig(
filename="training.log", filemode="w", level=logging.INFO, force=True
)
class ImageClassifier(pl.LightningModule):
def __init__(self, lmd=0):
super().__init__()
self.model = timm.create_model(
"resnet50", pretrained=True, num_classes=1
)
self.accuracy = Accuracy(task="binary", threshold=0.5)
self.recall = Recall(task="binary", threshold=0.5)
self.validation_outputs = []
self.lmd = lmd
def forward(self, x):
return self.model(x)
def training_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
print(f"Shape of outputs (training): {outputs.shape}")
print(f"Shape of labels (training): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
logging.info(f"Training Step - ERM loss: {loss.item()}")
loss += self.lmd * (outputs**2).mean() # SD loss penalty
logging.info(f"Training Step - SD loss: {loss.item()}")
return loss
def validation_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
if outputs.shape == torch.Size([]):
return
print(f"Shape of outputs (validation): {outputs.shape}")
print(f"Shape of labels (validation): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
preds = torch.sigmoid(outputs)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
self.log(
"val_acc",
self.accuracy(preds, labels.int()),
prog_bar=True,
sync_dist=True,
)
self.log(
"val_recall",
self.recall(preds, labels.int()),
prog_bar=True,
sync_dist=True,
)
output = {"val_loss": loss, "preds": preds, "labels": labels}
self.validation_outputs.append(output)
logging.info(f"Validation Step - Batch loss: {loss.item()}")
return output
def predict_step(self, batch):
images, label, domain = batch
outputs = self.forward(images).squeeze()
preds = torch.sigmoid(outputs)
return preds, label, domain
def on_validation_epoch_end(self):
if not self.validation_outputs:
logging.warning("No outputs in validation step to process")
return
preds = torch.cat([x["preds"] for x in self.validation_outputs])
labels = torch.cat([x["labels"] for x in self.validation_outputs])
if labels.unique().size(0) == 1:
logging.warning("Only one class in validation step")
return
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
self.log("val_auc", auc_score, prog_bar=True, sync_dist=True)
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
self.validation_outputs = []
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
return optimizer
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
dirpath="./model_checkpoints/",
filename="image-classifier-{step}-{val_loss:.2f}",
save_top_k=3,
mode="min",
every_n_train_steps=1001,
enable_version_counter=True,
)
early_stop_callback = EarlyStopping(
monitor="val_loss",
patience=4,
mode="min",
)
def load_image(image_path, transform=None):
image = Image.open(image_path).convert("RGB")
if transform:
image = transform(image)
return image
def predict_single_image(image_path, model, transform=None):
image = load_image(image_path, transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image = image.to(device)
model.eval()
with torch.no_grad():
image = image.unsqueeze(0)
output = model(image).squeeze()
print(output)
prediction = torch.sigmoid(output).item()
return prediction
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt_path", help="checkpoint to continue from", required=False
)
parser.add_argument(
"--predict", help="predict on test set", action="store_true"
)
parser.add_argument("--reset", help="reset training", action="store_true")
parser.add_argument(
"--predict_image",
help="predict the class of a single image",
action="store_true",
)
parser.add_argument(
"--image_path",
help="path to the image to predict",
type=str,
required=False,
)
args = parser.parse_args()
train_domains = [0, 1, 4]
val_domains = [0, 1, 4]
lmd_value = 0
if args.predict:
test_dl = load_dataloader(
[0, 1, 2, 3, 4], "test", batch_size=128, num_workers=1
)
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
trainer = pl.Trainer()
predictions = trainer.predict(model, dataloaders=test_dl)
preds, labels, domains = zip(*predictions)
preds = torch.cat(preds).cpu().numpy()
labels = torch.cat(labels).cpu().numpy()
domains = torch.cat(domains).cpu().numpy()
print(preds.shape, labels.shape, domains.shape)
df = pd.DataFrame({"preds": preds, "labels": labels, "domains": domains})
filename = "preds-" + args.ckpt_path.split("/")[-1]
df.to_csv(f"outputs/{filename}.csv", index=False)
elif args.predict_image:
image_path = args.image_path
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
# Define the transformations for the image
transform = transforms.Compose(
[
transforms.Resize((224, 224)), # Image size expected by ResNet50
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
prediction = predict_single_image(image_path, model, transform)
print("prediction", prediction)
# Output the prediction
print(
f"Prediction for {image_path}: {'Human' if prediction <= 0.001 else 'Generated'}"
)
else:
train_dl = load_dataloader(
train_domains, "train", batch_size=128, num_workers=4
)
logging.info("Training dataloader loaded")
val_dl = load_dataloader(val_domains, "val", batch_size=128, num_workers=4)
logging.info("Validation dataloader loaded")
if args.reset:
model = ImageClassifier.load_from_checkpoint(args.ckpt_path)
else:
model = ImageClassifier(lmd=lmd_value)
trainer = pl.Trainer(
callbacks=[checkpoint_callback, early_stop_callback],
max_steps=20000,
val_check_interval=1000,
check_val_every_n_epoch=None,
)
trainer.fit(
model=model,
train_dataloaders=train_dl,
val_dataloaders=val_dl,
ckpt_path=args.ckpt_path if not args.reset else None,
)