Spaces:
Paused
Paused
Update app_chat.py
Browse files- app_chat.py +1 -14
app_chat.py
CHANGED
@@ -30,20 +30,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
30 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
31 |
#tokenizer.use_default_system_prompt = False
|
32 |
|
33 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
34 |
-
def __init__(self, tokenizer, stops = [], encounters=1):
|
35 |
-
super().__init__()
|
36 |
-
self.stops = [stop.to("cuda") for stop in stops]
|
37 |
-
self.tokenizer = tokenizer
|
38 |
-
self.num_mamba_stop_ids = 8
|
39 |
-
|
40 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
41 |
-
last_token = input_ids[0][-self.num_mamba_stop_ids:]
|
42 |
-
for stop in self.stops:
|
43 |
-
if self.tokenizer.decode(stop) in self.tokenizer.decode(last_token):
|
44 |
-
return True
|
45 |
-
return False
|
46 |
-
|
47 |
@spaces.GPU
|
48 |
def generate(
|
49 |
message: str,
|
@@ -66,6 +52,7 @@ def generate(
|
|
66 |
|
67 |
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
|
68 |
|
|
|
69 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
70 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
71 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
|
|
30 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
31 |
#tokenizer.use_default_system_prompt = False
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
@spaces.GPU
|
34 |
def generate(
|
35 |
message: str,
|
|
|
52 |
|
53 |
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
|
54 |
|
55 |
+
|
56 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
57 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
58 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|