robert commited on
Commit
c42ac02
·
1 Parent(s): e57aba0

Changing the logic in spaces_model_predict

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -41,26 +41,20 @@ class StopOnSequence(StoppingCriteria):
41
  def __call__(
42
  self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
43
  ) -> bool:
44
- if input_ids.shape[1] < self.sequence_len:
45
- return False
46
- return (
47
- (
48
- input_ids[0, -self.sequence_len :]
49
- == torch.tensor(self.sequence_ids, device=input_ids.device)
50
- )
51
- .all()
52
- .item()
53
- )
54
 
55
 
56
  @spaces.GPU(duration=54)
57
  def spaces_model_predict(message: str, history: list[tuple[str, str]]):
58
  history_transformer_format = history + [[message, ""]]
59
- stop = StopOnSequence("<|human|>", tokenizer)
60
 
61
  messages = "".join(
62
  [
63
- "".join(["\n<human>:" + item[0], "\n<ai>:" + item[1]])
64
  for item in history_transformer_format
65
  ]
66
  )
@@ -85,9 +79,8 @@ def spaces_model_predict(message: str, history: list[tuple[str, str]]):
85
 
86
  partial_message = ""
87
  for new_token in streamer:
88
- if new_token != "<":
89
- partial_message += new_token
90
- return partial_message
91
 
92
 
93
  def predict(
 
41
  def __call__(
42
  self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
43
  ) -> bool:
44
+ for i in range(input_ids.shape[0]):
45
+ if input_ids[i, -self.sequence_len:].tolist() == self.sequence_ids:
46
+ return True
47
+ return False
 
 
 
 
 
 
48
 
49
 
50
  @spaces.GPU(duration=54)
51
  def spaces_model_predict(message: str, history: list[tuple[str, str]]):
52
  history_transformer_format = history + [[message, ""]]
53
+ stop = StopOnSequence("<|user|>", tokenizer)
54
 
55
  messages = "".join(
56
  [
57
+ f"<|user|>\n{item[0]}\n<|assistant|>\n{item[1]}"
58
  for item in history_transformer_format
59
  ]
60
  )
 
79
 
80
  partial_message = ""
81
  for new_token in streamer:
82
+ partial_message += new_token
83
+ yield partial_message
 
84
 
85
 
86
  def predict(