gagan3012 commited on
Commit
c07e561
·
1 Parent(s): fc890f9
Files changed (1) hide show
  1. src/models/model.py +3 -3
src/models/model.py CHANGED
@@ -467,7 +467,7 @@ class Summarization:
467
  )
468
  for g in generated_ids
469
  ]
470
- return preds[0]
471
 
472
  def evaluate(
473
  self,
@@ -477,9 +477,9 @@ class Summarization:
477
  metric = load_metric(metrics)
478
  input_text = test_df['input_text']
479
  references = test_df['output_text']
480
- predictions = [self.predict(x) for x in input_text]
481
 
482
- results = metric.compute(predictions=predictions, references=references)
 
483
 
484
  output = {
485
  'Rouge 1': {
 
467
  )
468
  for g in generated_ids
469
  ]
470
+ return preds
471
 
472
  def evaluate(
473
  self,
 
477
  metric = load_metric(metrics)
478
  input_text = test_df['input_text']
479
  references = test_df['output_text']
 
480
 
481
+ for x in input_text:
482
+ results = metric.add_batch(predictions=self.predict(x), references=references)
483
 
484
  output = {
485
  'Rouge 1': {