zyznull commited on
Commit
581c2b2
1 Parent(s): 8fdde9c

Update scripts/eval_mteb.py

Browse files
Files changed (1) hide show
  1. scripts/eval_mteb.py +21 -7
scripts/eval_mteb.py CHANGED
@@ -119,7 +119,6 @@ CMTEB_TASK_LIST = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'Onl
119
  'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
120
  'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
121
 
122
-
123
  MTEB_PL = [
124
  "CBD","PolEmo2.0-IN","PolEmo2.0-OUT","AllegroReviews","PAC","MassiveIntentClassification","MassiveScenarioClassification",
125
  "SICK-E-PL","PPC","CDSC-E","PSC","8TagsClustering","SICK-R-PL","CDSC-R","STS22",
@@ -406,9 +405,9 @@ class Wrapper:
406
  self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
407
  self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
408
  self.instruction = instruction
409
- self.default_query = default_query
 
410
  self.force_default = force_default
411
-
412
  if self.tokenizer.padding_side != 'right':
413
  logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
414
  self.tokenizer.padding_side = 'right'
@@ -675,13 +674,15 @@ class Wrapper:
675
  def main(args):
676
  tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
677
  encoder = Encoder(args.model, args.pooling)
 
678
  model = Wrapper(
679
  tokenizer, encoder,
680
  batch_size=args.batch_size,
681
  max_seq_len=args.max_seq_len,
682
- normalize_embeddings=args.norm
 
683
  )
684
-
685
  if args.task == 'mteb':
686
  task_names = MTEB_TASK_LIST
687
  lang = ['en']
@@ -709,8 +710,21 @@ def main(args):
709
  eval_splits = task_cls.description['eval_splits']
710
  else:
711
  eval_splits = ["test"]
712
-
 
 
 
 
 
 
 
 
 
713
  evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
 
 
 
 
714
  print('\n')
715
 
716
 
@@ -729,4 +743,4 @@ if __name__ == "__main__":
729
  )
730
  _PARSER.add_argument("--norm", action="store_true")
731
  _ARGS = _PARSER.parse_args()
732
- main(_ARGS)
 
119
  'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
120
  'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC', 'STS22']
121
 
 
122
  MTEB_PL = [
123
  "CBD","PolEmo2.0-IN","PolEmo2.0-OUT","AllegroReviews","PAC","MassiveIntentClassification","MassiveScenarioClassification",
124
  "SICK-E-PL","PPC","CDSC-E","PSC","8TagsClustering","SICK-R-PL","CDSC-R","STS22",
 
405
  self._target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
406
  self.eod_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
407
  self.instruction = instruction
408
+ self.default_query = default_query
409
+ self.sep = sep
410
  self.force_default = force_default
 
411
  if self.tokenizer.padding_side != 'right':
412
  logger.warning(f"Change tokenizer.padding_side from {self.tokenizer.padding_side} to right")
413
  self.tokenizer.padding_side = 'right'
 
674
  def main(args):
675
  tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
676
  encoder = Encoder(args.model, args.pooling)
677
+ default_query = args.default_type == 'query'
678
  model = Wrapper(
679
  tokenizer, encoder,
680
  batch_size=args.batch_size,
681
  max_seq_len=args.max_seq_len,
682
+ normalize_embeddings=args.norm,
683
+ default_query=default_query
684
  )
685
+ sym_retrievals = ['QuoraRetrieval', 'ArguAna', 'CQADupstack']
686
  if args.task == 'mteb':
687
  task_names = MTEB_TASK_LIST
688
  lang = ['en']
 
710
  eval_splits = task_cls.description['eval_splits']
711
  else:
712
  eval_splits = ["test"]
713
+ sym = False
714
+ for name in sym_retrievals:
715
+ if task.startswith(name):
716
+ sym = True
717
+ break
718
+ else:
719
+ sym = False
720
+ if sym:
721
+ logger.info(f"Switch to symmetric mode for {task}, all as {'query' if default_query else 'doc'}.")
722
+ model.force_default = True
723
  evaluation.run(model, output_folder=args.output_dir, eval_splits=eval_splits)
724
+
725
+ if sym:
726
+ logger.info(f"Switch back.")
727
+ model.force_default = force_default_ori
728
  print('\n')
729
 
730
 
 
743
  )
744
  _PARSER.add_argument("--norm", action="store_true")
745
  _ARGS = _PARSER.parse_args()
746
+ main(_ARGS)