da03 commited on
Commit
ae9daf1
·
1 Parent(s): bf65d9e
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -49,8 +49,10 @@ def predict_product(num1, num2):
49
 
50
  next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
51
  generated_ids = torch.cat((generated_ids, next_token_id.view(1,-1)), dim=-1)
 
52
 
53
  if next_token_id.item() == eos_token_id:
 
54
  break
55
  past_key_values = outputs.past_key_values
56
 
@@ -113,8 +115,8 @@ demo = gr.Interface(
113
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
114
  ],
115
  outputs=[
116
- gr.Textbox(label='Ground Truth Product'),
117
- gr.HighlightedText(label='Predicted Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}),
118
  gr.HTML(label='Result Message')
119
  ],
120
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
 
49
 
50
  next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
51
  generated_ids = torch.cat((generated_ids, next_token_id.view(1,-1)), dim=-1)
52
+ print (next_token_id)
53
 
54
  if next_token_id.item() == eos_token_id:
55
+ print ('berak')
56
  break
57
  past_key_values = outputs.past_key_values
58
 
 
115
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
116
  ],
117
  outputs=[
118
+ gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}),
119
+ gr.HighlightedText(label='GPT2 Predicted Product', combine_adjacent=False, show_legend=False, color_map={"-": "green", "+": "red"}),
120
  gr.HTML(label='Result Message')
121
  ],
122
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',