pmolchanov commited on
Commit
f491420
·
verified ·
1 Parent(s): e77e044

Update app_chat.py

Browse files
Files changed (1) hide show
  1. app_chat.py +1 -14
app_chat.py CHANGED
@@ -30,20 +30,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
30
  tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
31
  #tokenizer.use_default_system_prompt = False
32
 
33
- class StoppingCriteriaSub(StoppingCriteria):
34
- def __init__(self, tokenizer, stops = [], encounters=1):
35
- super().__init__()
36
- self.stops = [stop.to("cuda") for stop in stops]
37
- self.tokenizer = tokenizer
38
- self.num_mamba_stop_ids = 8
39
-
40
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
41
- last_token = input_ids[0][-self.num_mamba_stop_ids:]
42
- for stop in self.stops:
43
- if self.tokenizer.decode(stop) in self.tokenizer.decode(last_token):
44
- return True
45
- return False
46
-
47
  @spaces.GPU
48
  def generate(
49
  message: str,
@@ -66,6 +52,7 @@ def generate(
66
 
67
  stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
68
 
 
69
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
70
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
71
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
30
  tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
31
  #tokenizer.use_default_system_prompt = False
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @spaces.GPU
34
  def generate(
35
  message: str,
 
52
 
53
  stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
54
 
55
+
56
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
57
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
58
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")