from sklearn.metrics import roc_auc_score from torchmetrics import Accuracy, Recall import pytorch_lightning as pl import timm import torch import torch.nn.functional as F import logging from PIL import Image import torchvision.transforms as transforms from torchvision.transforms import v2 logging.basicConfig(filename='training.log',filemode='w',level=logging.INFO, force=True) CHECKPOINT = "models/image_classifier/image-classifier-step=8008-val_loss=0.11.ckpt" class ImageClassifier(pl.LightningModule): def __init__(self, lmd=0): super().__init__() self.model = timm.create_model('resnet50', pretrained=True, num_classes=1) self.accuracy = Accuracy(task='binary', threshold=0.5) self.recall = Recall(task='binary', threshold=0.5) self.validation_outputs = [] self.lmd = lmd def forward(self, x): return self.model(x) def training_step(self, batch): images, labels, _ = batch outputs = self.forward(images).squeeze() print(f"Shape of outputs (training): {outputs.shape}") print(f"Shape of labels (training): {labels.shape}") loss = F.binary_cross_entropy_with_logits(outputs, labels.float()) logging.info(f"Training Step - ERM loss: {loss.item()}") loss += self.lmd * (outputs ** 2).mean() # SD loss penalty logging.info(f"Training Step - SD loss: {loss.item()}") return loss def validation_step(self, batch): images, labels, _ = batch outputs = self.forward(images).squeeze() if outputs.shape == torch.Size([]): return print(f"Shape of outputs (validation): {outputs.shape}") print(f"Shape of labels (validation): {labels.shape}") loss = F.binary_cross_entropy_with_logits(outputs, labels.float()) preds = torch.sigmoid(outputs) self.log('val_loss', loss, prog_bar=True, sync_dist=True) self.log('val_acc', self.accuracy(preds, labels.int()), prog_bar=True, sync_dist=True) self.log('val_recall', self.recall(preds, labels.int()), prog_bar=True, sync_dist=True) output = {"val_loss": loss, "preds": preds, "labels": labels} self.validation_outputs.append(output) logging.info(f"Validation Step - Batch loss: {loss.item()}") return output def predict_step(self, batch): images, label, domain = batch outputs = self.forward(images).squeeze() preds = torch.sigmoid(outputs) return preds, label, domain def on_validation_epoch_end(self): if not self.validation_outputs: logging.warning("No outputs in validation step to process") return preds = torch.cat([x['preds'] for x in self.validation_outputs]) labels = torch.cat([x['labels'] for x in self.validation_outputs]) if labels.unique().size(0) == 1: logging.warning("Only one class in validation step") return auc_score = roc_auc_score(labels.cpu(), preds.cpu()) self.log('val_auc', auc_score, prog_bar=True, sync_dist=True) logging.info(f"Validation Epoch End - AUC score: {auc_score}") self.validation_outputs = [] def configure_optimizers(self): optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005) return optimizer def load_image(image_path, transform=None): image = Image.open(image_path).convert('RGB') if transform: image = transform(image) return image def predict_single_image(image_path, model, transform=None): image = load_image(image_path, transform) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) image = image.to(device) model.eval() with torch.no_grad(): image = image.unsqueeze(0) output = model(image).squeeze() prediction = torch.sigmoid(output).item() return prediction def image_generation_detection(image_path): model = ImageClassifier.load_from_checkpoint(CHECKPOINT) transform = v2.Compose([ transforms.ToTensor(), v2.CenterCrop((256, 256)), ]) prediction = predict_single_image(image_path, model, transform) result = "" if prediction <= 0.2: result += "Most likely human" image_prediction_label = "HUMAN" else: result += "Most likely machine" image_prediction_label = "MACHINE" image_confidence = min(1, 0.5 + abs(prediction - 0.2)) result += f" with confidence = {round(image_confidence * 100, 2)}%" image_confidence = round(image_confidence * 100, 2) return image_prediction_label, image_confidence if __name__ == "__main__": image_path = "path_to_your_image.jpg" # Replace with your image path image_prediction_label, image_confidence = image_generation_detection( image_path, )