Spaces:
Running
Running
File size: 1,320 Bytes
83f52e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
import orjson
import torch
import numpy as np
from model import TMR_textencoder
EMBS = "data/unit_motion_embs"
def load_json(path):
with open(path, "rb") as ff:
return orjson.loads(ff.read())
def load_keyids(split):
path = os.path.join(EMBS, f"{split}.keyids")
with open(path) as ff:
keyids = np.array([x.strip() for x in ff.readlines()])
return keyids
def load_keyids_splits(splits):
return {
split: load_keyids(split)
for split in splits
}
def load_unit_motion_embs(split, device):
path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy")
tensor = torch.from_numpy(np.load(path)).to(device)
return tensor
def load_unit_motion_embs_splits(splits, device):
return {
split: load_unit_motion_embs(split, device)
for split in splits
}
def load_model(device):
text_params = {
'latent_dim': 256, 'ff_size': 1024, 'num_layers': 6, 'num_heads': 4,
'activation': 'gelu', 'modelpath': 'distilbert-base-uncased'
}
"unit_motion_embs"
model = TMR_textencoder(**text_params)
state_dict = torch.load("data/textencoder.pt", map_location=device)
# load values for the transformer only
model.load_state_dict(state_dict, strict=False)
model = model.eval()
return model
|