SentenceMixer / t5mixer.py
Mya-Mya
Create T5Mixer
5ea6a45
raw
history blame
2.14 kB
from mixer import Mixer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
class T5Mixer(Mixer):
def __init__(self) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
"llm-book/t5-base-long-livedoor-news-corpus")
self.model = AutoModelForSeq2SeqLM.from_pretrained(
"llm-book/t5-base-long-livedoor-news-corpus")
self.tokenid_to_tokentext = {i: t for t,
i in self.tokenizer.get_vocab().items()}
def get_encoder_state(self, sentence: str):
inputs = self.tokenizer(sentence, return_tensors="pt")
eo = self.model.encoder.forward(**inputs)
es = eo["last_hidden_state"]
return es
def get_mixed_encode_state(
self, sentence_A: str, sentence_B: str, A_ratio: float = 0.5):
es_A = self.get_encoder_state(sentence_A)
es_B = self.get_encoder_state(sentence_B)
n_tokens_A = es_A.size(1)
n_tokens_B = es_B.size(1)
if n_tokens_A >= n_tokens_B:
es = es_A.clone().detach()*A_ratio
es[:, :n_tokens_B, :] += es_B*(1.-A_ratio)
else:
es = es_B.clone().detach()*(1.-A_ratio)
es[:, :n_tokens_A, :] += es_A*A_ratio
return es
def mix_sentences(self, sentence_A: str, sentence_B: str, A_ratio: float, max_n_tokens: int = 140):
es = self.get_mixed_encode_state(sentence_A, sentence_B, A_ratio)
to = torch.tensor([[self.tokenizer.pad_token_id]])
for i in range(max_n_tokens):
od = self.model.decoder.forward(
input_ids=to,
encoder_hidden_states=es+torch.randn_like(es)*noise_rate
)
sd = od.last_hidden_state
l = self.model.lm_head(sd[0, -1, :])
t_next = l.argmax()
ttext_next = self.tokenid_to_tokentext[int(t_next)]
to = torch.cat((to, t_next[None, None]), dim=-1)
if t_next == self.tokenizer.eos_token_id:
break
sentence = self.tokenizer.batch_decode(to)[0]
return sentence