my-cool-model / run /Inference.py
crapthings's picture
Upload folder using huggingface_hub
f7f604d
import os
import cv2
import sys
import tqdm
import torch
import argparse
import numpy as np
from PIL import Image
filepath = os.path.split(os.path.abspath(__file__))[0]
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from lib import *
from utils.misc import *
from data.dataloader import *
from data.custom_transforms import *
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def _args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml')
parser.add_argument('--source', '-s', type=str)
parser.add_argument('--dest', '-d', type=str, default=None)
parser.add_argument('--type', '-t', type=str, default='map')
parser.add_argument('--gpu', '-g', action='store_true', default=False)
parser.add_argument('--jit', '-j', action='store_true', default=False)
parser.add_argument('--verbose', '-v', action='store_true', default=False)
return parser.parse_args()
def get_format(source):
img_count = len([i for i in source if i.lower().endswith(('.jpg', '.png', '.jpeg'))])
vid_count = len([i for i in source if i.lower().endswith(('.mp4', '.avi', '.mov' ))])
if img_count * vid_count != 0:
return ''
elif img_count != 0:
return 'Image'
elif vid_count != 0:
return 'Video'
else:
return ''
def inference(opt, args):
model = eval(opt.Model.name)(**opt.Model)
model.load_state_dict(torch.load(os.path.join(
opt.Test.Checkpoint.checkpoint_dir, 'latest.pth'), map_location=torch.device('cpu')), strict=True)
if args.gpu is True:
model = model.cuda()
model.eval()
if args.jit is True:
if os.path.isfile(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt')) is False:
model = Simplify(model)
model = torch.jit.trace(model, torch.rand(1, 3, *opt.Test.Dataset.transforms.static_resize.size).cuda(), strict=False)
torch.jit.save(model, os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt'))
else:
del model
model = torch.jit.load(os.path.join(opt.Test.Checkpoint.checkpoint_dir, 'jit.pt'))
save_dir = None
_format = None
if args.source.isnumeric() is True:
_format = 'Webcam'
elif os.path.isdir(args.source):
save_dir = os.path.join('results', args.source.split(os.sep)[-1])
_format = get_format(os.listdir(args.source))
elif os.path.isfile(args.source):
save_dir = 'results'
_format = get_format([args.source])
if args.dest is not None:
save_dir = args.dest
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
sample_list = eval(_format + 'Loader')(args.source, opt.Test.Dataset.transforms)
if args.verbose is True:
samples = tqdm.tqdm(sample_list, desc='Inference', total=len(
sample_list), position=0, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
else:
samples = sample_list
writer = None
background = None
for sample in samples:
if _format == 'Video' and writer is None:
writer = cv2.VideoWriter(os.path.join(save_dir, sample['name'] + '.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), sample_list.fps, sample['shape'][::-1])
samples.total += int(sample_list.cap.get(cv2.CAP_PROP_FRAME_COUNT))
if _format == 'Video' and sample['image'] is None:
if writer is not None:
writer.release()
writer = None
continue
if args.gpu is True:
sample = to_cuda(sample)
with torch.no_grad():
if args.jit is True:
out = model(sample['image'])
else:
out = model(sample)
pred = to_numpy(out['pred'], sample['shape'])
img = np.array(sample['original'])
if args.type == 'map':
img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8)
elif args.type == 'rgba':
r, g, b = cv2.split(img)
pred = (pred * 255).astype(np.uint8)
img = cv2.merge([r, g, b, pred])
elif args.type == 'green':
bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155]
img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis])
elif args.type == 'blur':
img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * (1 - pred[..., np.newaxis])
elif args.type == 'overlay':
bg = (np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img) // 2
img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis])
border = cv2.Canny(((pred > .5) * 255).astype(np.uint8), 50, 100)
img[border != 0] = [120, 255, 155]
elif args.type.lower().endswith(('.jpg', '.jpeg', '.png')):
if background is None:
background = cv2.cvtColor(cv2.imread(args.type), cv2.COLOR_BGR2RGB)
background = cv2.resize(background, img.shape[:2][::-1])
img = img * pred[..., np.newaxis] + background * (1 - pred[..., np.newaxis])
elif args.type == 'debug':
debs = []
for k in opt.Train.Debug.keys:
debs.extend(out[k])
for i, j in enumerate(debs):
log = torch.sigmoid(j).cpu().detach().numpy().squeeze()
log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8)
log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB)
log = cv2.resize(log, img.shape[:2][::-1])
Image.fromarray(log).save(os.path.join(save_dir, sample['name'] + '_' + str(i) + '.png'))
# size=img.shape[:2][::-1]
img = img.astype(np.uint8)
if _format == 'Image':
Image.fromarray(img).save(os.path.join(save_dir, sample['name'] + '.png'))
elif _format == 'Video' and writer is not None:
writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
elif _format == 'Webcam':
cv2.imshow('InSPyReNet', img)
if __name__ == "__main__":
args = _args()
opt = load_config(args.config)
inference(opt, args)