Josh98 commited on
Commit
b9a7a2e
1 Parent(s): 5a101a3

handle cases where there are mulitple correct refs and use the best score

Browse files
Files changed (1) hide show
  1. nl2bash_m.py +25 -23
nl2bash_m.py CHANGED
@@ -109,37 +109,39 @@ class nl2bash_m(evaluate.Metric):
109
 
110
 
111
  final_score = 0
 
112
 
113
- for pred, ref in zip(predictions, references):
114
- if len(pred) == 0 and len(ref[0]) == 0:
115
  score = 1
116
- elif len(pred) == 0 or len(ref[0]) == 0:
117
  score = 0
118
- else:
119
- pred_words, ref_words = pred.split(), ref[0].split()
 
 
120
 
121
-
122
- # Get the cmd of predicted and ref
123
- cmd_corr = 1 if pred_words.pop(0)==ref_words.pop(0) else 0
124
 
125
- # Get the option of predicted and ref
126
- pred_option = [ x for x in pred_words if x[0] == '-']
127
- ref_option = [ x for x in ref_words if x[0] == '-']
128
-
129
- # Get the arguments of predicted and ref
130
- pred_args = [ x for x in pred_words if x[0] != '-']
131
- ref_args = [ x for x in ref_words if x[0] != '-']
132
 
133
- # Calculate scores
134
- cmd_score = cmd_weight * cmd_corr
135
- opt_score = opt_weight * self.get_score(pred_option, ref_option)
136
- arg_score = arg_weight * self.get_score(pred_args, ref_args)
137
 
138
- score = cmd_score + opt_score + arg_score
139
- final_score += score
 
 
140
 
141
  final_score = final_score/len(predictions)
142
- print("f_s: ", final_score)
143
-
144
 
145
  return {"nl2bash_m": (final_score)}
 
109
 
110
 
111
  final_score = 0
112
+ for pred, refs in zip(predictions, references):
113
 
114
+ if len(pred) == 0 and min([len(ref) for ref in refs]) == 0:
 
115
  score = 1
116
+ elif len(pred) == 0 or min([len(ref) for ref in refs]) == 0:
117
  score = 0
118
+ else:
119
+ best_score = 0
120
+ for ref in refs:
121
+ pred_words, ref_words = pred.split(), ref.split()
122
 
123
+
124
+ # Get the cmd of predicted and ref
125
+ cmd_corr = 1 if pred_words.pop(0)==ref_words.pop(0) else 0
126
 
127
+ # Get the option of predicted and ref
128
+ pred_option = [ x for x in pred_words if x[0] == '-']
129
+ ref_option = [ x for x in ref_words if x[0] == '-']
130
+
131
+ # Get the arguments of predicted and ref
132
+ pred_args = [ x for x in pred_words if x[0] != '-']
133
+ ref_args = [ x for x in ref_words if x[0] != '-']
134
 
135
+ # Calculate scores
136
+ cmd_score = cmd_weight * cmd_corr
137
+ opt_score = opt_weight * get_score(pred_option, ref_option)
138
+ arg_score = arg_weight * get_score(pred_args, ref_args)
139
 
140
+ score = cmd_score + opt_score + arg_score
141
+ best_score = max(best_score, score)
142
+
143
+ final_score += best_score
144
 
145
  final_score = final_score/len(predictions)
 
 
146
 
147
  return {"nl2bash_m": (final_score)}