da03 commited on
Commit
fe5b36f
·
1 Parent(s): 3738ecc
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -90,6 +90,17 @@ def predict_product(num1, num2):
90
 
91
  predicted_annotations = []
92
  is_correct_sofar = True
 
 
 
 
 
 
 
 
 
 
 
93
  for i in range(len(predicted_digits_reversed)):
94
  predicted_digit = predicted_digits_reversed[i]
95
  if i >= len(ground_truth_digits_reversed):
 
90
 
91
  predicted_annotations = []
92
  is_correct_sofar = True
93
+ if model_name == 'implicit':
94
+ num_equal_signs = sum([1 for token in predicted_digits_reversed if token == '='])
95
+ if num_equal_signs < 2:
96
+ predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
97
+ predicted_digits_reversed = []
98
+ else:
99
+ first_equal_sign_position = predicted_digits_reversed.index('=')
100
+ second_equal_sign_position = predicted_digits_reversed[first_equal_sign_position+1:].index('=') + first_equal_sign_position+1
101
+ predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:second_equal_sign_position+1]]
102
+ predicted_digits_reversed = predicted_digits_reversed[second_equal_sign_position+1:]
103
+
104
  for i in range(len(predicted_digits_reversed)):
105
  predicted_digit = predicted_digits_reversed[i]
106
  if i >= len(ground_truth_digits_reversed):