Nol00 commited on
Commit
f57c0ab
·
verified ·
1 Parent(s): 3511e77

Update src/model/UNICRS.py

Browse files
Files changed (1) hide show
  1. src/model/UNICRS.py +7 -4
src/model/UNICRS.py CHANGED
@@ -421,7 +421,7 @@ class UNICRS:
421
  }
422
 
423
  gen_seqs = self.model.generate(**input_batch["context"], **gen_args)
424
- gen_str = self.tokenizer.decode(gen_seqs[0], skip_special_tokens=True)
425
 
426
  return input_batch, gen_str
427
 
@@ -450,7 +450,7 @@ class UNICRS:
450
  id2entity: Dict[int, str],
451
  options: Tuple[str, Dict[str, str]],
452
  state: List[float],
453
- movie_token: str = "<mask>",
454
  ) -> Tuple[str, List[float]]:
455
  """Generates a response given a conversation context.
456
 
@@ -465,7 +465,7 @@ class UNICRS:
465
  id2entity: Mapping from entity ID to entity name.
466
  options: Prompt with options and dictionary of options.
467
  state: State of the option choices.
468
- movie_token: Mask token for the movie. Defaults to "<mask>".
469
 
470
  Returns:
471
  Generated response and updated state.
@@ -496,10 +496,13 @@ class UNICRS:
496
  generated_response = generated_response[
497
  generated_response.rfind("System:") + len("System:") + 1 :
498
  ]
 
 
 
499
  for i in range(str.count(generated_response, movie_token)):
500
  try:
501
  generated_response = generated_response.replace(
502
- movie_token, id2entity[recommended_items[i]], 1
503
  )
504
  except IndexError as e:
505
  logging.error(e)
 
421
  }
422
 
423
  gen_seqs = self.model.generate(**input_batch["context"], **gen_args)
424
+ gen_str = self.tokenizer.decode(gen_seqs[0], skip_special_tokens=False)
425
 
426
  return input_batch, gen_str
427
 
 
450
  id2entity: Dict[int, str],
451
  options: Tuple[str, Dict[str, str]],
452
  state: List[float],
453
+ movie_token: str = "<pad>",
454
  ) -> Tuple[str, List[float]]:
455
  """Generates a response given a conversation context.
456
 
 
465
  id2entity: Mapping from entity ID to entity name.
466
  options: Prompt with options and dictionary of options.
467
  state: State of the option choices.
468
+ movie_token: Mask token for the movie. Defaults to "<pad>".
469
 
470
  Returns:
471
  Generated response and updated state.
 
496
  generated_response = generated_response[
497
  generated_response.rfind("System:") + len("System:") + 1 :
498
  ]
499
+ generated_response = generated_response.replace(
500
+ "<|endoftext|>", ""
501
+ )
502
  for i in range(str.count(generated_response, movie_token)):
503
  try:
504
  generated_response = generated_response.replace(
505
+ movie_token, id2entity[recommended_items[0][i]], 1
506
  )
507
  except IndexError as e:
508
  logging.error(e)