File size: 3,234 Bytes
b599481
 
 
 
 
 
cc60b4d
b599481
cc60b4d
b599481
 
 
 
 
 
 
 
 
cedf204
b599481
9b43388
cedf204
a33c7c0
cedf204
9b43388
b599481
a33c7c0
 
 
b599481
 
 
 
 
 
 
 
 
 
 
cc60b4d
 
 
b599481
 
 
cc60b4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79f754c
 
 
 
 
b599481
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Battle Manager module.

Contains helper functions to select fighters for a battle and generate
unique user ids.
"""

import itertools
import logging
import random
import uuid
from collections import defaultdict
from typing import Optional, Tuple

from crs_fighter import CRSFighter
from utils import get_crs_model

# CRS models with their configuration files.
CRS_MODELS = {
    "crbcrs_redial": "data/arena/crs_config/CRB_CRS/crb_crs_redial.yaml",
    "kbrd_redial": "data/arena/crs_config/KBRD/kbrd_redial.yaml",
    "unicrs_redial": "data/arena/crs_config/UniCRS/unicrs_redial.yaml",
    "kbrd_opendialkg": "data/arena/crs_config/KBRD/kbrd_opendialkg.yaml",
    "chatgpt_redial": "data/arena/crs_config/ChatGPT/chatgpt_redial.yaml",
    "barcor_opendialkg": "data/arena/crs_config/BARCOR/barcor_opendialkg.yaml",
    "unicrs_opendialkg": "data/arena/crs_config/UniCRS/unicrs_opendialkg.yaml",
    "barcor_redial": "data/arena/crs_config/BARCOR/barcor_redial.yaml",
    "chatgpt_opendialkg": (
        "data/arena/crs_config/ChatGPT/chatgpt_opendialkg.yaml"
    ),
}

CONVERSATION_COUNTS = defaultdict(int).fromkeys(CRS_MODELS.keys(), 0)


def get_crs_fighters() -> Tuple[CRSFighter, CRSFighter]:
    """Selects two CRS models for a battle.

    The selection is based on the number of conversations collected per model.
    The ones with the least conversations are selected.

    Raises:
        Exception: If there is an error selecting the fighters.
        
    Returns:
        CRS models to battle.
    """
    sorted_count = sorted(CONVERSATION_COUNTS.items(), key=lambda x: x[1])
    # Group models by conversation count.
    groups = [
        list(group)
        for _, group in itertools.groupby(sorted_count, key=lambda x: x[1])
    ]

    model_1, model_2 = None, None

    try:
        if len(groups[0]) >= 2:
            model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
            model_2 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
        else:
            model_1 = groups[0].pop(random.randint(0, len(groups[0]) - 1))[0]
            model_2 = groups[1].pop(random.randint(0, len(groups[1]) - 1))[0]
    except Exception as e:
        logging.error(f"Error selecting CRS fighters: {e}")
        if model_1 is None:
            model_1 = sorted_count[0][0]
        if model_2 is None:
            model_2 = sorted_count[1][0]

    fighter1 = CRSFighter(1, model_1, CRS_MODELS[model_1])
    fighter2 = CRSFighter(2, model_2, CRS_MODELS[model_2])
    return fighter1, fighter2


def get_unique_user_id() -> str:
    """Generates a unique user id.

    Returns:
        Unique user id.
    """
    return str(uuid.uuid4())


def cache_fighters(n: Optional[int] = None) -> None:
    """Caches n CRS fighters.

    Args:
        n: Number of fighters to cache. If None, all fighters are cached.
    """
    logging.info(f"Caching {n} CRS fighters.")
    for i, (model_name, config_path) in enumerate(CRS_MODELS.items()):
        try:
            get_crs_model(model_name, config_path)
        except Exception as e:
            logging.error(f"Caching CRS fighters: {e}")
            get_crs_model.clear(model_name, config_path)
        if n is not None and i == n:
            break