Mya-Mya commited on
Commit
5ea6a45
·
1 Parent(s): 4f1e4fb

Create T5Mixer

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. t5mixer.py +52 -0
app.py CHANGED
@@ -1,4 +1,5 @@
1
  from dummymixer import DummyMixer
 
2
  from frontend import launch
3
 
4
- launch(DummyMixer())
 
1
  from dummymixer import DummyMixer
2
+ from t5mixer import T5Mixer
3
  from frontend import launch
4
 
5
+ launch(T5Mixer())
t5mixer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mixer import Mixer
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+
5
+
6
+ class T5Mixer(Mixer):
7
+ def __init__(self) -> None:
8
+ super().__init__()
9
+ self.tokenizer = AutoTokenizer.from_pretrained(
10
+ "llm-book/t5-base-long-livedoor-news-corpus")
11
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
12
+ "llm-book/t5-base-long-livedoor-news-corpus")
13
+ self.tokenid_to_tokentext = {i: t for t,
14
+ i in self.tokenizer.get_vocab().items()}
15
+
16
+ def get_encoder_state(self, sentence: str):
17
+ inputs = self.tokenizer(sentence, return_tensors="pt")
18
+ eo = self.model.encoder.forward(**inputs)
19
+ es = eo["last_hidden_state"]
20
+ return es
21
+
22
+ def get_mixed_encode_state(
23
+ self, sentence_A: str, sentence_B: str, A_ratio: float = 0.5):
24
+ es_A = self.get_encoder_state(sentence_A)
25
+ es_B = self.get_encoder_state(sentence_B)
26
+ n_tokens_A = es_A.size(1)
27
+ n_tokens_B = es_B.size(1)
28
+ if n_tokens_A >= n_tokens_B:
29
+ es = es_A.clone().detach()*A_ratio
30
+ es[:, :n_tokens_B, :] += es_B*(1.-A_ratio)
31
+ else:
32
+ es = es_B.clone().detach()*(1.-A_ratio)
33
+ es[:, :n_tokens_A, :] += es_A*A_ratio
34
+ return es
35
+
36
+ def mix_sentences(self, sentence_A: str, sentence_B: str, A_ratio: float, max_n_tokens: int = 140):
37
+ es = self.get_mixed_encode_state(sentence_A, sentence_B, A_ratio)
38
+ to = torch.tensor([[self.tokenizer.pad_token_id]])
39
+ for i in range(max_n_tokens):
40
+ od = self.model.decoder.forward(
41
+ input_ids=to,
42
+ encoder_hidden_states=es+torch.randn_like(es)*noise_rate
43
+ )
44
+ sd = od.last_hidden_state
45
+ l = self.model.lm_head(sd[0, -1, :])
46
+ t_next = l.argmax()
47
+ ttext_next = self.tokenid_to_tokentext[int(t_next)]
48
+ to = torch.cat((to, t_next[None, None]), dim=-1)
49
+ if t_next == self.tokenizer.eos_token_id:
50
+ break
51
+ sentence = self.tokenizer.batch_decode(to)[0]
52
+ return sentence