LukasHug commited on
Commit
88c2435
·
1 Parent(s): 999258b

update reward, prevent reward hacking

Browse files
VerifiableRewardsForScalableLogicalReasoning.py CHANGED
@@ -109,7 +109,7 @@ def validate_rule_no_hardcoded_cars(prediction):
109
  matches = re.findall(hardcoded_pattern, prediction)
110
 
111
  if matches:
112
- return False, f"Rule contains ground cars: {matches[0]}"
113
 
114
  return True, "Rule is valid"
115
 
@@ -131,12 +131,8 @@ def _evaluate_with_prolog(prediction, validation_program, eval_config, timeout=5
131
  positive_pred = eval_config.get("positive_predicate", "eastbound")
132
  negative_pred = eval_config.get("negative_predicate", "westbound")
133
 
134
-
135
- validation_program = anonymize_entities(validation_program)
136
-
137
-
138
  # extract predicate from rule_to_evaluate
139
- rule_to_evaluate = extract_ilp_from_text_v2(prediction)
140
  if positive_pred not in rule_to_evaluate:
141
  logger.warning(f"Rule '{rule_to_evaluate}' does not contain positive predicate '{positive_pred}'")
142
  return {
@@ -245,15 +241,16 @@ def extract_ilp_from_text(text):
245
  return p_code
246
 
247
 
248
- def extract_ilp_from_text_v2(text, target_predicates=None):
 
249
  # Pre-process: collapse code blocks to single lines
250
  text = re.sub(r'\n\s*', ' ', text) # crude: flatten all to one line
251
- # Optionally restrict to only some predicates
252
- preds = '|'.join([re.escape(p) for p in (target_predicates or [])])
253
- head_pat = rf"(?:{preds})" if preds else r"[a-zA-Z_][a-zA-Z0-9_]*"
254
  # Rule pattern, across newlines
255
- rule_pattern = re.compile(rf'({head_pat}\([^()]*\)\s*:-.*?\.)')
256
- rules = set(rule_pattern.findall(text))
 
 
 
257
  # Remove rules that are also captured as facts
258
  p_code = ''
259
  for rule in rules:
@@ -262,7 +259,6 @@ def extract_ilp_from_text_v2(text, target_predicates=None):
262
  if not statement.endswith('.'):
263
  statement += '.'
264
  p_code += statement + '\n'
265
- print(p_code)
266
  return p_code.strip() # Ensure no trailing whitespace
267
 
268
 
 
109
  matches = re.findall(hardcoded_pattern, prediction)
110
 
111
  if matches:
112
+ return False, f"Cars must be variables: {matches[0]}"
113
 
114
  return True, "Rule is valid"
115
 
 
131
  positive_pred = eval_config.get("positive_predicate", "eastbound")
132
  negative_pred = eval_config.get("negative_predicate", "westbound")
133
 
 
 
 
 
134
  # extract predicate from rule_to_evaluate
135
+ rule_to_evaluate = extract_ilp_from_text_v2(prediction, positive_pred)
136
  if positive_pred not in rule_to_evaluate:
137
  logger.warning(f"Rule '{rule_to_evaluate}' does not contain positive predicate '{positive_pred}'")
138
  return {
 
241
  return p_code
242
 
243
 
244
+ def extract_ilp_from_text_v2(text, target_predicate=None):
245
+ text = re.sub(r'%.*?(?=\n|$)', '', text) # remove comments
246
  # Pre-process: collapse code blocks to single lines
247
  text = re.sub(r'\n\s*', ' ', text) # crude: flatten all to one line
 
 
 
248
  # Rule pattern, across newlines
249
+ rule_pattern = re.compile(rf'({target_predicate}\([^()]*\)\s*:-.*?\.)')
250
+ rules = list(rule_pattern.findall(text))
251
+ if len(rules) > 1:
252
+ logger.warning(f"Found multiple rules in text: {rules}. Using only the first one.")
253
+ rules = rules[:1] # Use only the first match
254
  # Remove rules that are also captured as facts
255
  p_code = ''
256
  for rule in rules:
 
259
  if not statement.endswith('.'):
260
  statement += '.'
261
  p_code += statement + '\n'
 
262
  return p_code.strip() # Ensure no trailing whitespace
263
 
264