mipatov commited on
Commit
b17f4f4
·
1 Parent(s): a0b0d25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -21,7 +21,7 @@ def get_model_t5(model_name,tokenizer_name):
21
 
22
 
23
  def predict_gpt(text, model, tokenizer, temperature=1.0):
24
- input_ids = tokenizer.encode(text, return_tensors="pt")
25
  with torch.no_grad():
26
  out = model.generate(input_ids,
27
  do_sample=True,
@@ -38,7 +38,7 @@ def predict_gpt(text, model, tokenizer, temperature=1.0):
38
  )
39
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
40
  generated_text = list(map(decode, out['sequences']))[0].replace(text,'')
41
- return generated_text
42
 
43
  def predict_t5(text, model, tokenizer, temperature=1.2):
44
  input_ids = tokenizer.encode(text, return_tensors="pt")
@@ -57,7 +57,7 @@ def predict_t5(text, model, tokenizer, temperature=1.2):
57
  )
58
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
59
  generated_text = list(map(decode, out['sequences']))[0]
60
- return 'Описание '+generated_text
61
 
62
  def generate(model,temp,text):
63
  if model == 'GPT':
@@ -76,7 +76,7 @@ demo = gr.Interface(
76
  fn=generate,
77
  inputs=[
78
  gr.components.Dropdown(label="Модель", choices=('GPT', 'T5')),
79
- gr.components.Slider(label="Температура",minimum = 1.0,maximum = 3.0,step = 0.1),
80
  gr.components.Textbox(label="Характеристики",value = example),
81
  ],
82
  outputs=[
 
21
 
22
 
23
  def predict_gpt(text, model, tokenizer, temperature=1.0):
24
+ input_ids = tokenizer.encode(text+" \n Описание:", return_tensors="pt")
25
  with torch.no_grad():
26
  out = model.generate(input_ids,
27
  do_sample=True,
 
38
  )
39
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
40
  generated_text = list(map(decode, out['sequences']))[0].replace(text,'')
41
+ return "Описание : "+generated_text
42
 
43
  def predict_t5(text, model, tokenizer, temperature=1.2):
44
  input_ids = tokenizer.encode(text, return_tensors="pt")
 
57
  )
58
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
59
  generated_text = list(map(decode, out['sequences']))[0]
60
+ return 'Описание :'+generated_text
61
 
62
  def generate(model,temp,text):
63
  if model == 'GPT':
 
76
  fn=generate,
77
  inputs=[
78
  gr.components.Dropdown(label="Модель", choices=('GPT', 'T5')),
79
+ gr.components.Slider(label="Вариативность",minimum = 1.0,maximum = 3.0,step = 0.1),
80
  gr.components.Textbox(label="Характеристики",value = example),
81
  ],
82
  outputs=[