Spaces:
Sleeping
Sleeping
from typing import List, Dict | |
import torch | |
from torch import nn | |
try: | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
except ImportError: | |
from ditk import logging | |
logging.warning("not found transformer, please install it using: pip install transformers") | |
from ding.utils import MODEL_REGISTRY | |
class LanguageTransformer(nn.Module): | |
""" | |
Overview: | |
The LanguageTransformer network. Download a pre-trained language model and add head on it. | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
model_name: str = "bert-base-uncased", | |
add_linear: bool = False, | |
embedding_size: int = 128, | |
freeze_encoder: bool = True | |
) -> None: | |
""" | |
Overview: | |
Init the LanguageTransformer Model according to input arguments. | |
Arguments: | |
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased". | |
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \ | |
``False``. | |
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128. | |
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \ | |
defaults to be ``True``. | |
""" | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForTokenClassification.from_pretrained(model_name) | |
# Freeze transformer encoder and only train the linear layer | |
if freeze_encoder: | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
if add_linear: | |
# Add a small, adjustable linear layer on top of language model tuned through RL | |
self.embedding_size = embedding_size | |
self.linear = nn.Linear( | |
self.model.config.hidden_size, embedding_size | |
) # 768 for bert-base-uncased, distilbert-base-uncased | |
else: | |
self.linear = None | |
def _calc_embedding(self, x: list) -> torch.Tensor: | |
# ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer, | |
# the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach | |
# the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is | |
# exactly ``max_length``, which can enable batch-wise computing. | |
input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device) | |
output = self.model(**input, output_hidden_states=True) | |
# Get last layer hidden states | |
last_hidden_states = output.hidden_states[-1] | |
# Get [CLS] hidden states | |
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size | |
if self.linear: | |
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size | |
return sentence_embedding | |
def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict: | |
""" | |
Overview: | |
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores. | |
Arguments: | |
- train_samples (:obj:`List[str]`): One list of strings. | |
- candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores. | |
Returns: | |
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \ | |
corresponding ``torch.distributions.Categorical`` object. | |
Examples: | |
>>> test_pids = [1] | |
>>> cand_pids = [0, 2, 4] | |
>>> problems = [ \ | |
"This is problem 0", "This is the first question", "Second problem is here", "Another problem", \ | |
"This is the last problem" \ | |
] | |
>>> ctxt_list = [problems[pid] for pid in test_pids] | |
>>> cands_list = [problems[pid] for pid in cand_pids] | |
>>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) | |
>>> scores = model(ctxt_list, cands_list) | |
>>> assert scores.shape == (1, 3) | |
""" | |
prompt_embedding = self._calc_embedding(train_samples) | |
cands_embedding = self._calc_embedding(candidate_samples) | |
scores = torch.mm(prompt_embedding, cands_embedding.t()) | |
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores} | |