mipatov commited on
Commit
952c01c
·
1 Parent(s): c37a342

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -25,11 +25,11 @@ def predict_gpt(text, model, tokenizer, temperature=1.0):
25
  with torch.no_grad():
26
  out = model.generate(input_ids,
27
  do_sample=True,
28
- num_beams=2,
29
  temperature= temperature,
30
- top_p=0.75,
31
  max_length=512,
32
- length_penalty = 1.5,
33
  eos_token_id = tokenizer.eos_token_id,
34
  pad_token_id = tokenizer.pad_token_id,
35
  num_return_sequences = 1,
@@ -45,7 +45,7 @@ def predict_t5(text, model, tokenizer, temperature=1.2):
45
  with torch.no_grad():
46
  out = model.generate(input_ids,
47
  do_sample=True,
48
- num_beams=2,
49
  temperature=temperature,
50
  top_p=0.35,
51
  max_length=512,
 
25
  with torch.no_grad():
26
  out = model.generate(input_ids,
27
  do_sample=True,
28
+ num_beams=4,
29
  temperature= temperature,
30
+ top_p=0.65,
31
  max_length=512,
32
+ length_penalty = 2.5,
33
  eos_token_id = tokenizer.eos_token_id,
34
  pad_token_id = tokenizer.pad_token_id,
35
  num_return_sequences = 1,
 
45
  with torch.no_grad():
46
  out = model.generate(input_ids,
47
  do_sample=True,
48
+ num_beams=4,
49
  temperature=temperature,
50
  top_p=0.35,
51
  max_length=512,