|
import torch |
|
import numpy as np |
|
import utils.basic |
|
from sklearn.decomposition import PCA |
|
from matplotlib import cm |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
import torch.nn.functional as F |
|
import torchvision |
|
EPS = 1e-6 |
|
|
|
from skimage.color import ( |
|
rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb, |
|
rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb) |
|
|
|
def _convert(input_, type_): |
|
return { |
|
'float': input_.float(), |
|
'double': input_.double(), |
|
}.get(type_, input_) |
|
|
|
def _generic_transform_sk_3d(transform, in_type='', out_type=''): |
|
def apply_transform_individual(input_): |
|
device = input_.device |
|
input_ = input_.cpu() |
|
input_ = _convert(input_, in_type) |
|
|
|
input_ = input_.permute(1, 2, 0).detach().numpy() |
|
transformed = transform(input_) |
|
output = torch.from_numpy(transformed).float().permute(2, 0, 1) |
|
output = _convert(output, out_type) |
|
return output.to(device) |
|
|
|
def apply_transform(input_): |
|
to_stack = [] |
|
for image in input_: |
|
to_stack.append(apply_transform_individual(image)) |
|
return torch.stack(to_stack) |
|
return apply_transform |
|
|
|
hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb) |
|
|
|
def preprocess_color_tf(x): |
|
import tensorflow as tf |
|
return tf.cast(x,tf.float32) * 1./255 - 0.5 |
|
|
|
def preprocess_color(x): |
|
if isinstance(x, np.ndarray): |
|
return x.astype(np.float32) * 1./255 - 0.5 |
|
else: |
|
return x.float() * 1./255 - 0.5 |
|
|
|
def pca_embed(emb, keep, valid=None): |
|
|
|
|
|
|
|
emb = emb + EPS |
|
|
|
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() |
|
|
|
if valid: |
|
valid = valid.cpu().detach().numpy().reshape((H*W)) |
|
|
|
emb_reduced = list() |
|
|
|
B, H, W, C = np.shape(emb) |
|
for img in emb: |
|
if np.isnan(img).any(): |
|
emb_reduced.append(np.zeros([H, W, keep])) |
|
continue |
|
|
|
pixels_kd = np.reshape(img, (H*W, C)) |
|
|
|
if valid: |
|
pixels_kd_pca = pixels_kd[valid] |
|
else: |
|
pixels_kd_pca = pixels_kd |
|
|
|
P = PCA(keep) |
|
P.fit(pixels_kd_pca) |
|
|
|
if valid: |
|
pixels3d = P.transform(pixels_kd)*valid |
|
else: |
|
pixels3d = P.transform(pixels_kd) |
|
|
|
out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32) |
|
if np.isnan(out_img).any(): |
|
emb_reduced.append(np.zeros([H, W, keep])) |
|
continue |
|
|
|
emb_reduced.append(out_img) |
|
|
|
emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32) |
|
|
|
return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2) |
|
|
|
def pca_embed_together(emb, keep): |
|
|
|
|
|
|
|
emb = emb + EPS |
|
|
|
emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() |
|
|
|
B, H, W, C = np.shape(emb) |
|
if np.isnan(emb).any(): |
|
return torch.zeros(B, keep, H, W) |
|
|
|
pixelskd = np.reshape(emb, (B*H*W, C)) |
|
P = PCA(keep) |
|
P.fit(pixelskd) |
|
pixels3d = P.transform(pixelskd) |
|
out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32) |
|
|
|
if np.isnan(out_img).any(): |
|
return torch.zeros(B, keep, H, W) |
|
|
|
return torch.from_numpy(out_img).permute(0, 3, 1, 2) |
|
|
|
def reduce_emb(emb, valid=None, inbound=None, together=False): |
|
|
|
|
|
|
|
S, C, H, W = list(emb.size()) |
|
keep = 3 |
|
|
|
if together: |
|
reduced_emb = pca_embed_together(emb, keep) |
|
else: |
|
reduced_emb = pca_embed(emb, keep, valid) |
|
|
|
reduced_emb = utils.basic.normalize(reduced_emb) - 0.5 |
|
if inbound is not None: |
|
emb_inbound = emb*inbound |
|
else: |
|
emb_inbound = None |
|
|
|
return reduced_emb, emb_inbound |
|
|
|
def get_feat_pca(feat, valid=None): |
|
B, C, D, W = list(feat.size()) |
|
|
|
|
|
pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True) |
|
|
|
return pca |
|
|
|
def gif_and_tile(ims, just_gif=False): |
|
S = len(ims) |
|
|
|
|
|
|
|
|
|
gif = torch.stack(ims, dim=1) |
|
if just_gif: |
|
return gif |
|
til = torch.cat(ims, dim=2) |
|
til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1) |
|
im = torch.cat([gif, til], dim=3) |
|
return im |
|
|
|
def back2color(i, blacken_zeros=False): |
|
if blacken_zeros: |
|
const = torch.tensor([-0.5]) |
|
i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i) |
|
return back2color(i) |
|
else: |
|
return ((i+0.5)*255).type(torch.ByteTensor) |
|
|
|
def convert_occ_to_height(occ, reduce_axis=3): |
|
B, C, D, H, W = list(occ.shape) |
|
assert(C==1) |
|
|
|
|
|
|
|
G = list(occ.shape)[reduce_axis] |
|
values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device) |
|
if reduce_axis==2: |
|
|
|
values = values.view(1, 1, G, 1, 1) |
|
elif reduce_axis==3: |
|
|
|
values = values.view(1, 1, 1, G, 1) |
|
elif reduce_axis==4: |
|
|
|
values = values.view(1, 1, 1, 1, G) |
|
else: |
|
assert(False) |
|
values = torch.max(occ*values, dim=reduce_axis)[0]/float(G) |
|
|
|
return values |
|
|
|
def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False): |
|
|
|
|
|
|
|
B, N, Y, X = list(grid_xs.shape) |
|
|
|
mu_x = xy[:,:,0].clone() |
|
mu_y = xy[:,:,1].clone() |
|
|
|
x_valid = (mu_x>-0.5) & (mu_x<float(X+0.5)) |
|
y_valid = (mu_y>-0.5) & (mu_y<float(Y+0.5)) |
|
not_valid = ~(x_valid & y_valid) |
|
|
|
mu_x[not_valid] = -10000 |
|
mu_y[not_valid] = -10000 |
|
|
|
mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X) |
|
mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X) |
|
|
|
sigma_sq = sigma*sigma |
|
|
|
sq_diff_x = (grid_xs - mu_x)**2 |
|
sq_diff_y = (grid_ys - mu_y)**2 |
|
|
|
term1 = 1./2.*np.pi*sigma_sq |
|
term2 = torch.exp(-(sq_diff_x+sq_diff_y)/(2.*sigma_sq)) |
|
gauss = term1*term2 |
|
|
|
if norm: |
|
|
|
gauss_ = gauss.reshape(B*N, Y, X) |
|
gauss_ = utils.basic.normalize(gauss_) |
|
gauss = gauss_.reshape(B, N, Y, X) |
|
|
|
return gauss |
|
|
|
def xy2heatmaps(xy, Y, X, sigma=30.0, norm=True): |
|
|
|
|
|
B, N, D = list(xy.shape) |
|
assert(D==2) |
|
|
|
device = xy.device |
|
|
|
grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device) |
|
|
|
grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1) |
|
grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1) |
|
heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm) |
|
return heat |
|
|
|
def draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False): |
|
B, N, D = list(xy.shape) |
|
assert(D==2) |
|
prior = xy2heatmaps(xy, Y, X, sigma=sigma) |
|
|
|
if round: |
|
prior = (prior > 0.5).float() |
|
return prior |
|
|
|
def seq2color(im, norm=True, colormap='coolwarm'): |
|
B, S, H, W = list(im.shape) |
|
|
|
|
|
|
|
mask = torch.max(im, dim=1, keepdim=True)[0] |
|
|
|
|
|
coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S) |
|
|
|
|
|
|
|
|
|
|
|
coeffs = torch.from_numpy(coeffs).float().cuda() |
|
coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W) |
|
|
|
im = im * coeffs |
|
|
|
|
|
im = torch.max(im, dim=1, keepdim=True)[0] |
|
|
|
out = [] |
|
for b in range(B): |
|
im_ = im[b] |
|
|
|
im_ = im_.detach().cpu().numpy() |
|
im_ = np.squeeze(im_) |
|
|
|
if colormap=='coolwarm': |
|
im_ = cm.coolwarm(im_)[:, :, :3] |
|
elif colormap=='PiYG': |
|
im_ = cm.PiYG(im_)[:, :, :3] |
|
elif colormap=='winter': |
|
im_ = cm.winter(im_)[:, :, :3] |
|
elif colormap=='spring': |
|
im_ = cm.spring(im_)[:, :, :3] |
|
elif colormap=='onediff': |
|
im_ = np.reshape(im_, (-1)) |
|
im0_ = cm.spring(im_)[:, :3] |
|
im1_ = cm.winter(im_)[:, :3] |
|
im1_[im_==1/float(S)] = im0_[im_==1/float(S)] |
|
im_ = np.reshape(im1_, (H, W, 3)) |
|
else: |
|
assert(False) |
|
|
|
im_ = np.transpose(im_, [2, 0, 1]) |
|
im_ = torch.from_numpy(im_).float().cuda() |
|
out.append(im_) |
|
out = torch.stack(out, dim=0) |
|
|
|
|
|
out = out*mask |
|
|
|
|
|
|
|
out = out - 0.5 |
|
|
|
return out |
|
|
|
def colorize(d): |
|
|
|
|
|
if d.ndim==2: |
|
d = d.unsqueeze(dim=0) |
|
else: |
|
assert(d.ndim==3) |
|
|
|
|
|
color_map = cm.get_cmap('inferno') |
|
|
|
|
|
|
|
C,H,W = d.shape |
|
assert(C==1) |
|
d = d.reshape(-1) |
|
d = d.detach().cpu().numpy() |
|
|
|
color = np.array(color_map(d)) * 255 |
|
|
|
color = np.reshape(color[:,:3], [H*W, 3]) |
|
|
|
color = torch.from_numpy(color).permute(1,0).reshape(3,H,W) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return color |
|
|
|
|
|
def oned2inferno(d, norm=True, do_colorize=False): |
|
|
|
|
|
|
|
if d.ndim==3: |
|
d = d.unsqueeze(dim=1) |
|
|
|
B, C, H, W = list(d.shape) |
|
assert(C==1) |
|
|
|
if norm: |
|
d = utils.basic.normalize(d) |
|
|
|
if do_colorize: |
|
rgb = torch.zeros(B, 3, H, W) |
|
for b in list(range(B)): |
|
rgb[b] = colorize(d[b]) |
|
else: |
|
rgb = d.repeat(1, 3, 1, 1)*255.0 |
|
|
|
rgb = rgb.type(torch.ByteTensor) |
|
|
|
|
|
|
|
|
|
return rgb |
|
|
|
def oned2gray(d, norm=True): |
|
|
|
|
|
|
|
if d.ndim==3: |
|
d = d.unsqueeze(dim=1) |
|
|
|
B, C, H, W = list(d.shape) |
|
assert(C==1) |
|
|
|
if norm: |
|
d = utils.basic.normalize(d) |
|
|
|
rgb = d.repeat(1,3,1,1) |
|
rgb = (255.0*rgb).type(torch.ByteTensor) |
|
return rgb |
|
|
|
|
|
def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20): |
|
|
|
rgb = vis.detach().cpu().numpy()[0] |
|
rgb = np.transpose(rgb, [1, 2, 0]) |
|
rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) |
|
color = (255, 255, 255) |
|
|
|
|
|
frame_str = utils.basic.strnum(frame_id) |
|
|
|
text_color_bg = (0,0,0) |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
text_size, _ = cv2.getTextSize(frame_str, font, scale, 1) |
|
text_w, text_h = text_size |
|
cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1) |
|
|
|
cv2.putText( |
|
rgb, |
|
frame_str, |
|
(left, top), |
|
font, |
|
scale, |
|
color, |
|
1) |
|
rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) |
|
vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
|
return vis |
|
|
|
COLORMAP_FILE = "./utils/bremm.png" |
|
class ColorMap2d: |
|
def __init__(self, filename=None): |
|
self._colormap_file = filename or COLORMAP_FILE |
|
self._img = plt.imread(self._colormap_file) |
|
|
|
self._height = self._img.shape[0] |
|
self._width = self._img.shape[1] |
|
|
|
def __call__(self, X): |
|
assert len(X.shape) == 2 |
|
output = np.zeros((X.shape[0], 3)) |
|
for i in range(X.shape[0]): |
|
x, y = X[i, :] |
|
xp = int((self._width-1) * x) |
|
yp = int((self._height-1) * y) |
|
xp = np.clip(xp, 0, self._width-1) |
|
yp = np.clip(yp, 0, self._height-1) |
|
output[i, :] = self._img[yp, xp] |
|
return output |
|
|
|
def get_n_colors(N, sequential=False): |
|
label_colors = [] |
|
for ii in range(N): |
|
if sequential: |
|
rgb = cm.winter(ii/(N-1)) |
|
rgb = (np.array(rgb) * 255).astype(np.uint8)[:3] |
|
else: |
|
rgb = np.zeros(3) |
|
while np.sum(rgb) < 128: |
|
rgb = np.random.randint(0,256,3) |
|
label_colors.append(rgb) |
|
return label_colors |
|
|
|
class Summ_writer(object): |
|
def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False): |
|
self.writer = writer |
|
self.global_step = global_step |
|
self.log_freq = log_freq |
|
self.fps = fps |
|
self.just_gif = just_gif |
|
self.maxwidth = 10000 |
|
self.save_this = (self.global_step % self.log_freq == 0) |
|
self.scalar_freq = max(scalar_freq,1) |
|
|
|
|
|
def summ_gif(self, name, tensor, blacken_zeros=False): |
|
|
|
|
|
assert tensor.dtype in {torch.uint8,torch.float32} |
|
shape = list(tensor.shape) |
|
|
|
if tensor.dtype == torch.float32: |
|
tensor = back2color(tensor, blacken_zeros=blacken_zeros) |
|
|
|
video_to_write = tensor[0:1] |
|
|
|
S = video_to_write.shape[1] |
|
if S==1: |
|
|
|
self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step) |
|
else: |
|
self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step) |
|
|
|
return video_to_write |
|
|
|
def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1): |
|
B, C, H, W = list(rgb.shape) |
|
assert(C==3) |
|
B2, N, D = list(boxlist.shape) |
|
assert(B2==B) |
|
assert(D==4) |
|
|
|
rgb = back2color(rgb) |
|
if scores is None: |
|
scores = torch.ones(B2, N).float() |
|
if tids is None: |
|
tids = torch.arange(N).reshape(1,N).repeat(B2,N).long() |
|
|
|
out = self.draw_boxlist2d_on_image_py( |
|
rgb[0].cpu().detach().numpy(), |
|
boxlist[0].cpu().detach().numpy(), |
|
scores[0].cpu().detach().numpy(), |
|
tids[0].cpu().detach().numpy(), |
|
linewidth=linewidth) |
|
out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1) |
|
out = torch.unsqueeze(out, dim=0) |
|
out = preprocess_color(out) |
|
out = torch.reshape(out, [1, C, H, W]) |
|
return out |
|
|
|
def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1): |
|
|
|
|
|
|
|
|
|
|
|
|
|
rgb = np.transpose(rgb, [1, 2, 0]) |
|
|
|
|
|
rgb = rgb.astype(np.uint8).copy() |
|
|
|
|
|
H, W, C = rgb.shape |
|
assert(C==3) |
|
N, D = boxlist.shape |
|
assert(D==4) |
|
|
|
|
|
|
|
color_map = cm.get_cmap('Accent') |
|
color_map = color_map.colors |
|
|
|
|
|
|
|
for ind, box in enumerate(boxlist): |
|
|
|
if not np.isclose(scores[ind], 0.0): |
|
|
|
ymin, xmin, ymax, xmax = box |
|
|
|
|
|
|
|
|
|
|
|
|
|
color_id = tids[ind] |
|
color = color_map[color_id] |
|
color = np.array(color)*255.0 |
|
color = color.round() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if scores[ind] < 1.0: |
|
cv2.putText(rgb, |
|
|
|
'%.2f' % (scores[ind]), |
|
(int(xmin), int(ymin)), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.5, |
|
color), |
|
|
|
|
|
xmin = np.clip(int(xmin), 0, W-1) |
|
xmax = np.clip(int(xmax), 0, W-1) |
|
ymin = np.clip(int(ymin), 0, H-1) |
|
ymax = np.clip(int(ymax), 0, H-1) |
|
|
|
cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA) |
|
cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA) |
|
cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA) |
|
cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA) |
|
|
|
|
|
return rgb |
|
|
|
def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2): |
|
B, C, H, W = list(rgb.shape) |
|
boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth) |
|
return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return) |
|
|
|
def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False): |
|
if self.save_this: |
|
|
|
ims = gif_and_tile(ims, just_gif=self.just_gif) |
|
vis = ims |
|
|
|
assert vis.dtype in {torch.uint8,torch.float32} |
|
|
|
if vis.dtype == torch.float32: |
|
vis = back2color(vis, blacken_zeros) |
|
|
|
B, S, C, H, W = list(vis.shape) |
|
|
|
if frame_ids is not None: |
|
assert(len(frame_ids)==S) |
|
for s in range(S): |
|
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) |
|
|
|
if int(W) > self.maxwidth: |
|
vis = vis[:,:,:,:self.maxwidth] |
|
|
|
if only_return: |
|
return vis |
|
else: |
|
return self.summ_gif(name, vis, blacken_zeros) |
|
|
|
def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False): |
|
if self.save_this: |
|
assert ims.dtype in {torch.uint8,torch.float32} |
|
|
|
if ims.dtype == torch.float32: |
|
ims = back2color(ims, blacken_zeros) |
|
|
|
|
|
vis = ims[0:1] |
|
B, C, H, W = list(vis.shape) |
|
|
|
if halfres: |
|
vis = F.interpolate(vis, scale_factor=0.5) |
|
|
|
if frame_id is not None: |
|
vis = draw_frame_id_on_vis(vis, frame_id) |
|
|
|
if int(W) > self.maxwidth: |
|
vis = vis[:,:,:,:self.maxwidth] |
|
|
|
if only_return: |
|
return vis |
|
else: |
|
return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros) |
|
|
|
def flow2color(self, flow, clip=50.0): |
|
""" |
|
:param flow: Optical flow tensor. |
|
:return: RGB image normalized between 0 and 1. |
|
""" |
|
|
|
|
|
|
|
B, C, H, W = list(flow.size()) |
|
|
|
flow = flow.clone().detach() |
|
|
|
abs_image = torch.abs(flow) |
|
flow_mean = abs_image.mean(dim=[1,2,3]) |
|
flow_std = abs_image.std(dim=[1,2,3]) |
|
|
|
if clip: |
|
flow = torch.clamp(flow, -clip, clip)/clip |
|
else: |
|
|
|
flow_max = flow_mean + flow_std*2 + 1e-10 |
|
for b in range(B): |
|
flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1) |
|
|
|
radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) |
|
radius_clipped = torch.clamp(radius, 0.0, 1.0) |
|
|
|
angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi |
|
|
|
hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0) |
|
saturation = torch.ones_like(hue) * 0.75 |
|
value = radius_clipped |
|
hsv = torch.cat([hue, saturation, value], dim=1) |
|
|
|
|
|
flow = hsv_to_rgb(hsv) |
|
flow = (flow*255.0).type(torch.ByteTensor) |
|
return flow |
|
|
|
def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None): |
|
|
|
if self.save_this: |
|
return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id) |
|
else: |
|
return None |
|
|
|
def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False): |
|
if self.save_this: |
|
if bev: |
|
B, C, H, _, W = list(ims[0].shape) |
|
if reduce_max: |
|
ims = [torch.max(im, dim=3)[0] for im in ims] |
|
else: |
|
ims = [torch.mean(im, dim=3) for im in ims] |
|
elif fro: |
|
B, C, _, H, W = list(ims[0].shape) |
|
if reduce_max: |
|
ims = [torch.max(im, dim=2)[0] for im in ims] |
|
else: |
|
ims = [torch.mean(im, dim=2) for im in ims] |
|
|
|
|
|
if len(ims) != 1: |
|
im = gif_and_tile(ims, just_gif=self.just_gif) |
|
else: |
|
im = torch.stack(ims, dim=1) |
|
|
|
B, S, C, H, W = list(im.shape) |
|
|
|
if logvis and max_val: |
|
max_val = np.log(max_val) |
|
im = torch.log(torch.clamp(im, 0)+1.0) |
|
im = torch.clamp(im, 0, max_val) |
|
im = im/max_val |
|
norm = False |
|
elif max_val: |
|
im = torch.clamp(im, 0, max_val) |
|
im = im/max_val |
|
norm = False |
|
|
|
if norm: |
|
|
|
|
|
im = utils.basic.normalize(im) |
|
|
|
im = im.view(B*S, C, H, W) |
|
vis = oned2inferno(im, norm=norm, do_colorize=do_colorize) |
|
vis = vis.view(B, S, 3, H, W) |
|
|
|
if frame_ids is not None: |
|
assert(len(frame_ids)==S) |
|
for s in range(S): |
|
vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) |
|
|
|
if W > self.maxwidth: |
|
vis = vis[...,:self.maxwidth] |
|
|
|
if only_return: |
|
return vis |
|
else: |
|
self.summ_gif(name, vis) |
|
|
|
def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, only_return=False): |
|
if self.save_this: |
|
|
|
if bev: |
|
B, C, H, _, W = list(im.shape) |
|
if max_along_y: |
|
im = torch.max(im, dim=3)[0] |
|
else: |
|
im = torch.mean(im, dim=3) |
|
elif fro: |
|
B, C, _, H, W = list(im.shape) |
|
if max_along_y: |
|
im = torch.max(im, dim=2)[0] |
|
else: |
|
im = torch.mean(im, dim=2) |
|
else: |
|
B, C, H, W = list(im.shape) |
|
|
|
im = im[0:1] |
|
assert(C==1) |
|
|
|
if logvis and max_val: |
|
max_val = np.log(max_val) |
|
im = torch.log(im) |
|
im = torch.clamp(im, 0, max_val) |
|
im = im/max_val |
|
norm = False |
|
elif max_val: |
|
im = torch.clamp(im, 0, max_val)/max_val |
|
norm = False |
|
|
|
vis = oned2inferno(im, norm=norm) |
|
if W > self.maxwidth: |
|
vis = vis[...,:self.maxwidth] |
|
return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return) |
|
|
|
def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None): |
|
if self.save_this: |
|
if valids is not None: |
|
valids = torch.stack(valids, dim=1) |
|
|
|
feats = torch.stack(feats, dim=1) |
|
|
|
|
|
if feats.ndim==6: |
|
|
|
|
|
if fro: |
|
reduce_dim = 3 |
|
else: |
|
reduce_dim = 4 |
|
|
|
if valids is None: |
|
feats = torch.mean(feats, dim=reduce_dim) |
|
else: |
|
valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1) |
|
feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim) |
|
|
|
B, S, C, D, W = list(feats.size()) |
|
|
|
if not pca: |
|
|
|
feats = torch.mean(torch.abs(feats), dim=2, keepdims=True) |
|
|
|
feats = torch.unbind(feats, dim=1) |
|
return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids) |
|
|
|
else: |
|
__p = lambda x: utils.basic.pack_seqdim(x, B) |
|
__u = lambda x: utils.basic.unpack_seqdim(x, B) |
|
|
|
feats_ = __p(feats) |
|
|
|
if valids is None: |
|
feats_pca_ = get_feat_pca(feats_) |
|
else: |
|
valids_ = __p(valids) |
|
feats_pca_ = get_feat_pca(feats_, valids) |
|
|
|
feats_pca = __u(feats_pca_) |
|
|
|
return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids) |
|
|
|
def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None): |
|
if self.save_this: |
|
if feat.ndim==5: |
|
|
|
if bev: |
|
reduce_axis = 3 |
|
elif fro: |
|
reduce_axis = 2 |
|
else: |
|
|
|
reduce_axis = 3 |
|
|
|
if valid is None: |
|
feat = torch.mean(feat, dim=reduce_axis) |
|
else: |
|
valid = valid.repeat(1, feat.size()[1], 1, 1, 1) |
|
feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis) |
|
|
|
B, C, D, W = list(feat.shape) |
|
|
|
if not pca: |
|
feat = torch.mean(torch.abs(feat), dim=1, keepdims=True) |
|
|
|
return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id) |
|
else: |
|
feat_pca = get_feat_pca(feat, valid) |
|
return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id) |
|
|
|
def summ_scalar(self, name, value): |
|
if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()): |
|
value = value.detach().cpu().numpy() |
|
if not np.isnan(value): |
|
if (self.log_freq == 1): |
|
self.writer.add_scalar(name, value, global_step=self.global_step) |
|
elif self.save_this or np.mod(self.global_step, self.scalar_freq)==0: |
|
self.writer.add_scalar(name, value, global_step=self.global_step) |
|
|
|
def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None): |
|
if not self.save_this: |
|
return |
|
|
|
B,H,W = seg.shape |
|
|
|
if label_colors is None: |
|
custom_label_colors = False |
|
|
|
label_colors = cm.get_cmap(colormap).colors |
|
label_colors = [[int(i*255) for i in l] for l in label_colors] |
|
else: |
|
custom_label_colors = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = torch.zeros_like(seg,dtype=torch.uint8) |
|
g = torch.zeros_like(seg,dtype=torch.uint8) |
|
b = torch.zeros_like(seg,dtype=torch.uint8) |
|
|
|
for label in range(0,len(label_colors)): |
|
if (not custom_label_colors): |
|
label_ = label % 20 |
|
else: |
|
label_ = label |
|
|
|
idx = (seg == label+1) |
|
r[idx] = label_colors[label_][0] |
|
g[idx] = label_colors[label_][1] |
|
b[idx] = label_colors[label_][2] |
|
|
|
rgb = torch.stack([r,g,b],axis=1) |
|
return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id) |
|
|
|
def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): |
|
|
|
|
|
B, C, H, W = rgb.shape |
|
B, S, N, D = trajs.shape |
|
|
|
rgb = rgb[0] |
|
trajs = trajs[0] |
|
if valids is None: |
|
valids = torch.ones_like(trajs[:,:,0]) |
|
else: |
|
valids = valids[0] |
|
|
|
|
|
|
|
rgb = back2color(rgb).detach().cpu().numpy() |
|
rgb = np.transpose(rgb, [1, 2, 0]) |
|
|
|
trajs = trajs.long().detach().cpu().numpy() |
|
valids = valids.long().detach().cpu().numpy() |
|
|
|
rgb = rgb.astype(np.uint8).copy() |
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i] |
|
valid = valids[:,i] |
|
|
|
color_map = cm.get_cmap(cmap) |
|
color = np.array(color_map(i)[:3]) * 255 |
|
for s in range(S): |
|
if valid[s]: |
|
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1) |
|
rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0) |
|
rgb = preprocess_color(rgb) |
|
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) |
|
|
|
def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): |
|
|
|
|
|
B, S, C, H, W = rgbs.shape |
|
B, S2, N, D = trajs.shape |
|
assert(S==S2) |
|
|
|
rgbs = rgbs[0] |
|
trajs = trajs[0] |
|
if valids is None: |
|
valids = torch.ones_like(trajs[:,:,0]) |
|
else: |
|
valids = valids[0] |
|
|
|
|
|
|
|
rgbs_color = [] |
|
for rgb in rgbs: |
|
rgb = back2color(rgb).detach().cpu().numpy() |
|
rgb = np.transpose(rgb, [1, 2, 0]) |
|
rgbs_color.append(rgb) |
|
|
|
trajs = trajs.long().detach().cpu().numpy() |
|
valids = valids.long().detach().cpu().numpy() |
|
|
|
rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color] |
|
|
|
for i in range(N): |
|
traj = trajs[:,i] |
|
valid = valids[:,i] |
|
|
|
color_map = cm.get_cmap(cmap) |
|
color = np.array(color_map(0)[:3]) * 255 |
|
for s in range(S): |
|
if valid[s]: |
|
cv2.circle(rgbs_color[s], (traj[s,0], traj[s,1]), linewidth, color, -1) |
|
rgbs = [] |
|
for rgb in rgbs_color: |
|
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
|
rgbs.append(preprocess_color(rgb)) |
|
|
|
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) |
|
|
|
|
|
def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1): |
|
|
|
|
|
B, S, C, H, W = rgbs.shape |
|
B, S2, N, D = trajs.shape |
|
assert(S==S2) |
|
|
|
rgbs = rgbs[0] |
|
trajs = trajs[0] |
|
if valids is None: |
|
valids = torch.ones_like(trajs[:,:,0]) |
|
else: |
|
valids = valids[0] |
|
|
|
|
|
|
|
|
|
if vals is not None: |
|
vals = vals[0] |
|
|
|
|
|
rgbs_color = [] |
|
for rgb in rgbs: |
|
rgb = back2color(rgb).detach().cpu().numpy() |
|
rgb = np.transpose(rgb, [1, 2, 0]) |
|
rgbs_color.append(rgb) |
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i].long().detach().cpu().numpy() |
|
valid = valids[:,i].long().detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
if vals is not None: |
|
|
|
val = vals[i].float().detach().cpu().numpy() |
|
|
|
else: |
|
val = None |
|
|
|
for t in range(S): |
|
if valid[t]: |
|
|
|
traj_seq = traj[max(t-8,0):t+1] |
|
val_seq = np.linspace(0,1,len(traj_seq)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) |
|
|
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i] |
|
|
|
vis = torch.ones_like(traj[:,0]) |
|
valid = valids[:,i] |
|
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) |
|
|
|
rgbs = [] |
|
for rgb in rgbs_color: |
|
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
|
rgbs.append(preprocess_color(rgb)) |
|
|
|
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) |
|
|
|
|
|
def summ_traj2ds_on_rgbs_py(self, name, trajs, rgbs_color, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1): |
|
|
|
|
|
|
|
B, S, N, D = trajs.shape |
|
|
|
|
|
|
|
trajs = trajs[0] |
|
if valids is None: |
|
valids = torch.ones_like(trajs[:,:,0]) |
|
else: |
|
valids = valids[0] |
|
|
|
|
|
|
|
|
|
if vals is not None: |
|
vals = vals[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i].long().detach().cpu().numpy() |
|
valid = valids[:,i].long().detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
if vals is not None: |
|
|
|
val = vals[i].float().detach().cpu().numpy() |
|
|
|
else: |
|
val = None |
|
|
|
for t in range(S): |
|
|
|
|
|
traj_seq = traj[max(t-8,0):t+1] |
|
val_seq = np.linspace(0,1,len(traj_seq)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) |
|
|
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i] |
|
|
|
vis = torch.ones_like(traj[:,0]) |
|
valid = valids[:,i] |
|
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) |
|
|
|
rgbs = [] |
|
for rgb in rgbs_color: |
|
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
|
rgbs.append(preprocess_color(rgb)) |
|
|
|
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) |
|
|
|
|
|
def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap=None, linewidth=1): |
|
|
|
|
|
B, S, C, H, W = rgbs.shape |
|
B, S2, N, D = trajs.shape |
|
assert(S==S2) |
|
|
|
rgbs = rgbs[0] |
|
trajs = trajs[0] |
|
visibles = visibles[0] |
|
if valids is None: |
|
valids = torch.ones_like(trajs[:,:,0]) |
|
else: |
|
valids = valids[0] |
|
|
|
|
|
|
|
rgbs_color = [] |
|
for rgb in rgbs: |
|
rgb = back2color(rgb).detach().cpu().numpy() |
|
rgb = np.transpose(rgb, [1, 2, 0]) |
|
rgbs_color.append(rgb) |
|
|
|
trajs = trajs.long().detach().cpu().numpy() |
|
visibles = visibles.float().detach().cpu().numpy() |
|
valids = valids.long().detach().cpu().numpy() |
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i] |
|
vis = visibles[:,i] |
|
valid = valids[:,i] |
|
rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) |
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i] |
|
vis = visibles[:,i] |
|
valid = valids[:,i] |
|
if valid[0]: |
|
rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth) |
|
|
|
rgbs = [] |
|
for rgb in rgbs_color: |
|
rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
|
rgbs.append(preprocess_color(rgb)) |
|
|
|
return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) |
|
|
|
def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None, only_return=False, cmap='coolwarm', linewidth=1): |
|
|
|
|
|
B, C, H, W = rgb.shape |
|
B, S, N, D = trajs.shape |
|
|
|
rgb = rgb[0] |
|
trajs = trajs[0] |
|
|
|
if valids is None: |
|
valids = torch.ones_like(trajs[:,:,0]) |
|
else: |
|
valids = valids[0] |
|
|
|
rgb_color = back2color(rgb).detach().cpu().numpy() |
|
rgb_color = np.transpose(rgb_color, [1, 2, 0]) |
|
|
|
|
|
norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) |
|
maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy() |
|
maxdist = None |
|
trajs = trajs.long().detach().cpu().numpy() |
|
valids = valids.long().detach().cpu().numpy() |
|
|
|
for i in range(N): |
|
if cmap=='onediff' and i==0: |
|
cmap_ = 'spring' |
|
elif cmap=='onediff': |
|
cmap_ = 'winter' |
|
else: |
|
cmap_ = cmap |
|
traj = trajs[:,i] |
|
valid = valids[:,i] |
|
if valid[0]==1: |
|
traj = traj[valid>0] |
|
rgb_color = self.draw_traj_on_image_py( |
|
rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth) |
|
|
|
rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0) |
|
rgb = preprocess_color(rgb_color) |
|
return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) |
|
|
|
def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None): |
|
|
|
|
|
|
|
|
|
H, W, C = rgb.shape |
|
assert(C==3) |
|
|
|
rgb = rgb.astype(np.uint8).copy() |
|
|
|
S1, D = traj.shape |
|
assert(D==2) |
|
|
|
color_map = cm.get_cmap(cmap) |
|
S1, D = traj.shape |
|
|
|
for s in range(S1): |
|
if val is not None: |
|
color = np.array(color_map(val[s])[:3]) * 255 |
|
else: |
|
if maxdist is not None: |
|
val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1) |
|
color = np.array(color_map(val)[:3]) * 255 |
|
else: |
|
color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 |
|
|
|
if show_lines and s<(S1-1): |
|
cv2.line(rgb, |
|
(int(traj[s,0]), int(traj[s,1])), |
|
(int(traj[s+1,0]), int(traj[s+1,1])), |
|
color, |
|
linewidth, |
|
cv2.LINE_AA) |
|
if show_dots: |
|
cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, np.array(color_map(1)[:3])*255, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return rgb |
|
|
|
|
|
|
|
def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None): |
|
|
|
|
|
|
|
H, W, C = rgbs[0].shape |
|
assert(C==3) |
|
|
|
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] |
|
|
|
S1, D = traj.shape |
|
assert(D==2) |
|
|
|
x = int(np.clip(traj[0,0], 0, W-1)) |
|
y = int(np.clip(traj[0,1], 0, H-1)) |
|
color = rgbs[0][y,x] |
|
color = (int(color[0]),int(color[1]),int(color[2])) |
|
for s in range(S): |
|
|
|
|
|
cv2.polylines(rgbs[s], |
|
[traj[:s+1]], |
|
False, |
|
color, |
|
linewidth, |
|
cv2.LINE_AA) |
|
return rgbs |
|
|
|
def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None): |
|
|
|
|
|
|
|
H, W, C = rgb.shape |
|
assert(C==3) |
|
|
|
rgb = rgb.astype(np.uint8).copy() |
|
|
|
N, D = xy.shape |
|
assert(D==2) |
|
|
|
|
|
xy = xy.astype(np.float32) |
|
xy[:,0] = np.clip(xy[:,0], 0, W-1) |
|
xy[:,1] = np.clip(xy[:,1], 0, H-1) |
|
xy = xy.astype(np.int32) |
|
|
|
|
|
|
|
if colors is None: |
|
colors = get_n_colors(N) |
|
|
|
for n in range(N): |
|
color = colors[n] |
|
|
|
|
|
color = (int(color[0]),int(color[1]),int(color[2])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv2.circle(rgb, (xy[n,0], xy[n,1]), linewidth, color, 3) |
|
|
|
|
|
|
|
return rgb |
|
|
|
def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None): |
|
|
|
|
|
|
|
H, W, C = rgbs[0].shape |
|
assert(C==3) |
|
|
|
rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] |
|
|
|
S1, D = traj.shape |
|
assert(D==2) |
|
|
|
if cmap is None: |
|
bremm = ColorMap2d() |
|
traj_ = traj[0:1].astype(np.float32) |
|
traj_[:,0] /= float(W) |
|
traj_[:,1] /= float(H) |
|
color = bremm(traj_) |
|
|
|
color = (color[0]*255).astype(np.uint8) |
|
|
|
color = (int(color[2]),int(color[1]),int(color[0])) |
|
|
|
for s in range(S1): |
|
if cmap is not None: |
|
color_map = cm.get_cmap(cmap) |
|
|
|
color = np.array(color_map((s+1)/max(1,float(S-1)))[:3]) * 255 |
|
|
|
|
|
|
|
|
|
|
|
cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, color, -1) |
|
|
|
|
|
|
|
|
|
return rgbs |
|
|
|
def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None, is_g=False): |
|
B, S, N, D = trajs_e.shape |
|
assert(N==1) |
|
assert(D==2) |
|
|
|
rgbs_vis = [] |
|
n = 0 |
|
pad_amount = 100 |
|
trajs_e_py = trajs_e[0].detach().cpu().numpy() |
|
|
|
trajs_e_py = trajs_e_py + pad_amount |
|
|
|
if trajs_g is not None: |
|
trajs_g_py = trajs_g[0].detach().cpu().numpy() |
|
trajs_g_py = trajs_g_py + pad_amount |
|
|
|
for s in range(S): |
|
rgb = rgbs[0,s].detach().cpu().numpy() |
|
|
|
rgb = np.transpose(rgb,(1,2,0)) |
|
|
|
rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0))) |
|
|
|
H, W, C = rgb.shape |
|
|
|
if trajs_g is not None: |
|
xy_g = trajs_g_py[s,n] |
|
xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount) |
|
xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount) |
|
rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) |
|
|
|
xy_e = trajs_e_py[s,n] |
|
xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount) |
|
xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount) |
|
|
|
if show_circ: |
|
if is_g: |
|
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) |
|
else: |
|
rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3) |
|
|
|
|
|
xmin = int(xy_e[0])-pad_amount//2 |
|
xmax = int(xy_e[0])+pad_amount//2 |
|
ymin = int(xy_e[1])-pad_amount//2 |
|
ymax = int(xy_e[1])+pad_amount//2 |
|
|
|
rgb_ = rgb[ymin:ymax, xmin:xmax] |
|
|
|
H_, W_ = rgb_.shape[:2] |
|
|
|
|
|
if H_==0 or W_==0: |
|
import ipdb; ipdb.set_trace() |
|
|
|
rgb_ = rgb_.transpose(2,0,1) |
|
rgb_ = torch.from_numpy(rgb_) |
|
|
|
rgbs_vis.append(rgb_) |
|
|
|
|
|
nrow = int(np.sqrt(S)*1.5) |
|
grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0) |
|
|
|
return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return) |
|
|
|
def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False): |
|
if self.save_this: |
|
B, C, D, H, W = list(occ.shape) |
|
if bev: |
|
reduce_axes = [3] |
|
elif fro: |
|
reduce_axes = [2] |
|
elif pro: |
|
reduce_axes = [4] |
|
for reduce_axis in reduce_axes: |
|
height = convert_occ_to_height(occ, reduce_axis=reduce_axis) |
|
if reduce_axis == reduce_axes[-1]: |
|
return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) |
|
else: |
|
self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) |
|
|
|
def erode2d(im, times=1, device='cuda'): |
|
weights2d = torch.ones(1, 1, 3, 3, device=device) |
|
for time in range(times): |
|
im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1) |
|
return im |
|
|
|
def dilate2d(im, times=1, device='cuda', mode='square'): |
|
weights2d = torch.ones(1, 1, 3, 3, device=device) |
|
if mode=='cross': |
|
weights2d[:,:,0,0] = 0.0 |
|
weights2d[:,:,0,2] = 0.0 |
|
weights2d[:,:,2,0] = 0.0 |
|
weights2d[:,:,2,2] = 0.0 |
|
for time in range(times): |
|
im = F.conv2d(im, weights2d, padding=1).clamp(0, 1) |
|
return im |
|
|
|
|
|
|