Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import random | |
import argparse | |
from glob import glob | |
from os.path import join | |
from model.network import Recce | |
from model.common import freeze_weights | |
from albumentations import Compose, Normalize, Resize | |
from albumentations.pytorch.transforms import ToTensorV2 | |
import os | |
os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
# fix random seed | |
seed = 0 | |
random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
parser = argparse.ArgumentParser(description="This code helps you use a trained model to " | |
"do inference.") | |
parser.add_argument("--weight", "-w", | |
type=str, | |
default=None, | |
help="Specify the path to the model weight (the state dict file). " | |
"Do not use this argument when '--bin' is set.") | |
parser.add_argument("--bin", "-b", | |
type=str, | |
default=None, | |
help="Specify the path to the model bin which ends up with '.bin' " | |
"(which is generated by the trainer of this project). " | |
"Do not use this argument when '--weight' is set.") | |
parser.add_argument("--image", "-i", | |
type=str, | |
default=None, | |
help="Specify the path to the input image. " | |
"Do not use this argument when '--image_folder' is set.") | |
parser.add_argument("--image_folder", "-f", | |
type=str, | |
default=None, | |
help="Specify the directory to evaluate all the images. " | |
"Do not use this argument when '--image' is set.") | |
parser.add_argument('--device', '-d', type=str, | |
default="cpu", | |
help="Specify the device to load the model. Default: 'cpu'.") | |
parser.add_argument('--image_size', '-s', type=int, | |
default=299, | |
help="Specify the spatial size of the input image(s). Default: 299.") | |
parser.add_argument('--visualize', '-v', action="store_true", | |
default=False, help='Visualize images.') | |
def preprocess(file_path): | |
img = cv2.imread(file_path) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
compose = Compose([Resize(height=args.image_size, width=args.image_size), | |
Normalize(mean=[0.5] * 3, std=[0.5] * 3), | |
ToTensorV2()]) | |
img = compose(image=img)['image'].unsqueeze(0) | |
return img | |
def prepare_data(): | |
paths = list() | |
images = list() | |
# check the console arguments | |
if args.image and args.image_folder: | |
raise ValueError("Only one of '--image' or '--image_folder' can be set.") | |
elif args.image: | |
images.append(preprocess(args.image)) | |
paths.append(args.image) | |
elif args.image_folder: | |
image_folder = '.'.join(args.image_folder.split('.')[:-1]) | |
image_paths = glob(image_folder + "/*.jpg") | |
image_paths.extend(glob(image_folder + "/*.png")) | |
for _ in image_paths: | |
images.append(preprocess(_)) | |
paths.append(_) | |
else: | |
raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either " | |
"one of these two arguments to load input image(s) properly.") | |
return paths, images | |
def inference(model, images, paths, device): | |
mean_pred = 0 | |
for img, pt in zip(images, paths): | |
img = img.to(device) | |
prediction = model(img) | |
prediction = torch.sigmoid(prediction).cpu() | |
fake = True if prediction >= 0.5 else False | |
mean_pred += prediction.item() | |
print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| " | |
f"prediction: {'fake' if fake else 'real'}") | |
if args.visualize: | |
cvimg = cv2.imread(pt) | |
cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}", | |
(5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, | |
(0, 0, 255) if fake else (255, 0, 0), 2) | |
cv2.imshow("image", cvimg) | |
cv2.waitKey(0) | |
cv2.destroyWindow("image") | |
mean_pred = mean_pred / len(images) | |
return mean_pred | |
def main(): | |
print("Arguments:\n", args, end="\n\n") | |
# set device | |
device = torch.device(args.device) | |
# load model | |
model = eval("Recce")(num_classes=1) | |
# check the console arguments | |
if args.weight and args.bin: | |
raise ValueError("Only one of '--weight' or '--bin' can be set.") | |
elif args.weight: | |
weights = torch.load(args.weight, map_location="cpu") | |
elif args.bin: | |
weights = torch.load(args.bin, map_location="cpu")["model"] | |
else: | |
raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either " | |
"one of these two arguments to load model's weight properly.") | |
model.load_state_dict(weights) | |
model = model.to(device) | |
freeze_weights(model) | |
model.eval() | |
paths, images = prepare_data() | |
print("Inference:") | |
mean_pred = inference(model, images=images, paths=paths, device=device) | |
print("Mean prediction:", mean_pred) | |
if __name__ == '__main__': | |
args = parser.parse_args() | |
main() | |