File size: 4,678 Bytes
b599481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Utility functions for data preprocessing."""

import json
import re
from typing import Any, Dict, List

import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

nltk.download("stopwords")

DEFAULT_ITEM_PLACEHOLDER = "ITEM_ID"


def remove_stopwords(utterance: str) -> str:
    """Removes stopwords from an utterance.

    Args:
        utterance: Input utterance.

    Returns:
        Utterance without stopwords.
    """
    tokens = word_tokenize(utterance)
    filtered_tokens = [
        token for token in tokens if token not in stopwords.words()
    ]
    return " ".join(filtered_tokens)


def expand_contractions(utterance: str) -> str:
    """Expands contractions in an utterance.

    Args:
        utterance: Input utterance.

    Returns:
        Utterance with expanded contractions.
    """
    contractions = json.load(open("data/crb_crs/contractions.json", "r"))
    for word in utterance.split():
        if word.lower() in contractions:
            utterance = utterance.replace(word, contractions[word.lower()])
    return utterance


def redial_replace_movie_ids(
    utterance: str, movie_placeholder: str = DEFAULT_ITEM_PLACEHOLDER
) -> str:
    """Replaces movie ids with a placeholder in utterance from ReDial dataset.

    Args:
        utterance: Input utterance.
        movie_placeholder: Placeholder for movie id.

    Returns:
        Utterance with movie ids replaced by placeholder.
    """
    if "@" in utterance:
        movie_ids = re.findall(r"@\S+", utterance)
        if movie_ids:
            for movie_id in movie_ids:
                utterance = utterance.replace(movie_id, movie_placeholder)
    return utterance


def opendialkg_replace_items(
    text: str,
    items: List[str],
    item_placeholder: str = DEFAULT_ITEM_PLACEHOLDER,
):
    """Replaces items with a placeholder in utterance from OpenDialKG dataset.

    Args:
        text: Input utterance.
        items: List of items in the utterance (taken from dataset).
        item_placeholder: Placeholder for item.

    Returns:
        Utterance with items replaced by placeholder.
    """
    for item in items:
        text = text.replace(item, item_placeholder)
    return text


def preprocess_utterance(
    utterance: Dict[str, Any],
    dataset: str,
    item_placeholder: str = DEFAULT_ITEM_PLACEHOLDER,
    no_stopwords: bool = True,
) -> str:
    """Preprocesses an utterance.

    Preprocessing includes lowercasing, stripping, replacing item id with a
    palceholder, converting contractions to full form, and removing stopwords.

    Args:
        utterance: Input utterance.
        dataset: Name of the origin dataset.
        item_placeholder: Placeholder for item id.
        stopwords: Whether to remove stopwords.

    Raises:
        ValueError: If dataset is not supported.

    Returns:
        Preprocessed utterance.
    """
    processed_utterance = utterance.get("text").lower().strip()

    if dataset == "redial":
        processed_utterance = redial_replace_movie_ids(
            processed_utterance, item_placeholder
        )
    elif dataset == "opendialkg":
        processed_utterance = opendialkg_replace_items(
            processed_utterance, utterance.get("items", []), item_placeholder
        )
    else:
        raise ValueError(f"Dataset {dataset} not supported.")

    processed_utterance = expand_contractions(processed_utterance)
    if no_stopwords:
        processed_utterance = remove_stopwords(processed_utterance)

    if processed_utterance == "":
        processed_utterance = "**"

    return processed_utterance


def get_preference_keywords(domain: str) -> List[str]:
    """Returns a list of preference keywords.

    Args:
        domain: Domain name.

    Raises:
        ValueError: If the domain is not supported.
    """
    movies_preference_keywords = [
        "scary",
        "horror",
        "pixar",
        "graphic",
        "classic",
        "comedy",
        "kids",
        "funny",
        "disney",
        "comedies",
        "action",
        "family",
        "adventure",
        "crime",
        "fantasy",
        "thriller",
        "scifi",
        "documentary",
        "science fiction",
        "drama",
        "romance",
        "romances",
        "romantic",
        "mystery",
        "mysteries",
        "history",
        "no preference",
        "suspense",
    ]
    if domain == "movies":
        return movies_preference_keywords
    elif domain == "movies_books":
        return (
            movies_preference_keywords + []
        )  # TOOD: Add more keywords related to books
    raise ValueError(f"Domain not supported: {domain}")