Nol00 commited on
Commit
cc60b4d
·
verified ·
1 Parent(s): 9476267

Update crs_arena/battle_manager.py

Browse files
Files changed (1) hide show
  1. crs_arena/battle_manager.py +32 -5
crs_arena/battle_manager.py CHANGED
@@ -4,7 +4,9 @@ Contains helper functions to select fighters for a battle and generate
4
  unique user ids.
5
  """
6
 
 
7
  import logging
 
8
  import uuid
9
  from collections import defaultdict
10
  from typing import Optional, Tuple
@@ -36,8 +38,8 @@ CONVERSATION_COUNTS.update(
36
  "kbrd_opendialkg": 7,
37
  "unicrs_redial": 4,
38
  "unicrs_opendialkg": 3,
39
- "barcor_redial": 2,
40
- "barcor_opendialkg": 2,
41
  "crbcrs_redial": 4,
42
  }
43
  )
@@ -50,12 +52,37 @@ def get_crs_fighters() -> Tuple[CRSFighter, CRSFighter]:
50
  The selection is based on the number of conversations collected per model.
51
  The ones with the least conversations are selected.
52
 
 
 
 
53
  Returns:
54
  CRS models to battle.
55
  """
56
- pair = sorted(CONVERSATION_COUNTS.items(), key=lambda x: x[1])[:2]
57
- fighter1 = CRSFighter(1, pair[0][0], CRS_MODELS[pair[0][0]])
58
- fighter2 = CRSFighter(2, pair[1][0], CRS_MODELS[pair[1][0]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  return fighter1, fighter2
60
 
61
 
 
4
  unique user ids.
5
  """
6
 
7
+ import itertools
8
  import logging
9
+ import random
10
  import uuid
11
  from collections import defaultdict
12
  from typing import Optional, Tuple
 
38
  "kbrd_opendialkg": 7,
39
  "unicrs_redial": 4,
40
  "unicrs_opendialkg": 3,
41
+ "barcor_redial": 4,
42
+ "barcor_opendialkg": 6,
43
  "crbcrs_redial": 4,
44
  }
45
  )
 
52
  The selection is based on the number of conversations collected per model.
53
  The ones with the least conversations are selected.
54
 
55
+ Raises:
56
+ Exception: If there is an error selecting the fighters.
57
+
58
  Returns:
59
  CRS models to battle.
60
  """
61
+ sorted_count = sorted(CONVERSATION_COUNTS.items(), key=lambda x: x[1])
62
+ # Group models by conversation count.
63
+ groups = [
64
+ list(group)
65
+ for _, group in itertools.groupby(sorted_count, key=lambda x: x[1])
66
+ ]
67
+
68
+ model_1, model_2 = None, None
69
+
70
+ try:
71
+ if len(groups[0]) >= 2:
72
+ model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
73
+ model_2 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
74
+ else:
75
+ model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
76
+ model_2 = groups[1].pop(random.randint(0, len(groups[1]) - 1))[0]
77
+ except Exception as e:
78
+ logging.error(f"Error selecting CRS fighters: {e}")
79
+ if model_1 is None:
80
+ model_1 = sorted_count[0][0]
81
+ if model_2 is None:
82
+ model_2 = sorted_count[1][0]
83
+
84
+ fighter1 = CRSFighter(1, model_1, CRS_MODELS[model_1])
85
+ fighter2 = CRSFighter(2, model_2, CRS_MODELS[model_2])
86
  return fighter1, fighter2
87
 
88