da03 commited on
Commit
685026a
·
1 Parent(s): 6cc23f5
Files changed (1) hide show
  1. app.py +17 -34
app.py CHANGED
@@ -23,48 +23,31 @@ def predict_product(num1, num2):
23
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
24
 
25
  generated_ids = inputs['input_ids']
26
- prediction = ""
27
- correct_product = ""
28
- valid_input = True
29
 
30
  try:
31
  num1_int = int(num1)
32
  num2_int = int(num2)
33
  correct_product = str(num1_int * num2_int)
34
  except ValueError:
35
- valid_input = False
36
 
37
- for _ in range(40): # Adjust the range to control the maximum number of generated tokens
38
- outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False)
39
- generated_ids = torch.cat((generated_ids, outputs[:, -1:]), dim=-1)
40
- output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
41
- prediction = postprocess(output_text)
42
-
43
- # Manually create the diff for HighlightedText
44
- diff = []
45
- for i in range(len(prediction)):
46
- if i < len(correct_product) and prediction[i] == correct_product[i]:
47
- diff.append((prediction[i], None)) # No highlight for correct digits
48
- else:
49
- diff.append((prediction[i], "+")) # Highlight incorrect digits in red
50
 
51
- yield diff, ""
52
 
53
- if valid_input:
54
- is_correct = prediction == correct_product
55
- result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
56
- else:
57
- result_message = "Invalid input. Could not evaluate correctness."
58
-
59
- # Final diff for the complete prediction
60
- final_diff = []
61
- for i in range(len(prediction)):
62
- if i < len(correct_product) and prediction[i] == correct_product[i]:
63
- final_diff.append((prediction[i], None)) # No highlight for correct digits
64
- else:
65
- final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
66
-
67
- yield final_diff, result_message
68
 
69
  demo = gr.Interface(
70
  fn=predict_product,
@@ -73,7 +56,7 @@ demo = gr.Interface(
73
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
74
  ],
75
  outputs=[
76
- gr.HighlightedText(label='Predicted Product with Matching Digits Highlighted', combine_adjacent=True, show_legend=True, color_map={"+": "red"}),
77
  gr.HTML(label='Result Message')
78
  ],
79
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
 
23
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
24
 
25
  generated_ids = inputs['input_ids']
26
+ outputs = model.generate(generated_ids, max_new_tokens=40, do_sample=False)
27
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+ prediction = postprocess(full_output[len(input_text):])
29
 
30
  try:
31
  num1_int = int(num1)
32
  num2_int = int(num2)
33
  correct_product = str(num1_int * num2_int)
34
  except ValueError:
35
+ return [], "Invalid input. Could not evaluate correctness."
36
 
37
+ # Create the diff for HighlightedText
38
+ diff = []
39
+ max_len = max(len(prediction), len(correct_product))
40
+ for i in range(max_len):
41
+ if i < len(prediction) and i < len(correct_product) and prediction[i] == correct_product[i]:
42
+ diff.append((prediction[i], None)) # No highlight for correct digits
43
+ elif i < len(prediction) and (i >= len(correct_product) or prediction[i] != correct_product[i]):
44
+ diff.append((prediction[i], "+")) # Highlight incorrect digits in red
45
+ if i < len(correct_product) and (i >= len(prediction) or prediction[i] != correct_product[i]):
46
+ diff.append((correct_product[i], "-")) # Highlight missing/incorrect digits in green
 
 
 
47
 
48
+ result_message = "Correct!" if prediction == correct_product else f"Incorrect! The correct product is {correct_product}."
49
 
50
+ return diff, result_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  demo = gr.Interface(
53
  fn=predict_product,
 
56
  gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
57
  ],
58
  outputs=[
59
+ gr.HighlightedText(label='Predicted Product with Matching and Unmatching Digits Highlighted', combine_adjacent=True, show_legend=True, color_map={"-": "green", "+": "red"}),
60
  gr.HTML(label='Result Message')
61
  ],
62
  title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',