Update crs_arena/battle_manager.py
Browse files- crs_arena/battle_manager.py +32 -5
crs_arena/battle_manager.py
CHANGED
@@ -4,7 +4,9 @@ Contains helper functions to select fighters for a battle and generate
|
|
4 |
unique user ids.
|
5 |
"""
|
6 |
|
|
|
7 |
import logging
|
|
|
8 |
import uuid
|
9 |
from collections import defaultdict
|
10 |
from typing import Optional, Tuple
|
@@ -36,8 +38,8 @@ CONVERSATION_COUNTS.update(
|
|
36 |
"kbrd_opendialkg": 7,
|
37 |
"unicrs_redial": 4,
|
38 |
"unicrs_opendialkg": 3,
|
39 |
-
"barcor_redial":
|
40 |
-
"barcor_opendialkg":
|
41 |
"crbcrs_redial": 4,
|
42 |
}
|
43 |
)
|
@@ -50,12 +52,37 @@ def get_crs_fighters() -> Tuple[CRSFighter, CRSFighter]:
|
|
50 |
The selection is based on the number of conversations collected per model.
|
51 |
The ones with the least conversations are selected.
|
52 |
|
|
|
|
|
|
|
53 |
Returns:
|
54 |
CRS models to battle.
|
55 |
"""
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
return fighter1, fighter2
|
60 |
|
61 |
|
|
|
4 |
unique user ids.
|
5 |
"""
|
6 |
|
7 |
+
import itertools
|
8 |
import logging
|
9 |
+
import random
|
10 |
import uuid
|
11 |
from collections import defaultdict
|
12 |
from typing import Optional, Tuple
|
|
|
38 |
"kbrd_opendialkg": 7,
|
39 |
"unicrs_redial": 4,
|
40 |
"unicrs_opendialkg": 3,
|
41 |
+
"barcor_redial": 4,
|
42 |
+
"barcor_opendialkg": 6,
|
43 |
"crbcrs_redial": 4,
|
44 |
}
|
45 |
)
|
|
|
52 |
The selection is based on the number of conversations collected per model.
|
53 |
The ones with the least conversations are selected.
|
54 |
|
55 |
+
Raises:
|
56 |
+
Exception: If there is an error selecting the fighters.
|
57 |
+
|
58 |
Returns:
|
59 |
CRS models to battle.
|
60 |
"""
|
61 |
+
sorted_count = sorted(CONVERSATION_COUNTS.items(), key=lambda x: x[1])
|
62 |
+
# Group models by conversation count.
|
63 |
+
groups = [
|
64 |
+
list(group)
|
65 |
+
for _, group in itertools.groupby(sorted_count, key=lambda x: x[1])
|
66 |
+
]
|
67 |
+
|
68 |
+
model_1, model_2 = None, None
|
69 |
+
|
70 |
+
try:
|
71 |
+
if len(groups[0]) >= 2:
|
72 |
+
model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
|
73 |
+
model_2 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
|
74 |
+
else:
|
75 |
+
model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
|
76 |
+
model_2 = groups[1].pop(random.randint(0, len(groups[1]) - 1))[0]
|
77 |
+
except Exception as e:
|
78 |
+
logging.error(f"Error selecting CRS fighters: {e}")
|
79 |
+
if model_1 is None:
|
80 |
+
model_1 = sorted_count[0][0]
|
81 |
+
if model_2 is None:
|
82 |
+
model_2 = sorted_count[1][0]
|
83 |
+
|
84 |
+
fighter1 = CRSFighter(1, model_1, CRS_MODELS[model_1])
|
85 |
+
fighter2 = CRSFighter(2, model_2, CRS_MODELS[model_2])
|
86 |
return fighter1, fighter2
|
87 |
|
88 |
|