Update src/model/UNICRS.py
Browse files- src/model/UNICRS.py +7 -4
src/model/UNICRS.py
CHANGED
@@ -421,7 +421,7 @@ class UNICRS:
|
|
421 |
}
|
422 |
|
423 |
gen_seqs = self.model.generate(**input_batch["context"], **gen_args)
|
424 |
-
gen_str = self.tokenizer.decode(gen_seqs[0], skip_special_tokens=
|
425 |
|
426 |
return input_batch, gen_str
|
427 |
|
@@ -450,7 +450,7 @@ class UNICRS:
|
|
450 |
id2entity: Dict[int, str],
|
451 |
options: Tuple[str, Dict[str, str]],
|
452 |
state: List[float],
|
453 |
-
movie_token: str = "<
|
454 |
) -> Tuple[str, List[float]]:
|
455 |
"""Generates a response given a conversation context.
|
456 |
|
@@ -465,7 +465,7 @@ class UNICRS:
|
|
465 |
id2entity: Mapping from entity ID to entity name.
|
466 |
options: Prompt with options and dictionary of options.
|
467 |
state: State of the option choices.
|
468 |
-
movie_token: Mask token for the movie. Defaults to "<
|
469 |
|
470 |
Returns:
|
471 |
Generated response and updated state.
|
@@ -496,10 +496,13 @@ class UNICRS:
|
|
496 |
generated_response = generated_response[
|
497 |
generated_response.rfind("System:") + len("System:") + 1 :
|
498 |
]
|
|
|
|
|
|
|
499 |
for i in range(str.count(generated_response, movie_token)):
|
500 |
try:
|
501 |
generated_response = generated_response.replace(
|
502 |
-
movie_token, id2entity[recommended_items[i]], 1
|
503 |
)
|
504 |
except IndexError as e:
|
505 |
logging.error(e)
|
|
|
421 |
}
|
422 |
|
423 |
gen_seqs = self.model.generate(**input_batch["context"], **gen_args)
|
424 |
+
gen_str = self.tokenizer.decode(gen_seqs[0], skip_special_tokens=False)
|
425 |
|
426 |
return input_batch, gen_str
|
427 |
|
|
|
450 |
id2entity: Dict[int, str],
|
451 |
options: Tuple[str, Dict[str, str]],
|
452 |
state: List[float],
|
453 |
+
movie_token: str = "<pad>",
|
454 |
) -> Tuple[str, List[float]]:
|
455 |
"""Generates a response given a conversation context.
|
456 |
|
|
|
465 |
id2entity: Mapping from entity ID to entity name.
|
466 |
options: Prompt with options and dictionary of options.
|
467 |
state: State of the option choices.
|
468 |
+
movie_token: Mask token for the movie. Defaults to "<pad>".
|
469 |
|
470 |
Returns:
|
471 |
Generated response and updated state.
|
|
|
496 |
generated_response = generated_response[
|
497 |
generated_response.rfind("System:") + len("System:") + 1 :
|
498 |
]
|
499 |
+
generated_response = generated_response.replace(
|
500 |
+
"<|endoftext|>", ""
|
501 |
+
)
|
502 |
for i in range(str.count(generated_response, movie_token)):
|
503 |
try:
|
504 |
generated_response = generated_response.replace(
|
505 |
+
movie_token, id2entity[recommended_items[0][i]], 1
|
506 |
)
|
507 |
except IndexError as e:
|
508 |
logging.error(e)
|