Spaces:
Paused
Paused
Update app_chat.py
Browse files- app_chat.py +3 -5
app_chat.py
CHANGED
@@ -19,16 +19,14 @@ DEFAULT_MAX_NEW_TOKENS = 512
|
|
19 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
20 |
|
21 |
DESCRIPTION = """\
|
22 |
-
# Hymba-1.5B chat
|
23 |
-
|
24 |
"""
|
25 |
|
26 |
model_id = "nvidia/Hymba-1.5B-Instruct"
|
27 |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True)
|
28 |
model = model.cuda().to(torch.bfloat16)
|
29 |
model.compile()
|
30 |
-
#model.to('cuda')
|
31 |
-
#model.eval()
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
33 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
34 |
#tokenizer.use_default_system_prompt = False
|
@@ -73,7 +71,7 @@ def generate(
|
|
73 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
74 |
input_ids = input_ids.to(model.device)
|
75 |
|
76 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=
|
77 |
generate_kwargs = dict(
|
78 |
{"input_ids": input_ids},
|
79 |
streamer=streamer,
|
|
|
19 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
20 |
|
21 |
DESCRIPTION = """\
|
22 |
+
# Hymba-1.5B-Instruct chat
|
23 |
+
Feel free to chat with our model! More details: [Paper](https://arxiv.org/abs/2411.13676), [Model card](https://huggingface.co/nvidia/Hymba-1.5B-Instruct), [GitHub](https://github.com/NVlabs/hymba).
|
24 |
"""
|
25 |
|
26 |
model_id = "nvidia/Hymba-1.5B-Instruct"
|
27 |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True)
|
28 |
model = model.cuda().to(torch.bfloat16)
|
29 |
model.compile()
|
|
|
|
|
30 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
31 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
32 |
#tokenizer.use_default_system_prompt = False
|
|
|
71 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
72 |
input_ids = input_ids.to(model.device)
|
73 |
|
74 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=1.0, skip_prompt=True, skip_special_tokens=False)
|
75 |
generate_kwargs = dict(
|
76 |
{"input_ids": input_ids},
|
77 |
streamer=streamer,
|