File size: 1,721 Bytes
22e1b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse

import torch.nn
import torchvision.transforms as transforms
from PIL import Image

from .CNN.networks.resnet import resnet50


def predict_cnn(image, model_path, crop=None):
    model = resnet50(num_classes=1)
    state_dict = torch.load(model_path, map_location="cpu")
    model.load_state_dict(state_dict["model"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # Transform
    if crop is not None:
        trans_init = [transforms.CenterCrop(crop)]
        print("Cropping to [%i]" % crop)
        trans = transforms.Compose(
            trans_init
            + [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ],
        )

        image = trans(image.convert("RGB"))

    with torch.no_grad():
        in_tens = image.unsqueeze(0)
        prob = model(in_tens).sigmoid().item()

    return prob


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("-f", "--file", default="examples_realfakedir")
    parser.add_argument(
        "-m",
        "--model_path",
        type=str,
        default="weights/blur_jpg_prob0.5.pth",
    )
    parser.add_argument(
        "-c",
        "--crop",
        type=int,
        default=None,
        help="by default, do not crop. specify crop size",
    )

    opt = parser.parse_args()
    prob = predict_cnn(Image.open(opt.file), opt.model_path, crop=opt.crop)
    print(f"probability of being synthetic: {prob * 100:.2f}%")