File size: 2,143 Bytes
5ea6a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from mixer import Mixer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch


class T5Mixer(Mixer):
    def __init__(self) -> None:
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(
            "llm-book/t5-base-long-livedoor-news-corpus")
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            "llm-book/t5-base-long-livedoor-news-corpus")
        self.tokenid_to_tokentext = {i: t for t,
                                     i in self.tokenizer.get_vocab().items()}

    def get_encoder_state(self, sentence: str):
        inputs = self.tokenizer(sentence, return_tensors="pt")
        eo = self.model.encoder.forward(**inputs)
        es = eo["last_hidden_state"]
        return es

    def get_mixed_encode_state(
            self, sentence_A: str, sentence_B: str, A_ratio: float = 0.5):
        es_A = self.get_encoder_state(sentence_A)
        es_B = self.get_encoder_state(sentence_B)
        n_tokens_A = es_A.size(1)
        n_tokens_B = es_B.size(1)
        if n_tokens_A >= n_tokens_B:
            es = es_A.clone().detach()*A_ratio
            es[:, :n_tokens_B, :] += es_B*(1.-A_ratio)
        else:
            es = es_B.clone().detach()*(1.-A_ratio)
            es[:, :n_tokens_A, :] += es_A*A_ratio
        return es

    def mix_sentences(self, sentence_A: str, sentence_B: str, A_ratio: float, max_n_tokens: int = 140):
        es = self.get_mixed_encode_state(sentence_A, sentence_B, A_ratio)
        to = torch.tensor([[self.tokenizer.pad_token_id]])
        for i in range(max_n_tokens):
            od = self.model.decoder.forward(
                input_ids=to,
                encoder_hidden_states=es+torch.randn_like(es)*noise_rate
            )
            sd = od.last_hidden_state
            l = self.model.lm_head(sd[0, -1, :])
            t_next = l.argmax()
            ttext_next = self.tokenid_to_tokentext[int(t_next)]
            to = torch.cat((to, t_next[None, None]), dim=-1)
            if t_next == self.tokenizer.eos_token_id:
                break
        sentence = self.tokenizer.batch_decode(to)[0]
        return sentence