da03 commited on
Commit
0d4c25e
·
1 Parent(s): 73a674e
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -39,15 +39,18 @@ def predict_product(num1, num2):
39
 
40
  generated_ids = inputs['input_ids']
41
  past_key_values = None
42
- for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
43
- outputs = model.generate(
44
- input_ids=generated_ids,
45
- max_new_tokens=1,
46
- do_sample=False,
47
- past_key_values=past_key_values,
48
- return_dict_in_generate=True,
49
- use_cache=True
50
- )
 
 
 
51
  generated_ids = outputs.sequences
52
  next_token_id = generated_ids[0, -1]
53
  print (next_token_id)
 
39
 
40
  generated_ids = inputs['input_ids']
41
  past_key_values = None
42
+ for step in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
43
+ generation_kwargs = {
44
+ 'input_ids': generated_ids,
45
+ 'max_new_tokens': 1,
46
+ 'do_sample': False,
47
+ 'past_key_values': past_key_values,
48
+ 'return_dict_in_generate': True,
49
+ 'use_cache': True
50
+ }
51
+ if step == 0:
52
+ del generation_kwargs['past_key_values']
53
+ outputs = model.generate(**generation_kwargs)
54
  generated_ids = outputs.sequences
55
  next_token_id = generated_ids[0, -1]
56
  print (next_token_id)