jgauthier commited on
Commit
7eae31f
·
1 Parent(s): 8a3618a

don't mix up condition indexing in suites where items have different order of content by condition (e.g. number_prep in syntaxgym2020)

Browse files
Files changed (1) hide show
  1. syntaxgym.py +5 -4
syntaxgym.py CHANGED
@@ -265,12 +265,13 @@ class SyntaxGym(evaluate.EvaluationModule):
265
  }
266
  return results
267
 
268
- def get_region_edges(self, item, condition_idx):
269
  """
270
  Get left edge of each region as a character index.
271
  """
272
  # NB this is coupled with `condition_to_string` logic of course
273
 
 
274
  regions = item["conditions"]["regions"][condition_idx]
275
 
276
  idx = 0
@@ -298,8 +299,8 @@ class SyntaxGym(evaluate.EvaluationModule):
298
 
299
  max_long = torch.iinfo(torch.int64).max
300
 
301
- for i_cond, i_offsets in enumerate(offset_mapping):
302
- region_edges = self.get_region_edges(item, i_cond)
303
 
304
  t_cursor, r_cursor = 0, 0
305
  while t_cursor < i_offsets.shape[0]:
@@ -321,7 +322,7 @@ class SyntaxGym(evaluate.EvaluationModule):
321
  r_cursor += 1
322
  continue
323
 
324
- region2tokens[condition_order[i_cond]][r_cursor + 1].append(t_cursor)
325
  t_cursor += 1
326
 
327
  return region2tokens
 
265
  }
266
  return results
267
 
268
+ def get_region_edges(self, item, condition_name):
269
  """
270
  Get left edge of each region as a character index.
271
  """
272
  # NB this is coupled with `condition_to_string` logic of course
273
 
274
+ condition_idx = item["conditions"]["condition_name"].index(condition_name)
275
  regions = item["conditions"]["regions"][condition_idx]
276
 
277
  idx = 0
 
299
 
300
  max_long = torch.iinfo(torch.int64).max
301
 
302
+ for cond_name, i_offsets in zip(condition_order, offset_mapping):
303
+ region_edges = self.get_region_edges(item, cond_name)
304
 
305
  t_cursor, r_cursor = 0, 0
306
  while t_cursor < i_offsets.shape[0]:
 
322
  r_cursor += 1
323
  continue
324
 
325
+ region2tokens[cond_name][r_cursor + 1].append(t_cursor)
326
  t_cursor += 1
327
 
328
  return region2tokens