import argparse import torch.nn import torchvision.transforms as transforms from PIL import Image from .CNN.networks.resnet import resnet50 def predict_cnn(image, model_path, crop=None): model = resnet50(num_classes=1) state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict["model"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # Transform if crop is not None: trans_init = [transforms.CenterCrop(crop)] print("Cropping to [%i]" % crop) trans = transforms.Compose( trans_init + [ transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ], ) image = trans(image.convert("RGB")) with torch.no_grad(): in_tens = image.unsqueeze(0) prob = model(in_tens).sigmoid().item() return prob if __name__ == "__main__": parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("-f", "--file", default="examples_realfakedir") parser.add_argument( "-m", "--model_path", type=str, default="weights/blur_jpg_prob0.5.pth", ) parser.add_argument( "-c", "--crop", type=int, default=None, help="by default, do not crop. specify crop size", ) opt = parser.parse_args() prob = predict_cnn(Image.open(opt.file), opt.model_path, crop=opt.crop) print(f"probability of being synthetic: {prob * 100:.2f}%")