Spaces:
Running
Running
File size: 4,825 Bytes
007d500 83f52e6 007d500 83f52e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from typing import List
import torch.nn as nn
import os
import torch
import numpy as np
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from transformers import logging
from torch.nn.functional import normalize
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe, persistent=False)
def forward(self, x):
return x + self.pe[:x.shape[0], :]
class TMR_textencoder(nn.Module):
def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
super().__init__()
logging.set_verbosity_error()
# Tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
# Text model
self.text_model = AutoModel.from_pretrained(modelpath)
# Then configure the model
self.text_encoded_dim = self.text_model.config.hidden_size
# Projection of the text-outputs into the latent space
self.projection = nn.Sequential(
nn.ReLU(),
nn.Linear(self.text_encoded_dim, latent_dim)
)
self.mu_token = nn.Parameter(torch.randn(latent_dim))
self.logvar_token = nn.Parameter(torch.randn(latent_dim))
self.sequence_pos_encoding = PositionalEncoding(latent_dim)
seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=0.0,
activation=activation)
self.seqTransEncoder = nn.TransformerEncoder(
seq_trans_encoder_layer,
num_layers=num_layers
)
def get_last_hidden_state(self, texts: List[str],
return_mask: bool = False):
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
output = self.text_model(**encoded_inputs.to(self.text_model.device))
if not return_mask:
return output.last_hidden_state
return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
def forward(self, texts: List[str]) -> Tensor:
text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
x = self.projection(text_encoded)
bs, nframes, _ = x.shape
# bs, nframes, totjoints, nfeats = x.shape
# Switch sequence and batch_size because the input of
# Pytorch Transformer is [Sequence, Batch size, ...]
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
# adding the distribution tokens for all sequences
xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
# create a bigger mask, to allow attend to mu and logvar
token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
aug_mask = torch.cat((token_mask, mask), 1)
# add positional encoding
xseq = self.sequence_pos_encoding(xseq)
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
# only mu for inference
mu = final[0]
return mu
# compute score for retrieval
def compute_scores(self, texts, unit_embs=None, embs=None):
# not both empty
assert not (unit_embs is None and embs is None)
# not both filled
assert not (unit_embs is not None and embs is not None)
output_str = False
# if one input, squeeze the output
if isinstance(texts, str):
texts = [texts]
output_str = True
# compute unit_embs from embs if not given
if embs is not None:
unit_embs = normalize(embs)
with torch.no_grad():
latent_unit_texts = normalize(self(texts))
# compute cosine similarity between 0 and 1
scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
scores = scores.cpu().numpy()
if output_str:
scores = scores[0]
return scores
|