pmkhanh7890's picture
complete the 1st version of GUI
da7dbd0
raw
history blame
4.95 kB
from sklearn.metrics import roc_auc_score
from torchmetrics import Accuracy, Recall
import pytorch_lightning as pl
import timm
import torch
import torch.nn.functional as F
import logging
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms import v2
logging.basicConfig(filename='training.log',filemode='w',level=logging.INFO, force=True)
CHECKPOINT = "models/image_classifier/image-classifier-step=8008-val_loss=0.11.ckpt"
class ImageClassifier(pl.LightningModule):
def __init__(self, lmd=0):
super().__init__()
self.model = timm.create_model('resnet50', pretrained=True, num_classes=1)
self.accuracy = Accuracy(task='binary', threshold=0.5)
self.recall = Recall(task='binary', threshold=0.5)
self.validation_outputs = []
self.lmd = lmd
def forward(self, x):
return self.model(x)
def training_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
print(f"Shape of outputs (training): {outputs.shape}")
print(f"Shape of labels (training): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
logging.info(f"Training Step - ERM loss: {loss.item()}")
loss += self.lmd * (outputs ** 2).mean() # SD loss penalty
logging.info(f"Training Step - SD loss: {loss.item()}")
return loss
def validation_step(self, batch):
images, labels, _ = batch
outputs = self.forward(images).squeeze()
if outputs.shape == torch.Size([]):
return
print(f"Shape of outputs (validation): {outputs.shape}")
print(f"Shape of labels (validation): {labels.shape}")
loss = F.binary_cross_entropy_with_logits(outputs, labels.float())
preds = torch.sigmoid(outputs)
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
self.log('val_acc', self.accuracy(preds, labels.int()), prog_bar=True, sync_dist=True)
self.log('val_recall', self.recall(preds, labels.int()), prog_bar=True, sync_dist=True)
output = {"val_loss": loss, "preds": preds, "labels": labels}
self.validation_outputs.append(output)
logging.info(f"Validation Step - Batch loss: {loss.item()}")
return output
def predict_step(self, batch):
images, label, domain = batch
outputs = self.forward(images).squeeze()
preds = torch.sigmoid(outputs)
return preds, label, domain
def on_validation_epoch_end(self):
if not self.validation_outputs:
logging.warning("No outputs in validation step to process")
return
preds = torch.cat([x['preds'] for x in self.validation_outputs])
labels = torch.cat([x['labels'] for x in self.validation_outputs])
if labels.unique().size(0) == 1:
logging.warning("Only one class in validation step")
return
auc_score = roc_auc_score(labels.cpu(), preds.cpu())
self.log('val_auc', auc_score, prog_bar=True, sync_dist=True)
logging.info(f"Validation Epoch End - AUC score: {auc_score}")
self.validation_outputs = []
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0005)
return optimizer
def load_image(image_path, transform=None):
image = Image.open(image_path).convert('RGB')
if transform:
image = transform(image)
return image
def predict_single_image(image_path, model, transform=None):
image = load_image(image_path, transform)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
image = image.to(device)
model.eval()
with torch.no_grad():
image = image.unsqueeze(0)
output = model(image).squeeze()
prediction = torch.sigmoid(output).item()
return prediction
def image_generation_detection(image_path):
model = ImageClassifier.load_from_checkpoint(CHECKPOINT)
transform = v2.Compose([
transforms.ToTensor(),
v2.CenterCrop((256, 256)),
])
prediction = predict_single_image(image_path, model, transform)
result = ""
if prediction <= 0.2:
result += "Most likely human"
image_prediction_label = "HUMAN"
else:
result += "Most likely machine"
image_prediction_label = "MACHINE"
image_confidence = min(1, 0.5 + abs(prediction - 0.2))
result += f" with confidence = {round(image_confidence * 100, 2)}%"
image_confidence = round(image_confidence * 100, 2)
return image_prediction_label, image_confidence
if __name__ == "__main__":
image_path = "path_to_your_image.jpg" # Replace with your image path
image_prediction_label, image_confidence = image_generation_detection(
image_path,
)