news_verification / src /images /CNN_model_classifier.py
pmkhanh7890's picture
1st
22e1b62
raw
history blame
1.72 kB
import argparse
import torch.nn
import torchvision.transforms as transforms
from PIL import Image
from .CNN.networks.resnet import resnet50
def predict_cnn(image, model_path, crop=None):
model = resnet50(num_classes=1)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict["model"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Transform
if crop is not None:
trans_init = [transforms.CenterCrop(crop)]
print("Cropping to [%i]" % crop)
trans = transforms.Compose(
trans_init
+ [
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
],
)
image = trans(image.convert("RGB"))
with torch.no_grad():
in_tens = image.unsqueeze(0)
prob = model(in_tens).sigmoid().item()
return prob
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("-f", "--file", default="examples_realfakedir")
parser.add_argument(
"-m",
"--model_path",
type=str,
default="weights/blur_jpg_prob0.5.pth",
)
parser.add_argument(
"-c",
"--crop",
type=int,
default=None,
help="by default, do not crop. specify crop size",
)
opt = parser.parse_args()
prob = predict_cnn(Image.open(opt.file), opt.model_path, crop=opt.crop)
print(f"probability of being synthetic: {prob * 100:.2f}%")