Update crs_arena/crs_fighter.py
Browse files- crs_arena/crs_fighter.py +11 -1
crs_arena/crs_fighter.py
CHANGED
@@ -43,8 +43,18 @@ class CRSFighter:
|
|
43 |
# Load options
|
44 |
self.options = get_options(self.model.crs_model.kg_dataset)
|
45 |
|
46 |
-
# Generation arguments
|
47 |
self.response_generation_args = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def _load_entity_data(self):
|
50 |
"""Loads entity data."""
|
|
|
43 |
# Load options
|
44 |
self.options = get_options(self.model.crs_model.kg_dataset)
|
45 |
|
46 |
+
# Generation arguments
|
47 |
self.response_generation_args = {}
|
48 |
+
if self.name.split("_")[0] == "unicrs":
|
49 |
+
self.response_generation_args.update(
|
50 |
+
{
|
51 |
+
"movie_token": (
|
52 |
+
"<movie>"
|
53 |
+
if self.model.crs_model.kg_dataset.startswith("redial")
|
54 |
+
else "<mask>"
|
55 |
+
),
|
56 |
+
}
|
57 |
+
)
|
58 |
|
59 |
def _load_entity_data(self):
|
60 |
"""Loads entity data."""
|