from mixer import Mixer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch class T5Mixer(Mixer): def __init__(self) -> None: super().__init__() self.tokenizer = AutoTokenizer.from_pretrained( "llm-book/t5-base-long-livedoor-news-corpus") self.model = AutoModelForSeq2SeqLM.from_pretrained( "llm-book/t5-base-long-livedoor-news-corpus") self.tokenid_to_tokentext = {i: t for t, i in self.tokenizer.get_vocab().items()} def get_encoder_state(self, sentence: str): inputs = self.tokenizer(sentence, return_tensors="pt") eo = self.model.encoder.forward(**inputs) es = eo["last_hidden_state"] return es def get_mixed_encode_state( self, sentence_A: str, sentence_B: str, A_ratio: float = 0.5): es_A = self.get_encoder_state(sentence_A) es_B = self.get_encoder_state(sentence_B) n_tokens_A = es_A.size(1) n_tokens_B = es_B.size(1) if n_tokens_A >= n_tokens_B: es = es_A.clone().detach()*A_ratio es[:, :n_tokens_B, :] += es_B*(1.-A_ratio) else: es = es_B.clone().detach()*(1.-A_ratio) es[:, :n_tokens_A, :] += es_A*A_ratio return es def mix_sentences(self, sentence_A: str, sentence_B: str, A_ratio: float, max_n_tokens: int = 140): es = self.get_mixed_encode_state(sentence_A, sentence_B, A_ratio) to = torch.tensor([[self.tokenizer.pad_token_id]]) for i in range(max_n_tokens): od = self.model.decoder.forward( input_ids=to, encoder_hidden_states=es+torch.randn_like(es)*noise_rate ) sd = od.last_hidden_state l = self.model.lm_head(sd[0, -1, :]) t_next = l.argmax() ttext_next = self.tokenid_to_tokentext[int(t_next)] to = torch.cat((to, t_next[None, None]), dim=-1) if t_next == self.tokenizer.eos_token_id: break sentence = self.tokenizer.batch_decode(to)[0] return sentence