Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,44 +4,44 @@ import tokenizers
|
|
4 |
import gradio as gr
|
5 |
import re
|
6 |
|
7 |
-
from PIL import Image
|
8 |
-
|
9 |
|
10 |
def get_model_gpt(model_name,tokenizer_name):
|
11 |
tokenizer = transformers.GPT2Tokenizer.from_pretrained(tokenizer_name)
|
12 |
model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
|
13 |
-
model.eval()
|
14 |
return model, tokenizer
|
15 |
|
16 |
def get_model_t5(model_name,tokenizer_name):
|
17 |
tokenizer = transformers.T5Tokenizer.from_pretrained(tokenizer_name)
|
18 |
model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
|
19 |
-
model.eval()
|
20 |
return model, tokenizer
|
21 |
|
22 |
|
23 |
def predict_gpt(text, model, tokenizer, temperature=1.0):
|
24 |
input_ids = tokenizer.encode(text+" \n Описание:", return_tensors="pt")
|
|
|
|
|
25 |
with torch.no_grad():
|
26 |
out = model.generate(input_ids,
|
27 |
do_sample=True,
|
28 |
-
num_beams=
|
29 |
-
temperature=
|
30 |
-
top_p=0.
|
31 |
-
max_length=
|
32 |
-
length_penalty = 2.5,
|
33 |
eos_token_id = tokenizer.eos_token_id,
|
34 |
pad_token_id = tokenizer.pad_token_id,
|
|
|
35 |
num_return_sequences = 1,
|
36 |
-
|
37 |
-
|
38 |
)
|
39 |
decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
|
40 |
-
generated_text = list(map(decode, out['sequences']))
|
41 |
-
return generated_text
|
42 |
|
43 |
def predict_t5(text, model, tokenizer, temperature=1.2):
|
44 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
|
|
|
|
45 |
with torch.no_grad():
|
46 |
out = model.generate(input_ids,
|
47 |
do_sample=True,
|
|
|
4 |
import gradio as gr
|
5 |
import re
|
6 |
|
|
|
|
|
7 |
|
8 |
def get_model_gpt(model_name,tokenizer_name):
|
9 |
tokenizer = transformers.GPT2Tokenizer.from_pretrained(tokenizer_name)
|
10 |
model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
|
|
|
11 |
return model, tokenizer
|
12 |
|
13 |
def get_model_t5(model_name,tokenizer_name):
|
14 |
tokenizer = transformers.T5Tokenizer.from_pretrained(tokenizer_name)
|
15 |
model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
|
|
|
16 |
return model, tokenizer
|
17 |
|
18 |
|
19 |
def predict_gpt(text, model, tokenizer, temperature=1.0):
|
20 |
input_ids = tokenizer.encode(text+" \n Описание:", return_tensors="pt")
|
21 |
+
|
22 |
+
model.eval()
|
23 |
with torch.no_grad():
|
24 |
out = model.generate(input_ids,
|
25 |
do_sample=True,
|
26 |
+
num_beams=3,
|
27 |
+
temperature=temperature,
|
28 |
+
top_p=0.75,
|
29 |
+
max_length=1024,
|
|
|
30 |
eos_token_id = tokenizer.eos_token_id,
|
31 |
pad_token_id = tokenizer.pad_token_id,
|
32 |
+
repetition_penalty = 2.5,
|
33 |
num_return_sequences = 1,
|
34 |
+
output_attentions = True,
|
35 |
+
return_dict_in_generate=True,
|
36 |
)
|
37 |
decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
|
38 |
+
generated_text = list(map(decode, out['sequences']))
|
39 |
+
return generated_text[0].split('Описание :')[1]
|
40 |
|
41 |
def predict_t5(text, model, tokenizer, temperature=1.2):
|
42 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
43 |
+
|
44 |
+
model.eval()
|
45 |
with torch.no_grad():
|
46 |
out = model.generate(input_ids,
|
47 |
do_sample=True,
|