Spaces:
Runtime error
Runtime error
import torch | |
from utils.word_vectorizer import WordVectorizer, POS_enumerator | |
from utils.get_opt import get_opt | |
from models import MotionTransformer | |
from torch.utils.data import Dataset, DataLoader | |
from os.path import join as pjoin | |
from tqdm import tqdm | |
import numpy as np | |
from .evaluator_models import * | |
import os | |
import codecs as cs | |
import random | |
from torch.utils.data._utils.collate import default_collate | |
class EvaluationDataset(Dataset): | |
def __init__(self, opt, trainer, dataset, w_vectorizer, mm_num_samples, mm_num_repeats): | |
assert mm_num_samples < len(dataset) | |
print(opt.model_dir) | |
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) | |
epoch, it = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar')) | |
generated_motion = [] | |
min_mov_length = 10 if opt.dataset_name == 't2m' else 6 | |
trainer.eval_mode() | |
trainer.to(opt.device) | |
# Pre-process all target captions | |
mm_generated_motions = [] | |
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) | |
mm_idxs = np.sort(mm_idxs) | |
all_caption = [] | |
all_m_lens = [] | |
all_data = [] | |
with torch.no_grad(): | |
for i, data in tqdm(enumerate(dataloader)): | |
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data | |
all_data.append(data) | |
tokens = tokens[0].split('_') | |
mm_num_now = len(mm_generated_motions) | |
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False | |
repeat_times = mm_num_repeats if is_mm else 1 | |
m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length) | |
m_lens = min(m_lens, opt.max_motion_length) | |
if isinstance(m_lens, int): | |
m_lens = torch.LongTensor([m_lens]).to(opt.device) | |
else: | |
m_lens = m_lens.to(opt.device) | |
for t in range(repeat_times): | |
all_m_lens.append(m_lens) | |
all_caption.extend(caption) | |
if is_mm: | |
mm_generated_motions.append(0) | |
all_m_lens = torch.stack(all_m_lens) | |
# Generate all sequences | |
with torch.no_grad(): | |
all_pred_motions = trainer.generate(all_caption, all_m_lens, opt.dim_pose) | |
cur_idx = 0 | |
mm_generated_motions = [] | |
with torch.no_grad(): | |
for i, data_dummy in tqdm(enumerate(dataloader)): | |
data = all_data[i] | |
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data | |
tokens = tokens[0].split('_') | |
mm_num_now = len(mm_generated_motions) | |
is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False | |
repeat_times = mm_num_repeats if is_mm else 1 | |
mm_motions = [] | |
m_lens = max(m_lens // opt.unit_length * opt.unit_length, min_mov_length * opt.unit_length) | |
m_lens = min(m_lens, opt.max_motion_length) | |
if isinstance(m_lens, int): | |
m_lens = torch.LongTensor([m_lens]).to(opt.device) | |
else: | |
m_lens = m_lens.to(opt.device) | |
for t in range(repeat_times): | |
m_len = m_lens[0].item() | |
pred_motions = all_pred_motions[cur_idx][:m_lens[0].item()] | |
assert pred_motions.shape[0] == m_lens[0].item() | |
cur_idx += 1 | |
if t == 0: | |
sub_dict = {'motion': pred_motions.cpu().numpy(), | |
'length': pred_motions.shape[0], | |
'caption': caption[0], | |
'cap_len': cap_lens[0].item(), | |
'tokens': tokens} | |
generated_motion.append(sub_dict) | |
if is_mm: | |
mm_motions.append({ | |
'motion': pred_motions.cpu().numpy(), | |
'length': m_lens[0].item() | |
}) | |
if is_mm: | |
mm_generated_motions.append({'caption': caption[0], | |
'tokens': tokens, | |
'cap_len': cap_lens[0].item(), | |
'mm_motions': mm_motions}) | |
self.generated_motion = generated_motion | |
self.mm_generated_motion = mm_generated_motions | |
self.opt = opt | |
self.w_vectorizer = w_vectorizer | |
def __len__(self): | |
return len(self.generated_motion) | |
def __getitem__(self, item): | |
data = self.generated_motion[item] | |
motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens'] | |
sent_len = data['cap_len'] | |
pos_one_hots = [] | |
word_embeddings = [] | |
for token in tokens: | |
word_emb, pos_oh = self.w_vectorizer[token] | |
pos_one_hots.append(pos_oh[None, :]) | |
word_embeddings.append(word_emb[None, :]) | |
pos_one_hots = np.concatenate(pos_one_hots, axis=0) | |
word_embeddings = np.concatenate(word_embeddings, axis=0) | |
if m_length < self.opt.max_motion_length: | |
motion = np.concatenate([motion, | |
np.zeros((self.opt.max_motion_length - m_length, motion.shape[1])) | |
], axis=0) | |
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) | |
def collate_fn(batch): | |
batch.sort(key=lambda x: x[3], reverse=True) | |
return default_collate(batch) | |
'''For use of training text motion matching model, and evaluations''' | |
class Text2MotionDatasetV2(Dataset): | |
def __init__(self, opt, mean, std, split_file, w_vectorizer): | |
self.opt = opt | |
self.w_vectorizer = w_vectorizer | |
self.max_length = 20 | |
self.pointer = 0 | |
self.max_motion_length = opt.max_motion_length | |
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24 | |
data_dict = {} | |
id_list = [] | |
with cs.open(split_file, 'r') as f: | |
for line in f.readlines(): | |
id_list.append(line.strip()) | |
new_name_list = [] | |
length_list = [] | |
for name in tqdm(id_list): | |
try: | |
motion = np.load(pjoin(opt.motion_dir, name + '.npy')) | |
if (len(motion)) < min_motion_len or (len(motion) >= 200): | |
continue | |
text_data = [] | |
flag = False | |
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f: | |
for line in f.readlines(): | |
text_dict = {} | |
line_split = line.strip().split('#') | |
caption = line_split[0] | |
tokens = line_split[1].split(' ') | |
f_tag = float(line_split[2]) | |
to_tag = float(line_split[3]) | |
f_tag = 0.0 if np.isnan(f_tag) else f_tag | |
to_tag = 0.0 if np.isnan(to_tag) else to_tag | |
text_dict['caption'] = caption | |
text_dict['tokens'] = tokens | |
if f_tag == 0.0 and to_tag == 0.0: | |
flag = True | |
text_data.append(text_dict) | |
else: | |
try: | |
n_motion = motion[int(f_tag*20) : int(to_tag*20)] | |
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200): | |
continue | |
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name | |
while new_name in data_dict: | |
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name | |
data_dict[new_name] = {'motion': n_motion, | |
'length': len(n_motion), | |
'text':[text_dict]} | |
new_name_list.append(new_name) | |
length_list.append(len(n_motion)) | |
except: | |
print(line_split) | |
print(line_split[2], line_split[3], f_tag, to_tag, name) | |
# break | |
if flag: | |
data_dict[name] = {'motion': motion, | |
'length': len(motion), | |
'text': text_data} | |
new_name_list.append(name) | |
length_list.append(len(motion)) | |
except: | |
pass | |
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1])) | |
self.mean = mean | |
self.std = std | |
self.length_arr = np.array(length_list) | |
self.data_dict = data_dict | |
self.name_list = name_list | |
self.reset_max_len(self.max_length) | |
def reset_max_len(self, length): | |
assert length <= self.max_motion_length | |
self.pointer = np.searchsorted(self.length_arr, length) | |
print("Pointer Pointing at %d"%self.pointer) | |
self.max_length = length | |
def inv_transform(self, data): | |
return data * self.std + self.mean | |
def __len__(self): | |
return len(self.data_dict) - self.pointer | |
def __getitem__(self, item): | |
idx = self.pointer + item | |
data = self.data_dict[self.name_list[idx]] | |
motion, m_length, text_list = data['motion'], data['length'], data['text'] | |
# Randomly select a caption | |
text_data = random.choice(text_list) | |
caption, tokens = text_data['caption'], text_data['tokens'] | |
if len(tokens) < self.opt.max_text_len: | |
# pad with "unk" | |
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] | |
sent_len = len(tokens) | |
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len) | |
else: | |
# crop | |
tokens = tokens[:self.opt.max_text_len] | |
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] | |
sent_len = len(tokens) | |
pos_one_hots = [] | |
word_embeddings = [] | |
for token in tokens: | |
word_emb, pos_oh = self.w_vectorizer[token] | |
pos_one_hots.append(pos_oh[None, :]) | |
word_embeddings.append(word_emb[None, :]) | |
pos_one_hots = np.concatenate(pos_one_hots, axis=0) | |
word_embeddings = np.concatenate(word_embeddings, axis=0) | |
# Crop the motions in to times of 4, and introduce small variations | |
if self.opt.unit_length < 10: | |
coin2 = np.random.choice(['single', 'single', 'double']) | |
else: | |
coin2 = 'single' | |
if coin2 == 'double': | |
m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length | |
elif coin2 == 'single': | |
m_length = (m_length // self.opt.unit_length) * self.opt.unit_length | |
idx = random.randint(0, len(motion) - m_length) | |
motion = motion[idx:idx+m_length] | |
"Z Normalization" | |
motion = (motion - self.mean) / self.std | |
if m_length < self.max_motion_length: | |
motion = np.concatenate([motion, | |
np.zeros((self.max_motion_length - m_length, motion.shape[1])) | |
], axis=0) | |
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens) | |
def get_dataset_motion_loader(opt_path, batch_size, device): | |
opt = get_opt(opt_path, device) | |
# Configurations of T2M dataset and KIT dataset is almost the same | |
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': | |
print('Loading dataset %s ...' % opt.dataset_name) | |
mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) | |
std = np.load(pjoin(opt.meta_dir, 'std.npy')) | |
w_vectorizer = WordVectorizer('./data/glove', 'our_vab') | |
split_file = pjoin(opt.data_root, 'test.txt') | |
dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer) | |
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True, | |
collate_fn=collate_fn, shuffle=True) | |
else: | |
raise KeyError('Dataset not Recognized !!') | |
print('Ground Truth Dataset Loading Completed!!!') | |
return dataloader, dataset | |
class MMGeneratedDataset(Dataset): | |
def __init__(self, opt, motion_dataset, w_vectorizer): | |
self.opt = opt | |
self.dataset = motion_dataset.mm_generated_motion | |
self.w_vectorizer = w_vectorizer | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, item): | |
data = self.dataset[item] | |
mm_motions = data['mm_motions'] | |
m_lens = [] | |
motions = [] | |
for mm_motion in mm_motions: | |
m_lens.append(mm_motion['length']) | |
motion = mm_motion['motion'] | |
if len(motion) < self.opt.max_motion_length: | |
motion = np.concatenate([motion, | |
np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1])) | |
], axis=0) | |
motion = motion[None, :] | |
motions.append(motion) | |
m_lens = np.array(m_lens, dtype=np.int) | |
motions = np.concatenate(motions, axis=0) | |
sort_indx = np.argsort(m_lens)[::-1].copy() | |
# print(m_lens) | |
# print(sort_indx) | |
# print(m_lens[sort_indx]) | |
m_lens = m_lens[sort_indx] | |
motions = motions[sort_indx] | |
return motions, m_lens | |
def get_motion_loader(opt, batch_size, trainer, ground_truth_dataset, mm_num_samples, mm_num_repeats): | |
# Currently the configurations of two datasets are almost the same | |
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': | |
w_vectorizer = WordVectorizer('./data/glove', 'our_vab') | |
else: | |
raise KeyError('Dataset not recognized!!') | |
print('Generating %s ...' % opt.name) | |
dataset = EvaluationDataset(opt, trainer, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats) | |
mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) | |
motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4) | |
mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) | |
print('Generated Dataset Loading Completed!!!') | |
return motion_loader, mm_motion_loader | |
def build_models(opt): | |
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) | |
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, | |
pos_size=opt.dim_pos_ohot, | |
hidden_size=opt.dim_text_hidden, | |
output_size=opt.dim_coemb_hidden, | |
device=opt.device) | |
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, | |
hidden_size=opt.dim_motion_hidden, | |
output_size=opt.dim_coemb_hidden, | |
device=opt.device) | |
checkpoint = torch.load(pjoin('data/pretrained_models', opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), | |
map_location=opt.device) | |
movement_enc.load_state_dict(checkpoint['movement_encoder']) | |
text_enc.load_state_dict(checkpoint['text_encoder']) | |
motion_enc.load_state_dict(checkpoint['motion_encoder']) | |
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) | |
return text_enc, motion_enc, movement_enc | |
class EvaluatorModelWrapper(object): | |
def __init__(self, opt): | |
if opt.dataset_name == 't2m': | |
opt.dim_pose = 263 | |
elif opt.dataset_name == 'kit': | |
opt.dim_pose = 251 | |
else: | |
raise KeyError('Dataset not Recognized!!!') | |
opt.dim_word = 300 | |
opt.max_motion_length = 196 | |
opt.dim_pos_ohot = len(POS_enumerator) | |
opt.dim_motion_hidden = 1024 | |
opt.max_text_len = 20 | |
opt.dim_text_hidden = 512 | |
opt.dim_coemb_hidden = 512 | |
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) | |
self.opt = opt | |
self.device = opt.device | |
self.text_encoder.to(opt.device) | |
self.motion_encoder.to(opt.device) | |
self.movement_encoder.to(opt.device) | |
self.text_encoder.eval() | |
self.motion_encoder.eval() | |
self.movement_encoder.eval() | |
# Please note that the results does not following the order of inputs | |
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): | |
with torch.no_grad(): | |
word_embs = word_embs.detach().to(self.device).float() | |
pos_ohot = pos_ohot.detach().to(self.device).float() | |
motions = motions.detach().to(self.device).float() | |
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() | |
motions = motions[align_idx] | |
m_lens = m_lens[align_idx] | |
'''Movement Encoding''' | |
movements = self.movement_encoder(motions[..., :-4]).detach() | |
m_lens = m_lens // self.opt.unit_length | |
motion_embedding = self.motion_encoder(movements, m_lens) | |
'''Text Encoding''' | |
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) | |
text_embedding = text_embedding[align_idx] | |
return text_embedding, motion_embedding | |
# Please note that the results does not following the order of inputs | |
def get_motion_embeddings(self, motions, m_lens): | |
with torch.no_grad(): | |
motions = motions.detach().to(self.device).float() | |
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() | |
motions = motions[align_idx] | |
m_lens = m_lens[align_idx] | |
'''Movement Encoding''' | |
movements = self.movement_encoder(motions[..., :-4]).detach() | |
m_lens = m_lens // self.opt.unit_length | |
motion_embedding = self.motion_encoder(movements, m_lens) | |
return motion_embedding | |