MuseVSpace / MuseV /musev /data /data_util.py
anchorxia's picture
add musev
96d7ad8
from typing import List, Dict, Literal, Union, Tuple
import os
import string
import logging
import torch
import numpy as np
from einops import rearrange, repeat
logger = logging.getLogger(__name__)
def generate_tasks_of_dir(
path: str,
output_dir: str,
exts: Tuple[str],
same_dir_name: bool = False,
**kwargs,
) -> List[Dict]:
"""covert video directory into tasks
Args:
path (str): _description_
output_dir (str): _description_
exts (Tuple[str]): _description_
same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False.
whether keep the same parent dir name as the source video
Returns:
List[Dict]: _description_
"""
tasks = []
for rootdir, dirs, files in os.walk(path):
for basename in files:
if basename.lower().endswith(exts):
video_path = os.path.join(rootdir, basename)
filename, ext = basename.split(".")
rootdir_name = os.path.basename(rootdir)
if same_dir_name:
save_path = os.path.join(
output_dir, rootdir_name, f"{filename}.h5py"
)
save_dir = os.path.join(output_dir, rootdir_name)
else:
save_path = os.path.join(output_dir, f"{filename}.h5py")
save_dir = output_dir
task = {
"video_path": video_path,
"output_path": save_path,
"output_dir": save_dir,
"filename": filename,
"ext": ext,
}
task.update(kwargs)
tasks.append(task)
return tasks
def sample_by_idx(
T: int,
n_sample: int,
sample_rate: int,
sample_start_idx: int = None,
change_sample_rate: bool = False,
seed: int = None,
whether_random: bool = True,
n_independent: int = 0,
) -> List[int]:
"""given a int to represent candidate list, sample n_sample with sample_rate from the candidate list
Args:
T (int): _description_
n_sample (int): 目标采样数目. sample number
sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number
sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0.
change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False.
whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False.
Raises:
ValueError: T / sample_rate should be larger than n_sample
Returns:
List[int]: 采样的索引位置. sampled index position
"""
if T < n_sample:
raise ValueError(f"T({T}) < n_sample({n_sample})")
else:
if T / sample_rate < n_sample:
if not change_sample_rate:
raise ValueError(
f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})"
)
else:
while T / sample_rate < n_sample:
sample_rate -= 1
logger.error(
f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}"
)
if sample_rate == 0:
raise ValueError("T / sample_rate < n_sample")
if sample_start_idx is None:
if whether_random:
sample_start_idx_candidates = np.arange(T - n_sample * sample_rate)
if seed is not None:
np.random.seed(seed)
sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0]
else:
sample_start_idx = 0
sample_end_idx = sample_start_idx + sample_rate * n_sample
sample = list(range(sample_start_idx, sample_end_idx, sample_rate))
if n_independent == 0:
n_independent_sample = None
else:
left_candidate = np.array(
list(range(0, sample_start_idx)) + list(range(sample_end_idx, T))
)
if len(left_candidate) >= n_independent:
# 使用两端的剩余空间采样, use the left space to sample
n_independent_sample = np.random.choice(left_candidate, n_independent)
else:
# 当两端没有剩余采样空间时,使用任意不是sample中的帧
# if no enough space to sample, use any frame not in sample
left_candidate = np.array(list(set(range(T) - set(sample))))
n_independent_sample = np.random.choice(left_candidate, n_independent)
return sample, sample_rate, n_independent_sample
def sample_tensor_by_idx(
tensor: Union[torch.Tensor, np.ndarray],
n_sample: int,
sample_rate: int,
sample_start_idx: int = 0,
change_sample_rate: bool = False,
seed: int = None,
dim: int = 0,
return_type: Literal["numpy", "torch"] = "torch",
whether_random: bool = True,
n_independent: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]:
"""sample sub_tensor
Args:
tensor (Union[torch.Tensor, np.ndarray]): _description_
n_sample (int): _description_
sample_rate (int): _description_
sample_start_idx (int, optional): _description_. Defaults to 0.
change_sample_rate (bool, optional): _description_. Defaults to False.
seed (int, optional): _description_. Defaults to None.
dim (int, optional): _description_. Defaults to 0.
return_type (Literal[&quot;numpy&quot;, &quot;torch&quot;], optional): _description_. Defaults to "torch".
whether_random (bool, optional): _description_. Defaults to True.
n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0.
n_independent sample number that is independent of n_sample
Returns:
Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor
"""
if isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor)
T = tensor.shape[dim]
sample_idx, sample_rate, independent_sample_idx = sample_by_idx(
T,
n_sample,
sample_rate,
sample_start_idx,
change_sample_rate,
seed,
whether_random=whether_random,
n_independent=n_independent,
)
sample_idx = torch.LongTensor(sample_idx)
sample = torch.index_select(tensor, dim, sample_idx)
if independent_sample_idx is not None:
independent_sample_idx = torch.LongTensor(independent_sample_idx)
independent_sample = torch.index_select(tensor, dim, independent_sample_idx)
else:
independent_sample = None
independent_sample_idx = None
if return_type == "numpy":
sample = sample.cpu().numpy()
return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx
def concat_two_tensor(
data1: torch.Tensor,
data2: torch.Tensor,
dim: int,
method: Literal[
"first_in_first_out", "first_in_last_out", "intertwine", "index"
] = "first_in_first_out",
data1_index: torch.long = None,
data2_index: torch.long = None,
return_index: bool = False,
):
"""concat two tensor along dim with given method
Args:
data1 (torch.Tensor): first in data
data2 (torch.Tensor): last in data
dim (int): _description_
method (Literal[ &quot;first_in_first_out&quot;, &quot;first_in_last_out&quot;, &quot;intertwine&quot; ], optional): _description_. Defaults to "first_in_first_out".
Raises:
NotImplementedError: unsupported method
ValueError: unsupported method
Returns:
_type_: _description_
"""
len_data1 = data1.shape[dim]
len_data2 = data2.shape[dim]
if method == "first_in_first_out":
res = torch.concat([data1, data2], dim=dim)
data1_index = range(len_data1)
data2_index = [len_data1 + x for x in range(len_data2)]
elif method == "first_in_last_out":
res = torch.concat([data2, data1], dim=dim)
data2_index = range(len_data2)
data1_index = [len_data2 + x for x in range(len_data1)]
elif method == "intertwine":
raise NotImplementedError("intertwine")
elif method == "index":
res = concat_two_tensor_with_index(
data1=data1,
data1_index=data1_index,
data2=data2,
data2_index=data2_index,
dim=dim,
)
else:
raise ValueError(
"only support first_in_first_out, first_in_last_out, intertwine, index"
)
if return_index:
return res, data1_index, data2_index
else:
return res
def concat_two_tensor_with_index(
data1: torch.Tensor,
data1_index: torch.LongTensor,
data2: torch.Tensor,
data2_index: torch.LongTensor,
dim: int,
) -> torch.Tensor:
"""_summary_
Args:
data1 (torch.Tensor): b1*c1*h1*w1*...
data1_index (torch.LongTensor): N, if dim=1, N=c1
data2 (torch.Tensor): b2*c2*h2*w2*...
data2_index (torch.LongTensor): M, if dim=1, M=c2
dim (int): int
Returns:
torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,...
"""
shape1 = list(data1.shape)
shape2 = list(data2.shape)
target_shape = list(shape1)
target_shape[dim] = shape1[dim] + shape2[dim]
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
target = batch_index_copy(target, dim=dim, index=data1_index, source=data1)
target = batch_index_copy(target, dim=dim, index=data2_index, source=data2)
return target
def repeat_index_to_target_size(
index: torch.LongTensor, target_size: int
) -> torch.LongTensor:
if len(index.shape) == 1:
index = repeat(index, "n -> b n", b=target_size)
if len(index.shape) == 2:
remainder = target_size % index.shape[0]
assert (
remainder == 0
), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}"
index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0]))
return index
def batch_concat_two_tensor_with_index(
data1: torch.Tensor,
data1_index: torch.LongTensor,
data2: torch.Tensor,
data2_index: torch.LongTensor,
dim: int,
) -> torch.Tensor:
return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim)
def interwine_two_tensor(
data1: torch.Tensor,
data2: torch.Tensor,
dim: int,
return_index: bool = False,
) -> torch.Tensor:
shape1 = list(data1.shape)
shape2 = list(data2.shape)
target_shape = list(shape1)
target_shape[dim] = shape1[dim] + shape2[dim]
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype)
data1_reshape = torch.swapaxes(data1, 0, dim)
data2_reshape = torch.swapaxes(data2, 0, dim)
target = torch.swapaxes(target, 0, dim)
total_index = set(range(target_shape[dim]))
data1_index = range(0, 2 * shape1[dim], 2)
data2_index = sorted(list(set(total_index) - set(data1_index)))
data1_index = torch.LongTensor(data1_index)
data2_index = torch.LongTensor(data2_index)
target[data1_index, ...] = data1_reshape
target[data2_index, ...] = data2_reshape
target = torch.swapaxes(target, 0, dim)
if return_index:
return target, data1_index, data2_index
else:
return target
def split_index(
indexs: torch.Tensor,
n_first: int = None,
n_last: int = None,
method: Literal[
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
] = "first_in_first_out",
):
"""_summary_
Args:
indexs (List): _description_
n_first (int): _description_
n_last (int): _description_
method (Literal[ &quot;first_in_first_out&quot;, &quot;first_in_last_out&quot;, &quot;intertwine&quot;, &quot;index&quot; ], optional): _description_. Defaults to "first_in_first_out".
Raises:
NotImplementedError: _description_
Returns:
first_index: _description_
last_index:
"""
# assert (
# n_first is None and n_last is None
# ), "must assign one value for n_first or n_last"
n_total = len(indexs)
if n_first is None:
n_first = n_total - n_last
if n_last is None:
n_last = n_total - n_first
assert len(indexs) == n_first + n_last
if method == "first_in_first_out":
first_index = indexs[:n_first]
last_index = indexs[n_first:]
elif method == "first_in_last_out":
first_index = indexs[n_last:]
last_index = indexs[:n_last]
elif method == "intertwine":
raise NotImplementedError
elif method == "random":
idx_ = torch.randperm(len(indexs))
first_index = indexs[idx_[:n_first]]
last_index = indexs[idx_[n_first:]]
return first_index, last_index
def split_tensor(
tensor: torch.Tensor,
dim: int,
n_first=None,
n_last=None,
method: Literal[
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random"
] = "first_in_first_out",
need_return_index: bool = False,
):
device = tensor.device
total = tensor.shape[dim]
if n_first is None:
n_first = total - n_last
if n_last is None:
n_last = total - n_first
indexs = torch.arange(
total,
dtype=torch.long,
device=device,
)
(
first_index,
last_index,
) = split_index(
indexs=indexs,
n_first=n_first,
method=method,
)
first_tensor = torch.index_select(tensor, dim=dim, index=first_index)
last_tensor = torch.index_select(tensor, dim=dim, index=last_index)
if need_return_index:
return (
first_tensor,
last_tensor,
first_index,
last_index,
)
else:
return (first_tensor, last_tensor)
# TODO: 待确定batch_index_select的优化
def batch_index_select(
tensor: torch.Tensor, index: torch.LongTensor, dim: int
) -> torch.Tensor:
"""_summary_
Args:
tensor (torch.Tensor): D1*D2*D3*D4...
index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim]
dim (int): dim to select
Returns:
torch.Tensor: D1*...*N*...
"""
# TODO: now only support N same for every d1
if len(index.shape) == 1:
return torch.index_select(tensor, dim=dim, index=index)
else:
index = repeat_index_to_target_size(index, tensor.shape[0])
out = []
for i in torch.arange(tensor.shape[0]):
sub_tensor = tensor[i]
sub_index = index[i]
d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index)
out.append(d)
return torch.stack(out).to(dtype=tensor.dtype)
def batch_index_copy(
tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor
) -> torch.Tensor:
"""_summary_
Args:
tensor (torch.Tensor): b*c*h
dim (int):
index (torch.LongTensor): b*d,
source (torch.Tensor):
b*d*h*..., if dim=1
b*c*d*..., if dim=2
Returns:
torch.Tensor: b*c*d*...
"""
if len(index.shape) == 1:
tensor.index_copy_(dim=dim, index=index, source=source)
else:
index = repeat_index_to_target_size(index, tensor.shape[0])
batch_size = tensor.shape[0]
for b in torch.arange(batch_size):
sub_index = index[b]
sub_source = source[b]
sub_tensor = tensor[b]
sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source)
tensor[b] = sub_tensor
return tensor
def batch_index_fill(
tensor: torch.Tensor,
dim: int,
index: torch.LongTensor,
value: Literal[torch.Tensor, torch.float],
) -> torch.Tensor:
"""_summary_
Args:
tensor (torch.Tensor): b*c*h
dim (int):
index (torch.LongTensor): b*d,
value (torch.Tensor): b
Returns:
torch.Tensor: b*c*d*...
"""
index = repeat_index_to_target_size(index, tensor.shape[0])
batch_size = tensor.shape[0]
for b in torch.arange(batch_size):
sub_index = index[b]
sub_value = value[b] if isinstance(value, torch.Tensor) else value
sub_tensor = tensor[b]
sub_tensor.index_fill_(dim - 1, sub_index, sub_value)
tensor[b] = sub_tensor
return tensor
def adaptive_instance_normalization(
src: torch.Tensor,
dst: torch.Tensor,
eps: float = 1e-6,
):
"""
Args:
src (torch.Tensor): b c t h w
dst (torch.Tensor): b c t h w
"""
ndim = src.ndim
if ndim == 5:
dim = (2, 3, 4)
elif ndim == 4:
dim = (2, 3)
elif ndim == 3:
dim = 2
else:
raise ValueError("only support ndim in [3,4,5], but given {ndim}")
var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0)
mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0)
# mean_acc = sum(mean_acc) / float(len(mean_acc))
# var_acc = sum(var_acc) / float(len(var_acc))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
src = (((src - mean) / std) * std_acc) + mean_acc
return src
def adaptive_instance_normalization_with_ref(
src: torch.LongTensor,
dst: torch.LongTensor,
style_fidelity: float = 0.5,
do_classifier_free_guidance: bool = True,
):
# logger.debug(
# f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n"
# f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}"
# )
batch_size = src.shape[0] // 2
uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool()
src_uc = adaptive_instance_normalization(src, dst)
src_c = src_uc.clone()
# TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True
if do_classifier_free_guidance and style_fidelity > 0:
src_c[uc_mask] = src[uc_mask]
src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc
return src
def batch_adain_conditioned_tensor(
tensor: torch.Tensor,
src_index: torch.LongTensor,
dst_index: torch.LongTensor,
keep_dim: bool = True,
num_frames: int = None,
dim: int = 2,
style_fidelity: float = 0.5,
do_classifier_free_guidance: bool = True,
need_style_fidelity: bool = False,
):
"""_summary_
Args:
tensor (torch.Tensor): b c t h w
src_index (torch.LongTensor): _description_
dst_index (torch.LongTensor): _description_
keep_dim (bool, optional): _description_. Defaults to True.
Returns:
_type_: _description_
"""
ndim = tensor.ndim
dtype = tensor.dtype
if ndim == 4 and num_frames is not None:
tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames)
src = batch_index_select(tensor, dim=dim, index=src_index).contiguous()
dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous()
if need_style_fidelity:
src = adaptive_instance_normalization_with_ref(
src=src,
dst=dst,
style_fidelity=style_fidelity,
do_classifier_free_guidance=do_classifier_free_guidance,
need_style_fidelity=need_style_fidelity,
)
else:
src = adaptive_instance_normalization(
src=src,
dst=dst,
)
if keep_dim:
src = batch_concat_two_tensor_with_index(
src.to(dtype=dtype),
src_index,
dst.to(dtype=dtype),
dst_index,
dim=dim,
)
if ndim == 4 and num_frames is not None:
src = rearrange(tensor, "b c t h w ->(b t) c h w")
return src
def align_repeat_tensor_single_dim(
src: torch.Tensor,
target_length: int,
dim: int = 0,
n_src_base_length: int = 1,
src_base_index: List[int] = None,
) -> torch.Tensor:
"""沿着 dim 纬度, 补齐 src 的长度到目标 target_length。
当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length
align length of src to target_length along dim
when src length is less than target_length, take the first n_src_base_length and repeat to target_length
Args:
src (torch.Tensor): 输入 tensor, input tensor
target_length (int): 目标长度, target_length
dim (int, optional): 处理纬度, target dim . Defaults to 0.
n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1.
Returns:
torch.Tensor: _description_
"""
src_dim_length = src.shape[dim]
if target_length > src_dim_length:
if target_length % src_dim_length == 0:
new = src.repeat_interleave(
repeats=target_length // src_dim_length, dim=dim
)
else:
if src_base_index is None and n_src_base_length is not None:
src_base_index = torch.arange(n_src_base_length)
new = src.index_select(
dim=dim,
index=torch.LongTensor(src_base_index).to(device=src.device),
)
new = new.repeat_interleave(
repeats=target_length // len(src_base_index),
dim=dim,
)
elif target_length < src_dim_length:
new = src.index_select(
dim=dim,
index=torch.LongTensor(torch.arange(target_length)).to(device=src.device),
)
else:
new = src
return new
def fuse_part_tensor(
src: torch.Tensor,
dst: torch.Tensor,
overlap: int,
weight: float = 0.5,
skip_step: int = 0,
) -> torch.Tensor:
"""fuse overstep tensor with weight of src into dst
out = src_fused_part * weight + dst * (1-weight) for overlap
Args:
src (torch.Tensor): b c t h w
dst (torch.Tensor): b c t h w
overlap (int): 1
weight (float, optional): weight of src tensor part. Defaults to 0.5.
Returns:
torch.Tensor: fused tensor
"""
if overlap == 0:
return dst
else:
dst[:, :, skip_step : skip_step + overlap] = (
weight * src[:, :, -overlap:]
+ (1 - weight) * dst[:, :, skip_step : skip_step + overlap]
)
return dst