Spaces:
Sleeping
Sleeping
import os | |
from .. import constants as const | |
import pickle | |
import pickle5 | |
from shutil import copyfile, move | |
from ..custom_types import * | |
from PIL import Image | |
import time | |
import json | |
import matplotlib.pyplot as plt | |
import sys | |
from ..constants import PROJECT_ROOT | |
if PROJECT_ROOT not in sys.path: | |
sys.path.append(PROJECT_ROOT) | |
# sys.path.append("/home/juil/projects/3D_CRISPR/spaghetti_github") | |
def image_to_display(img) -> ARRAY: | |
if type(img) is str: | |
img = Image.open(str(img)) | |
if type(img) is not V: | |
img = V(img) | |
return img | |
def imshow(img, title: Optional[str] = None): | |
img = image_to_display(img) | |
plt.imshow(img) | |
plt.axis("off") | |
if title is not None: | |
plt.title(title) | |
plt.show() | |
plt.close('all') | |
def load_image(path: str, color_type: str = 'RGB') -> ARRAY: | |
for suffix in ('.png', '.jpg'): | |
path_ = add_suffix(path, suffix) | |
if os.path.isfile(path_): | |
path = path_ | |
break | |
image = Image.open(path).convert(color_type) | |
return V(image) | |
def save_image(image: Union[ARRAY, Image.Image], path: str): | |
if type(image) is ARRAY: | |
if image.shape[-1] == 1: | |
image = image[:, :, 0] | |
image = Image.fromarray(image) | |
init_folders(path) | |
image.save(path) | |
def split_path(path: str) -> List[str]: | |
extension = os.path.splitext(path)[1] | |
dir_name, name = os.path.split(path) | |
name = name[: len(name) - len(extension)] | |
return [dir_name, name, extension] | |
def init_folders(*folders): | |
if const.DEBUG: | |
return | |
for f in folders: | |
dir_name = os.path.dirname(f) | |
if dir_name and not os.path.exists(dir_name): | |
os.makedirs(dir_name) | |
def is_file(path: str): | |
return os.path.isfile(path) | |
def add_suffix(path: str, suffix: str) -> str: | |
if len(path) < len(suffix) or path[-len(suffix):] != suffix: | |
path = f'{path}{suffix}' | |
return path | |
def remove_suffix(path: str, suffix: str) -> str: | |
if len(path) > len(suffix) and path[-len(suffix):] == suffix: | |
path = path[:-len(suffix)] | |
return path | |
def path_init(suffix: str, path_arg_ind: int, is_save: bool): | |
def wrapper(func): | |
def do(*args, **kwargs): | |
path = add_suffix(args[path_arg_ind], suffix) | |
if is_save: | |
init_folders(path) | |
args = [args[i] if i != path_arg_ind else path for i in range(len(args))] | |
return func(*args, **kwargs) | |
return do | |
return wrapper | |
def copy_file(src: str, dest: str, force=False): | |
if const.DEBUG: | |
return | |
if os.path.isfile(src): | |
if force or not os.path.isfile(dest): | |
copyfile(src, dest) | |
return True | |
else: | |
print("Destination file already exist. To override, set force=True") | |
return False | |
def load_image(path: str, color_type: str = 'RGB') -> ARRAY: | |
for suffix in ('.png', '.jpg'): | |
path_ = add_suffix(path, suffix) | |
if os.path.isfile(path_): | |
path = path_ | |
break | |
image = Image.open(path).convert(color_type) | |
return V(image) | |
def save_image(image: ARRAY, path: str): | |
if type(image) is ARRAY: | |
if image.shape[-1] == 1: | |
image = image[:, :, 0] | |
image = Image.fromarray(image) | |
image.save(path) | |
def save_np(arr_or_dict: Union[ARRAY, T, dict], path: str): | |
if const.DEBUG: | |
return | |
init_folders(path) | |
if type(arr_or_dict) is dict: | |
path = add_suffix(path, '.npz') | |
np.savez_compressed(path, **arr_or_dict) | |
else: | |
if type(arr_or_dict) is T: | |
arr_or_dict = arr_or_dict.detach().cpu().numpy() | |
path = remove_suffix(path, '.npy') | |
np.save(path, arr_or_dict) | |
def load_np(path: str): | |
return np.load(path) | |
def load_pickle(path: str): | |
data = None | |
if os.path.isfile(path): | |
try: | |
with open(path, 'rb') as f: | |
data = pickle.load(f) | |
except ValueError: | |
with open(path, 'rb') as f: | |
data = pickle5.load(f) | |
return data | |
def save_pickle(obj, path: str): | |
if const.DEBUG: | |
return | |
with open(path, 'wb') as f: | |
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) | |
def load_txt_labels(path: str) -> VN: | |
for suffix in ('.txt', '.seg'): | |
path_ = add_suffix(path, suffix) | |
if os.path.isfile(path_): | |
return np.loadtxt(path_, dtype=np.int64) - 1 | |
return None | |
def load_txt(path: str) -> List[str]: | |
data = [] | |
if os.path.isfile(path): | |
with open(path, 'r') as f: | |
for line in f: | |
data.append(line.strip()) | |
return data | |
# def load_points(path: str) -> T: | |
# path = add_suffix(path, '.pts') | |
# points = [int_b(num) for num in load_txt(path)] | |
# return torch.tensor(points, dtype=torch.int64) | |
def save_txt(array, path: str): | |
if const.DEBUG: | |
return | |
path_ = add_suffix(path, '.txt') | |
with open(path_, 'w') as f: | |
for i, num in enumerate(array): | |
f.write(f'{num}{" " if i < len(array) - 1 else ""}') | |
def move_file(src: str, dest: str): | |
if const.DEBUG: | |
return | |
if os.path.isfile(src): | |
move(src, dest) | |
return True | |
return False | |
def save_json(obj, path: str): | |
with open(path, 'w') as f: | |
json.dump(obj, f, indent=4) | |
def collect(root: str, *suffix, prefix='') -> List[List[str]]: | |
if os.path.isfile(root): | |
folder = os.path.split(root)[0] + '/' | |
extension = os.path.splitext(root)[-1] | |
name = root[len(folder): -len(extension)] | |
paths = [[folder, name, extension]] | |
else: | |
paths = [] | |
root = add_suffix(root, '/') | |
if not os.path.isdir(root): | |
print(f'Warning: trying to collect from {root} but dir isn\'t exist') | |
else: | |
p_len = len(prefix) | |
for path, _, files in os.walk(root): | |
for file in files: | |
file_name, file_extension = os.path.splitext(file) | |
p_len_ = min(p_len, len(file_name)) | |
if file_extension in suffix and file_name[:p_len_] == prefix: | |
paths.append((f'{add_suffix(path, "/")}', file_name, file_extension)) | |
paths.sort(key=lambda x: os.path.join(x[1], x[2])) | |
return paths | |
def delete_all(root:str, *suffix: str): | |
if const.DEBUG: | |
return | |
paths = collect(root, *suffix) | |
for path in paths: | |
os.remove(''.join(path)) | |
def delete_single(path: str) -> bool: | |
if os.path.isfile(path): | |
os.remove(path) | |
return True | |
return False | |
def colors_to_colors(colors: COLORS, mesh: T_Mesh) -> T: | |
if type(colors) is not T: | |
if type(colors) is V: | |
colors = torch.from_numpy(colors).long() | |
else: | |
colors = torch.tensor(colors, dtype=torch.int64) | |
if colors.max() > 1: | |
colors = colors.float() / 255 | |
if colors.dim() == 1: | |
colors = colors.unsqueeze(int(colors.shape[0] != 3)).expand_as(mesh[0]) | |
return colors | |
def load_mesh(file_name: str, dtype: Union[type(T), type(V)] = T, | |
device: D = CPU) -> Union[T_Mesh, V_Mesh, T, Tuple[T, List[List[int]]]]: | |
def off_parser(): | |
header = None | |
def parser_(clean_line: list): | |
nonlocal header | |
if not clean_line: | |
return False | |
if len(clean_line) == 3 and not header: | |
header = True | |
elif len(clean_line) == 3: | |
return 0, 0, float | |
elif len(clean_line) > 3: | |
return 1, -int(clean_line[0]), int | |
return parser_ | |
def obj_parser(clean_line: list): | |
nonlocal is_quad | |
if not clean_line: | |
return False | |
elif clean_line[0] == 'v': | |
return 0, 1, float | |
elif clean_line[0] == 'f': | |
is_quad = is_quad or len(clean_line) != 4 | |
return 1, 1, int | |
return False | |
def fetch(lst: list, idx: int, dtype: type): | |
uv_vs_ids = None | |
if '/' in lst[idx]: | |
lst = [item.split('/') for item in lst[idx:]] | |
lst = [item[0] for item in lst] | |
idx = 0 | |
face_vs_ids = [dtype(c.split('/')[0]) for c in lst[idx:]] | |
if dtype is float and len(face_vs_ids) > 3: | |
face_vs_ids = face_vs_ids[:3] | |
return face_vs_ids, uv_vs_ids | |
def load_from_txt(parser) -> TS: | |
mesh_ = [[], []] | |
with open(file_name, 'r') as f: | |
for line in f: | |
clean_line = line.strip().split() | |
info = parser(clean_line) | |
if not info: | |
continue | |
data = fetch(clean_line, info[1], info[2]) | |
mesh_[info[0]].append(data[0]) | |
if is_quad: | |
faces = mesh_[1] | |
for face in faces: | |
for i in range(len(face)): | |
face[i] -= 1 | |
else: | |
faces = torch.tensor(mesh_[1], dtype=torch.int64) | |
if len(faces) > 0 and faces.min() != 0: | |
faces -= 1 | |
mesh_ = torch.tensor(mesh_[0], dtype=torch.float32), faces | |
return mesh_ | |
for suffix in ['.obj', '.off', '.ply']: | |
file_name_tmp = add_suffix(file_name, suffix) | |
if os.path.isfile(file_name_tmp): | |
file_name = file_name_tmp | |
break | |
is_quad = False | |
name, extension = os.path.splitext(file_name) | |
if extension == '.obj': | |
mesh = load_from_txt(obj_parser) | |
elif extension == '.off': | |
mesh = load_from_txt(off_parser()) | |
elif extension == '.ply': | |
mesh = load_ply(file_name) | |
else: | |
raise ValueError(f'mesh file {file_name} is not exist or not supported') | |
if type(mesh[1]) is T and not ((mesh[1] >= 0) * (mesh[1] < mesh[0].shape[0])).all(): | |
print(f"err: {file_name}") | |
assert type(mesh[1]) is not T or ((mesh[1] >= 0) * (mesh[1] < mesh[0].shape[0])).all() | |
if dtype is V: | |
mesh = mesh[0].numpy(), mesh[1].numpy() | |
elif device != CPU: | |
mesh = mesh[0].to(device), mesh[1].to(device) | |
if len(mesh[1]) == 0 and len(mesh[0]) > 0: | |
return mesh[0] | |
return mesh | |
def export_xyz(pc: T, path: str, normals: Optional[T] = None): | |
pc = pc.tolist() | |
if normals is not None: | |
normals = normals.tolist() | |
with open(path, 'w') as f: | |
for i in range(len(pc)): | |
x, y, z = pc[i] | |
f.write(f'{x} {y} {z}') | |
if normals is not None: | |
x, y, z = normals[i] | |
f.write(f' {x} {y} {z}') | |
if i < len(pc) - 1: | |
f.write('\n') | |
def export_gmm(gmm: TS, item: int, file_name: str, included: Optional[List[int]] = None): | |
if included is None: | |
included = [1] * gmm[0].shape[2] | |
mu, p, phi, eigen = [tensor[item, 0].flatten().cpu() for tensor in gmm] | |
# phi = phi.softmax(0) | |
with open(file_name, 'w') as f: | |
for tensor in (phi, mu, eigen, p): | |
tensor_str = [f'{number:.5f}' for number in tensor.tolist()] | |
f.write(f"{' '.join(tensor_str)}\n") | |
list_str = [f'{number:d}' for number in included] | |
f.write(f"{' '.join(list_str)}\n") | |
def load_gmm(path, as_np: bool = False, device: D = CPU): | |
parsed = [] | |
with open(path, 'r') as f: | |
lines = [line.strip() for line in f] | |
for i, line in enumerate(lines): | |
line = line.split(" ") | |
arr = [float(item) for item in line] | |
if as_np: | |
arr = V(arr) | |
else: | |
arr = torch.tensor(arr, device=device) | |
if 0 < i < 3: | |
arr = arr.reshape((-1, 3)) | |
# swap = arr[:, 2].copy() | |
# arr[:, 2] = arr[:, 1] | |
# arr[:, 1] = swap | |
elif i == 3: | |
arr = arr.reshape((-1, 3, 3)) | |
# arr = arr.transpose(0, 2, 1) | |
elif i == 4: | |
if as_np: | |
arr = arr.astype(np.bool_) | |
else: | |
arr = arr.bool() | |
parsed.append(arr) | |
return parsed | |
def export_list(lst: List[Any], path: str): | |
with open(path, "w") as f: | |
for i in range(len(lst)): | |
f.write(f'{lst[i]}\n') | |
def export_mesh(mesh: Union[V_Mesh, T_Mesh, T, Tuple[T, List[List[int]]]], file_name: str, | |
colors: Optional[COLORS] = None, normals: TN = None, edges=None, spheres=None): | |
# return | |
if type(mesh) is not tuple and type(mesh) is not list: | |
mesh = mesh, None | |
vs, faces = mesh | |
if vs.shape[1] < 3: | |
vs = torch.cat((vs, torch.zeros(len(vs), 3 - vs.shape[1], device=vs.device)), dim=1) | |
if colors is not None: | |
colors = colors_to_colors(colors, mesh) | |
if not os.path.isdir(os.path.dirname(file_name)): | |
return | |
if faces is not None: | |
if type(faces) is T: | |
faces: T = faces + 1 | |
faces_lst = faces.tolist() | |
else: | |
faces_lst_: List[List[int]] = faces | |
faces_lst = [] | |
for face in faces_lst_: | |
faces_lst.append([face[i] + 1 for i in range(len(face))]) | |
with open(file_name, 'w') as f: | |
for vi, v in enumerate(vs): | |
if colors is None or colors[vi, 0] < 0: | |
v_color = '' | |
else: | |
v_color = ' %f %f %f' % (colors[vi, 0].item(), colors[vi, 1].item(), colors[vi, 2].item()) | |
f.write("v %f %f %f%s\n" % (v[0], v[1], v[2], v_color)) | |
if normals is not None: | |
for n in normals: | |
f.write("vn %f %f %f\n" % (n[0], n[1], n[2])) | |
if faces is not None: | |
for face in faces_lst: | |
face = [str(f) for f in face] | |
f.write(f'f {" ".join(face)}\n') | |
if edges is not None: | |
for edges_id in range(edges.shape[0]): | |
f.write(f'\ne {edges[edges_id][0].item():d} {edges[edges_id][1].item():d}') | |
if spheres is not None: | |
for sphere_id in range(spheres.shape[0]): | |
f.write(f'\nsp {spheres[sphere_id].item():d}') | |
def export_ply(mesh: T_Mesh, path: str, colors: T): | |
colors = colors_to_colors(colors, mesh) | |
colors = (colors * 255).long() | |
vs, faces = mesh | |
vs = vs.clone() | |
swap = vs[:, 1].clone() | |
vs[:, 1] = vs[:, 2] | |
vs[:, 2] = swap | |
min_cor, max_cor= vs.min(0)[0], vs.max(0)[0] | |
vs = vs - ((min_cor + max_cor) / 2)[None, :] | |
vs = vs / vs.max() | |
vs[:, 2] = vs[:, 2] - vs[:, 2].min() | |
num_vs = vs.shape[0] | |
num_faces = faces.shape[0] | |
with open(path, 'w') as f: | |
f.write(f'ply\nformat ascii 1.0\n' | |
f'element vertex {num_vs:d}\nproperty float x\nproperty float y\nproperty float z\n' | |
f'property uchar red\nproperty uchar green\nproperty uchar blue\n' | |
f'element face {num_faces:d}\nproperty list uchar int vertex_indices\nend_header\n') | |
for vi, v in enumerate(vs): | |
color = f'{colors[vi, 0].item():d} {colors[vi, 1].item():d} {colors[vi, 2].item():d}' | |
f.write(f'{v[0].item():f} {v[1].item():f} {v[2].item():f} {color}\n') | |
for face in faces: | |
f.write(f'3 {face[0].item():d} {face[1].item():d} {face[2].item():d}\n') | |
def load_ply(path: str): | |
import plyfile | |
plydata = plyfile.PlyData.read(path) | |
vertices = plydata.elements[0].data | |
vertices = [[float(item[0]), float(item[1]), float(item[2])] for item in vertices] | |
vertices = torch.tensor(vertices) | |
faces = plydata.elements[1].data | |
faces = [[int(item[0][0]), int(item[0][1]), int(item[0][2])] for item in faces] | |
faces = torch.tensor(faces) | |
return vertices, faces | |
def save_model(model: Union[Optimizer, nn.Module], model_path: str): | |
if const.DEBUG: | |
return | |
init_folders(model_path) | |
torch.save(model.state_dict(), model_path) | |
def load_model(model: Union[Optimizer, nn.Module], model_path: str, device: D, verbose: bool = False): | |
if os.path.isfile(model_path): | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
if verbose: | |
print(f'loading {type(model).__name__} from {model_path}') | |
elif verbose: | |
print(f'init {type(model).__name__}') | |
return model | |
def measure_time(func, num_iters: int, *args): | |
start_time = time.time() | |
for i in range(num_iters): | |
func(*args) | |
total_time = time.time() - start_time | |
avg_time = total_time / num_iters | |
print(f"{str(func).split()[1].split('.')[-1]} total time: {total_time}, average time: {avg_time}") | |
def get_time_name(name: str, format_="%m_%d-%H_%M") -> str: | |
return f'{name}_{time.strftime(format_)}' | |
def load_shapenet_seg(path: str) -> TS: | |
labels, vs = [], [] | |
with open(path, 'r') as f: | |
for line in f: | |
data = line.strip().split() | |
vs.append([float(item) for item in data[:3]]) | |
labels.append(int(data[-1].split('.')[0])) | |
return torch.tensor(vs, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64) | |
def load_json(path: str): | |
with open(path, 'r') as f: | |
data = json.load(f) | |
return data | |