import torch import copy import numpy as np from typing import OrderedDict from scipy.ndimage import gaussian_filter1d from transformers import PreTrainedModel from in2in.utils.configs import get_config from in2in.models.in2in import in2IN from in2in.utils.preprocess import MotionNormalizer from .config import in2INConfig class in2INModel(PreTrainedModel): config_class = in2INConfig def __init__(self, config): super().__init__(config) self.mode = config.MODE self.model = in2IN(config, mode=config.MODE) self.normalizer = MotionNormalizer() def forward(self, prompt_interaction, prompt_individual1, prompt_individual2): self.model.eval() batch = OrderedDict({}) batch["motion_lens"] = torch.zeros(1,1).long().cuda() batch["prompt_interaction"] = prompt_interaction if self.mode != "individual": batch["prompt_individual1"] = prompt_individual1 batch["prompt_individual2"] = prompt_individual2 window_size = 210 motion_output = self.generate_loop(batch, window_size) return motion_output def generate_loop(self, batch, window_size): prompt_interaction = batch["prompt_interaction"] if self.mode != "individual": prompt_individual1 = batch["prompt_individual1"] prompt_individual2 = batch["prompt_individual2"] batch = copy.deepcopy(batch) batch["motion_lens"][:] = window_size batch["text"] = [prompt_interaction] if self.mode != "individual": batch["text_individual1"] = [prompt_individual1] batch["text_individual2"] = [prompt_individual2] batch = self.model.forward_test(batch) if self.mode == "individual": motion_output = batch["output"][0].reshape(-1, 262) motion_output = self.normalizer.backward(motion_output.cpu().detach().numpy()) joints3d = motion_output[:,:22*3].reshape(-1,22,3) joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest') return joints3d motion_output_both = batch["output"][0].reshape(batch["output"][0].shape[0], 2, -1) motion_output_both = self.normalizer.backward(motion_output_both.cpu().detach().numpy()) sequences = [[], []] for j in range(2): motion_output = motion_output_both[:,j] joints3d = motion_output[:,:22*3].reshape(-1,22,3) joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest') sequences[j].append(joints3d) sequences[0] = np.concatenate(sequences[0], axis=0) sequences[1] = np.concatenate(sequences[1], axis=0) return sequences