Spaces:
Running
Running
import argparse | |
import torch.nn | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from .CNN.networks.resnet import resnet50 | |
def predict_cnn(image, model_path, crop=None): | |
model = resnet50(num_classes=1) | |
state_dict = torch.load(model_path, map_location="cpu") | |
model.load_state_dict(state_dict["model"]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.eval() | |
# Transform | |
if crop is not None: | |
trans_init = [transforms.CenterCrop(crop)] | |
print("Cropping to [%i]" % crop) | |
trans = transforms.Compose( | |
trans_init | |
+ [ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
), | |
], | |
) | |
image = trans(image.convert("RGB")) | |
with torch.no_grad(): | |
in_tens = image.unsqueeze(0) | |
prob = model(in_tens).sigmoid().item() | |
return prob | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument("-f", "--file", default="examples_realfakedir") | |
parser.add_argument( | |
"-m", | |
"--model_path", | |
type=str, | |
default="weights/blur_jpg_prob0.5.pth", | |
) | |
parser.add_argument( | |
"-c", | |
"--crop", | |
type=int, | |
default=None, | |
help="by default, do not crop. specify crop size", | |
) | |
opt = parser.parse_args() | |
prob = predict_cnn(Image.open(opt.file), opt.model_path, crop=opt.crop) | |
print(f"probability of being synthetic: {prob * 100:.2f}%") | |