crapthings's picture
Upload folder using huggingface_hub
f7f604d
raw
history blame
2.88 kB
import re
import os
import yaml
import cv2
import argparse
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict as ed
class Simplify(nn.Module):
def __init__(self, model):
super(Simplify, self).__init__()
self.model = model
def cuda(self):
self.model = self.model.cuda()
return self
def forward(self, x):
out = self.model({'image': x})
return out['pred']
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml')
parser.add_argument('--resume', '-r', action='store_true', default=False)
parser.add_argument('--verbose', '-v', action='store_true', default=False)
parser.add_argument('--debug', '-d', action='store_true', default=False)
args = parser.parse_args()
cuda_visible_devices = None
local_rank = -1
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
cuda_visible_devices = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(',')]
if "LOCAL_RANK" in os.environ.keys():
local_rank = int(os.environ["LOCAL_RANK"])
if local_rank == -1:
device_num = 1
elif cuda_visible_devices is None:
device_num = torch.cuda.device_count()
else:
device_num = len(cuda_visible_devices)
args.device_num = device_num
args.local_rank = local_rank
warnings.simplefilter("ignore")
return args
def sort(x):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(x, key=alphanum_key)
def load_config(config_dir, easy=True):
cfg = yaml.load(open(config_dir), yaml.FullLoader)
if easy is True:
cfg = ed(cfg)
return cfg
def to_cuda(sample):
for key in sample.keys():
if type(sample[key]) == torch.Tensor:
sample[key] = sample[key].cuda()
return sample
def to_numpy(pred, shape):
pred = F.interpolate(pred, shape, mode='bilinear', align_corners=True)
pred = pred.data.cpu()
pred = pred.numpy().squeeze()
return pred
def debug_tile(deblist, size=(100, 100), activation=None):
debugs = []
for debs in deblist:
debug = []
for deb in debs:
if activation is not None:
deb = activation(deb)
log = deb.cpu().detach().numpy().squeeze()
log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8)
log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB)
log = cv2.resize(log, size)
debug.append(log)
debugs.append(np.vstack(debug))
return np.hstack(debugs)
if __name__ == "__main__":
x = torch.rand(4, 3, 576, 576)