|
"""Battle Manager module. |
|
|
|
Contains helper functions to select fighters for a battle and generate |
|
unique user ids. |
|
""" |
|
|
|
import itertools |
|
import logging |
|
import random |
|
import uuid |
|
from collections import defaultdict |
|
from typing import Optional, Tuple |
|
|
|
from crs_fighter import CRSFighter |
|
from utils import get_crs_model |
|
|
|
|
|
CRS_MODELS = { |
|
"crbcrs_redial": "data/arena/crs_config/CRB_CRS/crb_crs_redial.yaml", |
|
"kbrd_redial": "data/arena/crs_config/KBRD/kbrd_redial.yaml", |
|
"unicrs_redial": "data/arena/crs_config/UniCRS/unicrs_redial.yaml", |
|
"kbrd_opendialkg": "data/arena/crs_config/KBRD/kbrd_opendialkg.yaml", |
|
"chatgpt_redial": "data/arena/crs_config/ChatGPT/chatgpt_redial.yaml", |
|
"barcor_opendialkg": "data/arena/crs_config/BARCOR/barcor_opendialkg.yaml", |
|
"unicrs_opendialkg": "data/arena/crs_config/UniCRS/unicrs_opendialkg.yaml", |
|
"barcor_redial": "data/arena/crs_config/BARCOR/barcor_redial.yaml", |
|
"chatgpt_opendialkg": ( |
|
"data/arena/crs_config/ChatGPT/chatgpt_opendialkg.yaml" |
|
), |
|
} |
|
|
|
CONVERSATION_COUNTS = defaultdict(int).fromkeys(CRS_MODELS.keys(), 0) |
|
|
|
|
|
def get_crs_fighters() -> Tuple[CRSFighter, CRSFighter]: |
|
"""Selects two CRS models for a battle. |
|
|
|
The selection is based on the number of conversations collected per model. |
|
The ones with the least conversations are selected. |
|
|
|
Raises: |
|
Exception: If there is an error selecting the fighters. |
|
|
|
Returns: |
|
CRS models to battle. |
|
""" |
|
sorted_count = sorted(CONVERSATION_COUNTS.items(), key=lambda x: x[1]) |
|
|
|
groups = [ |
|
list(group) |
|
for _, group in itertools.groupby(sorted_count, key=lambda x: x[1]) |
|
] |
|
|
|
model_1, model_2 = None, None |
|
|
|
try: |
|
if len(groups[0]) >= 2: |
|
model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0] |
|
model_2 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0] |
|
else: |
|
model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0] |
|
model_2 = groups[1].pop(random.randint(0, len(groups[1]) - 1))[0] |
|
except Exception as e: |
|
logging.error(f"Error selecting CRS fighters: {e}") |
|
if model_1 is None: |
|
model_1 = sorted_count[0][0] |
|
if model_2 is None: |
|
model_2 = sorted_count[1][0] |
|
|
|
fighter1 = CRSFighter(1, model_1, CRS_MODELS[model_1]) |
|
fighter2 = CRSFighter(2, model_2, CRS_MODELS[model_2]) |
|
return fighter1, fighter2 |
|
|
|
|
|
def get_unique_user_id() -> str: |
|
"""Generates a unique user id. |
|
|
|
Returns: |
|
Unique user id. |
|
""" |
|
return str(uuid.uuid4()) |
|
|
|
|
|
def cache_fighters(n: Optional[int] = None) -> None: |
|
"""Caches n CRS fighters. |
|
|
|
Args: |
|
n: Number of fighters to cache. If None, all fighters are cached. |
|
""" |
|
logging.info(f"Caching {n} CRS fighters.") |
|
for i, (model_name, config_path) in enumerate(CRS_MODELS.items()): |
|
try: |
|
get_crs_model(model_name, config_path) |
|
except Exception as e: |
|
logging.error(f"Caching CRS fighters: {e}") |
|
get_crs_model.clear(model_name, config_path) |
|
if n is not None and i == n: |
|
break |
|
|