DveloperY0115's picture
init repo
801501a
import os
from .. import constants as const
import pickle
import pickle5
from shutil import copyfile, move
from ..custom_types import *
from PIL import Image
import time
import json
import matplotlib.pyplot as plt
import sys
from ..constants import PROJECT_ROOT
if PROJECT_ROOT not in sys.path:
sys.path.append(PROJECT_ROOT)
# sys.path.append("/home/juil/projects/3D_CRISPR/spaghetti_github")
def image_to_display(img) -> ARRAY:
if type(img) is str:
img = Image.open(str(img))
if type(img) is not V:
img = V(img)
return img
def imshow(img, title: Optional[str] = None):
img = image_to_display(img)
plt.imshow(img)
plt.axis("off")
if title is not None:
plt.title(title)
plt.show()
plt.close('all')
def load_image(path: str, color_type: str = 'RGB') -> ARRAY:
for suffix in ('.png', '.jpg'):
path_ = add_suffix(path, suffix)
if os.path.isfile(path_):
path = path_
break
image = Image.open(path).convert(color_type)
return V(image)
def save_image(image: Union[ARRAY, Image.Image], path: str):
if type(image) is ARRAY:
if image.shape[-1] == 1:
image = image[:, :, 0]
image = Image.fromarray(image)
init_folders(path)
image.save(path)
def split_path(path: str) -> List[str]:
extension = os.path.splitext(path)[1]
dir_name, name = os.path.split(path)
name = name[: len(name) - len(extension)]
return [dir_name, name, extension]
def init_folders(*folders):
if const.DEBUG:
return
for f in folders:
dir_name = os.path.dirname(f)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
def is_file(path: str):
return os.path.isfile(path)
def add_suffix(path: str, suffix: str) -> str:
if len(path) < len(suffix) or path[-len(suffix):] != suffix:
path = f'{path}{suffix}'
return path
def remove_suffix(path: str, suffix: str) -> str:
if len(path) > len(suffix) and path[-len(suffix):] == suffix:
path = path[:-len(suffix)]
return path
def path_init(suffix: str, path_arg_ind: int, is_save: bool):
def wrapper(func):
def do(*args, **kwargs):
path = add_suffix(args[path_arg_ind], suffix)
if is_save:
init_folders(path)
args = [args[i] if i != path_arg_ind else path for i in range(len(args))]
return func(*args, **kwargs)
return do
return wrapper
def copy_file(src: str, dest: str, force=False):
if const.DEBUG:
return
if os.path.isfile(src):
if force or not os.path.isfile(dest):
copyfile(src, dest)
return True
else:
print("Destination file already exist. To override, set force=True")
return False
def load_image(path: str, color_type: str = 'RGB') -> ARRAY:
for suffix in ('.png', '.jpg'):
path_ = add_suffix(path, suffix)
if os.path.isfile(path_):
path = path_
break
image = Image.open(path).convert(color_type)
return V(image)
@path_init('.png', 1, True)
def save_image(image: ARRAY, path: str):
if type(image) is ARRAY:
if image.shape[-1] == 1:
image = image[:, :, 0]
image = Image.fromarray(image)
image.save(path)
def save_np(arr_or_dict: Union[ARRAY, T, dict], path: str):
if const.DEBUG:
return
init_folders(path)
if type(arr_or_dict) is dict:
path = add_suffix(path, '.npz')
np.savez_compressed(path, **arr_or_dict)
else:
if type(arr_or_dict) is T:
arr_or_dict = arr_or_dict.detach().cpu().numpy()
path = remove_suffix(path, '.npy')
np.save(path, arr_or_dict)
@path_init('.npy', 0, False)
def load_np(path: str):
return np.load(path)
@path_init('.pkl', 0, False)
def load_pickle(path: str):
data = None
if os.path.isfile(path):
try:
with open(path, 'rb') as f:
data = pickle.load(f)
except ValueError:
with open(path, 'rb') as f:
data = pickle5.load(f)
return data
@path_init('.pkl', 1, True)
def save_pickle(obj, path: str):
if const.DEBUG:
return
with open(path, 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_txt_labels(path: str) -> VN:
for suffix in ('.txt', '.seg'):
path_ = add_suffix(path, suffix)
if os.path.isfile(path_):
return np.loadtxt(path_, dtype=np.int64) - 1
return None
@path_init('.txt', 0, False)
def load_txt(path: str) -> List[str]:
data = []
if os.path.isfile(path):
with open(path, 'r') as f:
for line in f:
data.append(line.strip())
return data
# def load_points(path: str) -> T:
# path = add_suffix(path, '.pts')
# points = [int_b(num) for num in load_txt(path)]
# return torch.tensor(points, dtype=torch.int64)
def save_txt(array, path: str):
if const.DEBUG:
return
path_ = add_suffix(path, '.txt')
with open(path_, 'w') as f:
for i, num in enumerate(array):
f.write(f'{num}{" " if i < len(array) - 1 else ""}')
def move_file(src: str, dest: str):
if const.DEBUG:
return
if os.path.isfile(src):
move(src, dest)
return True
return False
@path_init('.json', 1, True)
def save_json(obj, path: str):
with open(path, 'w') as f:
json.dump(obj, f, indent=4)
def collect(root: str, *suffix, prefix='') -> List[List[str]]:
if os.path.isfile(root):
folder = os.path.split(root)[0] + '/'
extension = os.path.splitext(root)[-1]
name = root[len(folder): -len(extension)]
paths = [[folder, name, extension]]
else:
paths = []
root = add_suffix(root, '/')
if not os.path.isdir(root):
print(f'Warning: trying to collect from {root} but dir isn\'t exist')
else:
p_len = len(prefix)
for path, _, files in os.walk(root):
for file in files:
file_name, file_extension = os.path.splitext(file)
p_len_ = min(p_len, len(file_name))
if file_extension in suffix and file_name[:p_len_] == prefix:
paths.append((f'{add_suffix(path, "/")}', file_name, file_extension))
paths.sort(key=lambda x: os.path.join(x[1], x[2]))
return paths
def delete_all(root:str, *suffix: str):
if const.DEBUG:
return
paths = collect(root, *suffix)
for path in paths:
os.remove(''.join(path))
def delete_single(path: str) -> bool:
if os.path.isfile(path):
os.remove(path)
return True
return False
def colors_to_colors(colors: COLORS, mesh: T_Mesh) -> T:
if type(colors) is not T:
if type(colors) is V:
colors = torch.from_numpy(colors).long()
else:
colors = torch.tensor(colors, dtype=torch.int64)
if colors.max() > 1:
colors = colors.float() / 255
if colors.dim() == 1:
colors = colors.unsqueeze(int(colors.shape[0] != 3)).expand_as(mesh[0])
return colors
def load_mesh(file_name: str, dtype: Union[type(T), type(V)] = T,
device: D = CPU) -> Union[T_Mesh, V_Mesh, T, Tuple[T, List[List[int]]]]:
def off_parser():
header = None
def parser_(clean_line: list):
nonlocal header
if not clean_line:
return False
if len(clean_line) == 3 and not header:
header = True
elif len(clean_line) == 3:
return 0, 0, float
elif len(clean_line) > 3:
return 1, -int(clean_line[0]), int
return parser_
def obj_parser(clean_line: list):
nonlocal is_quad
if not clean_line:
return False
elif clean_line[0] == 'v':
return 0, 1, float
elif clean_line[0] == 'f':
is_quad = is_quad or len(clean_line) != 4
return 1, 1, int
return False
def fetch(lst: list, idx: int, dtype: type):
uv_vs_ids = None
if '/' in lst[idx]:
lst = [item.split('/') for item in lst[idx:]]
lst = [item[0] for item in lst]
idx = 0
face_vs_ids = [dtype(c.split('/')[0]) for c in lst[idx:]]
if dtype is float and len(face_vs_ids) > 3:
face_vs_ids = face_vs_ids[:3]
return face_vs_ids, uv_vs_ids
def load_from_txt(parser) -> TS:
mesh_ = [[], []]
with open(file_name, 'r') as f:
for line in f:
clean_line = line.strip().split()
info = parser(clean_line)
if not info:
continue
data = fetch(clean_line, info[1], info[2])
mesh_[info[0]].append(data[0])
if is_quad:
faces = mesh_[1]
for face in faces:
for i in range(len(face)):
face[i] -= 1
else:
faces = torch.tensor(mesh_[1], dtype=torch.int64)
if len(faces) > 0 and faces.min() != 0:
faces -= 1
mesh_ = torch.tensor(mesh_[0], dtype=torch.float32), faces
return mesh_
for suffix in ['.obj', '.off', '.ply']:
file_name_tmp = add_suffix(file_name, suffix)
if os.path.isfile(file_name_tmp):
file_name = file_name_tmp
break
is_quad = False
name, extension = os.path.splitext(file_name)
if extension == '.obj':
mesh = load_from_txt(obj_parser)
elif extension == '.off':
mesh = load_from_txt(off_parser())
elif extension == '.ply':
mesh = load_ply(file_name)
else:
raise ValueError(f'mesh file {file_name} is not exist or not supported')
if type(mesh[1]) is T and not ((mesh[1] >= 0) * (mesh[1] < mesh[0].shape[0])).all():
print(f"err: {file_name}")
assert type(mesh[1]) is not T or ((mesh[1] >= 0) * (mesh[1] < mesh[0].shape[0])).all()
if dtype is V:
mesh = mesh[0].numpy(), mesh[1].numpy()
elif device != CPU:
mesh = mesh[0].to(device), mesh[1].to(device)
if len(mesh[1]) == 0 and len(mesh[0]) > 0:
return mesh[0]
return mesh
@path_init('.xyz', 1, True)
def export_xyz(pc: T, path: str, normals: Optional[T] = None):
pc = pc.tolist()
if normals is not None:
normals = normals.tolist()
with open(path, 'w') as f:
for i in range(len(pc)):
x, y, z = pc[i]
f.write(f'{x} {y} {z}')
if normals is not None:
x, y, z = normals[i]
f.write(f' {x} {y} {z}')
if i < len(pc) - 1:
f.write('\n')
@path_init('.txt', 2, True)
def export_gmm(gmm: TS, item: int, file_name: str, included: Optional[List[int]] = None):
if included is None:
included = [1] * gmm[0].shape[2]
mu, p, phi, eigen = [tensor[item, 0].flatten().cpu() for tensor in gmm]
# phi = phi.softmax(0)
with open(file_name, 'w') as f:
for tensor in (phi, mu, eigen, p):
tensor_str = [f'{number:.5f}' for number in tensor.tolist()]
f.write(f"{' '.join(tensor_str)}\n")
list_str = [f'{number:d}' for number in included]
f.write(f"{' '.join(list_str)}\n")
@path_init('.txt', 0, False)
def load_gmm(path, as_np: bool = False, device: D = CPU):
parsed = []
with open(path, 'r') as f:
lines = [line.strip() for line in f]
for i, line in enumerate(lines):
line = line.split(" ")
arr = [float(item) for item in line]
if as_np:
arr = V(arr)
else:
arr = torch.tensor(arr, device=device)
if 0 < i < 3:
arr = arr.reshape((-1, 3))
# swap = arr[:, 2].copy()
# arr[:, 2] = arr[:, 1]
# arr[:, 1] = swap
elif i == 3:
arr = arr.reshape((-1, 3, 3))
# arr = arr.transpose(0, 2, 1)
elif i == 4:
if as_np:
arr = arr.astype(np.bool_)
else:
arr = arr.bool()
parsed.append(arr)
return parsed
@path_init('.txt', 1, True)
def export_list(lst: List[Any], path: str):
with open(path, "w") as f:
for i in range(len(lst)):
f.write(f'{lst[i]}\n')
@path_init('.obj', 1, True)
def export_mesh(mesh: Union[V_Mesh, T_Mesh, T, Tuple[T, List[List[int]]]], file_name: str,
colors: Optional[COLORS] = None, normals: TN = None, edges=None, spheres=None):
# return
if type(mesh) is not tuple and type(mesh) is not list:
mesh = mesh, None
vs, faces = mesh
if vs.shape[1] < 3:
vs = torch.cat((vs, torch.zeros(len(vs), 3 - vs.shape[1], device=vs.device)), dim=1)
if colors is not None:
colors = colors_to_colors(colors, mesh)
if not os.path.isdir(os.path.dirname(file_name)):
return
if faces is not None:
if type(faces) is T:
faces: T = faces + 1
faces_lst = faces.tolist()
else:
faces_lst_: List[List[int]] = faces
faces_lst = []
for face in faces_lst_:
faces_lst.append([face[i] + 1 for i in range(len(face))])
with open(file_name, 'w') as f:
for vi, v in enumerate(vs):
if colors is None or colors[vi, 0] < 0:
v_color = ''
else:
v_color = ' %f %f %f' % (colors[vi, 0].item(), colors[vi, 1].item(), colors[vi, 2].item())
f.write("v %f %f %f%s\n" % (v[0], v[1], v[2], v_color))
if normals is not None:
for n in normals:
f.write("vn %f %f %f\n" % (n[0], n[1], n[2]))
if faces is not None:
for face in faces_lst:
face = [str(f) for f in face]
f.write(f'f {" ".join(face)}\n')
if edges is not None:
for edges_id in range(edges.shape[0]):
f.write(f'\ne {edges[edges_id][0].item():d} {edges[edges_id][1].item():d}')
if spheres is not None:
for sphere_id in range(spheres.shape[0]):
f.write(f'\nsp {spheres[sphere_id].item():d}')
@path_init('.ply', 1, True)
def export_ply(mesh: T_Mesh, path: str, colors: T):
colors = colors_to_colors(colors, mesh)
colors = (colors * 255).long()
vs, faces = mesh
vs = vs.clone()
swap = vs[:, 1].clone()
vs[:, 1] = vs[:, 2]
vs[:, 2] = swap
min_cor, max_cor= vs.min(0)[0], vs.max(0)[0]
vs = vs - ((min_cor + max_cor) / 2)[None, :]
vs = vs / vs.max()
vs[:, 2] = vs[:, 2] - vs[:, 2].min()
num_vs = vs.shape[0]
num_faces = faces.shape[0]
with open(path, 'w') as f:
f.write(f'ply\nformat ascii 1.0\n'
f'element vertex {num_vs:d}\nproperty float x\nproperty float y\nproperty float z\n'
f'property uchar red\nproperty uchar green\nproperty uchar blue\n'
f'element face {num_faces:d}\nproperty list uchar int vertex_indices\nend_header\n')
for vi, v in enumerate(vs):
color = f'{colors[vi, 0].item():d} {colors[vi, 1].item():d} {colors[vi, 2].item():d}'
f.write(f'{v[0].item():f} {v[1].item():f} {v[2].item():f} {color}\n')
for face in faces:
f.write(f'3 {face[0].item():d} {face[1].item():d} {face[2].item():d}\n')
@path_init('.ply', 0, False)
def load_ply(path: str):
import plyfile
plydata = plyfile.PlyData.read(path)
vertices = plydata.elements[0].data
vertices = [[float(item[0]), float(item[1]), float(item[2])] for item in vertices]
vertices = torch.tensor(vertices)
faces = plydata.elements[1].data
faces = [[int(item[0][0]), int(item[0][1]), int(item[0][2])] for item in faces]
faces = torch.tensor(faces)
return vertices, faces
@path_init('', 1, True)
def save_model(model: Union[Optimizer, nn.Module], model_path: str):
if const.DEBUG:
return
init_folders(model_path)
torch.save(model.state_dict(), model_path)
def load_model(model: Union[Optimizer, nn.Module], model_path: str, device: D, verbose: bool = False):
if os.path.isfile(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
if verbose:
print(f'loading {type(model).__name__} from {model_path}')
elif verbose:
print(f'init {type(model).__name__}')
return model
def measure_time(func, num_iters: int, *args):
start_time = time.time()
for i in range(num_iters):
func(*args)
total_time = time.time() - start_time
avg_time = total_time / num_iters
print(f"{str(func).split()[1].split('.')[-1]} total time: {total_time}, average time: {avg_time}")
def get_time_name(name: str, format_="%m_%d-%H_%M") -> str:
return f'{name}_{time.strftime(format_)}'
@path_init('.txt', 0, False)
def load_shapenet_seg(path: str) -> TS:
labels, vs = [], []
with open(path, 'r') as f:
for line in f:
data = line.strip().split()
vs.append([float(item) for item in data[:3]])
labels.append(int(data[-1].split('.')[0]))
return torch.tensor(vs, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)
@path_init('.json', 0, False)
def load_json(path: str):
with open(path, 'r') as f:
data = json.load(f)
return data