mipatov commited on
Commit
4f971e0
·
1 Parent(s): 952c01c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -4,44 +4,44 @@ import tokenizers
4
  import gradio as gr
5
  import re
6
 
7
- from PIL import Image
8
-
9
 
10
  def get_model_gpt(model_name,tokenizer_name):
11
  tokenizer = transformers.GPT2Tokenizer.from_pretrained(tokenizer_name)
12
  model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
13
- model.eval()
14
  return model, tokenizer
15
 
16
  def get_model_t5(model_name,tokenizer_name):
17
  tokenizer = transformers.T5Tokenizer.from_pretrained(tokenizer_name)
18
  model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
19
- model.eval()
20
  return model, tokenizer
21
 
22
 
23
  def predict_gpt(text, model, tokenizer, temperature=1.0):
24
  input_ids = tokenizer.encode(text+" \n Описание:", return_tensors="pt")
 
 
25
  with torch.no_grad():
26
  out = model.generate(input_ids,
27
  do_sample=True,
28
- num_beams=4,
29
- temperature= temperature,
30
- top_p=0.65,
31
- max_length=512,
32
- length_penalty = 2.5,
33
  eos_token_id = tokenizer.eos_token_id,
34
  pad_token_id = tokenizer.pad_token_id,
 
35
  num_return_sequences = 1,
36
- output_attentions = True,
37
- return_dict_in_generate=True,
38
  )
39
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
40
- generated_text = list(map(decode, out['sequences']))[0].replace(text,'')
41
- return generated_text
42
 
43
  def predict_t5(text, model, tokenizer, temperature=1.2):
44
  input_ids = tokenizer.encode(text, return_tensors="pt")
 
 
45
  with torch.no_grad():
46
  out = model.generate(input_ids,
47
  do_sample=True,
 
4
  import gradio as gr
5
  import re
6
 
 
 
7
 
8
  def get_model_gpt(model_name,tokenizer_name):
9
  tokenizer = transformers.GPT2Tokenizer.from_pretrained(tokenizer_name)
10
  model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
 
11
  return model, tokenizer
12
 
13
  def get_model_t5(model_name,tokenizer_name):
14
  tokenizer = transformers.T5Tokenizer.from_pretrained(tokenizer_name)
15
  model = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
 
16
  return model, tokenizer
17
 
18
 
19
  def predict_gpt(text, model, tokenizer, temperature=1.0):
20
  input_ids = tokenizer.encode(text+" \n Описание:", return_tensors="pt")
21
+
22
+ model.eval()
23
  with torch.no_grad():
24
  out = model.generate(input_ids,
25
  do_sample=True,
26
+ num_beams=3,
27
+ temperature=temperature,
28
+ top_p=0.75,
29
+ max_length=1024,
 
30
  eos_token_id = tokenizer.eos_token_id,
31
  pad_token_id = tokenizer.pad_token_id,
32
+ repetition_penalty = 2.5,
33
  num_return_sequences = 1,
34
+ output_attentions = True,
35
+ return_dict_in_generate=True,
36
  )
37
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
38
+ generated_text = list(map(decode, out['sequences']))
39
+ return generated_text[0].split('Описание :')[1]
40
 
41
  def predict_t5(text, model, tokenizer, temperature=1.2):
42
  input_ids = tokenizer.encode(text, return_tensors="pt")
43
+
44
+ model.eval()
45
  with torch.no_grad():
46
  out = model.generate(input_ids,
47
  do_sample=True,