Spaces:
Running
Running
File size: 4,951 Bytes
da7dbd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from sklearn.metrics import roc_auc_score
from torchmetrics import Accuracy, Recall
import pytorch_lightning as pl
import timm
import torch
import torch.nn.functional as F
import logging
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms import v2
logging.basicConfig(filename='training.log',filemode='w',level=logging.INFO, force=True)
CHECKPOINT = "models/image_classifier/image-classifier-step=8008-val_loss=0.11.ckpt"
class ImageClassifier(pl.LightningModule):
def __init__(self, lmd=0):
super().__init__()
self.model = timm.create_model('resnet50', pretrained=True, num_classes=1)
self.accuracy = Accuracy(task='binary', threshold=0.5)
self.recall = Recall(task='binary', threshold=0.5)
self.validation_outputs = []
self.lmd = lmd
def forward(self, x):
return self.model(x)
def training_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
print(f"Shape of outputs (training): {outputs.shape}")
print(f"Shape of labels (training): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
logging.info(f"Training Step - ERM loss: {loss.item()}")
loss += self.lmd * (outputs ** 2).mean() # SD loss penalty
logging.info(f"Training Step - SD loss: {loss.item()}")
return loss
def validation_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
if outputs.shape == torch.Size([]):
return
print(f"Shape of outputs (validation): {outputs.shape}")
print(f"Shape of labels (validation): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
preds = torch.sigmoid(outputs)
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
self.log('val_acc', self.accuracy(preds, labels.int()), prog_bar=True, sync_dist=True)
self.log('val_recall', self.recall(preds, labels.int()), prog_bar=True, sync_dist=True)
output = {"val_loss": loss, "preds": preds, "labels": labels}
self.validation_outputs.append(output)
logging.info(f"Validation Step - Batch loss: {loss.item()}")
return output
def predict_step(self, batch):
images, label, domain = batch
outputs = self.forward(images).squeeze()
preds = torch.sigmoid(outputs)
return preds, label, domain
def on_validation_epoch_end(self):
if not self.validation_outputs:
logging.warning("No outputs in validation step to process")
return
preds = torch.cat([x['preds'] for x in self.validation_outputs])
labels = torch.cat([x['labels'] for x in self.validation_outputs])
if labels.unique().size(0) == 1:
logging.warning("Only one class in validation step")
return
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
self.log('val_auc', auc_score, prog_bar=True, sync_dist=True)
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
self.validation_outputs = []
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
return optimizer
def load_image(image_path, transform=None):
image = Image.open(image_path).convert('RGB')
if transform:
image = transform(image)
return image
def predict_single_image(image_path, model, transform=None):
image = load_image(image_path, transform)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
image = image.to(device)
model.eval()
with torch.no_grad():
image = image.unsqueeze(0)
output = model(image).squeeze()
prediction = torch.sigmoid(output).item()
return prediction
def image_generation_detection(image_path):
model = ImageClassifier.load_from_checkpoint(CHECKPOINT)
transform = v2.Compose([
transforms.ToTensor(),
v2.CenterCrop((256, 256)),
])
prediction = predict_single_image(image_path, model, transform)
result = ""
if prediction <= 0.2:
result += "Most likely human"
image_prediction_label = "HUMAN"
else:
result += "Most likely machine"
image_prediction_label = "MACHINE"
image_confidence = min(1, 0.5 + abs(prediction - 0.2))
result += f" with confidence = {round(image_confidence * 100, 2)}%"
image_confidence = round(image_confidence * 100, 2)
return image_prediction_label, image_confidence
if __name__ == "__main__":
image_path = "path_to_your_image.jpg" # Replace with your image path
image_prediction_label, image_confidence = image_generation_detection(
image_path,
)
|