File size: 1,856 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sys
from typing import Any, Dict, List, Tuple

sys.path.append("..")

from src.model.BARCOR import BARCOR
from src.model.CHATGPT import CHATGPT
from src.model.CRB_CRS import CRBCRSModel
from src.model.KBRD import KBRD
from src.model.UNICRS import UNICRS

name2class = {
    "kbrd": KBRD,
    "barcor": BARCOR,
    "unicrs": UNICRS,
    "chatgpt": CHATGPT,
    "crbcrs": CRBCRSModel,
}


class CRSModel:
    def __init__(self, crs_model, *args, **kwargs) -> None:
        model_class = name2class[crs_model]
        self.crs_model = model_class(*args, **kwargs)

    def get_rec(self, conv_dict: Dict[str, Any]):
        """Generates recommendations given a conversation context."""
        return self.crs_model.get_rec(conv_dict)

    def get_conv(self, conv_dict: Dict[str, Any]):
        """Generates utterance given a conversation context."""
        return self.crs_model.get_conv(conv_dict)

    def get_response(
        self,
        conv_dict: Dict[str, Any],
        id2entity: Dict[int, str],
        options: Tuple[str, Dict[str, str]],
        state: List[float],
        **kwargs
    ) -> Tuple[str, List[float]]:
        """Generates a response given a conversation context.

        Args:
            conv_dict: Conversation context.
            id2entity: Mapping from entity id to entity name.
            options: Prompt with options and dictionary of options.
            state: State of the option choices.

        Returns:
            Generated response and updated state.
        """
        return self.crs_model.get_response(
            conv_dict, id2entity, options, state, **kwargs
        )

    def get_choice(self, gen_inputs, option, state, conv_dict=None):
        """Generates a choice between options given a conversation context."""
        return self.crs_model.get_choice(gen_inputs, option, state, conv_dict)