mipatov commited on
Commit
40ec473
·
1 Parent(s): 3408ba9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -17,7 +17,8 @@ def get_model_t5(model_name,tokenizer_name):
17
 
18
 
19
  def predict_gpt(text, model, tokenizer, temperature=1.0):
20
- input_ids = tokenizer.encode(text+" \n Описание:", return_tensors="pt")
 
21
 
22
  model.eval()
23
  with torch.no_grad():
@@ -67,7 +68,7 @@ def generate(model,temp,text):
67
  return result
68
 
69
 
70
- gpt_model, gpt_tokenizer = get_model_gpt('gpt/', 'gpt/')
71
  t5_model, t5_tokenizer = get_model_t5('mipatov/rut5_nb_descr', 'mipatov/rut5_nb_descr')
72
 
73
 
 
17
 
18
 
19
  def predict_gpt(text, model, tokenizer, temperature=1.0):
20
+ text = text.replace('\n','')+'Описание:'
21
+ input_ids = tokenizer.encode(text, return_tensors="pt")
22
 
23
  model.eval()
24
  with torch.no_grad():
 
68
  return result
69
 
70
 
71
+ gpt_model, gpt_tokenizer = get_model_gpt('mipatov/rugpt3_nb_descr', 'mipatov/rugpt3_nb_descr')
72
  t5_model, t5_tokenizer = get_model_t5('mipatov/rut5_nb_descr', 'mipatov/rut5_nb_descr')
73
 
74