mipatov commited on
Commit
7a6a8cc
·
1 Parent(s): fb6ce8e
Files changed (1) hide show
  1. app.py +2 -19
app.py CHANGED
@@ -15,25 +15,8 @@ def get_model(model_name, model_path):
15
  return model, tokenizer
16
 
17
 
18
- def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
19
- # text += '\n'
20
- input_ids = tokenizer.encode(text, return_tensors="pt")
21
- length_of_prompt = len(input_ids[0])
22
- with torch.no_grad():
23
- out = model.generate(input_ids,
24
- do_sample=True,
25
- num_beams=n_beams,
26
- temperature=temperature,
27
- top_p=top_p,
28
- max_length=length_of_prompt + length_of_generated,
29
- eos_token_id=tokenizer.eos_token_id
30
- )
31
-
32
- generated = list(map(tokenizer.decode, out))[0]
33
- return generated.replace('\n[EOS]\n', '')
34
-
35
  def predict_gpt(text, model, tokenizer,):
36
- input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
37
  with torch.no_grad():
38
  out = model.generate(input_ids,
39
  do_sample=True,
@@ -53,7 +36,7 @@ def predict_gpt(text, model, tokenizer,):
53
  return generated_text
54
 
55
  def predict_t5(text, model, tokenizer,):
56
- input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
57
  with torch.no_grad():
58
  out = model.generate(input_ids,
59
  do_sample=True,
 
15
  return model, tokenizer
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def predict_gpt(text, model, tokenizer,):
19
+ input_ids = tokenizer.encode(text, return_tensors="pt")
20
  with torch.no_grad():
21
  out = model.generate(input_ids,
22
  do_sample=True,
 
36
  return generated_text
37
 
38
  def predict_t5(text, model, tokenizer,):
39
+ input_ids = tokenizer.encode(text, return_tensors="pt")
40
  with torch.no_grad():
41
  out = model.generate(input_ids,
42
  do_sample=True,