File size: 2,177 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Battle Manager module.

Contains helper functions to select fighters for a battle and generate
unique user ids.
"""

import logging
import uuid
from collections import defaultdict
from typing import Optional, Tuple

from crs_fighter import CRSFighter
from utils import get_crs_model

# CRS models with their configuration files.
CRS_MODELS = {
    "kbrd_redial": "data/arena/crs_config/KBRD/kbrd_redial.yaml",
    "kbrd_opendialkg": "data/arena/crs_config/KBRD/kbrd_opendialkg.yaml",
    "unicrs_redial": "data/arena/crs_config/UniCRS/unicrs_redial.yaml",
    "unicrs_opendialkg": "data/arena/crs_config/UniCRS/unicrs_opendialkg.yaml",
    "barcor_redial": "data/arena/crs_config/BARCOR/barcor_redial.yaml",
    "barcor_opendialkg": "data/arena/crs_config/BARCOR/barcor_opendialkg.yaml",
    "chatgpt_redial": "data/arena/crs_config/ChatGPT/chatgpt_redial.yaml",
    "chatgpt_opendialkg": (
        "data/arena/crs_config/ChatGPT/chatgpt_opendialkg.yaml"
    ),
    "crbcrs_redial": "data/arena/crs_config/CRB_CRS/crb_crs_redial.yaml",
}

CONVERSATION_COUNTS = defaultdict(int).fromkeys(CRS_MODELS.keys(), 0)


def get_crs_fighters() -> Tuple[CRSFighter, CRSFighter]:
    """Selects two CRS models for a battle.

    The selection is based on the number of conversations collected per model.
    The ones with the least conversations are selected.

    Returns:
        CRS models to battle.
    """
    pair = sorted(CONVERSATION_COUNTS.items(), key=lambda x: x[1])[:2]
    fighter1 = CRSFighter(1, pair[0][0], CRS_MODELS[pair[0][0]])
    fighter2 = CRSFighter(2, pair[1][0], CRS_MODELS[pair[1][0]])
    return fighter1, fighter2


def get_unique_user_id() -> str:
    """Generates a unique user id.

    Returns:
        Unique user id.
    """
    return str(uuid.uuid4())


def cache_fighters(n: Optional[int] = None) -> None:
    """Caches n CRS fighters.

    Args:
        n: Number of fighters to cache. If None, all fighters are cached.
    """
    logging.info(f"Caching {n} CRS fighters.")
    for i, (model_name, config_path) in enumerate(CRS_MODELS.items()):
        get_crs_model(model_name, config_path)
        if n is not None and i == n:
            break