Josh98 commited on
Commit
40d109e
·
1 Parent(s): b16de67

version 1 metric

Browse files
Files changed (1) hide show
  1. nl2bash_m.py +45 -17
nl2bash_m.py CHANGED
@@ -94,38 +94,66 @@ class nl2bash_m(evaluate.Metric):
94
  reference_urls=[],
95
  )
96
 
 
 
 
 
 
 
 
 
 
97
  def _compute(
98
  self,
99
  predictions,
100
- references,
101
- regexes_to_ignore=None,
 
 
102
  ignore_case=False,
103
- ignore_punctuation=False,
104
  ignore_numbers=False,
105
  ):
106
 
107
- if regexes_to_ignore is not None:
108
- for s in regexes_to_ignore:
109
- predictions = np.array([re.sub(s, "", x) for x in predictions])
110
- references = np.array([re.sub(s, "", x) for x in references])
111
- else:
112
- predictions = np.asarray(predictions)
113
- references = np.asarray(references)
114
 
115
  if ignore_case:
116
  predictions = np.char.lower(predictions)
117
  references = np.char.lower(references)
118
 
119
- if ignore_punctuation:
120
- repl_table = string.punctuation.maketrans("", "", string.punctuation)
121
- predictions = np.char.translate(predictions, table=repl_table)
122
- references = np.char.translate(references, table=repl_table)
123
-
124
  if ignore_numbers:
125
  repl_table = string.digits.maketrans("", "", string.digits)
126
  predictions = np.char.translate(predictions, table=repl_table)
127
  references = np.char.translate(references, table=repl_table)
128
 
129
- score_list = predictions == references
130
 
131
- return {"exact_match": np.mean(score_list)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  reference_urls=[],
95
  )
96
 
97
+ def get_score(self, pred, ref):
98
+ if not pred and not ref: return 1
99
+ cor = 0
100
+ for i in range(min(len(pred), len(ref))):
101
+ if (pred[i] == ref[i]):
102
+ cor += 1
103
+
104
+ return cor/max(len(pred), len(ref))
105
+
106
  def _compute(
107
  self,
108
  predictions,
109
+ references,
110
+ cmd_weight = 0.65,
111
+ opt_weight = 0.25,
112
+ arg_weight = 0.15,
113
  ignore_case=False,
 
114
  ignore_numbers=False,
115
  ):
116
 
117
+ predictions = np.asarray(predictions)
118
+ references = np.asarray(references)
 
 
 
 
 
119
 
120
  if ignore_case:
121
  predictions = np.char.lower(predictions)
122
  references = np.char.lower(references)
123
 
 
 
 
 
 
124
  if ignore_numbers:
125
  repl_table = string.digits.maketrans("", "", string.digits)
126
  predictions = np.char.translate(predictions, table=repl_table)
127
  references = np.char.translate(references, table=repl_table)
128
 
 
129
 
130
+ final_score = 0
131
+
132
+ for pred, ref in zip(predictions, references):
133
+ print(pred, ref)
134
+ pred_words, ref_words = pred[0].split(), ref[0].split()
135
+ # Get the cmd of predicted and ref
136
+ cmd_corr = 1 if pred_words.pop(0)==ref_words.pop(0) else 0
137
+
138
+ # Get the option of predicted and ref
139
+ pred_option = [ x for x in pred_words if x[0] == '-']
140
+ ref_option = [ x for x in ref_words if x[0] == '-']
141
+
142
+ # Get the arguments of predicted and ref
143
+ pred_args = [ x for x in pred_words if x[0] != '-']
144
+ ref_args = [ x for x in ref_words if x[0] != '-']
145
+
146
+ # Calculate scores
147
+ cmd_score = cmd_weight * cmd_corr
148
+ opt_score = opt_weight * self.get_score(pred_option, ref_option)
149
+ arg_score = arg_weight * self.get_score(pred_args, ref_args)
150
+
151
+ score = cmd_score + opt_score + arg_score
152
+ final_score += score
153
+ print(score)
154
+
155
+ final_score = final_score/len(self.preds)
156
+ print("f_s: ", final_score)
157
+
158
+
159
+ return {"nl2bash_m": (final_score)}