Update src/model/UNICRS.py
Browse files- src/model/UNICRS.py +20 -3
src/model/UNICRS.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
import sys
|
3 |
from collections import defaultdict
|
4 |
from typing import Any, Dict, List, Tuple
|
@@ -449,6 +450,7 @@ class UNICRS:
|
|
449 |
id2entity: Dict[int, str],
|
450 |
options: Tuple[str, Dict[str, str]],
|
451 |
state: List[float],
|
|
|
452 |
) -> Tuple[str, List[float]]:
|
453 |
"""Generates a response given a conversation context.
|
454 |
|
@@ -463,6 +465,7 @@ class UNICRS:
|
|
463 |
id2entity: Mapping from entity ID to entity name.
|
464 |
options: Prompt with options and dictionary of options.
|
465 |
state: State of the option choices.
|
|
|
466 |
|
467 |
Returns:
|
468 |
Generated response and updated state.
|
@@ -473,9 +476,10 @@ class UNICRS:
|
|
473 |
# Get the choice between recommend and generate
|
474 |
choice = self.get_choice(generated_inputs, options_letter, state)
|
475 |
|
|
|
|
|
|
|
476 |
if choice == options_letter[-1]:
|
477 |
-
# Generate a recommendation
|
478 |
-
recommended_items, _ = self.get_rec(conv_dict)
|
479 |
recommended_items_str = ""
|
480 |
for i, item_id in enumerate(recommended_items[0][:3]):
|
481 |
recommended_items_str += f"{i+1}: {id2entity[item_id]} \n"
|
@@ -489,7 +493,20 @@ class UNICRS:
|
|
489 |
# response = (
|
490 |
# options[1].get(choice, {}).get("template", generated_response)
|
491 |
# )
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
|
494 |
# Update the state. Hack: penalize the choice to reduce the
|
495 |
# likelihood of selecting the same choice again
|
|
|
1 |
import json
|
2 |
+
import logging
|
3 |
import sys
|
4 |
from collections import defaultdict
|
5 |
from typing import Any, Dict, List, Tuple
|
|
|
450 |
id2entity: Dict[int, str],
|
451 |
options: Tuple[str, Dict[str, str]],
|
452 |
state: List[float],
|
453 |
+
movie_token: str = "<mask>",
|
454 |
) -> Tuple[str, List[float]]:
|
455 |
"""Generates a response given a conversation context.
|
456 |
|
|
|
465 |
id2entity: Mapping from entity ID to entity name.
|
466 |
options: Prompt with options and dictionary of options.
|
467 |
state: State of the option choices.
|
468 |
+
movie_token: Mask token for the movie. Defaults to "<mask>".
|
469 |
|
470 |
Returns:
|
471 |
Generated response and updated state.
|
|
|
476 |
# Get the choice between recommend and generate
|
477 |
choice = self.get_choice(generated_inputs, options_letter, state)
|
478 |
|
479 |
+
# Generate a recommendation
|
480 |
+
recommended_items, _ = self.get_rec(conv_dict)
|
481 |
+
|
482 |
if choice == options_letter[-1]:
|
|
|
|
|
483 |
recommended_items_str = ""
|
484 |
for i, item_id in enumerate(recommended_items[0][:3]):
|
485 |
recommended_items_str += f"{i+1}: {id2entity[item_id]} \n"
|
|
|
493 |
# response = (
|
494 |
# options[1].get(choice, {}).get("template", generated_response)
|
495 |
# )
|
496 |
+
generated_response = generated_response[
|
497 |
+
generated_response.rfind("System:") + len("System:") + 1 :
|
498 |
+
]
|
499 |
+
for i in range(str.count(generated_response, movie_token)):
|
500 |
+
try:
|
501 |
+
generated_response = generated_response.replace(
|
502 |
+
movie_token, id2entity[recommended_items[i]], 1
|
503 |
+
)
|
504 |
+
except IndexError as e:
|
505 |
+
logging.error(e)
|
506 |
+
generated_response = generated_response.replace(
|
507 |
+
movie_token, "", 1
|
508 |
+
)
|
509 |
+
response = generated_response.strip()
|
510 |
|
511 |
# Update the state. Hack: penalize the choice to reduce the
|
512 |
# likelihood of selecting the same choice again
|