Spaces:
Running
Running
Mya-Mya
commited on
Commit
·
5ea6a45
1
Parent(s):
4f1e4fb
Create T5Mixer
Browse files- app.py +2 -1
- t5mixer.py +52 -0
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from dummymixer import DummyMixer
|
|
|
2 |
from frontend import launch
|
3 |
|
4 |
-
launch(
|
|
|
1 |
from dummymixer import DummyMixer
|
2 |
+
from t5mixer import T5Mixer
|
3 |
from frontend import launch
|
4 |
|
5 |
+
launch(T5Mixer())
|
t5mixer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mixer import Mixer
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class T5Mixer(Mixer):
|
7 |
+
def __init__(self) -> None:
|
8 |
+
super().__init__()
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
10 |
+
"llm-book/t5-base-long-livedoor-news-corpus")
|
11 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
12 |
+
"llm-book/t5-base-long-livedoor-news-corpus")
|
13 |
+
self.tokenid_to_tokentext = {i: t for t,
|
14 |
+
i in self.tokenizer.get_vocab().items()}
|
15 |
+
|
16 |
+
def get_encoder_state(self, sentence: str):
|
17 |
+
inputs = self.tokenizer(sentence, return_tensors="pt")
|
18 |
+
eo = self.model.encoder.forward(**inputs)
|
19 |
+
es = eo["last_hidden_state"]
|
20 |
+
return es
|
21 |
+
|
22 |
+
def get_mixed_encode_state(
|
23 |
+
self, sentence_A: str, sentence_B: str, A_ratio: float = 0.5):
|
24 |
+
es_A = self.get_encoder_state(sentence_A)
|
25 |
+
es_B = self.get_encoder_state(sentence_B)
|
26 |
+
n_tokens_A = es_A.size(1)
|
27 |
+
n_tokens_B = es_B.size(1)
|
28 |
+
if n_tokens_A >= n_tokens_B:
|
29 |
+
es = es_A.clone().detach()*A_ratio
|
30 |
+
es[:, :n_tokens_B, :] += es_B*(1.-A_ratio)
|
31 |
+
else:
|
32 |
+
es = es_B.clone().detach()*(1.-A_ratio)
|
33 |
+
es[:, :n_tokens_A, :] += es_A*A_ratio
|
34 |
+
return es
|
35 |
+
|
36 |
+
def mix_sentences(self, sentence_A: str, sentence_B: str, A_ratio: float, max_n_tokens: int = 140):
|
37 |
+
es = self.get_mixed_encode_state(sentence_A, sentence_B, A_ratio)
|
38 |
+
to = torch.tensor([[self.tokenizer.pad_token_id]])
|
39 |
+
for i in range(max_n_tokens):
|
40 |
+
od = self.model.decoder.forward(
|
41 |
+
input_ids=to,
|
42 |
+
encoder_hidden_states=es+torch.randn_like(es)*noise_rate
|
43 |
+
)
|
44 |
+
sd = od.last_hidden_state
|
45 |
+
l = self.model.lm_head(sd[0, -1, :])
|
46 |
+
t_next = l.argmax()
|
47 |
+
ttext_next = self.tokenid_to_tokentext[int(t_next)]
|
48 |
+
to = torch.cat((to, t_next[None, None]), dim=-1)
|
49 |
+
if t_next == self.tokenizer.eos_token_id:
|
50 |
+
break
|
51 |
+
sentence = self.tokenizer.batch_decode(to)[0]
|
52 |
+
return sentence
|