gagan3012 commited on
Commit
973f235
·
1 Parent(s): c07e561
Files changed (1) hide show
  1. src/models/model.py +6 -5
src/models/model.py CHANGED
@@ -467,7 +467,7 @@ class Summarization:
467
  )
468
  for g in generated_ids
469
  ]
470
- return preds
471
 
472
  def evaluate(
473
  self,
@@ -475,11 +475,12 @@ class Summarization:
475
  metrics: str = "rouge"
476
  ):
477
  metric = load_metric(metrics)
478
- input_text = test_df['input_text']
479
- references = test_df['output_text']
480
 
481
- for x in input_text:
482
- results = metric.add_batch(predictions=self.predict(x), references=references)
 
483
 
484
  output = {
485
  'Rouge 1': {
 
467
  )
468
  for g in generated_ids
469
  ]
470
+ return preds[0]
471
 
472
  def evaluate(
473
  self,
 
475
  metrics: str = "rouge"
476
  ):
477
  metric = load_metric(metrics)
478
+ input_text = test_df['input_text'][:5]
479
+ references = test_df['output_text'][:5]
480
 
481
+ predictions = [self.predict(x) for x in input_text]
482
+
483
+ results = metric.add_batch(predictions=predictions, references=references)
484
 
485
  output = {
486
  'Rouge 1': {