|
import os |
|
import numpy as np |
|
try: |
|
import cynetworkx as netx |
|
except ImportError: |
|
import networkx as netx |
|
import matplotlib.pyplot as plt |
|
from functools import partial |
|
from vispy import scene, io |
|
from vispy.scene import visuals |
|
from vispy.visuals.filters import Alpha |
|
import cv2 |
|
from moviepy.editor import ImageSequenceClip |
|
from skimage.transform import resize |
|
import time |
|
import copy |
|
import torch |
|
import os |
|
from utils import path_planning, open_small_mask, clean_far_edge, refine_depth_around_edge |
|
from utils import refine_color_around_edge, filter_irrelevant_edge_new, require_depth_edge, clean_far_edge_new |
|
from utils import create_placeholder, refresh_node, find_largest_rect |
|
from mesh_tools import get_depth_from_maps, get_map_from_ccs, get_edge_from_nodes, get_depth_from_nodes, get_rgb_from_nodes, crop_maps_by_size, convert2tensor, recursive_add_edge, update_info, filter_edge, relabel_node, depth_inpainting |
|
from mesh_tools import refresh_bord_depth, enlarge_border, fill_dummy_bord, extrapolate, fill_missing_node, incomplete_node, get_valid_size, dilate_valid_size, size_operation |
|
import transforms3d |
|
import random |
|
from functools import reduce |
|
|
|
def create_mesh(depth, image, int_mtx, config): |
|
H, W, C = image.shape |
|
ext_H, ext_W = H + 2 * config['extrapolation_thickness'], W + 2 * config['extrapolation_thickness'] |
|
LDI = netx.Graph(H=ext_H, W=ext_W, noext_H=H, noext_W=W, cam_param=int_mtx) |
|
xy2depth = {} |
|
int_mtx_pix = int_mtx * np.array([[W], [H], [1.]]) |
|
LDI.graph['cam_param_pix'], LDI.graph['cam_param_pix_inv'] = int_mtx_pix, np.linalg.inv(int_mtx_pix) |
|
disp = 1. / (-depth) |
|
LDI.graph['hoffset'], LDI.graph['woffset'] = config['extrapolation_thickness'], config['extrapolation_thickness'] |
|
LDI.graph['bord_up'], LDI.graph['bord_down'] = LDI.graph['hoffset'] + 0, LDI.graph['hoffset'] + H |
|
LDI.graph['bord_left'], LDI.graph['bord_right'] = LDI.graph['woffset'] + 0, LDI.graph['woffset'] + W |
|
for idx in range(H): |
|
for idy in range(W): |
|
x, y = idx + LDI.graph['hoffset'], idy + LDI.graph['woffset'] |
|
LDI.add_node((x, y, -depth[idx, idy]), |
|
color=image[idx, idy], |
|
disp=disp[idx, idy], |
|
synthesis=False, |
|
cc_id=set()) |
|
xy2depth[(x, y)] = [-depth[idx, idy]] |
|
for x, y, d in LDI.nodes: |
|
two_nes = [ne for ne in [(x+1, y), (x, y+1)] if ne[0] < LDI.graph['bord_down'] and ne[1] < LDI.graph['bord_right']] |
|
[LDI.add_edge((ne[0], ne[1], xy2depth[ne][0]), (x, y, d)) for ne in two_nes] |
|
LDI = calculate_fov(LDI) |
|
image = np.pad(image, |
|
pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']), |
|
(config['extrapolation_thickness'], config['extrapolation_thickness']), |
|
(0, 0)), |
|
mode='constant') |
|
depth = np.pad(depth, |
|
pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']), |
|
(config['extrapolation_thickness'], config['extrapolation_thickness'])), |
|
mode='constant') |
|
|
|
return LDI, xy2depth, image, depth |
|
|
|
|
|
def tear_edges(mesh, threshold = 0.00025, xy2depth=None): |
|
remove_edge_list = [] |
|
remove_horizon, remove_vertical = np.zeros((2, mesh.graph['H'], mesh.graph['W'])) |
|
mesh_nodes = mesh.nodes |
|
for edge in mesh.edges: |
|
if abs(mesh_nodes[edge[0]]['disp'] - mesh_nodes[edge[1]]['disp']) > threshold: |
|
remove_edge_list.append((edge[0], edge[1])) |
|
|
|
near, far = edge if abs(edge[0][2]) < abs(edge[1][2]) else edge[::-1] |
|
|
|
mesh_nodes[far]['near'] = [] if mesh_nodes[far].get('near') is None else mesh_nodes[far]['near'].append(near) |
|
mesh_nodes[near]['far'] = [] if mesh_nodes[near].get('far') is None else mesh_nodes[near]['far'].append(far) |
|
|
|
if near[0] == far[0]: |
|
remove_horizon[near[0], np.minimum(near[1], far[1])] = 1 |
|
elif near[1] == far[1]: |
|
remove_vertical[np.minimum(near[0], far[0]), near[1]] = 1 |
|
mesh.remove_edges_from(remove_edge_list) |
|
|
|
remove_edge_list = [] |
|
|
|
dang_horizon = np.where(np.roll(remove_horizon, 1, 0) + np.roll(remove_horizon, -1, 0) - remove_horizon == 2) |
|
dang_vertical = np.where(np.roll(remove_vertical, 1, 1) + np.roll(remove_vertical, -1, 1) - remove_vertical == 2) |
|
|
|
horizon_condition = lambda x, y: mesh.graph['bord_up'] + 1 <= x < mesh.graph['bord_down'] - 1 |
|
vertical_condition = lambda x, y: mesh.graph['bord_left'] + 1 <= y < mesh.graph['bord_right'] - 1 |
|
|
|
prjto3d = lambda x, y: (x, y, xy2depth[(x, y)][0]) |
|
|
|
node_existence = lambda x, y: mesh.has_node(prjto3d(x, y)) |
|
|
|
for x, y in zip(dang_horizon[0], dang_horizon[1]): |
|
if horizon_condition(x, y) and node_existence(x, y) and node_existence(x, y+1): |
|
remove_edge_list.append((prjto3d(x, y), prjto3d(x, y+1))) |
|
for x, y in zip(dang_vertical[0], dang_vertical[1]): |
|
if vertical_condition(x, y) and node_existence(x, y) and node_existence(x+1, y): |
|
remove_edge_list.append((prjto3d(x, y), prjto3d(x+1, y))) |
|
mesh.remove_edges_from(remove_edge_list) |
|
|
|
return mesh |
|
|
|
def calculate_fov(mesh): |
|
k = mesh.graph['cam_param'] |
|
mesh.graph['hFov'] = 2 * np.arctan(1. / (2*k[0, 0])) |
|
mesh.graph['vFov'] = 2 * np.arctan(1. / (2*k[1, 1])) |
|
mesh.graph['aspect'] = mesh.graph['noext_H'] / mesh.graph['noext_W'] |
|
|
|
return mesh |
|
|
|
def calculate_fov_FB(mesh): |
|
mesh.graph['aspect'] = mesh.graph['H'] / mesh.graph['W'] |
|
if mesh.graph['H'] > mesh.graph['W']: |
|
mesh.graph['hFov'] = 0.508015513 |
|
half_short = np.tan(mesh.graph['hFov']/2.0) |
|
half_long = half_short * mesh.graph['aspect'] |
|
mesh.graph['vFov'] = 2.0 * np.arctan(half_long) |
|
else: |
|
mesh.graph['vFov'] = 0.508015513 |
|
half_short = np.tan(mesh.graph['vFov']/2.0) |
|
half_long = half_short / mesh.graph['aspect'] |
|
mesh.graph['hFov'] = 2.0 * np.arctan(half_long) |
|
|
|
return mesh |
|
|
|
def reproject_3d_int_detail(sx, sy, z, k_00, k_02, k_11, k_12, w_offset, h_offset): |
|
abs_z = abs(z) |
|
return [abs_z * ((sy+0.5-w_offset) * k_00 + k_02), abs_z * ((sx+0.5-h_offset) * k_11 + k_12), abs_z] |
|
|
|
def reproject_3d_int_detail_FB(sx, sy, z, w_offset, h_offset, mesh): |
|
if mesh.graph.get('tan_hFov') is None: |
|
mesh.graph['tan_hFov'] = np.tan(mesh.graph['hFov'] / 2.) |
|
if mesh.graph.get('tan_vFov') is None: |
|
mesh.graph['tan_vFov'] = np.tan(mesh.graph['vFov'] / 2.) |
|
|
|
ray = np.array([(-1. + 2. * ((sy+0.5-w_offset)/(mesh.graph['W'] - 1))) * mesh.graph['tan_hFov'], |
|
(1. - 2. * (sx+0.5-h_offset)/(mesh.graph['H'] - 1)) * mesh.graph['tan_vFov'], |
|
-1]) |
|
point_3d = ray * np.abs(z) |
|
|
|
return point_3d |
|
|
|
|
|
def reproject_3d_int(sx, sy, z, mesh): |
|
k = mesh.graph['cam_param_pix_inv'].copy() |
|
if k[0, 2] > 0: |
|
k = np.linalg.inv(k) |
|
ray = np.dot(k, np.array([sy-mesh.graph['woffset'], sx-mesh.graph['hoffset'], 1]).reshape(3, 1)) |
|
|
|
point_3d = ray * np.abs(z) |
|
point_3d = point_3d.flatten() |
|
|
|
return point_3d |
|
|
|
def generate_init_node(mesh, config, min_node_in_cc): |
|
mesh_nodes = mesh.nodes |
|
|
|
info_on_pix = {} |
|
|
|
ccs = sorted(netx.connected_components(mesh), key = len, reverse=True) |
|
remove_nodes = [] |
|
|
|
for cc in ccs: |
|
|
|
remove_flag = True if len(cc) < min_node_in_cc else False |
|
if remove_flag is False: |
|
for (nx, ny, nd) in cc: |
|
info_on_pix[(nx, ny)] = [{'depth':nd, |
|
'color':mesh_nodes[(nx, ny, nd)]['color'], |
|
'synthesis':False, |
|
'disp':mesh_nodes[(nx, ny, nd)]['disp']}] |
|
else: |
|
[remove_nodes.append((nx, ny, nd)) for (nx, ny, nd) in cc] |
|
|
|
for node in remove_nodes: |
|
far_nodes = [] if mesh_nodes[node].get('far') is None else mesh_nodes[node]['far'] |
|
for far_node in far_nodes: |
|
if mesh.has_node(far_node) and mesh_nodes[far_node].get('near') is not None and node in mesh_nodes[far_node]['near']: |
|
mesh_nodes[far_node]['near'].remove(node) |
|
near_nodes = [] if mesh_nodes[node].get('near') is None else mesh_nodes[node]['near'] |
|
for near_node in near_nodes: |
|
if mesh.has_node(near_node) and mesh_nodes[near_node].get('far') is not None and node in mesh_nodes[near_node]['far']: |
|
mesh_nodes[near_node]['far'].remove(node) |
|
|
|
[mesh.remove_node(node) for node in remove_nodes] |
|
|
|
return mesh, info_on_pix |
|
|
|
def get_neighbors(mesh, node): |
|
return [*mesh.neighbors(node)] |
|
|
|
def generate_face(mesh, info_on_pix, config): |
|
H, W = mesh.graph['H'], mesh.graph['W'] |
|
str_faces = [] |
|
num_node = len(mesh.nodes) |
|
ply_flag = config.get('save_ply') |
|
def out_fmt(input, cur_id_b, cur_id_self, cur_id_a, ply_flag): |
|
if ply_flag is True: |
|
input.append(' '.join(['3', cur_id_b, cur_id_self, cur_id_a]) + '\n') |
|
else: |
|
input.append([cur_id_b, cur_id_self, cur_id_a]) |
|
mesh_nodes = mesh.nodes |
|
for node in mesh_nodes: |
|
cur_id_self = mesh_nodes[node]['cur_id'] |
|
ne_nodes = get_neighbors(mesh, node) |
|
four_dir_nes = {'up': [], 'left': [], |
|
'down': [], 'right': []} |
|
for ne_node in ne_nodes: |
|
store_tuple = [ne_node, mesh_nodes[ne_node]['cur_id']] |
|
if ne_node[0] == node[0]: |
|
if ne_node[1] == ne_node[1] - 1: |
|
four_dir_nes['left'].append(store_tuple) |
|
else: |
|
four_dir_nes['right'].append(store_tuple) |
|
else: |
|
if ne_node[0] == ne_node[0] - 1: |
|
four_dir_nes['up'].append(store_tuple) |
|
else: |
|
four_dir_nes['down'].append(store_tuple) |
|
for node_a, cur_id_a in four_dir_nes['up']: |
|
for node_b, cur_id_b in four_dir_nes['right']: |
|
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag) |
|
for node_a, cur_id_a in four_dir_nes['right']: |
|
for node_b, cur_id_b in four_dir_nes['down']: |
|
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag) |
|
for node_a, cur_id_a in four_dir_nes['down']: |
|
for node_b, cur_id_b in four_dir_nes['left']: |
|
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag) |
|
for node_a, cur_id_a in four_dir_nes['left']: |
|
for node_b, cur_id_b in four_dir_nes['up']: |
|
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag) |
|
|
|
return str_faces |
|
|
|
def reassign_floating_island(mesh, info_on_pix, image, depth): |
|
H, W = mesh.graph['H'], mesh.graph['W'], |
|
mesh_nodes = mesh.nodes |
|
bord_up, bord_down = mesh.graph['bord_up'], mesh.graph['bord_down'] |
|
bord_left, bord_right = mesh.graph['bord_left'], mesh.graph['bord_right'] |
|
W = mesh.graph['W'] |
|
lost_map = np.zeros((H, W)) |
|
|
|
''' |
|
(5) is_inside(x, y, xmin, xmax, ymin, ymax) : Check if a pixel(x, y) is inside the border. |
|
(6) get_cross_nes(x, y) : Get the four cross neighbors of pixel(x, y). |
|
''' |
|
key_exist = lambda d, k: k in d |
|
is_inside = lambda x, y, xmin, xmax, ymin, ymax: xmin <= x < xmax and ymin <= y < ymax |
|
get_cross_nes = lambda x, y: [(x + 1, y), (x - 1, y), (x, y - 1), (x, y + 1)] |
|
''' |
|
(A) Highlight the pixels on isolated floating island. |
|
(B) Number those isolated floating islands with connected component analysis. |
|
(C) For each isolated island: |
|
(1) Find its longest surrounded depth edge. |
|
(2) Propogate depth from that depth edge to the pixels on the isolated island. |
|
(3) Build the connection between the depth edge and that isolated island. |
|
''' |
|
for x in range(H): |
|
for y in range(W): |
|
if is_inside(x, y, bord_up, bord_down, bord_left, bord_right) and not(key_exist(info_on_pix, (x, y))): |
|
lost_map[x, y] = 1 |
|
_, label_lost_map = cv2.connectedComponents(lost_map.astype(np.uint8), connectivity=4) |
|
mask = np.zeros((H, W)) |
|
mask[bord_up:bord_down, bord_left:bord_right] = 1 |
|
label_lost_map = (label_lost_map * mask).astype(np.int) |
|
|
|
for i in range(1, label_lost_map.max()+1): |
|
lost_xs, lost_ys = np.where(label_lost_map == i) |
|
surr_edge_ids = {} |
|
for lost_x, lost_y in zip(lost_xs, lost_ys): |
|
if (lost_x, lost_y) == (295, 389) or (lost_x, lost_y) == (296, 389): |
|
import pdb; pdb.set_trace() |
|
for ne in get_cross_nes(lost_x, lost_y): |
|
if key_exist(info_on_pix, ne): |
|
for info in info_on_pix[ne]: |
|
ne_node = (ne[0], ne[1], info['depth']) |
|
if key_exist(mesh_nodes[ne_node], 'edge_id'): |
|
edge_id = mesh_nodes[ne_node]['edge_id'] |
|
surr_edge_ids[edge_id] = surr_edge_ids[edge_id] + [ne_node] if \ |
|
key_exist(surr_edge_ids, edge_id) else [ne_node] |
|
if len(surr_edge_ids) == 0: |
|
continue |
|
edge_id, edge_nodes = sorted([*surr_edge_ids.items()], key=lambda x: len(x[1]), reverse=True)[0] |
|
edge_depth_map = np.zeros((H, W)) |
|
for node in edge_nodes: |
|
edge_depth_map[node[0], node[1]] = node[2] |
|
lost_xs, lost_ys = np.where(label_lost_map == i) |
|
while lost_xs.shape[0] > 0: |
|
lost_xs, lost_ys = np.where(label_lost_map == i) |
|
for lost_x, lost_y in zip(lost_xs, lost_ys): |
|
propagated_depth = [] |
|
real_nes = [] |
|
for ne in get_cross_nes(lost_x, lost_y): |
|
if not(is_inside(ne[0], ne[1], bord_up, bord_down, bord_left, bord_right)) or \ |
|
edge_depth_map[ne[0], ne[1]] == 0: |
|
continue |
|
propagated_depth.append(edge_depth_map[ne[0], ne[1]]) |
|
real_nes.append(ne) |
|
if len(real_nes) == 0: |
|
continue |
|
reassign_depth = np.mean(propagated_depth) |
|
label_lost_map[lost_x, lost_y] = 0 |
|
edge_depth_map[lost_x, lost_y] = reassign_depth |
|
depth[lost_x, lost_y] = -reassign_depth |
|
mesh.add_node((lost_x, lost_y, reassign_depth), color=image[lost_x, lost_y], |
|
synthesis=False, |
|
disp=1./reassign_depth, |
|
cc_id=set()) |
|
info_on_pix[(lost_x, lost_y)] = [{'depth':reassign_depth, |
|
'color':image[lost_x, lost_y], |
|
'synthesis':False, |
|
'disp':1./reassign_depth}] |
|
new_connections = [((lost_x, lost_y, reassign_depth), |
|
(ne[0], ne[1], edge_depth_map[ne[0], ne[1]])) for ne in real_nes] |
|
mesh.add_edges_from(new_connections) |
|
|
|
return mesh, info_on_pix, depth |
|
|
|
def remove_node_feat(mesh, *feats): |
|
mesh_nodes = mesh.nodes |
|
for node in mesh_nodes: |
|
for feat in feats: |
|
mesh_nodes[node][feat] = None |
|
|
|
return mesh |
|
|
|
def update_status(mesh, info_on_pix, depth=None): |
|
''' |
|
(2) clear_node_feat(G, *fts) : Clear all the node feature on graph G. |
|
(6) get_cross_nes(x, y) : Get the four cross neighbors of pixel(x, y). |
|
''' |
|
key_exist = lambda d, k: d.get(k) is not None |
|
is_inside = lambda x, y, xmin, xmax, ymin, ymax: xmin <= x < xmax and ymin <= y < ymax |
|
get_cross_nes = lambda x, y: [(x + 1, y), (x - 1, y), (x, y - 1), (x, y + 1)] |
|
append_element = lambda d, k, x: d[k] + [x] if key_exist(d, k) else [x] |
|
|
|
def clear_node_feat(G, fts): |
|
le_nodes = G.nodes |
|
for k in le_nodes: |
|
v = le_nodes[k] |
|
for ft in fts: |
|
if ft in v: |
|
v[ft] = None |
|
|
|
clear_node_feat(mesh, ['edge_id', 'far', 'near']) |
|
bord_up, bord_down = mesh.graph['bord_up'], mesh.graph['bord_down'] |
|
bord_left, bord_right = mesh.graph['bord_left'], mesh.graph['bord_right'] |
|
|
|
le_nodes = mesh.nodes |
|
|
|
for node_key in le_nodes: |
|
if mesh.neighbors(node_key).__length_hint__() == 4: |
|
continue |
|
four_nes = [xx for xx in get_cross_nes(node_key[0], node_key[1]) if |
|
is_inside(xx[0], xx[1], bord_up, bord_down, bord_left, bord_right) and |
|
xx in info_on_pix] |
|
[four_nes.remove((ne_node[0], ne_node[1])) for ne_node in mesh.neighbors(node_key)] |
|
for ne in four_nes: |
|
for info in info_on_pix[ne]: |
|
assert mesh.has_node((ne[0], ne[1], info['depth'])), "No node_key" |
|
ind_node = le_nodes[node_key] |
|
if abs(node_key[2]) > abs(info['depth']): |
|
ind_node['near'] = append_element(ind_node, 'near', (ne[0], ne[1], info['depth'])) |
|
else: |
|
ind_node['far'] = append_element(ind_node, 'far', (ne[0], ne[1], info['depth'])) |
|
if depth is not None: |
|
for key, value in info_on_pix.items(): |
|
if depth[key[0], key[1]] != abs(value[0]['depth']): |
|
value[0]['disp'] = 1. / value[0]['depth'] |
|
depth[key[0], key[1]] = abs(value[0]['depth']) |
|
|
|
return mesh, depth, info_on_pix |
|
else: |
|
return mesh |
|
|
|
def group_edges(LDI, config, image, remove_conflict_ordinal, spdb=False): |
|
|
|
''' |
|
(1) add_new_node(G, node) : add "node" to graph "G" |
|
(2) add_new_edge(G, node_a, node_b) : add edge "node_a--node_b" to graph "G" |
|
(3) exceed_thre(x, y, thre) : Check if difference between "x" and "y" exceed threshold "thre" |
|
(4) key_exist(d, k) : Check if key "k' exists in dictionary "d" |
|
(5) comm_opp_bg(G, x, y) : Check if node "x" and "y" in graph "G" treat the same opposite node as background |
|
(6) comm_opp_fg(G, x, y) : Check if node "x" and "y" in graph "G" treat the same opposite node as foreground |
|
''' |
|
add_new_node = lambda G, node: None if G.has_node(node) else G.add_node(node) |
|
add_new_edge = lambda G, node_a, node_b: None if G.has_edge(node_a, node_b) else G.add_edge(node_a, node_b) |
|
exceed_thre = lambda x, y, thre: (abs(x) - abs(y)) > thre |
|
key_exist = lambda d, k: d.get(k) is not None |
|
comm_opp_bg = lambda G, x, y: key_exist(G.nodes[x], 'far') and key_exist(G.nodes[y], 'far') and \ |
|
not(set(G.nodes[x]['far']).isdisjoint(set(G.nodes[y]['far']))) |
|
comm_opp_fg = lambda G, x, y: key_exist(G.nodes[x], 'near') and key_exist(G.nodes[y], 'near') and \ |
|
not(set(G.nodes[x]['near']).isdisjoint(set(G.nodes[y]['near']))) |
|
discont_graph = netx.Graph() |
|
''' |
|
(A) Skip the pixel at image boundary, we don't want to deal with them. |
|
(B) Identify discontinuity by the number of its neighbor(degree). |
|
If the degree < 4(up/right/buttom/left). We will go through following steps: |
|
(1) Add the discontinuity pixel "node" to graph "discont_graph". |
|
(2) Find "node"'s cross neighbor(up/right/buttom/left) "ne_node". |
|
- If the cross neighbor "ne_node" is a discontinuity pixel(degree("ne_node") < 4), |
|
(a) add it to graph "discont_graph" and build the connection between "ne_node" and "node". |
|
(b) label its cross neighbor as invalid pixels "inval_diag_candi" to avoid building |
|
connection between original discontinuity pixel "node" and "inval_diag_candi". |
|
- Otherwise, find "ne_node"'s cross neighbors, called diagonal candidate "diag_candi". |
|
- The "diag_candi" is diagonal to the original discontinuity pixel "node". |
|
- If "diag_candi" exists, go to step(3). |
|
(3) A diagonal candidate "diag_candi" will be : |
|
- added to the "discont_graph" if its degree < 4. |
|
- connected to the original discontinuity pixel "node" if it satisfied either |
|
one of following criterion: |
|
(a) the difference of disparity between "diag_candi" and "node" is smaller than default threshold. |
|
(b) the "diag_candi" and "node" face the same opposite pixel. (See. function "tear_edges") |
|
(c) Both of "diag_candi" and "node" must_connect to each other. (See. function "combine_end_node") |
|
(C) Aggregate each connected part in "discont_graph" into "discont_ccs" (A.K.A. depth edge). |
|
''' |
|
for node in LDI.nodes: |
|
if not(LDI.graph['bord_up'] + 1 <= node[0] <= LDI.graph['bord_down'] - 2 and \ |
|
LDI.graph['bord_left'] + 1 <= node[1] <= LDI.graph['bord_right'] - 2): |
|
continue |
|
neighbors = [*LDI.neighbors(node)] |
|
if len(neighbors) < 4: |
|
add_new_node(discont_graph, node) |
|
diag_candi_anc, inval_diag_candi, discont_nes = set(), set(), set() |
|
for ne_node in neighbors: |
|
if len([*LDI.neighbors(ne_node)]) < 4: |
|
add_new_node(discont_graph, ne_node) |
|
add_new_edge(discont_graph, ne_node, node) |
|
discont_nes.add(ne_node) |
|
else: |
|
diag_candi_anc.add(ne_node) |
|
inval_diag_candi = set([inval_diagonal for ne_node in discont_nes for inval_diagonal in LDI.neighbors(ne_node) if \ |
|
abs(inval_diagonal[0] - node[0]) < 2 and abs(inval_diagonal[1] - node[1]) < 2]) |
|
for ne_node in diag_candi_anc: |
|
if ne_node[0] == node[0]: |
|
diagonal_xys = [[ne_node[0] + 1, ne_node[1]], [ne_node[0] - 1, ne_node[1]]] |
|
elif ne_node[1] == node[1]: |
|
diagonal_xys = [[ne_node[0], ne_node[1] + 1], [ne_node[0], ne_node[1] - 1]] |
|
for diag_candi in LDI.neighbors(ne_node): |
|
if [diag_candi[0], diag_candi[1]] in diagonal_xys and LDI.degree(diag_candi) < 4: |
|
if diag_candi not in inval_diag_candi: |
|
if not exceed_thre(1./node[2], 1./diag_candi[2], config['depth_threshold']) or \ |
|
(comm_opp_bg(LDI, diag_candi, node) and comm_opp_fg(LDI, diag_candi, node)): |
|
add_new_node(discont_graph, diag_candi) |
|
add_new_edge(discont_graph, diag_candi, node) |
|
if key_exist(LDI.nodes[diag_candi], 'must_connect') and node in LDI.nodes[diag_candi]['must_connect'] and \ |
|
key_exist(LDI.nodes[node], 'must_connect') and diag_candi in LDI.nodes[node]['must_connect']: |
|
add_new_node(discont_graph, diag_candi) |
|
add_new_edge(discont_graph, diag_candi, node) |
|
if spdb == True: |
|
import pdb; pdb.set_trace() |
|
discont_ccs = [*netx.connected_components(discont_graph)] |
|
''' |
|
In some corner case, a depth edge "discont_cc" will contain both |
|
foreground(FG) and background(BG) pixels. This violate the assumption that |
|
a depth edge can only composite by one type of pixel(FG or BG). |
|
We need to further divide this depth edge into several sub-part so that the |
|
assumption is satisfied. |
|
(A) A depth edge is invalid if both of its "far_flag"(BG) and |
|
"near_flag"(FG) are True. |
|
(B) If the depth edge is invalid, we need to do: |
|
(1) Find the role("oridinal") of each pixel on the depth edge. |
|
"-1" --> Its opposite pixels has smaller depth(near) than it. |
|
It is a backgorund pixel. |
|
"+1" --> Its opposite pixels has larger depth(far) than it. |
|
It is a foregorund pixel. |
|
"0" --> Some of opposite pixels has larger depth(far) than it, |
|
and some has smaller pixel than it. |
|
It is an ambiguous pixel. |
|
(2) For each pixel "discont_node", check if its neigbhors' roles are consistent. |
|
- If not, break the connection between the neighbor "ne_node" that has a role |
|
different from "discont_node". |
|
- If yes, remove all the role that are inconsistent to its neighbors "ne_node". |
|
(3) Connected component analysis to re-identified those divided depth edge. |
|
(C) Aggregate each connected part in "discont_graph" into "discont_ccs" (A.K.A. depth edge). |
|
''' |
|
if remove_conflict_ordinal: |
|
new_discont_ccs = [] |
|
num_new_cc = 0 |
|
for edge_id, discont_cc in enumerate(discont_ccs): |
|
near_flag = False |
|
far_flag = False |
|
for discont_node in discont_cc: |
|
near_flag = True if key_exist(LDI.nodes[discont_node], 'far') else near_flag |
|
far_flag = True if key_exist(LDI.nodes[discont_node], 'near') else far_flag |
|
if far_flag and near_flag: |
|
break |
|
if far_flag and near_flag: |
|
for discont_node in discont_cc: |
|
discont_graph.nodes[discont_node]['ordinal'] = \ |
|
np.array([key_exist(LDI.nodes[discont_node], 'far'), |
|
key_exist(LDI.nodes[discont_node], 'near')]) * \ |
|
np.array([-1, 1]) |
|
discont_graph.nodes[discont_node]['ordinal'] = \ |
|
np.sum(discont_graph.nodes[discont_node]['ordinal']) |
|
remove_nodes, remove_edges = [], [] |
|
for discont_node in discont_cc: |
|
ordinal_relation = np.sum([discont_graph.nodes[xx]['ordinal'] \ |
|
for xx in discont_graph.neighbors(discont_node)]) |
|
near_side = discont_graph.nodes[discont_node]['ordinal'] <= 0 |
|
if abs(ordinal_relation) < len([*discont_graph.neighbors(discont_node)]): |
|
remove_nodes.append(discont_node) |
|
for ne_node in discont_graph.neighbors(discont_node): |
|
remove_flag = (near_side and not(key_exist(LDI.nodes[ne_node], 'far'))) or \ |
|
(not near_side and not(key_exist(LDI.nodes[ne_node], 'near'))) |
|
remove_edges += [(discont_node, ne_node)] if remove_flag else [] |
|
else: |
|
if near_side and key_exist(LDI.nodes[discont_node], 'near'): |
|
LDI.nodes[discont_node].pop('near') |
|
elif not(near_side) and key_exist(LDI.nodes[discont_node], 'far'): |
|
LDI.nodes[discont_node].pop('far') |
|
discont_graph.remove_edges_from(remove_edges) |
|
sub_mesh = discont_graph.subgraph(list(discont_cc)).copy() |
|
sub_discont_ccs = [*netx.connected_components(sub_mesh)] |
|
is_redun_near = lambda xx: len(xx) == 1 and xx[0] in remove_nodes and key_exist(LDI.nodes[xx[0]], 'far') |
|
for sub_discont_cc in sub_discont_ccs: |
|
if is_redun_near(list(sub_discont_cc)): |
|
LDI.nodes[list(sub_discont_cc)[0]].pop('far') |
|
new_discont_ccs.append(sub_discont_cc) |
|
else: |
|
new_discont_ccs.append(discont_cc) |
|
discont_ccs = new_discont_ccs |
|
new_discont_ccs = None |
|
if spdb == True: |
|
import pdb; pdb.set_trace() |
|
|
|
for edge_id, edge_cc in enumerate(discont_ccs): |
|
for node in edge_cc: |
|
LDI.nodes[node]['edge_id'] = edge_id |
|
|
|
return discont_ccs, LDI, discont_graph |
|
|
|
def combine_end_node(mesh, edge_mesh, edge_ccs, depth): |
|
import collections |
|
mesh_nodes = mesh.nodes |
|
connect_dict = dict() |
|
for valid_edge_id, valid_edge_cc in enumerate(edge_ccs): |
|
connect_info = [] |
|
for valid_edge_node in valid_edge_cc: |
|
single_connect = set() |
|
for ne_node in mesh.neighbors(valid_edge_node): |
|
if mesh_nodes[ne_node].get('far') is not None: |
|
for fn in mesh_nodes[ne_node].get('far'): |
|
if mesh.has_node(fn) and mesh_nodes[fn].get('edge_id') is not None: |
|
single_connect.add(mesh_nodes[fn]['edge_id']) |
|
if mesh_nodes[ne_node].get('near') is not None: |
|
for fn in mesh_nodes[ne_node].get('near'): |
|
if mesh.has_node(fn) and mesh_nodes[fn].get('edge_id') is not None: |
|
single_connect.add(mesh_nodes[fn]['edge_id']) |
|
connect_info.extend([*single_connect]) |
|
connect_dict[valid_edge_id] = collections.Counter(connect_info) |
|
|
|
end_maps = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
edge_maps = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1 |
|
for valid_edge_id, valid_edge_cc in enumerate(edge_ccs): |
|
for valid_edge_node in valid_edge_cc: |
|
edge_maps[valid_edge_node[0], valid_edge_node[1]] = valid_edge_id |
|
if len([*edge_mesh.neighbors(valid_edge_node)]) == 1: |
|
num_ne = 1 |
|
if num_ne == 1: |
|
end_maps[valid_edge_node[0], valid_edge_node[1]] = valid_edge_node[2] |
|
nxs, nys = np.where(end_maps != 0) |
|
invalid_nodes = set() |
|
for nx, ny in zip(nxs, nys): |
|
if mesh.has_node((nx, ny, end_maps[nx, ny])) is False: |
|
invalid_nodes.add((nx, ny)) |
|
continue |
|
four_nes = [xx for xx in [(nx - 1, ny), (nx + 1, ny), (nx, ny - 1), (nx, ny + 1)] \ |
|
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \ |
|
end_maps[xx[0], xx[1]] != 0] |
|
mesh_nes = [*mesh.neighbors((nx, ny, end_maps[nx, ny]))] |
|
remove_num = 0 |
|
for fne in four_nes: |
|
if (fne[0], fne[1], end_maps[fne[0], fne[1]]) in mesh_nes: |
|
remove_num += 1 |
|
if remove_num == len(four_nes): |
|
invalid_nodes.add((nx, ny)) |
|
for invalid_node in invalid_nodes: |
|
end_maps[invalid_node[0], invalid_node[1]] = 0 |
|
|
|
nxs, nys = np.where(end_maps != 0) |
|
invalid_nodes = set() |
|
for nx, ny in zip(nxs, nys): |
|
if mesh_nodes[(nx, ny, end_maps[nx, ny])].get('edge_id') is None: |
|
continue |
|
else: |
|
self_id = mesh_nodes[(nx, ny, end_maps[nx, ny])].get('edge_id') |
|
self_connect = connect_dict[self_id] if connect_dict.get(self_id) is not None else dict() |
|
four_nes = [xx for xx in [(nx - 1, ny), (nx + 1, ny), (nx, ny - 1), (nx, ny + 1)] \ |
|
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \ |
|
end_maps[xx[0], xx[1]] != 0] |
|
for fne in four_nes: |
|
if mesh_nodes[(fne[0], fne[1], end_maps[fne[0], fne[1]])].get('edge_id') is None: |
|
continue |
|
else: |
|
ne_id = mesh_nodes[(fne[0], fne[1], end_maps[fne[0], fne[1]])]['edge_id'] |
|
if self_connect.get(ne_id) is None or self_connect.get(ne_id) == 1: |
|
continue |
|
else: |
|
invalid_nodes.add((nx, ny)) |
|
for invalid_node in invalid_nodes: |
|
end_maps[invalid_node[0], invalid_node[1]] = 0 |
|
nxs, nys = np.where(end_maps != 0) |
|
invalid_nodes = set() |
|
for nx, ny in zip(nxs, nys): |
|
four_nes = [xx for xx in [(nx - 1, ny), (nx + 1, ny), (nx, ny - 1), (nx, ny + 1)] \ |
|
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \ |
|
end_maps[xx[0], xx[1]] != 0] |
|
for fne in four_nes: |
|
if mesh.has_node((fne[0], fne[1], end_maps[fne[0], fne[1]])): |
|
node_a, node_b = (fne[0], fne[1], end_maps[fne[0], fne[1]]), (nx, ny, end_maps[nx, ny]) |
|
mesh.add_edge(node_a, node_b) |
|
mesh_nodes[node_b]['must_connect'] = set() if mesh_nodes[node_b].get('must_connect') is None else mesh_nodes[node_b]['must_connect'] |
|
mesh_nodes[node_b]['must_connect'].add(node_a) |
|
mesh_nodes[node_b]['must_connect'] |= set([xx for xx in [*edge_mesh.neighbors(node_a)] if \ |
|
(xx[0] - node_b[0]) < 2 and (xx[1] - node_b[1]) < 2]) |
|
mesh_nodes[node_a]['must_connect'] = set() if mesh_nodes[node_a].get('must_connect') is None else mesh_nodes[node_a]['must_connect'] |
|
mesh_nodes[node_a]['must_connect'].add(node_b) |
|
mesh_nodes[node_a]['must_connect'] |= set([xx for xx in [*edge_mesh.neighbors(node_b)] if \ |
|
(xx[0] - node_a[0]) < 2 and (xx[1] - node_a[1]) < 2]) |
|
invalid_nodes.add((nx, ny)) |
|
for invalid_node in invalid_nodes: |
|
end_maps[invalid_node[0], invalid_node[1]] = 0 |
|
|
|
return mesh |
|
|
|
def remove_redundant_edge(mesh, edge_mesh, edge_ccs, info_on_pix, config, redundant_number=1000, invalid=False, spdb=False): |
|
point_to_amount = {} |
|
point_to_id = {} |
|
end_maps = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1 |
|
for valid_edge_id, valid_edge_cc in enumerate(edge_ccs): |
|
for valid_edge_node in valid_edge_cc: |
|
point_to_amount[valid_edge_node] = len(valid_edge_cc) |
|
point_to_id[valid_edge_node] = valid_edge_id |
|
if edge_mesh.has_node(valid_edge_node) is True: |
|
if len([*edge_mesh.neighbors(valid_edge_node)]) == 1: |
|
end_maps[valid_edge_node[0], valid_edge_node[1]] = valid_edge_id |
|
nxs, nys = np.where(end_maps > -1) |
|
point_to_adjoint = {} |
|
for nx, ny in zip(nxs, nys): |
|
adjoint_edges = set([end_maps[x, y] for x, y in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)] if end_maps[x, y] != -1]) |
|
point_to_adjoint[end_maps[nx, ny]] = (point_to_adjoint[end_maps[nx, ny]] | adjoint_edges) if point_to_adjoint.get(end_maps[nx, ny]) is not None else adjoint_edges |
|
valid_edge_ccs = filter_edge(mesh, edge_ccs, config, invalid=invalid) |
|
edge_canvas = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1 |
|
for valid_edge_id, valid_edge_cc in enumerate(valid_edge_ccs): |
|
for valid_edge_node in valid_edge_cc: |
|
edge_canvas[valid_edge_node[0], valid_edge_node[1]] = valid_edge_id |
|
if spdb is True: |
|
plt.imshow(edge_canvas); plt.show() |
|
import pdb; pdb.set_trace() |
|
for valid_edge_id, valid_edge_cc in enumerate(valid_edge_ccs): |
|
end_number = 0 |
|
four_end_number = 0 |
|
eight_end_number = 0 |
|
db_eight_end_number = 0 |
|
if len(valid_edge_cc) > redundant_number: |
|
continue |
|
for valid_edge_node in valid_edge_cc: |
|
if len([*edge_mesh.neighbors(valid_edge_node)]) == 3: |
|
break |
|
elif len([*edge_mesh.neighbors(valid_edge_node)]) == 1: |
|
hx, hy, hz = valid_edge_node |
|
if invalid is False: |
|
eight_nes = [(x, y) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), |
|
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id] |
|
if len(eight_nes) == 0: |
|
end_number += 1 |
|
if invalid is True: |
|
four_nes = []; eight_nes = []; db_eight_nes = [] |
|
four_nes = [(x, y) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id] |
|
eight_nes = [(x, y) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), \ |
|
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id] |
|
db_eight_nes = [(x, y) for x in range(hx - 2, hx + 3) for y in range(hy - 2, hy + 3) \ |
|
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id and (x, y) != (hx, hy)] |
|
if len(four_nes) == 0 or len(eight_nes) == 0: |
|
end_number += 1 |
|
if len(four_nes) == 0: |
|
four_end_number += 1 |
|
if len(eight_nes) == 0: |
|
eight_end_number += 1 |
|
if len(db_eight_nes) == 0: |
|
db_eight_end_number += 1 |
|
elif len([*edge_mesh.neighbors(valid_edge_node)]) == 0: |
|
hx, hy, hz = valid_edge_node |
|
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None and \ |
|
mesh.has_edge(valid_edge_node, (x, y, info_on_pix[(x, y)][0]['depth'])) is False] |
|
for ne in four_nes: |
|
try: |
|
if invalid is True or (point_to_amount.get(ne) is None or point_to_amount[ne] < redundant_number) or \ |
|
point_to_id[ne] in point_to_adjoint.get(point_to_id[valid_edge_node], set()): |
|
mesh.add_edge(valid_edge_node, ne) |
|
except: |
|
import pdb; pdb.set_trace() |
|
if (invalid is not True and end_number >= 1) or (invalid is True and end_number >= 2 and eight_end_number >= 1 and db_eight_end_number >= 1): |
|
for valid_edge_node in valid_edge_cc: |
|
hx, hy, _ = valid_edge_node |
|
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None and \ |
|
mesh.has_edge(valid_edge_node, (x, y, info_on_pix[(x, y)][0]['depth'])) is False and \ |
|
(edge_canvas[x, y] == -1 or edge_canvas[x, y] == valid_edge_id)] |
|
for ne in four_nes: |
|
if invalid is True or (point_to_amount.get(ne) is None or point_to_amount[ne] < redundant_number) or \ |
|
point_to_id[ne] in point_to_adjoint.get(point_to_id[valid_edge_node], set()): |
|
mesh.add_edge(valid_edge_node, ne) |
|
|
|
return mesh |
|
|
|
def judge_dangle(mark, mesh, node): |
|
if not (1 <= node[0] < mesh.graph['H']-1) or not(1 <= node[1] < mesh.graph['W']-1): |
|
return mark |
|
mesh_neighbors = [*mesh.neighbors(node)] |
|
mesh_neighbors = [xx for xx in mesh_neighbors if 0 < xx[0] < mesh.graph['H'] - 1 and 0 < xx[1] < mesh.graph['W'] - 1] |
|
if len(mesh_neighbors) >= 3: |
|
return mark |
|
elif len(mesh_neighbors) <= 1: |
|
mark[node[0], node[1]] = (len(mesh_neighbors) + 1) |
|
else: |
|
dan_ne_node_a = mesh_neighbors[0] |
|
dan_ne_node_b = mesh_neighbors[1] |
|
if abs(dan_ne_node_a[0] - dan_ne_node_b[0]) > 1 or \ |
|
abs(dan_ne_node_a[1] - dan_ne_node_b[1]) > 1: |
|
mark[node[0], node[1]] = 3 |
|
|
|
return mark |
|
|
|
def remove_dangling(mesh, edge_ccs, edge_mesh, info_on_pix, image, depth, config): |
|
|
|
tmp_edge_ccs = copy.deepcopy(edge_ccs) |
|
for edge_cc_id, valid_edge_cc in enumerate(tmp_edge_ccs): |
|
if len(valid_edge_cc) > 1 or len(valid_edge_cc) == 0: |
|
continue |
|
single_edge_node = [*valid_edge_cc][0] |
|
hx, hy, hz = single_edge_node |
|
eight_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), |
|
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None]) |
|
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None] |
|
sub_mesh = mesh.subgraph(eight_nes).copy() |
|
ccs = netx.connected_components(sub_mesh) |
|
four_ccs = [] |
|
for cc_id, _cc in enumerate(ccs): |
|
four_ccs.append(set()) |
|
for cc_node in _cc: |
|
if abs(cc_node[0] - hx) + abs(cc_node[1] - hy) < 2: |
|
four_ccs[cc_id].add(cc_node) |
|
largest_cc = sorted(four_ccs, key=lambda x: (len(x), -np.sum([abs(xx[2] - hz) for xx in x])))[-1] |
|
if len(largest_cc) < 2: |
|
for ne in four_nes: |
|
mesh.add_edge(single_edge_node, ne) |
|
else: |
|
mesh.remove_edges_from([(single_edge_node, ne) for ne in mesh.neighbors(single_edge_node)]) |
|
new_depth = np.mean([xx[2] for xx in largest_cc]) |
|
info_on_pix[(hx, hy)][0]['depth'] = new_depth |
|
info_on_pix[(hx, hy)][0]['disp'] = 1./new_depth |
|
new_node = (hx, hy, new_depth) |
|
mesh = refresh_node(single_edge_node, mesh.node[single_edge_node], new_node, dict(), mesh) |
|
edge_ccs[edge_cc_id] = set([new_node]) |
|
for ne in largest_cc: |
|
mesh.add_edge(new_node, ne) |
|
|
|
mark = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
for edge_idx, edge_cc in enumerate(edge_ccs): |
|
for edge_node in edge_cc: |
|
if not (mesh.graph['bord_up'] <= edge_node[0] < mesh.graph['bord_down']-1) or \ |
|
not (mesh.graph['bord_left'] <= edge_node[1] < mesh.graph['bord_right']-1): |
|
continue |
|
mesh_neighbors = [*mesh.neighbors(edge_node)] |
|
mesh_neighbors = [xx for xx in mesh_neighbors \ |
|
if mesh.graph['bord_up'] < xx[0] < mesh.graph['bord_down'] - 1 and \ |
|
mesh.graph['bord_left'] < xx[1] < mesh.graph['bord_right'] - 1] |
|
if len([*mesh.neighbors(edge_node)]) >= 3: |
|
continue |
|
elif len([*mesh.neighbors(edge_node)]) <= 1: |
|
mark[edge_node[0], edge_node[1]] += (len([*mesh.neighbors(edge_node)]) + 1) |
|
else: |
|
dan_ne_node_a = [*mesh.neighbors(edge_node)][0] |
|
dan_ne_node_b = [*mesh.neighbors(edge_node)][1] |
|
if abs(dan_ne_node_a[0] - dan_ne_node_b[0]) > 1 or \ |
|
abs(dan_ne_node_a[1] - dan_ne_node_b[1]) > 1: |
|
mark[edge_node[0], edge_node[1]] += 3 |
|
mxs, mys = np.where(mark == 1) |
|
conn_0_nodes = [(x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']) for x in zip(mxs, mys) \ |
|
if mesh.has_node((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']))] |
|
mxs, mys = np.where(mark == 2) |
|
conn_1_nodes = [(x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']) for x in zip(mxs, mys) \ |
|
if mesh.has_node((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']))] |
|
for node in conn_0_nodes: |
|
hx, hy = node[0], node[1] |
|
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None] |
|
re_depth = {'value' : 0, 'count': 0} |
|
for ne in four_nes: |
|
mesh.add_edge(node, ne) |
|
re_depth['value'] += cc_node[2] |
|
re_depth['count'] += 1. |
|
re_depth = re_depth['value'] / re_depth['count'] |
|
mapping_dict = {node: (node[0], node[1], re_depth)} |
|
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh) |
|
depth[node[0], node[1]] = abs(re_depth) |
|
mark[node[0], node[1]] = 0 |
|
for node in conn_1_nodes: |
|
hx, hy = node[0], node[1] |
|
eight_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), |
|
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None]) |
|
self_nes = set([ne2 for ne1 in mesh.neighbors(node) for ne2 in mesh.neighbors(ne1) if ne2 in eight_nes]) |
|
eight_nes = [*(eight_nes - self_nes)] |
|
sub_mesh = mesh.subgraph(eight_nes).copy() |
|
ccs = netx.connected_components(sub_mesh) |
|
largest_cc = sorted(ccs, key=lambda x: (len(x), -np.sum([abs(xx[0] - node[0]) + abs(xx[1] - node[1]) for xx in x])))[-1] |
|
|
|
mesh.remove_edges_from([(xx, node) for xx in mesh.neighbors(node)]) |
|
re_depth = {'value' : 0, 'count': 0} |
|
for cc_node in largest_cc: |
|
if cc_node[0] == node[0] and cc_node[1] == node[1]: |
|
continue |
|
re_depth['value'] += cc_node[2] |
|
re_depth['count'] += 1. |
|
if abs(cc_node[0] - node[0]) + abs(cc_node[1] - node[1]) < 2: |
|
mesh.add_edge(cc_node, node) |
|
try: |
|
re_depth = re_depth['value'] / re_depth['count'] |
|
except: |
|
re_depth = node[2] |
|
renode = (node[0], node[1], re_depth) |
|
mapping_dict = {node: renode} |
|
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh) |
|
depth[node[0], node[1]] = abs(re_depth) |
|
mark[node[0], node[1]] = 0 |
|
edge_mesh, mesh, mark, info_on_pix = recursive_add_edge(edge_mesh, mesh, info_on_pix, renode, mark) |
|
mxs, mys = np.where(mark == 3) |
|
conn_2_nodes = [(x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']) for x in zip(mxs, mys) \ |
|
if mesh.has_node((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth'])) and \ |
|
mesh.degree((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth'])) == 2] |
|
sub_mesh = mesh.subgraph(conn_2_nodes).copy() |
|
ccs = netx.connected_components(sub_mesh) |
|
for cc in ccs: |
|
candidate_nodes = [xx for xx in cc if sub_mesh.degree(xx) == 1] |
|
for node in candidate_nodes: |
|
if mesh.has_node(node) is False: |
|
continue |
|
ne_node = [xx for xx in mesh.neighbors(node) if xx not in cc][0] |
|
hx, hy = node[0], node[1] |
|
eight_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), |
|
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None and (x, y, info_on_pix[(x, y)][0]['depth']) not in cc]) |
|
ne_sub_mesh = mesh.subgraph(eight_nes).copy() |
|
ne_ccs = netx.connected_components(ne_sub_mesh) |
|
try: |
|
ne_cc = [ne_cc for ne_cc in ne_ccs if ne_node in ne_cc][0] |
|
except: |
|
import pdb; pdb.set_trace() |
|
largest_cc = [xx for xx in ne_cc if abs(xx[0] - node[0]) + abs(xx[1] - node[1]) == 1] |
|
mesh.remove_edges_from([(xx, node) for xx in mesh.neighbors(node)]) |
|
re_depth = {'value' : 0, 'count': 0} |
|
for cc_node in largest_cc: |
|
re_depth['value'] += cc_node[2] |
|
re_depth['count'] += 1. |
|
mesh.add_edge(cc_node, node) |
|
try: |
|
re_depth = re_depth['value'] / re_depth['count'] |
|
except: |
|
re_depth = node[2] |
|
renode = (node[0], node[1], re_depth) |
|
mapping_dict = {node: renode} |
|
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh) |
|
depth[node[0], node[1]] = abs(re_depth) |
|
mark[node[0], node[1]] = 0 |
|
edge_mesh, mesh, mark, info_on_pix = recursive_add_edge(edge_mesh, mesh, info_on_pix, renode, mark) |
|
break |
|
if len(cc) == 1: |
|
node = [node for node in cc][0] |
|
hx, hy = node[0], node[1] |
|
nine_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx, hy), (hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), |
|
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \ |
|
if info_on_pix.get((x, y)) is not None and mesh.has_node((x, y, info_on_pix[(x, y)][0]['depth']))]) |
|
ne_sub_mesh = mesh.subgraph(nine_nes).copy() |
|
ne_ccs = netx.connected_components(ne_sub_mesh) |
|
for ne_cc in ne_ccs: |
|
if node in ne_cc: |
|
re_depth = {'value' : 0, 'count': 0} |
|
for ne in ne_cc: |
|
if abs(ne[0] - node[0]) + abs(ne[1] - node[1]) == 1: |
|
mesh.add_edge(node, ne) |
|
re_depth['value'] += ne[2] |
|
re_depth['count'] += 1. |
|
re_depth = re_depth['value'] / re_depth['count'] |
|
mapping_dict = {node: (node[0], node[1], re_depth)} |
|
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh) |
|
depth[node[0], node[1]] = abs(re_depth) |
|
mark[node[0], node[1]] = 0 |
|
|
|
|
|
return mesh, info_on_pix, edge_mesh, depth, mark |
|
|
|
def context_and_holes(mesh, edge_ccs, config, specific_edge_id, specific_edge_loc, depth_feat_model, |
|
connect_points_ccs=None, inpaint_iter=0, filter_edge=False, vis_edge_id=None): |
|
edge_maps = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1 |
|
mask_info = {} |
|
for edge_id, edge_cc in enumerate(edge_ccs): |
|
for edge_node in edge_cc: |
|
edge_maps[edge_node[0], edge_node[1]] = edge_id |
|
|
|
context_ccs = [set() for x in range(len(edge_ccs))] |
|
extend_context_ccs = [set() for x in range(len(edge_ccs))] |
|
extend_erode_context_ccs = [set() for x in range(len(edge_ccs))] |
|
extend_edge_ccs = [set() for x in range(len(edge_ccs))] |
|
accomp_extend_context_ccs = [set() for x in range(len(edge_ccs))] |
|
erode_context_ccs = [set() for x in range(len(edge_ccs))] |
|
broken_mask_ccs = [set() for x in range(len(edge_ccs))] |
|
invalid_extend_edge_ccs = [set() for x in range(len(edge_ccs))] |
|
intouched_ccs = [set() for x in range(len(edge_ccs))] |
|
redundant_ccs = [set() for x in range(len(edge_ccs))] |
|
if inpaint_iter == 0: |
|
background_thickness = config['background_thickness'] |
|
context_thickness = config['context_thickness'] |
|
else: |
|
background_thickness = config['background_thickness_2'] |
|
context_thickness = config['context_thickness_2'] |
|
|
|
mesh_nodes = mesh.nodes |
|
for edge_id, edge_cc in enumerate(edge_ccs): |
|
if context_thickness == 0 or (len(specific_edge_id) > 0 and edge_id not in specific_edge_id): |
|
continue |
|
edge_group = {} |
|
for edge_node in edge_cc: |
|
far_nodes = mesh_nodes[edge_node].get('far') |
|
if far_nodes is None: |
|
continue |
|
for far_node in far_nodes: |
|
if far_node in edge_cc: |
|
continue |
|
context_ccs[edge_id].add(far_node) |
|
if mesh_nodes[far_node].get('edge_id') is not None: |
|
if edge_group.get(mesh_nodes[far_node]['edge_id']) is None: |
|
edge_group[mesh_nodes[far_node]['edge_id']] = set() |
|
edge_group[mesh_nodes[far_node]['edge_id']].add(far_node) |
|
if len(edge_cc) > 2: |
|
for edge_key in [*edge_group.keys()]: |
|
if len(edge_group[edge_key]) == 1: |
|
context_ccs[edge_id].remove([*edge_group[edge_key]][0]) |
|
for edge_id, edge_cc in enumerate(edge_ccs): |
|
if inpaint_iter != 0: |
|
continue |
|
tmp_intouched_nodes = set() |
|
for edge_node in edge_cc: |
|
raw_intouched_nodes = set(mesh_nodes[edge_node].get('near')) if mesh_nodes[edge_node].get('near') is not None else set() |
|
tmp_intouched_nodes |= set([xx for xx in raw_intouched_nodes if mesh_nodes[xx].get('edge_id') is not None and \ |
|
len(context_ccs[mesh_nodes[xx].get('edge_id')]) > 0]) |
|
intouched_ccs[edge_id] |= tmp_intouched_nodes |
|
tmp_intouched_nodes = None |
|
mask_ccs = copy.deepcopy(edge_ccs) |
|
forbidden_len = 3 |
|
forbidden_map = np.ones((mesh.graph['H'] - forbidden_len, mesh.graph['W'] - forbidden_len)) |
|
forbidden_map = np.pad(forbidden_map, ((forbidden_len, forbidden_len), (forbidden_len, forbidden_len)), mode='constant').astype(np.bool) |
|
cur_tmp_mask_map = np.zeros_like(forbidden_map).astype(np.bool) |
|
passive_background = 10 if 10 is not None else background_thickness |
|
passive_context = 1 if 1 is not None else context_thickness |
|
|
|
for edge_id, edge_cc in enumerate(edge_ccs): |
|
cur_mask_cc = None; cur_mask_cc = [] |
|
cur_context_cc = None; cur_context_cc = [] |
|
cur_accomp_near_cc = None; cur_accomp_near_cc = [] |
|
cur_invalid_extend_edge_cc = None; cur_invalid_extend_edge_cc = [] |
|
cur_comp_far_cc = None; cur_comp_far_cc = [] |
|
tmp_erode = [] |
|
if len(context_ccs[edge_id]) == 0 or (len(specific_edge_id) > 0 and edge_id not in specific_edge_id): |
|
continue |
|
for i in range(max(background_thickness, context_thickness)): |
|
cur_tmp_mask_map.fill(False) |
|
if i == 0: |
|
tmp_mask_nodes = copy.deepcopy(mask_ccs[edge_id]) |
|
tmp_intersect_nodes = [] |
|
tmp_intersect_context_nodes = [] |
|
mask_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool) |
|
context_depth = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
comp_cnt_depth = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
connect_map = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
for node in tmp_mask_nodes: |
|
mask_map[node[0], node[1]] = True |
|
depth_count = 0 |
|
if mesh_nodes[node].get('far') is not None: |
|
for comp_cnt_node in mesh_nodes[node]['far']: |
|
comp_cnt_depth[node[0], node[1]] += abs(comp_cnt_node[2]) |
|
depth_count += 1 |
|
if depth_count > 0: |
|
comp_cnt_depth[node[0], node[1]] = comp_cnt_depth[node[0], node[1]] / depth_count |
|
connect_node = [] |
|
if mesh_nodes[node].get('connect_point_id') is not None: |
|
connect_node.append(mesh_nodes[node]['connect_point_id']) |
|
connect_point_id = np.bincount(connect_node).argmax() if len(connect_node) > 0 else -1 |
|
if connect_point_id > -1 and connect_points_ccs is not None: |
|
for xx in connect_points_ccs[connect_point_id]: |
|
if connect_map[xx[0], xx[1]] == 0: |
|
connect_map[xx[0], xx[1]] = xx[2] |
|
if mesh_nodes[node].get('connect_point_exception') is not None: |
|
for xx in mesh_nodes[node]['connect_point_exception']: |
|
if connect_map[xx[0], xx[1]] == 0: |
|
connect_map[xx[0], xx[1]] = xx[2] |
|
tmp_context_nodes = [*context_ccs[edge_id]] |
|
tmp_erode.append([*context_ccs[edge_id]]) |
|
context_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool) |
|
if (context_map.astype(np.uint8) * mask_map.astype(np.uint8)).max() > 0: |
|
import pdb; pdb.set_trace() |
|
for node in tmp_context_nodes: |
|
context_map[node[0], node[1]] = True |
|
context_depth[node[0], node[1]] = node[2] |
|
context_map[mask_map == True] = False |
|
if (context_map.astype(np.uint8) * mask_map.astype(np.uint8)).max() > 0: |
|
import pdb; pdb.set_trace() |
|
tmp_intouched_nodes = [*intouched_ccs[edge_id]] |
|
intouched_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool) |
|
for node in tmp_intouched_nodes: intouched_map[node[0], node[1]] = True |
|
intouched_map[mask_map == True] = False |
|
tmp_redundant_nodes = set() |
|
tmp_noncont_nodes = set() |
|
noncont_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool) |
|
intersect_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool) |
|
intersect_context_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool) |
|
if i > passive_background and inpaint_iter == 0: |
|
new_tmp_intersect_nodes = None |
|
new_tmp_intersect_nodes = [] |
|
for node in tmp_intersect_nodes: |
|
nes = mesh.neighbors(node) |
|
for ne in nes: |
|
if bool(context_map[ne[0], ne[1]]) is False and \ |
|
bool(mask_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True and \ |
|
bool(intouched_map[ne[0], ne[1]]) is False and\ |
|
bool(intersect_map[ne[0], ne[1]]) is False and\ |
|
bool(intersect_context_map[ne[0], ne[1]]) is False: |
|
break_flag = False |
|
if (i - passive_background) % 2 == 0 and (i - passive_background) % 8 != 0: |
|
four_nes = [xx for xx in[[ne[0] - 1, ne[1]], [ne[0] + 1, ne[1]], [ne[0], ne[1] - 1], [ne[0], ne[1] + 1]] \ |
|
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W']] |
|
for fne in four_nes: |
|
if bool(mask_map[fne[0], fne[1]]) is True: |
|
break_flag = True |
|
break |
|
if break_flag is True: |
|
continue |
|
intersect_map[ne[0], ne[1]] = True |
|
new_tmp_intersect_nodes.append(ne) |
|
tmp_intersect_nodes = None |
|
tmp_intersect_nodes = new_tmp_intersect_nodes |
|
|
|
if i > passive_context and inpaint_iter == 1: |
|
new_tmp_intersect_context_nodes = None |
|
new_tmp_intersect_context_nodes = [] |
|
for node in tmp_intersect_context_nodes: |
|
nes = mesh.neighbors(node) |
|
for ne in nes: |
|
if bool(context_map[ne[0], ne[1]]) is False and \ |
|
bool(mask_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True and \ |
|
bool(intouched_map[ne[0], ne[1]]) is False and\ |
|
bool(intersect_map[ne[0], ne[1]]) is False and \ |
|
bool(intersect_context_map[ne[0], ne[1]]) is False: |
|
intersect_context_map[ne[0], ne[1]] = True |
|
new_tmp_intersect_context_nodes.append(ne) |
|
tmp_intersect_context_nodes = None |
|
tmp_intersect_context_nodes = new_tmp_intersect_context_nodes |
|
|
|
new_tmp_mask_nodes = None |
|
new_tmp_mask_nodes = [] |
|
for node in tmp_mask_nodes: |
|
four_nes = {xx:[] for xx in [(node[0] - 1, node[1]), (node[0] + 1, node[1]), (node[0], node[1] - 1), (node[0], node[1] + 1)] if \ |
|
0 <= xx[0] < connect_map.shape[0] and 0 <= xx[1] < connect_map.shape[1]} |
|
if inpaint_iter > 0: |
|
for ne in four_nes.keys(): |
|
if connect_map[ne[0], ne[1]] == True: |
|
tmp_context_nodes.append((ne[0], ne[1], connect_map[ne[0], ne[1]])) |
|
context_map[ne[0], ne[1]] = True |
|
nes = mesh.neighbors(node) |
|
if inpaint_iter > 0: |
|
for ne in nes: four_nes[(ne[0], ne[1])].append(ne[2]) |
|
nes = [] |
|
for kfne, vfnes in four_nes.items(): vfnes.sort(key = lambda xx: abs(xx), reverse=True) |
|
for kfne, vfnes in four_nes.items(): |
|
for vfne in vfnes: nes.append((kfne[0], kfne[1], vfne)) |
|
for ne in nes: |
|
if bool(context_map[ne[0], ne[1]]) is False and \ |
|
bool(mask_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True and \ |
|
bool(intouched_map[ne[0], ne[1]]) is False and \ |
|
bool(intersect_map[ne[0], ne[1]]) is False and \ |
|
bool(intersect_context_map[ne[0], ne[1]]) is False: |
|
if i == passive_background and inpaint_iter == 0: |
|
if np.any(context_map[max(ne[0] - 1, 0):min(ne[0] + 2, mesh.graph['H']), max(ne[1] - 1, 0):min(ne[1] + 2, mesh.graph['W'])]) == True: |
|
intersect_map[ne[0], ne[1]] = True |
|
tmp_intersect_nodes.append(ne) |
|
continue |
|
if i < background_thickness: |
|
if inpaint_iter == 0: |
|
cur_mask_cc.append(ne) |
|
elif mesh_nodes[ne].get('inpaint_id') == 1: |
|
cur_mask_cc.append(ne) |
|
else: |
|
continue |
|
mask_ccs[edge_id].add(ne) |
|
if inpaint_iter == 0: |
|
if comp_cnt_depth[node[0], node[1]] > 0 and comp_cnt_depth[ne[0], ne[1]] == 0: |
|
comp_cnt_depth[ne[0], ne[1]] = comp_cnt_depth[node[0], node[1]] |
|
if mesh_nodes[ne].get('far') is not None: |
|
for comp_far_node in mesh_nodes[ne]['far']: |
|
cur_comp_far_cc.append(comp_far_node) |
|
cur_accomp_near_cc.append(ne) |
|
cur_invalid_extend_edge_cc.append(comp_far_node) |
|
if mesh_nodes[ne].get('edge_id') is not None and \ |
|
len(context_ccs[mesh_nodes[ne].get('edge_id')]) > 0: |
|
intouched_fars = set(mesh_nodes[ne].get('far')) if mesh_nodes[ne].get('far') is not None else set() |
|
accum_intouched_fars = set(intouched_fars) |
|
for intouched_far in intouched_fars: |
|
accum_intouched_fars |= set([*mesh.neighbors(intouched_far)]) |
|
for intouched_far in accum_intouched_fars: |
|
if bool(mask_map[intouched_far[0], intouched_far[1]]) is True or \ |
|
bool(context_map[intouched_far[0], intouched_far[1]]) is True: |
|
continue |
|
tmp_redundant_nodes.add(intouched_far) |
|
intouched_map[intouched_far[0], intouched_far[1]] = True |
|
if mesh_nodes[ne].get('near') is not None: |
|
intouched_nears = set(mesh_nodes[ne].get('near')) |
|
for intouched_near in intouched_nears: |
|
if bool(mask_map[intouched_near[0], intouched_near[1]]) is True or \ |
|
bool(context_map[intouched_near[0], intouched_near[1]]) is True: |
|
continue |
|
tmp_redundant_nodes.add(intouched_near) |
|
intouched_map[intouched_near[0], intouched_near[1]] = True |
|
if not (mesh_nodes[ne].get('inpaint_id') != 1 and inpaint_iter == 1): |
|
new_tmp_mask_nodes.append(ne) |
|
mask_map[ne[0], ne[1]] = True |
|
tmp_mask_nodes = new_tmp_mask_nodes |
|
|
|
new_tmp_context_nodes = None |
|
new_tmp_context_nodes = [] |
|
for node in tmp_context_nodes: |
|
nes = mesh.neighbors(node) |
|
if inpaint_iter > 0: |
|
four_nes = {(node[0] - 1, node[1]):[], (node[0] + 1, node[1]):[], (node[0], node[1] - 1):[], (node[0], node[1] + 1):[]} |
|
for ne in nes: four_nes[(ne[0], ne[1])].append(ne[2]) |
|
nes = [] |
|
for kfne, vfnes in four_nes.items(): vfnes.sort(key = lambda xx: abs(xx), reverse=True) |
|
for kfne, vfnes in four_nes.items(): |
|
for vfne in vfnes: nes.append((kfne[0], kfne[1], vfne)) |
|
for ne in nes: |
|
mask_flag = (bool(mask_map[ne[0], ne[1]]) is False) |
|
if bool(context_map[ne[0], ne[1]]) is False and mask_flag and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True and bool(noncont_map[ne[0], ne[1]]) is False and \ |
|
bool(intersect_context_map[ne[0], ne[1]]) is False: |
|
if i == passive_context and inpaint_iter == 1: |
|
mnes = mesh.neighbors(ne) |
|
if any([mask_map[mne[0], mne[1]] == True for mne in mnes]) is True: |
|
intersect_context_map[ne[0], ne[1]] = True |
|
tmp_intersect_context_nodes.append(ne) |
|
continue |
|
if False and mesh_nodes[ne].get('near') is not None and mesh_nodes[ne].get('edge_id') != edge_id: |
|
noncont_nears = set(mesh_nodes[ne].get('near')) |
|
for noncont_near in noncont_nears: |
|
if bool(context_map[noncont_near[0], noncont_near[1]]) is False: |
|
tmp_noncont_nodes.add(noncont_near) |
|
noncont_map[noncont_near[0], noncont_near[1]] = True |
|
new_tmp_context_nodes.append(ne) |
|
context_map[ne[0], ne[1]] = True |
|
context_depth[ne[0], ne[1]] = ne[2] |
|
cur_context_cc.extend(new_tmp_context_nodes) |
|
tmp_erode.append(new_tmp_context_nodes) |
|
tmp_context_nodes = None |
|
tmp_context_nodes = new_tmp_context_nodes |
|
new_tmp_intouched_nodes = None; new_tmp_intouched_nodes = [] |
|
|
|
for node in tmp_intouched_nodes: |
|
if bool(context_map[node[0], node[1]]) is True or bool(mask_map[node[0], node[1]]) is True: |
|
continue |
|
nes = mesh.neighbors(node) |
|
|
|
for ne in nes: |
|
if bool(context_map[ne[0], ne[1]]) is False and \ |
|
bool(mask_map[ne[0], ne[1]]) is False and \ |
|
bool(intouched_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True: |
|
new_tmp_intouched_nodes.append(ne) |
|
intouched_map[ne[0], ne[1]] = True |
|
tmp_intouched_nodes = None |
|
tmp_intouched_nodes = set(new_tmp_intouched_nodes) |
|
new_tmp_redundant_nodes = None; new_tmp_redundant_nodes = [] |
|
for node in tmp_redundant_nodes: |
|
if bool(context_map[node[0], node[1]]) is True or \ |
|
bool(mask_map[node[0], node[1]]) is True: |
|
continue |
|
nes = mesh.neighbors(node) |
|
|
|
for ne in nes: |
|
if bool(context_map[ne[0], ne[1]]) is False and \ |
|
bool(mask_map[ne[0], ne[1]]) is False and \ |
|
bool(intouched_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True: |
|
new_tmp_redundant_nodes.append(ne) |
|
intouched_map[ne[0], ne[1]] = True |
|
tmp_redundant_nodes = None |
|
tmp_redundant_nodes = set(new_tmp_redundant_nodes) |
|
new_tmp_noncont_nodes = None; new_tmp_noncont_nodes = [] |
|
for node in tmp_noncont_nodes: |
|
if bool(context_map[node[0], node[1]]) is True or \ |
|
bool(mask_map[node[0], node[1]]) is True: |
|
continue |
|
nes = mesh.neighbors(node) |
|
rmv_flag = False |
|
for ne in nes: |
|
if bool(context_map[ne[0], ne[1]]) is False and \ |
|
bool(mask_map[ne[0], ne[1]]) is False and \ |
|
bool(noncont_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True: |
|
patch_context_map = context_map[max(ne[0] - 1, 0):min(ne[0] + 2, context_map.shape[0]), |
|
max(ne[1] - 1, 0):min(ne[1] + 2, context_map.shape[1])] |
|
if bool(np.any(patch_context_map)) is True: |
|
new_tmp_noncont_nodes.append(ne) |
|
noncont_map[ne[0], ne[1]] = True |
|
tmp_noncont_nodes = None |
|
tmp_noncont_nodes = set(new_tmp_noncont_nodes) |
|
if inpaint_iter == 0: |
|
depth_dict = get_depth_from_maps(context_map, mask_map, context_depth, mesh.graph['H'], mesh.graph['W'], log_depth=config['log_depth']) |
|
mask_size = get_valid_size(depth_dict['mask']) |
|
mask_size = dilate_valid_size(mask_size, depth_dict['mask'], dilate=[20, 20]) |
|
context_size = get_valid_size(depth_dict['context']) |
|
context_size = dilate_valid_size(context_size, depth_dict['context'], dilate=[20, 20]) |
|
union_size = size_operation(mask_size, context_size, operation='+') |
|
depth_dict = depth_inpainting(None, None, None, None, mesh, config, union_size, depth_feat_model, None, given_depth_dict=depth_dict, spdb=False) |
|
near_depth_map, raw_near_depth_map = np.zeros((mesh.graph['H'], mesh.graph['W'])), np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
filtered_comp_far_cc, filtered_accomp_near_cc = set(), set() |
|
for node in cur_accomp_near_cc: |
|
near_depth_map[node[0], node[1]] = depth_dict['output'][node[0], node[1]] |
|
raw_near_depth_map[node[0], node[1]] = node[2] |
|
for node in cur_comp_far_cc: |
|
four_nes = [xx for xx in [(node[0] - 1, node[1]), (node[0] + 1, node[1]), (node[0], node[1] - 1), (node[0], node[1] + 1)] \ |
|
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \ |
|
near_depth_map[xx[0], xx[1]] != 0 and \ |
|
abs(near_depth_map[xx[0], xx[1]]) < abs(node[2])] |
|
if len(four_nes) > 0: |
|
filtered_comp_far_cc.add(node) |
|
for ne in four_nes: |
|
filtered_accomp_near_cc.add((ne[0], ne[1], -abs(raw_near_depth_map[ne[0], ne[1]]))) |
|
cur_comp_far_cc, cur_accomp_near_cc = filtered_comp_far_cc, filtered_accomp_near_cc |
|
mask_ccs[edge_id] |= set(cur_mask_cc) |
|
context_ccs[edge_id] |= set(cur_context_cc) |
|
accomp_extend_context_ccs[edge_id] |= set(cur_accomp_near_cc).intersection(cur_mask_cc) |
|
extend_edge_ccs[edge_id] |= set(cur_accomp_near_cc).intersection(cur_mask_cc) |
|
extend_context_ccs[edge_id] |= set(cur_comp_far_cc) |
|
invalid_extend_edge_ccs[edge_id] |= set(cur_invalid_extend_edge_cc) |
|
erode_size = [0] |
|
for tmp in tmp_erode: |
|
erode_size.append(len(tmp)) |
|
if len(erode_size) > 1: |
|
erode_size[-1] += erode_size[-2] |
|
if inpaint_iter == 0: |
|
tmp_width = config['depth_edge_dilate'] |
|
else: |
|
tmp_width = 0 |
|
while float(erode_size[tmp_width]) / (erode_size[-1] + 1e-6) > 0.3: |
|
tmp_width = tmp_width - 1 |
|
try: |
|
if tmp_width == 0: |
|
erode_context_ccs[edge_id] = set([]) |
|
else: |
|
erode_context_ccs[edge_id] = set(reduce(lambda x, y : x + y, [] + tmp_erode[:tmp_width])) |
|
except: |
|
import pdb; pdb.set_trace() |
|
erode_context_cc = copy.deepcopy(erode_context_ccs[edge_id]) |
|
for erode_context_node in erode_context_cc: |
|
if (inpaint_iter != 0 and (mesh_nodes[erode_context_node].get('inpaint_id') is None or |
|
mesh_nodes[erode_context_node].get('inpaint_id') == 0)): |
|
erode_context_ccs[edge_id].remove(erode_context_node) |
|
else: |
|
context_ccs[edge_id].remove(erode_context_node) |
|
context_map = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
for context_node in context_ccs[edge_id]: |
|
context_map[context_node[0], context_node[1]] = 1 |
|
extend_context_ccs[edge_id] = extend_context_ccs[edge_id] - mask_ccs[edge_id] - accomp_extend_context_ccs[edge_id] |
|
if inpaint_iter == 0: |
|
all_ecnt_cc = set() |
|
for ecnt_id, ecnt_cc in enumerate(extend_context_ccs): |
|
constraint_context_ids = set() |
|
constraint_context_cc = set() |
|
constraint_erode_context_cc = set() |
|
tmp_mask_cc = set() |
|
accum_context_cc = None; accum_context_cc = [] |
|
for ecnt_node in accomp_extend_context_ccs[ecnt_id]: |
|
if edge_maps[ecnt_node[0], ecnt_node[1]] > -1: |
|
constraint_context_ids.add(int(round(edge_maps[ecnt_node[0], ecnt_node[1]]))) |
|
constraint_erode_context_cc = erode_context_ccs[ecnt_id] |
|
for constraint_context_id in constraint_context_ids: |
|
constraint_context_cc = constraint_context_cc | context_ccs[constraint_context_id] | erode_context_ccs[constraint_context_id] |
|
constraint_erode_context_cc = constraint_erode_context_cc | erode_context_ccs[constraint_context_id] |
|
for i in range(background_thickness): |
|
if i == 0: |
|
tmp_context_nodes = copy.deepcopy(ecnt_cc) |
|
tmp_invalid_context_nodes = copy.deepcopy(invalid_extend_edge_ccs[ecnt_id]) |
|
tmp_mask_nodes = copy.deepcopy(accomp_extend_context_ccs[ecnt_id]) |
|
tmp_context_map = np.zeros((mesh.graph['H'], mesh.graph['W'])).astype(np.bool) |
|
tmp_mask_map = np.zeros((mesh.graph['H'], mesh.graph['W'])).astype(np.bool) |
|
tmp_invalid_context_map = np.zeros((mesh.graph['H'], mesh.graph['W'])).astype(np.bool) |
|
for node in tmp_mask_nodes: |
|
tmp_mask_map[node[0], node[1]] = True |
|
for node in context_ccs[ecnt_id]: |
|
tmp_context_map[node[0], node[1]] = True |
|
for node in erode_context_ccs[ecnt_id]: |
|
tmp_context_map[node[0], node[1]] = True |
|
for node in extend_context_ccs[ecnt_id]: |
|
tmp_context_map[node[0], node[1]] = True |
|
for node in invalid_extend_edge_ccs[ecnt_id]: |
|
tmp_invalid_context_map[node[0], node[1]] = True |
|
init_invalid_context_map = tmp_invalid_context_map.copy() |
|
init_context_map = tmp |
|
if (tmp_mask_map.astype(np.uint8) * tmp_context_map.astype(np.uint8)).max() > 0: |
|
import pdb; pdb.set_trace() |
|
if vis_edge_id is not None and ecnt_id == vis_edge_id: |
|
f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True) |
|
ax1.imshow(tmp_context_map * 1); ax2.imshow(init_invalid_context_map * 1 + tmp_context_map * 2) |
|
plt.show() |
|
import pdb; pdb.set_trace() |
|
else: |
|
tmp_context_nodes = new_tmp_context_nodes |
|
new_tmp_context_nodes = None |
|
tmp_mask_nodes = new_tmp_mask_nodes |
|
new_tmp_mask_nodes = None |
|
tmp_invalid_context_nodes = new_tmp_invalid_context_nodes |
|
new_tmp_invalid_context_nodes = None |
|
new_tmp_context_nodes = None |
|
new_tmp_context_nodes = [] |
|
new_tmp_invalid_context_nodes = None |
|
new_tmp_invalid_context_nodes = [] |
|
new_tmp_mask_nodes = set([]) |
|
for node in tmp_context_nodes: |
|
for ne in mesh.neighbors(node): |
|
if ne in constraint_context_cc and \ |
|
bool(tmp_mask_map[ne[0], ne[1]]) is False and \ |
|
bool(tmp_context_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True: |
|
new_tmp_context_nodes.append(ne) |
|
tmp_context_map[ne[0], ne[1]] = True |
|
accum_context_cc.extend(new_tmp_context_nodes) |
|
for node in tmp_invalid_context_nodes: |
|
for ne in mesh.neighbors(node): |
|
if bool(tmp_mask_map[ne[0], ne[1]]) is False and \ |
|
bool(tmp_context_map[ne[0], ne[1]]) is False and \ |
|
bool(tmp_invalid_context_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True: |
|
tmp_invalid_context_map[ne[0], ne[1]] = True |
|
new_tmp_invalid_context_nodes.append(ne) |
|
for node in tmp_mask_nodes: |
|
for ne in mesh.neighbors(node): |
|
if bool(tmp_mask_map[ne[0], ne[1]]) is False and \ |
|
bool(tmp_context_map[ne[0], ne[1]]) is False and \ |
|
bool(tmp_invalid_context_map[ne[0], ne[1]]) is False and \ |
|
bool(forbidden_map[ne[0], ne[1]]) is True: |
|
new_tmp_mask_nodes.add(ne) |
|
tmp_mask_map[ne[0], ne[1]] = True |
|
init_invalid_context_map[tmp_context_map] = False |
|
_, tmp_label_map = cv2.connectedComponents((init_invalid_context_map | tmp_context_map).astype(np.uint8), connectivity=8) |
|
tmp_label_ids = set(np.unique(tmp_label_map[init_invalid_context_map])) |
|
if (tmp_mask_map.astype(np.uint8) * tmp_context_map.astype(np.uint8)).max() > 0: |
|
import pdb; pdb.set_trace() |
|
if vis_edge_id is not None and ecnt_id == vis_edge_id: |
|
f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True) |
|
ax1.imshow(tmp_label_map); ax2.imshow(init_invalid_context_map * 1 + tmp_context_map * 2) |
|
plt.show() |
|
import pdb; pdb.set_trace() |
|
extend_context_ccs[ecnt_id] |= set(accum_context_cc) |
|
extend_context_ccs[ecnt_id] = extend_context_ccs[ecnt_id] - mask_ccs[ecnt_id] |
|
extend_erode_context_ccs[ecnt_id] = extend_context_ccs[ecnt_id] & constraint_erode_context_cc |
|
extend_context_ccs[ecnt_id] = extend_context_ccs[ecnt_id] - extend_erode_context_ccs[ecnt_id] - erode_context_ccs[ecnt_id] |
|
tmp_context_cc = context_ccs[ecnt_id] - extend_erode_context_ccs[ecnt_id] - erode_context_ccs[ecnt_id] |
|
if len(tmp_context_cc) > 0: |
|
context_ccs[ecnt_id] = tmp_context_cc |
|
tmp_mask_cc = tmp_mask_cc - context_ccs[ecnt_id] - erode_context_ccs[ecnt_id] |
|
mask_ccs[ecnt_id] = mask_ccs[ecnt_id] | tmp_mask_cc |
|
|
|
return context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, invalid_extend_edge_ccs, edge_maps, extend_context_ccs, extend_edge_ccs, extend_erode_context_ccs |
|
|
|
def DL_inpaint_edge(mesh, |
|
info_on_pix, |
|
config, |
|
image, |
|
depth, |
|
context_ccs, |
|
erode_context_ccs, |
|
extend_context_ccs, |
|
extend_erode_context_ccs, |
|
mask_ccs, |
|
broken_mask_ccs, |
|
edge_ccs, |
|
extend_edge_ccs, |
|
init_mask_connect, |
|
edge_maps, |
|
rgb_model=None, |
|
depth_edge_model=None, |
|
depth_edge_model_init=None, |
|
depth_feat_model=None, |
|
specific_edge_id=-1, |
|
specific_edge_loc=None, |
|
inpaint_iter=0): |
|
|
|
if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0): |
|
device = config["gpu_ids"] |
|
else: |
|
device = "cpu" |
|
|
|
edge_map = np.zeros_like(depth) |
|
new_edge_ccs = [set() for _ in range(len(edge_ccs))] |
|
edge_maps_with_id = edge_maps |
|
edge_condition = lambda x, m: m.nodes[x].get('far') is not None and len(m.nodes[x].get('far')) > 0 |
|
edge_map = get_map_from_ccs(edge_ccs, mesh.graph['H'], mesh.graph['W'], mesh, edge_condition) |
|
np_depth, np_image = depth.copy(), image.copy() |
|
image_c = image.shape[-1] |
|
image = torch.FloatTensor(image.transpose(2, 0, 1)).unsqueeze(0).to(device) |
|
if depth.ndim < 3: |
|
depth = depth[..., None] |
|
depth = torch.FloatTensor(depth.transpose(2, 0, 1)).unsqueeze(0).to(device) |
|
mesh.graph['max_edge_id'] = len(edge_ccs) |
|
connnect_points_ccs = [set() for _ in range(len(edge_ccs))] |
|
gp_time, tmp_mesh_time, bilateral_time = 0, 0, 0 |
|
edges_infos = dict() |
|
edges_in_mask = [set() for _ in range(len(edge_ccs))] |
|
tmp_specific_edge_id = [] |
|
for edge_id, (context_cc, mask_cc, erode_context_cc, extend_context_cc, edge_cc) in enumerate(zip(context_ccs, mask_ccs, erode_context_ccs, extend_context_ccs, edge_ccs)): |
|
if len(specific_edge_id) > 0: |
|
if edge_id not in specific_edge_id: |
|
continue |
|
if len(context_cc) < 1 or len(mask_cc) < 1: |
|
continue |
|
edge_dict = get_edge_from_nodes(context_cc | extend_context_cc, erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, edge_cc, extend_edge_ccs[edge_id], |
|
mesh.graph['H'], mesh.graph['W'], mesh) |
|
edge_dict['edge'], end_depth_maps, _ = \ |
|
filter_irrelevant_edge_new(edge_dict['self_edge'], edge_dict['comp_edge'], |
|
edge_map, |
|
edge_maps_with_id, |
|
edge_id, |
|
edge_dict['context'], |
|
edge_dict['depth'], mesh, context_cc | erode_context_cc | extend_context_cc | extend_erode_context_ccs[edge_id], spdb=False) |
|
if specific_edge_loc is not None and \ |
|
(specific_edge_loc is not None and edge_dict['mask'][specific_edge_loc[0], specific_edge_loc[1]] == 0): |
|
continue |
|
mask_size = get_valid_size(edge_dict['mask']) |
|
mask_size = dilate_valid_size(mask_size, edge_dict['mask'], dilate=[20, 20]) |
|
context_size = get_valid_size(edge_dict['context']) |
|
context_size = dilate_valid_size(context_size, edge_dict['context'], dilate=[20, 20]) |
|
union_size = size_operation(mask_size, context_size, operation='+') |
|
patch_edge_dict = dict() |
|
patch_edge_dict['mask'], patch_edge_dict['context'], patch_edge_dict['rgb'], \ |
|
patch_edge_dict['disp'], patch_edge_dict['edge'] = \ |
|
crop_maps_by_size(union_size, edge_dict['mask'], edge_dict['context'], |
|
edge_dict['rgb'], edge_dict['disp'], edge_dict['edge']) |
|
x_anchor, y_anchor = [union_size['x_min'], union_size['x_max']], [union_size['y_min'], union_size['y_max']] |
|
tensor_edge_dict = convert2tensor(patch_edge_dict) |
|
input_edge_feat = torch.cat((tensor_edge_dict['rgb'], |
|
tensor_edge_dict['disp'], |
|
tensor_edge_dict['edge'], |
|
1 - tensor_edge_dict['context'], |
|
tensor_edge_dict['mask']), dim=1) |
|
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0: |
|
with torch.no_grad(): |
|
depth_edge_output = depth_edge_model.forward_3P(tensor_edge_dict['mask'], |
|
tensor_edge_dict['context'], |
|
tensor_edge_dict['rgb'], |
|
tensor_edge_dict['disp'], |
|
tensor_edge_dict['edge'], |
|
unit_length=128, |
|
cuda=device) |
|
depth_edge_output = depth_edge_output.cpu() |
|
tensor_edge_dict['output'] = (depth_edge_output> config['ext_edge_threshold']).float() * tensor_edge_dict['mask'] + tensor_edge_dict['edge'] |
|
else: |
|
tensor_edge_dict['output'] = tensor_edge_dict['edge'] |
|
depth_edge_output = tensor_edge_dict['edge'] + 0 |
|
patch_edge_dict['output'] = tensor_edge_dict['output'].squeeze().data.cpu().numpy() |
|
edge_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
edge_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \ |
|
patch_edge_dict['output'] |
|
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0: |
|
if ((depth_edge_output> config['ext_edge_threshold']).float() * tensor_edge_dict['mask']).max() > 0: |
|
try: |
|
edge_dict['fpath_map'], edge_dict['npath_map'], break_flag, npaths, fpaths, invalid_edge_id = \ |
|
clean_far_edge_new(edge_dict['output'], end_depth_maps, edge_dict['mask'], edge_dict['context'], mesh, info_on_pix, edge_dict['self_edge'], inpaint_iter, config) |
|
except: |
|
import pdb; pdb.set_trace() |
|
pre_npath_map = edge_dict['npath_map'].copy() |
|
if config.get('repeat_inpaint_edge') is True: |
|
for _ in range(2): |
|
tmp_input_edge = ((edge_dict['npath_map'] > -1) + edge_dict['edge']).clip(0, 1) |
|
patch_tmp_input_edge = crop_maps_by_size(union_size, tmp_input_edge)[0] |
|
tensor_input_edge = torch.FloatTensor(patch_tmp_input_edge)[None, None, ...] |
|
depth_edge_output = depth_edge_model.forward_3P(tensor_edge_dict['mask'], |
|
tensor_edge_dict['context'], |
|
tensor_edge_dict['rgb'], |
|
tensor_edge_dict['disp'], |
|
tensor_input_edge, |
|
unit_length=128, |
|
cuda=device) |
|
depth_edge_output = depth_edge_output.cpu() |
|
depth_edge_output = (depth_edge_output> config['ext_edge_threshold']).float() * tensor_edge_dict['mask'] + tensor_edge_dict['edge'] |
|
depth_edge_output = depth_edge_output.squeeze().data.cpu().numpy() |
|
full_depth_edge_output = np.zeros((mesh.graph['H'], mesh.graph['W'])) |
|
full_depth_edge_output[union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \ |
|
depth_edge_output |
|
edge_dict['fpath_map'], edge_dict['npath_map'], break_flag, npaths, fpaths, invalid_edge_id = \ |
|
clean_far_edge_new(full_depth_edge_output, end_depth_maps, edge_dict['mask'], edge_dict['context'], mesh, info_on_pix, edge_dict['self_edge'], inpaint_iter, config) |
|
for nid in npaths.keys(): |
|
npath, fpath = npaths[nid], fpaths[nid] |
|
start_mx, start_my, end_mx, end_my = -1, -1, -1, -1 |
|
if end_depth_maps[npath[0][0], npath[0][1]] != 0: |
|
start_mx, start_my = npath[0][0], npath[0][1] |
|
if end_depth_maps[npath[-1][0], npath[-1][1]] != 0: |
|
end_mx, end_my = npath[-1][0], npath[-1][1] |
|
if start_mx == -1: |
|
import pdb; pdb.set_trace() |
|
valid_end_pt = () if end_mx == -1 else (end_mx, end_my, info_on_pix[(end_mx, end_my)][0]['depth']) |
|
new_edge_info = dict(fpath=fpath, |
|
npath=npath, |
|
cont_end_pts=valid_end_pt, |
|
mask_id=edge_id, |
|
comp_edge_id=nid, |
|
depth=end_depth_maps[start_mx, start_my]) |
|
if edges_infos.get((start_mx, start_my)) is None: |
|
edges_infos[(start_mx, start_my)] = [] |
|
edges_infos[(start_mx, start_my)].append(new_edge_info) |
|
edges_in_mask[edge_id].add((start_mx, start_my)) |
|
if len(valid_end_pt) > 0: |
|
new_edge_info = dict(fpath=fpath[::-1], |
|
npath=npath[::-1], |
|
cont_end_pts=(start_mx, start_my, info_on_pix[(start_mx, start_my)][0]['depth']), |
|
mask_id=edge_id, |
|
comp_edge_id=nid, |
|
depth=end_depth_maps[end_mx, end_my]) |
|
if edges_infos.get((end_mx, end_my)) is None: |
|
edges_infos[(end_mx, end_my)] = [] |
|
edges_infos[(end_mx, end_my)].append(new_edge_info) |
|
edges_in_mask[edge_id].add((end_mx, end_my)) |
|
for edge_id, (context_cc, mask_cc, erode_context_cc, extend_context_cc, edge_cc) in enumerate(zip(context_ccs, mask_ccs, erode_context_ccs, extend_context_ccs, edge_ccs)): |
|
if len(specific_edge_id) > 0: |
|
if edge_id not in specific_edge_id: |
|
continue |
|
if len(context_cc) < 1 or len(mask_cc) < 1: |
|
continue |
|
edge_dict = get_edge_from_nodes(context_cc | extend_context_cc, erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, edge_cc, extend_edge_ccs[edge_id], |
|
mesh.graph['H'], mesh.graph['W'], mesh) |
|
if specific_edge_loc is not None and \ |
|
(specific_edge_loc is not None and edge_dict['mask'][specific_edge_loc[0], specific_edge_loc[1]] == 0): |
|
continue |
|
else: |
|
tmp_specific_edge_id.append(edge_id) |
|
edge_dict['edge'], end_depth_maps, _ = \ |
|
filter_irrelevant_edge_new(edge_dict['self_edge'], edge_dict['comp_edge'], |
|
edge_map, |
|
edge_maps_with_id, |
|
edge_id, |
|
edge_dict['context'], |
|
edge_dict['depth'], mesh, context_cc | erode_context_cc | extend_context_cc | extend_erode_context_ccs[edge_id], spdb=False) |
|
discard_map = np.zeros_like(edge_dict['edge']) |
|
mask_size = get_valid_size(edge_dict['mask']) |
|
mask_size = dilate_valid_size(mask_size, edge_dict['mask'], dilate=[20, 20]) |
|
context_size = get_valid_size(edge_dict['context']) |
|
context_size = dilate_valid_size(context_size, edge_dict['context'], dilate=[20, 20]) |
|
union_size = size_operation(mask_size, context_size, operation='+') |
|
patch_edge_dict = dict() |
|
patch_edge_dict['mask'], patch_edge_dict['context'], patch_edge_dict['rgb'], \ |
|
patch_edge_dict['disp'], patch_edge_dict['edge'] = \ |
|
crop_maps_by_size(union_size, edge_dict['mask'], edge_dict['context'], |
|
edge_dict['rgb'], edge_dict['disp'], edge_dict['edge']) |
|
x_anchor, y_anchor = [union_size['x_min'], union_size['x_max']], [union_size['y_min'], union_size['y_max']] |
|
tensor_edge_dict = convert2tensor(patch_edge_dict) |
|
input_edge_feat = torch.cat((tensor_edge_dict['rgb'], |
|
tensor_edge_dict['disp'], |
|
tensor_edge_dict['edge'], |
|
1 - tensor_edge_dict['context'], |
|
tensor_edge_dict['mask']), dim=1) |
|
edge_dict['output'] = edge_dict['edge'].copy() |
|
|
|
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0: |
|
edge_dict['fpath_map'], edge_dict['npath_map'] = edge_dict['fpath_map'] * 0 - 1, edge_dict['npath_map'] * 0 - 1 |
|
end_pts = edges_in_mask[edge_id] |
|
for end_pt in end_pts: |
|
cur_edge_infos = edges_infos[(end_pt[0], end_pt[1])] |
|
cur_info = [xx for xx in cur_edge_infos if xx['mask_id'] == edge_id][0] |
|
other_infos = [xx for xx in cur_edge_infos if xx['mask_id'] != edge_id and len(xx['cont_end_pts']) > 0] |
|
if len(cur_info['cont_end_pts']) > 0 or (len(cur_info['cont_end_pts']) == 0 and len(other_infos) == 0): |
|
for fnode in cur_info['fpath']: |
|
edge_dict['fpath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id'] |
|
for fnode in cur_info['npath']: |
|
edge_dict['npath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id'] |
|
fnmap = edge_dict['fpath_map'] * 1 |
|
fnmap[edge_dict['npath_map'] != -1] = edge_dict['npath_map'][edge_dict['npath_map'] != -1] |
|
for end_pt in end_pts: |
|
cur_edge_infos = edges_infos[(end_pt[0], end_pt[1])] |
|
cur_info = [xx for xx in cur_edge_infos if xx['mask_id'] == edge_id][0] |
|
cur_depth = cur_info['depth'] |
|
other_infos = [xx for xx in cur_edge_infos if xx['mask_id'] != edge_id and len(xx['cont_end_pts']) > 0] |
|
comp_edge_id = cur_info['comp_edge_id'] |
|
if len(cur_info['cont_end_pts']) == 0 and len(other_infos) > 0: |
|
other_infos = sorted(other_infos, key=lambda aa: abs(abs(aa['cont_end_pts'][2]) - abs(cur_depth))) |
|
for other_info in other_infos: |
|
tmp_fmap, tmp_nmap = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1, np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1 |
|
for fnode in other_info['fpath']: |
|
if fnmap[fnode[0], fnode[1]] != -1: |
|
tmp_fmap = tmp_fmap * 0 - 1 |
|
break |
|
else: |
|
tmp_fmap[fnode[0], fnode[1]] = comp_edge_id |
|
if fnmap[fnode[0], fnode[1]] != -1: |
|
continue |
|
for fnode in other_info['npath']: |
|
if fnmap[fnode[0], fnode[1]] != -1: |
|
tmp_nmap = tmp_nmap * 0 - 1 |
|
break |
|
else: |
|
tmp_nmap[fnode[0], fnode[1]] = comp_edge_id |
|
if fnmap[fnode[0], fnode[1]] != -1: |
|
continue |
|
break |
|
if min(tmp_fmap.max(), tmp_nmap.max()) != -1: |
|
edge_dict['fpath_map'] = tmp_fmap |
|
edge_dict['fpath_map'][edge_dict['valid_area'] == 0] = -1 |
|
edge_dict['npath_map'] = tmp_nmap |
|
edge_dict['npath_map'][edge_dict['valid_area'] == 0] = -1 |
|
discard_map = ((tmp_nmap != -1).astype(np.uint8) + (tmp_fmap != -1).astype(np.uint8)) * edge_dict['mask'] |
|
else: |
|
for fnode in cur_info['fpath']: |
|
edge_dict['fpath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id'] |
|
for fnode in cur_info['npath']: |
|
edge_dict['npath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id'] |
|
if edge_dict['npath_map'].min() == 0 or edge_dict['fpath_map'].min() == 0: |
|
import pdb; pdb.set_trace() |
|
edge_dict['output'] = (edge_dict['npath_map'] > -1) * edge_dict['mask'] + edge_dict['context'] * edge_dict['edge'] |
|
mesh, _, _, _ = create_placeholder(edge_dict['context'], edge_dict['mask'], |
|
edge_dict['depth'], edge_dict['fpath_map'], |
|
edge_dict['npath_map'], mesh, inpaint_iter, |
|
edge_ccs, |
|
extend_edge_ccs[edge_id], |
|
edge_maps_with_id, |
|
edge_id) |
|
|
|
dxs, dys = np.where(discard_map != 0) |
|
for dx, dy in zip(dxs, dys): |
|
mesh.nodes[(dx, dy)]['inpaint_twice'] = False |
|
depth_dict = depth_inpainting(context_cc, extend_context_cc, erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, mesh, config, union_size, depth_feat_model, edge_dict['output']) |
|
refine_depth_output = depth_dict['output']*depth_dict['mask'] |
|
for near_id in np.unique(edge_dict['npath_map'])[1:]: |
|
refine_depth_output = refine_depth_around_edge(refine_depth_output.copy(), |
|
(edge_dict['fpath_map'] == near_id).astype(np.uint8) * edge_dict['mask'], |
|
(edge_dict['fpath_map'] == near_id).astype(np.uint8), |
|
(edge_dict['npath_map'] == near_id).astype(np.uint8) * edge_dict['mask'], |
|
depth_dict['mask'].copy(), |
|
depth_dict['output'] * depth_dict['context'], |
|
config) |
|
depth_dict['output'][depth_dict['mask'] > 0] = refine_depth_output[depth_dict['mask'] > 0] |
|
rgb_dict = get_rgb_from_nodes(context_cc | extend_context_cc, |
|
erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, mesh.graph['H'], mesh.graph['W'], mesh) |
|
if np.all(rgb_dict['mask'] == edge_dict['mask']) is False: |
|
import pdb; pdb.set_trace() |
|
rgb_dict['edge'] = edge_dict['output'] |
|
patch_rgb_dict = dict() |
|
patch_rgb_dict['mask'], patch_rgb_dict['context'], patch_rgb_dict['rgb'], \ |
|
patch_rgb_dict['edge'] = crop_maps_by_size(union_size, rgb_dict['mask'], |
|
rgb_dict['context'], rgb_dict['rgb'], |
|
rgb_dict['edge']) |
|
tensor_rgb_dict = convert2tensor(patch_rgb_dict) |
|
resize_rgb_dict = {k: v.clone() for k, v in tensor_rgb_dict.items()} |
|
max_hw = np.array([*patch_rgb_dict['mask'].shape[-2:]]).max() |
|
init_frac = config['largest_size'] / (np.array([*patch_rgb_dict['mask'].shape[-2:]]).prod() ** 0.5) |
|
resize_hw = [patch_rgb_dict['mask'].shape[-2] * init_frac, patch_rgb_dict['mask'].shape[-1] * init_frac] |
|
resize_max_hw = max(resize_hw) |
|
frac = (np.floor(resize_max_hw / 128.) * 128.) / max_hw |
|
if frac < 1: |
|
resize_mark = torch.nn.functional.interpolate(torch.cat((resize_rgb_dict['mask'], |
|
resize_rgb_dict['context']), |
|
dim=1), |
|
scale_factor=frac, |
|
mode='area') |
|
resize_rgb_dict['mask'] = (resize_mark[:, 0:1] > 0).float() |
|
resize_rgb_dict['context'] = (resize_mark[:, 1:2] == 1).float() |
|
resize_rgb_dict['context'][resize_rgb_dict['mask'] > 0] = 0 |
|
resize_rgb_dict['rgb'] = torch.nn.functional.interpolate(resize_rgb_dict['rgb'], |
|
scale_factor=frac, |
|
mode='area') |
|
resize_rgb_dict['rgb'] = resize_rgb_dict['rgb'] * resize_rgb_dict['context'] |
|
resize_rgb_dict['edge'] = torch.nn.functional.interpolate(resize_rgb_dict['edge'], |
|
scale_factor=frac, |
|
mode='area') |
|
resize_rgb_dict['edge'] = (resize_rgb_dict['edge'] > 0).float() * 0 |
|
resize_rgb_dict['edge'] = resize_rgb_dict['edge'] * (resize_rgb_dict['context'] + resize_rgb_dict['mask']) |
|
rgb_input_feat = torch.cat((resize_rgb_dict['rgb'], resize_rgb_dict['edge']), dim=1) |
|
rgb_input_feat[:, 3] = 1 - rgb_input_feat[:, 3] |
|
resize_mask = open_small_mask(resize_rgb_dict['mask'], resize_rgb_dict['context'], 3, 41) |
|
specified_hole = resize_mask |
|
with torch.no_grad(): |
|
rgb_output = rgb_model.forward_3P(specified_hole, |
|
resize_rgb_dict['context'], |
|
resize_rgb_dict['rgb'], |
|
resize_rgb_dict['edge'], |
|
unit_length=128, |
|
cuda=device) |
|
rgb_output = rgb_output.cpu() |
|
if config.get('gray_image') is True: |
|
rgb_output = rgb_output.mean(1, keepdim=True).repeat((1,3,1,1)) |
|
rgb_output = rgb_output.cpu() |
|
resize_rgb_dict['output'] = rgb_output * resize_rgb_dict['mask'] + resize_rgb_dict['rgb'] |
|
tensor_rgb_dict['output'] = resize_rgb_dict['output'] |
|
if frac < 1: |
|
tensor_rgb_dict['output'] = torch.nn.functional.interpolate(tensor_rgb_dict['output'], |
|
size=tensor_rgb_dict['mask'].shape[-2:], |
|
mode='bicubic') |
|
tensor_rgb_dict['output'] = tensor_rgb_dict['output'] * \ |
|
tensor_rgb_dict['mask'] + (tensor_rgb_dict['rgb'] * tensor_rgb_dict['context']) |
|
patch_rgb_dict['output'] = tensor_rgb_dict['output'].data.cpu().numpy().squeeze().transpose(1,2,0) |
|
rgb_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W'], 3)) |
|
rgb_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \ |
|
patch_rgb_dict['output'] |
|
|
|
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) or inpaint_iter > 0: |
|
edge_occlusion = True |
|
else: |
|
edge_occlusion = False |
|
for node in erode_context_cc: |
|
if rgb_dict['mask'][node[0], node[1]] > 0: |
|
for info in info_on_pix[(node[0], node[1])]: |
|
if abs(info['depth']) == abs(node[2]): |
|
info['update_color'] = (rgb_dict['output'][node[0], node[1]] * 255).astype(np.uint8) |
|
if frac < 1.: |
|
depth_edge_dilate_2_color_flag = False |
|
else: |
|
depth_edge_dilate_2_color_flag = True |
|
hxs, hys = np.where((rgb_dict['mask'] > 0) & (rgb_dict['erode'] == 0)) |
|
for hx, hy in zip(hxs, hys): |
|
real_depth = None |
|
if abs(depth_dict['output'][hx, hy]) <= abs(np_depth[hx, hy]): |
|
depth_dict['output'][hx, hy] = np_depth[hx, hy] + 0.01 |
|
node = (hx, hy, -depth_dict['output'][hx, hy]) |
|
if info_on_pix.get((node[0], node[1])) is not None: |
|
for info in info_on_pix.get((node[0], node[1])): |
|
if info.get('inpaint_id') is None or abs(info['inpaint_id'] < mesh.nodes[(hx, hy)]['inpaint_id']): |
|
pre_depth = info['depth'] if info.get('real_depth') is None else info['real_depth'] |
|
if abs(node[2]) < abs(pre_depth): |
|
node = (node[0], node[1], -(abs(pre_depth) + 0.001)) |
|
if mesh.has_node(node): |
|
real_depth = node[2] |
|
while True: |
|
if mesh.has_node(node): |
|
node = (node[0], node[1], -(abs(node[2]) + 0.001)) |
|
else: |
|
break |
|
if real_depth == node[2]: |
|
real_depth = None |
|
cur_disp = 1./node[2] |
|
if not(mesh.has_node(node)): |
|
if not mesh.has_node((node[0], node[1])): |
|
print("2D node not found.") |
|
import pdb; pdb.set_trace() |
|
if inpaint_iter == 1: |
|
paint = (rgb_dict['output'][hx, hy] * 255).astype(np.uint8) |
|
else: |
|
paint = (rgb_dict['output'][hx, hy] * 255).astype(np.uint8) |
|
ndict = dict(color=paint, |
|
synthesis=True, |
|
disp=cur_disp, |
|
cc_id=set([edge_id]), |
|
overlap_number=1.0, |
|
refine_depth=False, |
|
edge_occlusion=edge_occlusion, |
|
depth_edge_dilate_2_color_flag=depth_edge_dilate_2_color_flag, |
|
real_depth=real_depth) |
|
mesh, _, _ = refresh_node((node[0], node[1]), mesh.nodes[(node[0], node[1])], node, ndict, mesh, stime=True) |
|
if inpaint_iter == 0 and mesh.degree(node) < 4: |
|
connnect_points_ccs[edge_id].add(node) |
|
if info_on_pix.get((hx, hy)) is None: |
|
info_on_pix[(hx, hy)] = [] |
|
new_info = {'depth':node[2], |
|
'color': paint, |
|
'synthesis':True, |
|
'disp':cur_disp, |
|
'cc_id':set([edge_id]), |
|
'inpaint_id':inpaint_iter + 1, |
|
'edge_occlusion':edge_occlusion, |
|
'overlap_number':1.0, |
|
'real_depth': real_depth} |
|
info_on_pix[(hx, hy)].append(new_info) |
|
specific_edge_id = tmp_specific_edge_id |
|
for erode_id, erode_context_cc in enumerate(erode_context_ccs): |
|
if len(specific_edge_id) > 0 and erode_id not in specific_edge_id: |
|
continue |
|
for erode_node in erode_context_cc: |
|
for info in info_on_pix[(erode_node[0], erode_node[1])]: |
|
if info['depth'] == erode_node[2]: |
|
info['color'] = info['update_color'] |
|
mesh.nodes[erode_node]['color'] = info['update_color'] |
|
np_image[(erode_node[0], erode_node[1])] = info['update_color'] |
|
new_edge_ccs = [set() for _ in range(mesh.graph['max_edge_id'] + 1)] |
|
for node in mesh.nodes: |
|
if len(node) == 2: |
|
mesh.remove_node(node) |
|
continue |
|
if mesh.nodes[node].get('edge_id') is not None and mesh.nodes[node].get('inpaint_id') == inpaint_iter + 1: |
|
if mesh.nodes[node].get('inpaint_twice') is False: |
|
continue |
|
try: |
|
new_edge_ccs[mesh.nodes[node].get('edge_id')].add(node) |
|
except: |
|
import pdb; pdb.set_trace() |
|
specific_mask_nodes = None |
|
if inpaint_iter == 0: |
|
mesh, info_on_pix = refine_color_around_edge(mesh, info_on_pix, new_edge_ccs, config, False) |
|
|
|
return mesh, info_on_pix, specific_mask_nodes, new_edge_ccs, connnect_points_ccs, np_image |
|
|
|
|
|
def write_ply(image, |
|
depth, |
|
int_mtx, |
|
ply_name, |
|
config, |
|
rgb_model, |
|
depth_edge_model, |
|
depth_edge_model_init, |
|
depth_feat_model): |
|
depth = depth.astype(np.float64) |
|
input_mesh, xy2depth, image, depth = create_mesh(depth, image, int_mtx, config) |
|
|
|
H, W = input_mesh.graph['H'], input_mesh.graph['W'] |
|
input_mesh = tear_edges(input_mesh, config['depth_threshold'], xy2depth) |
|
input_mesh, info_on_pix = generate_init_node(input_mesh, config, min_node_in_cc=200) |
|
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=False) |
|
edge_canvas = np.zeros((H, W)) - 1 |
|
|
|
input_mesh, info_on_pix, depth = reassign_floating_island(input_mesh, info_on_pix, image, depth) |
|
input_mesh = update_status(input_mesh, info_on_pix) |
|
specific_edge_id = [] |
|
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True) |
|
pre_depth = depth.copy() |
|
input_mesh, info_on_pix, edge_mesh, depth, aft_mark = remove_dangling(input_mesh, edge_ccs, edge_mesh, info_on_pix, image, depth, config) |
|
|
|
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth) |
|
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True) |
|
edge_canvas = np.zeros((H, W)) - 1 |
|
|
|
mesh, info_on_pix, depth = fill_missing_node(input_mesh, info_on_pix, image, depth) |
|
if config['extrapolate_border'] is True: |
|
pre_depth = depth.copy() |
|
input_mesh, info_on_pix, depth = refresh_bord_depth(input_mesh, info_on_pix, image, depth) |
|
input_mesh = remove_node_feat(input_mesh, 'edge_id') |
|
aft_depth = depth.copy() |
|
input_mesh, info_on_pix, depth, image = enlarge_border(input_mesh, info_on_pix, depth, image, config) |
|
noext_H, noext_W = H, W |
|
H, W = image.shape[:2] |
|
input_mesh, info_on_pix = fill_dummy_bord(input_mesh, info_on_pix, image, depth, config) |
|
edge_ccs, input_mesh, edge_mesh = \ |
|
group_edges(input_mesh, config, image, remove_conflict_ordinal=True) |
|
input_mesh = combine_end_node(input_mesh, edge_mesh, edge_ccs, depth) |
|
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth) |
|
edge_ccs, input_mesh, edge_mesh = \ |
|
group_edges(input_mesh, config, image, remove_conflict_ordinal=True, spdb=False) |
|
input_mesh = remove_redundant_edge(input_mesh, edge_mesh, edge_ccs, info_on_pix, config, redundant_number=config['redundant_number'], spdb=False) |
|
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth) |
|
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True) |
|
input_mesh = combine_end_node(input_mesh, edge_mesh, edge_ccs, depth) |
|
input_mesh = remove_redundant_edge(input_mesh, edge_mesh, edge_ccs, info_on_pix, config, redundant_number=config['redundant_number'], invalid=True, spdb=False) |
|
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth) |
|
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True) |
|
input_mesh = combine_end_node(input_mesh, edge_mesh, edge_ccs, depth) |
|
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth) |
|
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True) |
|
edge_condition = lambda x, m: m.nodes[x].get('far') is not None and len(m.nodes[x].get('far')) > 0 |
|
edge_map = get_map_from_ccs(edge_ccs, input_mesh.graph['H'], input_mesh.graph['W'], input_mesh, edge_condition) |
|
other_edge_with_id = get_map_from_ccs(edge_ccs, input_mesh.graph['H'], input_mesh.graph['W'], real_id=True) |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="up") |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="left") |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="down") |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="right") |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="right-up") |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="right-down") |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="left-up") |
|
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs, |
|
depth_edge_model, depth_feat_model, rgb_model, config, direc="left-down") |
|
specific_edge_loc = None |
|
specific_edge_id = [] |
|
vis_edge_id = None |
|
context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, \ |
|
init_mask_connect, edge_maps, extend_context_ccs, extend_edge_ccs, extend_erode_context_ccs = \ |
|
context_and_holes(input_mesh, |
|
edge_ccs, |
|
config, |
|
specific_edge_id, |
|
specific_edge_loc, |
|
depth_feat_model, |
|
inpaint_iter=0, |
|
vis_edge_id=vis_edge_id) |
|
edge_canvas = np.zeros((H, W)) |
|
mask = np.zeros((H, W)) |
|
context = np.zeros((H, W)) |
|
vis_edge_ccs = filter_edge(input_mesh, edge_ccs, config) |
|
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) - 1 |
|
specific_edge_loc = None |
|
FG_edge_maps = edge_maps.copy() |
|
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) - 1 |
|
|
|
|
|
|
|
|
|
input_mesh, info_on_pix, specific_edge_nodes, new_edge_ccs, connect_points_ccs, image = DL_inpaint_edge(input_mesh, |
|
info_on_pix, |
|
config, |
|
image, |
|
depth, |
|
context_ccs, |
|
erode_context_ccs, |
|
extend_context_ccs, |
|
extend_erode_context_ccs, |
|
mask_ccs, |
|
broken_mask_ccs, |
|
edge_ccs, |
|
extend_edge_ccs, |
|
init_mask_connect, |
|
edge_maps, |
|
rgb_model, |
|
depth_edge_model, |
|
depth_edge_model_init, |
|
depth_feat_model, |
|
specific_edge_id, |
|
specific_edge_loc, |
|
inpaint_iter=0) |
|
specific_edge_id = [] |
|
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) |
|
connect_points_ccs = [set() for _ in connect_points_ccs] |
|
context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, init_mask_connect, \ |
|
edge_maps, extend_context_ccs, extend_edge_ccs, extend_erode_context_ccs = \ |
|
context_and_holes(input_mesh, new_edge_ccs, config, specific_edge_id, specific_edge_loc, depth_feat_model, connect_points_ccs, inpaint_iter=1) |
|
mask_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) |
|
context_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) |
|
erode_context_ccs_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) |
|
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) |
|
|
|
|
|
|
|
|
|
specific_edge_id = [] |
|
input_mesh, info_on_pix, specific_edge_nodes, new_edge_ccs, _, image = DL_inpaint_edge(input_mesh, |
|
info_on_pix, |
|
config, |
|
image, |
|
depth, |
|
context_ccs, |
|
erode_context_ccs, |
|
extend_context_ccs, |
|
extend_erode_context_ccs, |
|
mask_ccs, |
|
broken_mask_ccs, |
|
edge_ccs, |
|
extend_edge_ccs, |
|
init_mask_connect, |
|
edge_maps, |
|
rgb_model, |
|
depth_edge_model, |
|
depth_edge_model_init, |
|
depth_feat_model, |
|
specific_edge_id, |
|
specific_edge_loc, |
|
inpaint_iter=1) |
|
vertex_id = 0 |
|
input_mesh.graph['H'], input_mesh.graph['W'] = input_mesh.graph['noext_H'], input_mesh.graph['noext_W'] |
|
background_canvas = np.zeros((input_mesh.graph['H'], |
|
input_mesh.graph['W'], |
|
3)) |
|
ply_flag = config.get('save_ply') |
|
if ply_flag is True: |
|
node_str_list = [] |
|
else: |
|
node_str_color = [] |
|
node_str_point = [] |
|
out_fmt = lambda x, x_flag: str(x) if x_flag is True else x |
|
point_time = 0 |
|
hlight_time = 0 |
|
cur_id_time = 0 |
|
node_str_time = 0 |
|
generate_face_time = 0 |
|
point_list = [] |
|
k_00, k_02, k_11, k_12 = \ |
|
input_mesh.graph['cam_param_pix_inv'][0, 0], input_mesh.graph['cam_param_pix_inv'][0, 2], \ |
|
input_mesh.graph['cam_param_pix_inv'][1, 1], input_mesh.graph['cam_param_pix_inv'][1, 2] |
|
w_offset = input_mesh.graph['woffset'] |
|
h_offset = input_mesh.graph['hoffset'] |
|
for pix_xy, pix_list in info_on_pix.items(): |
|
for pix_idx, pix_info in enumerate(pix_list): |
|
pix_depth = pix_info['depth'] if pix_info.get('real_depth') is None else pix_info['real_depth'] |
|
str_pt = [out_fmt(x, ply_flag) for x in reproject_3d_int_detail(pix_xy[0], pix_xy[1], pix_depth, |
|
k_00, k_02, k_11, k_12, w_offset, h_offset)] |
|
if input_mesh.has_node((pix_xy[0], pix_xy[1], pix_info['depth'])) is False: |
|
return False |
|
continue |
|
if pix_info.get('overlap_number') is not None: |
|
str_color = [out_fmt(x, ply_flag) for x in (pix_info['color']/pix_info['overlap_number']).astype(np.uint8).tolist()] |
|
else: |
|
str_color = [out_fmt(x, ply_flag) for x in pix_info['color'].tolist()] |
|
if pix_info.get('edge_occlusion') is True: |
|
str_color.append(out_fmt(4, ply_flag)) |
|
else: |
|
if pix_info.get('inpaint_id') is None: |
|
str_color.append(out_fmt(1, ply_flag)) |
|
else: |
|
str_color.append(out_fmt(pix_info.get('inpaint_id') + 1, ply_flag)) |
|
if pix_info.get('modified_border') is True or pix_info.get('ext_pixel') is True: |
|
if len(str_color) == 4: |
|
str_color[-1] = out_fmt(5, ply_flag) |
|
else: |
|
str_color.append(out_fmt(5, ply_flag)) |
|
pix_info['cur_id'] = vertex_id |
|
input_mesh.nodes[(pix_xy[0], pix_xy[1], pix_info['depth'])]['cur_id'] = out_fmt(vertex_id, ply_flag) |
|
vertex_id += 1 |
|
if ply_flag is True: |
|
node_str_list.append(' '.join(str_pt) + ' ' + ' '.join(str_color) + '\n') |
|
else: |
|
node_str_color.append(str_color) |
|
node_str_point.append(str_pt) |
|
str_faces = generate_face(input_mesh, info_on_pix, config) |
|
if config['save_ply'] is True: |
|
print("Writing mesh file %s ..." % ply_name) |
|
with open(ply_name, 'w') as ply_fi: |
|
ply_fi.write('ply\n' + 'format ascii 1.0\n') |
|
ply_fi.write('comment H ' + str(int(input_mesh.graph['H'])) + '\n') |
|
ply_fi.write('comment W ' + str(int(input_mesh.graph['W'])) + '\n') |
|
ply_fi.write('comment hFov ' + str(float(input_mesh.graph['hFov'])) + '\n') |
|
ply_fi.write('comment vFov ' + str(float(input_mesh.graph['vFov'])) + '\n') |
|
ply_fi.write('element vertex ' + str(len(node_str_list)) + '\n') |
|
ply_fi.write('property float x\n' + \ |
|
'property float y\n' + \ |
|
'property float z\n' + \ |
|
'property uchar red\n' + \ |
|
'property uchar green\n' + \ |
|
'property uchar blue\n' + \ |
|
'property uchar alpha\n') |
|
ply_fi.write('element face ' + str(len(str_faces)) + '\n') |
|
ply_fi.write('property list uchar int vertex_index\n') |
|
ply_fi.write('end_header\n') |
|
ply_fi.writelines(node_str_list) |
|
ply_fi.writelines(str_faces) |
|
ply_fi.close() |
|
return input_mesh |
|
else: |
|
H = int(input_mesh.graph['H']) |
|
W = int(input_mesh.graph['W']) |
|
hFov = input_mesh.graph['hFov'] |
|
vFov = input_mesh.graph['vFov'] |
|
node_str_color = np.array(node_str_color).astype(np.float32) |
|
node_str_color[..., :3] = node_str_color[..., :3] / 255. |
|
node_str_point = np.array(node_str_point) |
|
str_faces = np.array(str_faces) |
|
|
|
return node_str_point, node_str_color, str_faces, H, W, hFov, vFov |
|
|
|
def read_ply(mesh_fi): |
|
ply_fi = open(mesh_fi, 'r') |
|
Height = None |
|
Width = None |
|
hFov = None |
|
vFov = None |
|
while True: |
|
line = ply_fi.readline().split('\n')[0] |
|
if line.startswith('element vertex'): |
|
num_vertex = int(line.split(' ')[-1]) |
|
elif line.startswith('element face'): |
|
num_face = int(line.split(' ')[-1]) |
|
elif line.startswith('comment'): |
|
if line.split(' ')[1] == 'H': |
|
Height = int(line.split(' ')[-1].split('\n')[0]) |
|
if line.split(' ')[1] == 'W': |
|
Width = int(line.split(' ')[-1].split('\n')[0]) |
|
if line.split(' ')[1] == 'hFov': |
|
hFov = float(line.split(' ')[-1].split('\n')[0]) |
|
if line.split(' ')[1] == 'vFov': |
|
vFov = float(line.split(' ')[-1].split('\n')[0]) |
|
elif line.startswith('end_header'): |
|
break |
|
contents = ply_fi.readlines() |
|
vertex_infos = contents[:num_vertex] |
|
face_infos = contents[num_vertex:] |
|
verts = [] |
|
colors = [] |
|
faces = [] |
|
for v_info in vertex_infos: |
|
str_info = [float(v) for v in v_info.split('\n')[0].split(' ')] |
|
if len(str_info) == 6: |
|
vx, vy, vz, r, g, b = str_info |
|
else: |
|
vx, vy, vz, r, g, b, hi = str_info |
|
verts.append([vx, vy, vz]) |
|
colors.append([r, g, b, hi]) |
|
verts = np.array(verts) |
|
try: |
|
colors = np.array(colors) |
|
colors[..., :3] = colors[..., :3]/255. |
|
except: |
|
import pdb |
|
pdb.set_trace() |
|
|
|
for f_info in face_infos: |
|
_, v1, v2, v3 = [int(f) for f in f_info.split('\n')[0].split(' ')] |
|
faces.append([v1, v2, v3]) |
|
faces = np.array(faces) |
|
|
|
|
|
return verts, colors, faces, Height, Width, hFov, vFov |
|
|
|
|
|
class Canvas_view(): |
|
def __init__(self, |
|
fov, |
|
verts, |
|
faces, |
|
colors, |
|
canvas_size, |
|
factor=1, |
|
bgcolor='gray', |
|
proj='perspective', |
|
): |
|
self.canvas = scene.SceneCanvas(bgcolor=bgcolor, size=(canvas_size*factor, canvas_size*factor)) |
|
self.view = self.canvas.central_widget.add_view() |
|
self.view.camera = 'perspective' |
|
self.view.camera.fov = fov |
|
self.mesh = visuals.Mesh(shading=None) |
|
self.mesh.attach(Alpha(1.0)) |
|
self.view.add(self.mesh) |
|
self.tr = self.view.camera.transform |
|
self.mesh.set_data(vertices=verts, faces=faces, vertex_colors=colors[:, :3]) |
|
self.translate([0,0,0]) |
|
self.rotate(axis=[1,0,0], angle=180) |
|
self.view_changed() |
|
|
|
def translate(self, trans=[0,0,0]): |
|
self.tr.translate(trans) |
|
|
|
def rotate(self, axis=[1,0,0], angle=0): |
|
self.tr.rotate(axis=axis, angle=angle) |
|
|
|
def view_changed(self): |
|
self.view.camera.view_changed() |
|
|
|
def render(self): |
|
return self.canvas.render() |
|
|
|
def reinit_mesh(self, verts, faces, colors): |
|
self.mesh.set_data(vertices=verts, faces=faces, vertex_colors=colors[:, :3]) |
|
|
|
def reinit_camera(self, fov): |
|
self.view.camera.fov = fov |
|
self.view.camera.view_changed() |
|
|
|
|
|
def output_3d_photo(verts, colors, faces, Height, Width, hFov, vFov, tgt_poses, video_traj_types, ref_pose, |
|
output_dir, ref_image, int_mtx, config, image, videos_poses, video_basename, original_H=None, original_W=None, |
|
border=None, depth=None, normal_canvas=None, all_canvas=None, mean_loc_depth=None): |
|
|
|
cam_mesh = netx.Graph() |
|
cam_mesh.graph['H'] = Height |
|
cam_mesh.graph['W'] = Width |
|
cam_mesh.graph['original_H'] = original_H |
|
cam_mesh.graph['original_W'] = original_W |
|
int_mtx_real_x = int_mtx[0] * Width |
|
int_mtx_real_y = int_mtx[1] * Height |
|
cam_mesh.graph['hFov'] = 2 * np.arctan((1. / 2.) * ((cam_mesh.graph['original_W']) / int_mtx_real_x[0])) |
|
cam_mesh.graph['vFov'] = 2 * np.arctan((1. / 2.) * ((cam_mesh.graph['original_H']) / int_mtx_real_y[1])) |
|
colors = colors[..., :3] |
|
|
|
fov_in_rad = max(cam_mesh.graph['vFov'], cam_mesh.graph['hFov']) |
|
fov = (fov_in_rad * 180 / np.pi) |
|
print("fov: " + str(fov)) |
|
init_factor = 1 |
|
if config.get('anti_flickering') is True: |
|
init_factor = 3 |
|
if (cam_mesh.graph['original_H'] is not None) and (cam_mesh.graph['original_W'] is not None): |
|
canvas_w = cam_mesh.graph['original_W'] |
|
canvas_h = cam_mesh.graph['original_H'] |
|
else: |
|
canvas_w = cam_mesh.graph['W'] |
|
canvas_h = cam_mesh.graph['H'] |
|
canvas_size = max(canvas_h, canvas_w) |
|
if normal_canvas is None: |
|
normal_canvas = Canvas_view(fov, |
|
verts, |
|
faces, |
|
colors, |
|
canvas_size=canvas_size, |
|
factor=init_factor, |
|
bgcolor='gray', |
|
proj='perspective') |
|
else: |
|
normal_canvas.reinit_mesh(verts, faces, colors) |
|
normal_canvas.reinit_camera(fov) |
|
img = normal_canvas.render() |
|
backup_img, backup_all_img, all_img_wo_bound = img.copy(), img.copy() * 0, img.copy() * 0 |
|
img = cv2.resize(img, (int(img.shape[1] / init_factor), int(img.shape[0] / init_factor)), interpolation=cv2.INTER_AREA) |
|
if border is None: |
|
border = [0, img.shape[0], 0, img.shape[1]] |
|
H, W = cam_mesh.graph['H'], cam_mesh.graph['W'] |
|
if (cam_mesh.graph['original_H'] is not None) and (cam_mesh.graph['original_W'] is not None): |
|
aspect_ratio = cam_mesh.graph['original_H'] / cam_mesh.graph['original_W'] |
|
else: |
|
aspect_ratio = cam_mesh.graph['H'] / cam_mesh.graph['W'] |
|
if aspect_ratio > 1: |
|
img_h_len = cam_mesh.graph['H'] if cam_mesh.graph.get('original_H') is None else cam_mesh.graph['original_H'] |
|
img_w_len = img_h_len / aspect_ratio |
|
anchor = [0, |
|
img.shape[0], |
|
int(max(0, int((img.shape[1])//2 - img_w_len//2))), |
|
int(min(int((img.shape[1])//2 + img_w_len//2), (img.shape[1])-1))] |
|
elif aspect_ratio <= 1: |
|
img_w_len = cam_mesh.graph['W'] if cam_mesh.graph.get('original_W') is None else cam_mesh.graph['original_W'] |
|
img_h_len = img_w_len * aspect_ratio |
|
anchor = [int(max(0, int((img.shape[0])//2 - img_h_len//2))), |
|
int(min(int((img.shape[0])//2 + img_h_len//2), (img.shape[0])-1)), |
|
0, |
|
img.shape[1]] |
|
anchor = np.array(anchor) |
|
plane_width = np.tan(fov_in_rad/2.) * np.abs(mean_loc_depth) |
|
for video_pose, video_traj_type in zip(videos_poses, video_traj_types): |
|
stereos = [] |
|
tops = []; buttoms = []; lefts = []; rights = [] |
|
for tp_id, tp in enumerate(video_pose): |
|
rel_pose = np.linalg.inv(np.dot(tp, np.linalg.inv(ref_pose))) |
|
axis, angle = transforms3d.axangles.mat2axangle(rel_pose[0:3, 0:3]) |
|
normal_canvas.rotate(axis=axis, angle=(angle*180)/np.pi) |
|
normal_canvas.translate(rel_pose[:3,3]) |
|
new_mean_loc_depth = mean_loc_depth - float(rel_pose[2, 3]) |
|
if 'dolly' in video_traj_type: |
|
new_fov = float((np.arctan2(plane_width, np.array([np.abs(new_mean_loc_depth)])) * 180. / np.pi) * 2) |
|
normal_canvas.reinit_camera(new_fov) |
|
else: |
|
normal_canvas.reinit_camera(fov) |
|
normal_canvas.view_changed() |
|
img = normal_canvas.render() |
|
img = cv2.GaussianBlur(img,(int(init_factor//2 * 2 + 1), int(init_factor//2 * 2 + 1)), 0) |
|
img = cv2.resize(img, (int(img.shape[1] / init_factor), int(img.shape[0] / init_factor)), interpolation=cv2.INTER_AREA) |
|
img = img[anchor[0]:anchor[1], anchor[2]:anchor[3]] |
|
img = img[int(border[0]):int(border[1]), int(border[2]):int(border[3])] |
|
|
|
if any(np.array(config['crop_border']) > 0.0): |
|
H_c, W_c, _ = img.shape |
|
o_t = int(H_c * config['crop_border'][0]) |
|
o_l = int(W_c * config['crop_border'][1]) |
|
o_b = int(H_c * config['crop_border'][2]) |
|
o_r = int(W_c * config['crop_border'][3]) |
|
img = img[o_t:H_c-o_b, o_l:W_c-o_r] |
|
img = cv2.resize(img, (W_c, H_c), interpolation=cv2.INTER_CUBIC) |
|
|
|
""" |
|
img = cv2.resize(img, (int(img.shape[1] / init_factor), int(img.shape[0] / init_factor)), interpolation=cv2.INTER_CUBIC) |
|
img = img[anchor[0]:anchor[1], anchor[2]:anchor[3]] |
|
img = img[int(border[0]):int(border[1]), int(border[2]):int(border[3])] |
|
|
|
if config['crop_border'] is True: |
|
top, buttom, left, right = find_largest_rect(img, bg_color=(128, 128, 128)) |
|
tops.append(top); buttoms.append(buttom); lefts.append(left); rights.append(right) |
|
""" |
|
stereos.append(img[..., :3]) |
|
normal_canvas.translate(-rel_pose[:3,3]) |
|
normal_canvas.rotate(axis=axis, angle=-(angle*180)/np.pi) |
|
normal_canvas.view_changed() |
|
""" |
|
if config['crop_border'] is True: |
|
atop, abuttom = min(max(tops), img.shape[0]//2 - 10), max(min(buttoms), img.shape[0]//2 + 10) |
|
aleft, aright = min(max(lefts), img.shape[1]//2 - 10), max(min(rights), img.shape[1]//2 + 10) |
|
atop -= atop % 2; abuttom -= abuttom % 2; aleft -= aleft % 2; aright -= aright % 2 |
|
else: |
|
atop = 0; abuttom = img.shape[0] - img.shape[0] % 2; aleft = 0; aright = img.shape[1] - img.shape[1] % 2 |
|
""" |
|
atop = 0; abuttom = img.shape[0] - img.shape[0] % 2; aleft = 0; aright = img.shape[1] - img.shape[1] % 2 |
|
crop_stereos = [] |
|
for stereo in stereos: |
|
crop_stereos.append((stereo[atop:abuttom, aleft:aright, :3] * 1).astype(np.uint8)) |
|
stereos = crop_stereos |
|
clip = ImageSequenceClip(stereos, fps=config['fps']) |
|
if isinstance(video_basename, list): |
|
video_basename = video_basename[0] |
|
clip.write_videofile(os.path.join(output_dir, video_basename + '_' + video_traj_type + '.mp4'), fps=config['fps']) |
|
|
|
|
|
|
|
return normal_canvas, all_canvas |
|
|