yzha commited on
Commit
2c0bdcf
1 Parent(s): 9af909b
Files changed (1) hide show
  1. ctc_eval.py +19 -5
ctc_eval.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
 
16
  import evaluate
17
  import datasets
18
 
@@ -89,11 +90,24 @@ class CTC_Eval(evaluate.EvaluationModule):
89
  def install(package):
90
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
91
 
92
- install('ctc-score')
93
- from ctc_score import StyleTransferScorer, SummarizationScorer, DialogScorer
94
- self.scorer = SummarizationScorer(align='D-cnndm')
 
 
 
 
 
 
95
 
96
- self.compute(references=['hello world'], predictions=['hi world'])
 
 
 
 
 
 
 
97
 
98
  def _compute(self, predictions, references):
99
  """Returns the scores"""
@@ -104,5 +118,5 @@ class CTC_Eval(evaluate.EvaluationModule):
104
  print(references)
105
  ctc_score = self.scorer.score(doc=references[0], refs=[], hypo=predictions[0], aspect='consistency')
106
  return {
107
- "ctc_score": [ctc_score]
108
  }
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ from typing import final
17
  import evaluate
18
  import datasets
19
 
 
90
  def install(package):
91
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
92
 
93
+
94
+ try:
95
+ from ctc_score import StyleTransferScorer, SummarizationScorer, DialogScorer
96
+ except:
97
+ print('ctc package is not installed. installing...')
98
+ install('ctc-score')
99
+
100
+ if self.config_name == 'default':
101
+ self.config_name = 'D-cnndm,consistency'
102
 
103
+ model_name, self.aspect = self.config_name.split(',')
104
+ if self.aspect in ['consistency, relevance']:
105
+ self.scorer = SummarizationScorer(align=model_name)
106
+ elif self.aspect in ['preservation']:
107
+ self.scorer = StyleTransferScorer(align=model_name)
108
+ elif self.aspect in ['engagingness', 'groundedness']:
109
+ self.scorer = DialogScorer(align=model_name)
110
+
111
 
112
  def _compute(self, predictions, references):
113
  """Returns the scores"""
 
118
  print(references)
119
  ctc_score = self.scorer.score(doc=references[0], refs=[], hypo=predictions[0], aspect='consistency')
120
  return {
121
+ "ctc_score": ctc_score
122
  }