|
import os |
|
import cv2 |
|
import sys |
|
import tqdm |
|
import torch |
|
import argparse |
|
|
|
import numpy as np |
|
|
|
from PIL import Image |
|
|
|
filepath = os.path.split(os.path.abspath(__file__))[0] |
|
repopath = os.path.split(filepath)[0] |
|
sys.path.append(repopath) |
|
|
|
from lib import * |
|
from utils.misc import * |
|
from data.dataloader import * |
|
from data.custom_transforms import * |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
torch.backends.cudnn.allow_tf32 = False |
|
|
|
def _args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml') |
|
parser.add_argument('--source', '-s', type=str) |
|
parser.add_argument('--dest', '-d', type=str, default=None) |
|
parser.add_argument('--type', '-t', type=str, default='map') |
|
parser.add_argument('--gpu', '-g', action='store_true', default=False) |
|
parser.add_argument('--jit', '-j', action='store_true', default=False) |
|
parser.add_argument('--verbose', '-v', action='store_true', default=False) |
|
return parser.parse_args() |
|
|
|
def get_format(source): |
|
img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))]) |
|
vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))]) |
|
|
|
if img_count * vid_count != 0: |
|
return '' |
|
elif img_count != 0: |
|
return 'Image' |
|
elif vid_count != 0: |
|
return 'Video' |
|
else: |
|
return '' |
|
|
|
def inference(opt, args): |
|
model = eval(opt.Model.name)(**opt.Model) |
|
model.load_state_dict(torch.load(os.path.join( |
|
opt.Test.Checkpoint.checkpoint_dir, 'latest.pth'), map_location=torch.device('cpu')), strict=True) |
|
|
|
if args.gpu is True: |
|
model = model.cuda() |
|
model.eval() |
|
|
|
if args.jit is True: |
|
if os.path.isfile(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) is False: |
|
model = Simplify(model) |
|
model = torch.jit.trace(model, torch.rand(1, 3, *opt.Test.Dataset.transforms.static_resize.size).cuda(), strict=False) |
|
torch.jit.save(model, os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) |
|
|
|
else: |
|
del model |
|
model = torch.jit.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) |
|
|
|
save_dir = None |
|
_format = None |
|
|
|
if args.source.isnumeric() is True: |
|
_format = 'Webcam' |
|
|
|
elif os.path.isdir(args.source): |
|
save_dir = os.path.join('results', args.source.split(os.sep)[-1]) |
|
_format = get_format(os.listdir(args.source)) |
|
|
|
elif os.path.isfile(args.source): |
|
save_dir = 'results' |
|
_format = get_format([args.source]) |
|
|
|
if args.dest is not None: |
|
save_dir = args.dest |
|
|
|
if save_dir is not None: |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
sample_list = eval(_format + 'Loader')(args.source, opt.Test.Dataset.transforms) |
|
|
|
if args.verbose is True: |
|
samples = tqdm.tqdm(sample_list, desc='Inference', total=len( |
|
sample_list), position=0, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}') |
|
else: |
|
samples = sample_list |
|
|
|
writer = None |
|
background = None |
|
|
|
for sample in samples: |
|
if _format == 'Video' and writer is None: |
|
writer = cv2.VideoWriter(os.path.join(save_dir, sample['name'] + '.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), sample_list.fps, sample['shape'][::-1]) |
|
samples.total += int(sample_list.cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
if _format == 'Video' and sample['image'] is None: |
|
if writer is not None: |
|
writer.release() |
|
writer = None |
|
continue |
|
|
|
if args.gpu is True: |
|
sample = to_cuda(sample) |
|
|
|
with torch.no_grad(): |
|
if args.jit is True: |
|
out = model(sample['image']) |
|
else: |
|
out = model(sample) |
|
|
|
|
|
pred = to_numpy(out['pred'], sample['shape']) |
|
img = np.array(sample['original']) |
|
|
|
if args.type == 'map': |
|
img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8) |
|
elif args.type == 'rgba': |
|
r, g, b = cv2.split(img) |
|
pred = (pred * 255).astype(np.uint8) |
|
img = cv2.merge([r, g, b, pred]) |
|
elif args.type == 'green': |
|
bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] |
|
img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
|
elif args.type == 'blur': |
|
img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * (1 - pred[..., np.newaxis]) |
|
elif args.type == 'overlay': |
|
bg = (np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img) // 2 |
|
img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis]) |
|
border = cv2.Canny(((pred > .5) * 255).astype(np.uint8), 50, 100) |
|
img[border != 0] = [120, 255, 155] |
|
elif args.type.lower().endswith(('.jpg', '.jpeg', '.png')): |
|
if background is None: |
|
background = cv2.cvtColor(cv2.imread(args.type), cv2.COLOR_BGR2RGB) |
|
background = cv2.resize(background, img.shape[:2][::-1]) |
|
img = img * pred[..., np.newaxis] + background * (1 - pred[..., np.newaxis]) |
|
elif args.type == 'debug': |
|
debs = [] |
|
for k in opt.Train.Debug.keys: |
|
debs.extend(out[k]) |
|
for i, j in enumerate(debs): |
|
log = torch.sigmoid(j).cpu().detach().numpy().squeeze() |
|
log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8) |
|
log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB) |
|
log = cv2.resize(log, img.shape[:2][::-1]) |
|
Image.fromarray(log).save(os.path.join(save_dir, sample['name'] + '_' + str(i) + '.png')) |
|
|
|
|
|
|
|
img = img.astype(np.uint8) |
|
|
|
if _format == 'Image': |
|
Image.fromarray(img).save(os.path.join(save_dir, sample['name'] + '.png')) |
|
elif _format == 'Video' and writer is not None: |
|
writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
elif _format == 'Webcam': |
|
cv2.imshow('InSPyReNet', img) |
|
|
|
if __name__ == "__main__": |
|
args = _args() |
|
opt = load_config(args.config) |
|
inference(opt, args) |
|
|