Nol00 commited on
Commit
ef03bcf
·
verified ·
1 Parent(s): 6fe92ef

Update src/model/UNICRS.py

Browse files
Files changed (1) hide show
  1. src/model/UNICRS.py +20 -3
src/model/UNICRS.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import sys
3
  from collections import defaultdict
4
  from typing import Any, Dict, List, Tuple
@@ -449,6 +450,7 @@ class UNICRS:
449
  id2entity: Dict[int, str],
450
  options: Tuple[str, Dict[str, str]],
451
  state: List[float],
 
452
  ) -> Tuple[str, List[float]]:
453
  """Generates a response given a conversation context.
454
 
@@ -463,6 +465,7 @@ class UNICRS:
463
  id2entity: Mapping from entity ID to entity name.
464
  options: Prompt with options and dictionary of options.
465
  state: State of the option choices.
 
466
 
467
  Returns:
468
  Generated response and updated state.
@@ -473,9 +476,10 @@ class UNICRS:
473
  # Get the choice between recommend and generate
474
  choice = self.get_choice(generated_inputs, options_letter, state)
475
 
 
 
 
476
  if choice == options_letter[-1]:
477
- # Generate a recommendation
478
- recommended_items, _ = self.get_rec(conv_dict)
479
  recommended_items_str = ""
480
  for i, item_id in enumerate(recommended_items[0][:3]):
481
  recommended_items_str += f"{i+1}: {id2entity[item_id]} \n"
@@ -489,7 +493,20 @@ class UNICRS:
489
  # response = (
490
  # options[1].get(choice, {}).get("template", generated_response)
491
  # )
492
- response = generated_response
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  # Update the state. Hack: penalize the choice to reduce the
495
  # likelihood of selecting the same choice again
 
1
  import json
2
+ import logging
3
  import sys
4
  from collections import defaultdict
5
  from typing import Any, Dict, List, Tuple
 
450
  id2entity: Dict[int, str],
451
  options: Tuple[str, Dict[str, str]],
452
  state: List[float],
453
+ movie_token: str = "<mask>",
454
  ) -> Tuple[str, List[float]]:
455
  """Generates a response given a conversation context.
456
 
 
465
  id2entity: Mapping from entity ID to entity name.
466
  options: Prompt with options and dictionary of options.
467
  state: State of the option choices.
468
+ movie_token: Mask token for the movie. Defaults to "<mask>".
469
 
470
  Returns:
471
  Generated response and updated state.
 
476
  # Get the choice between recommend and generate
477
  choice = self.get_choice(generated_inputs, options_letter, state)
478
 
479
+ # Generate a recommendation
480
+ recommended_items, _ = self.get_rec(conv_dict)
481
+
482
  if choice == options_letter[-1]:
 
 
483
  recommended_items_str = ""
484
  for i, item_id in enumerate(recommended_items[0][:3]):
485
  recommended_items_str += f"{i+1}: {id2entity[item_id]} \n"
 
493
  # response = (
494
  # options[1].get(choice, {}).get("template", generated_response)
495
  # )
496
+ generated_response = generated_response[
497
+ generated_response.rfind("System:") + len("System:") + 1 :
498
+ ]
499
+ for i in range(str.count(generated_response, movie_token)):
500
+ try:
501
+ generated_response = generated_response.replace(
502
+ movie_token, id2entity[recommended_items[i]], 1
503
+ )
504
+ except IndexError as e:
505
+ logging.error(e)
506
+ generated_response = generated_response.replace(
507
+ movie_token, "", 1
508
+ )
509
+ response = generated_response.strip()
510
 
511
  # Update the state. Hack: penalize the choice to reduce the
512
  # likelihood of selecting the same choice again