Spaces:
Running
Running
Add selector for model and do some layout
Browse files
README.md
CHANGED
@@ -11,4 +11,4 @@ license: mit
|
|
11 |
short_description: ServiceNow-AI model chat
|
12 |
---
|
13 |
|
14 |
-
|
|
|
11 |
short_description: ServiceNow-AI model chat
|
12 |
---
|
13 |
|
14 |
+
A chatbot for ServiceNow-AI model chat. This is a demo of the Apriel Nemotron Chat model. The chatbot can answer questions, provide information, etc.
|
app.py
CHANGED
@@ -1,99 +1,66 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
import datetime
|
4 |
|
5 |
from openai import OpenAI
|
6 |
import gradio as gr
|
7 |
-
from gradio.components.chatbot import ChatMessage, Message
|
8 |
-
from typing import (
|
9 |
-
Any,
|
10 |
-
Literal,
|
11 |
-
)
|
12 |
|
13 |
-
|
14 |
|
15 |
print(f"Gradio version: {gr.__version__}")
|
16 |
|
17 |
-
|
18 |
-
description = "Please use the community section on this space to provide feedback! <a href=\"https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker/discussions\">ServiceNow-AI/Apriel-Nemotron-Chat</a>"
|
19 |
|
20 |
chat_start_count = 0
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
"
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
if
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
# Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not!
|
43 |
-
def _check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
|
44 |
-
if type == "messages":
|
45 |
-
all_valid = all(
|
46 |
-
isinstance(message, dict)
|
47 |
-
and "role" in message
|
48 |
-
and "content" in message
|
49 |
-
or isinstance(message, ChatMessage | Message)
|
50 |
-
for message in messages
|
51 |
-
)
|
52 |
-
if not all_valid:
|
53 |
-
# Display which message is not valid
|
54 |
-
for i, message in enumerate(messages):
|
55 |
-
if not (isinstance(message, dict) and
|
56 |
-
"role" in message and
|
57 |
-
"content" in message) and not isinstance(message, ChatMessage | Message):
|
58 |
-
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
|
59 |
-
break
|
60 |
-
|
61 |
-
raise Exception(
|
62 |
-
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
|
63 |
-
)
|
64 |
-
# else:
|
65 |
-
# print("_check_format() --> All messages are valid.")
|
66 |
-
elif not all(
|
67 |
-
isinstance(message, (tuple, list)) and len(message) == 2
|
68 |
-
for message in messages
|
69 |
-
):
|
70 |
-
raise Exception(
|
71 |
-
"Data incompatible with tuples format. Each message should be a list of length 2."
|
72 |
-
)
|
73 |
|
74 |
|
75 |
def chat_fn(message, history):
|
76 |
-
log_message(f"{'-' * 80}
|
|
|
|
|
77 |
|
78 |
global chat_start_count
|
79 |
chat_start_count = chat_start_count + 1
|
80 |
print(
|
81 |
f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")
|
82 |
|
|
|
|
|
83 |
# Remove any assistant messages with metadata from history for multiple turns
|
84 |
log_message(f"Original History: {history}")
|
85 |
-
|
86 |
history = [item for item in history if
|
87 |
not (isinstance(item, dict) and
|
88 |
item.get("role") == "assistant" and
|
89 |
isinstance(item.get("metadata"), dict) and
|
90 |
item.get("metadata", {}).get("title") is not None)]
|
91 |
log_message(f"Updated History: {history}")
|
92 |
-
|
93 |
|
94 |
history.append({"role": "user", "content": message})
|
95 |
log_message(f"History with user message: {history}")
|
96 |
-
|
97 |
|
98 |
# Create the streaming response
|
99 |
stream = client.chat.completions.create(
|
@@ -103,13 +70,14 @@ def chat_fn(message, history):
|
|
103 |
stream=True
|
104 |
)
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
113 |
|
114 |
output = ""
|
115 |
completion_started = False
|
@@ -118,49 +86,121 @@ def chat_fn(message, history):
|
|
118 |
content = getattr(chunk.choices[0].delta, "content", "")
|
119 |
output += content
|
120 |
|
121 |
-
|
|
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
128 |
|
129 |
-
|
130 |
-
role="assistant",
|
131 |
-
content=parts[0],
|
132 |
-
metadata={"title": "🧠 Thought"}
|
133 |
-
)
|
134 |
-
if completion_started:
|
135 |
-
history[-1] = gr.ChatMessage(
|
136 |
role="assistant",
|
137 |
-
content=parts[
|
|
|
138 |
)
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
role="assistant",
|
143 |
-
content=
|
144 |
-
)
|
145 |
|
146 |
# only yield the most recent assistant messages
|
147 |
messages_to_yield = history[-1:] if not completion_started else history[-2:]
|
148 |
-
#
|
|
|
149 |
yield messages_to_yield
|
150 |
|
151 |
log_message(f"Final History: {history}")
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
|
161 |
-
chat_fn,
|
162 |
-
title=title,
|
163 |
-
description=description,
|
164 |
-
theme=gr.themes.Default(primary_hue="green"),
|
165 |
-
type="messages",
|
166 |
-
).launch()
|
|
|
|
|
|
|
1 |
import datetime
|
2 |
|
3 |
from openai import OpenAI
|
4 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
from utils import COMMUNITY_POSTFIX_URL, get_model_config, log_message, check_format, models_config
|
7 |
|
8 |
print(f"Gradio version: {gr.__version__}")
|
9 |
|
10 |
+
DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker"
|
|
|
11 |
|
12 |
chat_start_count = 0
|
13 |
+
model_config = None
|
14 |
+
client = None
|
15 |
|
16 |
+
|
17 |
+
def setup_model(model_name, intial=False):
|
18 |
+
global model_config, client
|
19 |
+
model_config = get_model_config(model_name)
|
20 |
+
log_message(f"update_model() --> Model config: {model_config}")
|
21 |
+
client = OpenAI(
|
22 |
+
api_key=model_config.get('AUTH_TOKEN'),
|
23 |
+
base_url=model_config.get('VLLM_API_URL')
|
24 |
+
)
|
25 |
+
|
26 |
+
_model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1]
|
27 |
+
_link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>"
|
28 |
+
_description = f"Please use the community section on this space to provide feedback! {_link}"
|
29 |
+
|
30 |
+
print(f"Switched to model {_model_hf_name}")
|
31 |
+
|
32 |
+
if intial:
|
33 |
+
return
|
34 |
+
else:
|
35 |
+
return _description
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
def chat_fn(message, history):
|
39 |
+
log_message(f"{'-' * 80}")
|
40 |
+
log_message(f"chat_fn() --> Message: {message}")
|
41 |
+
log_message(f"chat_fn() --> History: {history}")
|
42 |
|
43 |
global chat_start_count
|
44 |
chat_start_count = chat_start_count + 1
|
45 |
print(
|
46 |
f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")
|
47 |
|
48 |
+
is_reasoning = model_config.get("REASONING")
|
49 |
+
|
50 |
# Remove any assistant messages with metadata from history for multiple turns
|
51 |
log_message(f"Original History: {history}")
|
52 |
+
check_format(history, "messages")
|
53 |
history = [item for item in history if
|
54 |
not (isinstance(item, dict) and
|
55 |
item.get("role") == "assistant" and
|
56 |
isinstance(item.get("metadata"), dict) and
|
57 |
item.get("metadata", {}).get("title") is not None)]
|
58 |
log_message(f"Updated History: {history}")
|
59 |
+
check_format(history, "messages")
|
60 |
|
61 |
history.append({"role": "user", "content": message})
|
62 |
log_message(f"History with user message: {history}")
|
63 |
+
check_format(history, "messages")
|
64 |
|
65 |
# Create the streaming response
|
66 |
stream = client.chat.completions.create(
|
|
|
70 |
stream=True
|
71 |
)
|
72 |
|
73 |
+
if is_reasoning:
|
74 |
+
history.append(gr.ChatMessage(
|
75 |
+
role="assistant",
|
76 |
+
content="Thinking...",
|
77 |
+
metadata={"title": "🧠 Thought"}
|
78 |
+
))
|
79 |
+
log_message(f"History added thinking: {history}")
|
80 |
+
check_format(history, "messages")
|
81 |
|
82 |
output = ""
|
83 |
completion_started = False
|
|
|
86 |
content = getattr(chunk.choices[0].delta, "content", "")
|
87 |
output += content
|
88 |
|
89 |
+
if is_reasoning:
|
90 |
+
parts = output.split("[BEGIN FINAL RESPONSE]")
|
91 |
|
92 |
+
if len(parts) > 1:
|
93 |
+
if parts[1].endswith("[END FINAL RESPONSE]"):
|
94 |
+
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
|
95 |
+
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
|
96 |
+
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")
|
97 |
+
if parts[1].endswith("<|end|>"):
|
98 |
+
parts[1] = parts[1].replace("<|end|>", "")
|
99 |
|
100 |
+
history[-1 if not completion_started else -2] = gr.ChatMessage(
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
role="assistant",
|
102 |
+
content=parts[0],
|
103 |
+
metadata={"title": "🧠 Thought"}
|
104 |
)
|
105 |
+
if completion_started:
|
106 |
+
history[-1] = gr.ChatMessage(
|
107 |
+
role="assistant",
|
108 |
+
content=parts[1]
|
109 |
+
)
|
110 |
+
elif len(parts) > 1 and not completion_started:
|
111 |
+
completion_started = True
|
112 |
+
history.append(gr.ChatMessage(
|
113 |
+
role="assistant",
|
114 |
+
content=parts[1]
|
115 |
+
))
|
116 |
+
else:
|
117 |
+
if output.endswith("<|end|>"):
|
118 |
+
output = output.replace("<|end|>", "")
|
119 |
+
history[-1] = gr.ChatMessage(
|
120 |
role="assistant",
|
121 |
+
content=output
|
122 |
+
)
|
123 |
|
124 |
# only yield the most recent assistant messages
|
125 |
messages_to_yield = history[-1:] if not completion_started else history[-2:]
|
126 |
+
# check_format(messages_to_yield, "messages")
|
127 |
+
# log_message(f"Yielding messages: {messages_to_yield}")
|
128 |
yield messages_to_yield
|
129 |
|
130 |
log_message(f"Final History: {history}")
|
131 |
+
check_format(history, "messages")
|
132 |
+
|
133 |
+
|
134 |
+
title = None
|
135 |
+
description = None
|
136 |
+
|
137 |
+
with gr.Blocks(theme=gr.themes.Default(primary_hue="green")) as demo:
|
138 |
+
gr.HTML("""
|
139 |
+
<style>
|
140 |
+
.model-message {
|
141 |
+
text-align: end;
|
142 |
+
}
|
143 |
+
|
144 |
+
.model-dropdown-container {
|
145 |
+
display: flex;
|
146 |
+
align-items: center;
|
147 |
+
gap: 10px;
|
148 |
+
padding: 0;
|
149 |
+
}
|
150 |
+
|
151 |
+
@media (max-width: 800px) {
|
152 |
+
.responsive-row {
|
153 |
+
flex-direction: column;
|
154 |
+
}
|
155 |
+
.model-dropdown-container {
|
156 |
+
flex-direction: column;
|
157 |
+
align-items: flex-start;
|
158 |
+
}
|
159 |
+
}
|
160 |
+
""")
|
161 |
+
|
162 |
+
with gr.Row(variant="panel", elem_classes="responsive-row"):
|
163 |
+
with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"):
|
164 |
+
model_dropdown = gr.Dropdown(
|
165 |
+
choices=[f"Model: {model}" for model in models_config.keys()],
|
166 |
+
value=f"Model: {DEFAULT_MODEL_NAME}",
|
167 |
+
label=None,
|
168 |
+
interactive=True,
|
169 |
+
container=False,
|
170 |
+
scale=0,
|
171 |
+
min_width=400
|
172 |
+
)
|
173 |
+
with gr.Column(scale=4, min_width=0):
|
174 |
+
description_html = gr.HTML(description, elem_classes="model-message")
|
175 |
|
176 |
+
chat_bot = gr.Chatbot(
|
177 |
+
type="messages",
|
178 |
+
height="calc(100vh - 320px)",
|
179 |
+
)
|
180 |
|
181 |
+
chat_interface = gr.ChatInterface(
|
182 |
+
chat_fn,
|
183 |
+
description="",
|
184 |
+
type="messages",
|
185 |
+
chatbot=chat_bot
|
186 |
+
)
|
187 |
|
188 |
+
# Add this line to ensure the model is reset to default on page reload
|
189 |
+
demo.load(lambda: setup_model(DEFAULT_MODEL_NAME, intial=False), [], [description_html])
|
190 |
+
|
191 |
+
|
192 |
+
def update_model_and_clear(model_name):
|
193 |
+
# Remove the "Model: " prefix to get the actual model name
|
194 |
+
actual_model_name = model_name.replace("Model: ", "")
|
195 |
+
desc = setup_model(actual_model_name)
|
196 |
+
chat_bot.clear() # Critical line
|
197 |
+
return desc
|
198 |
+
|
199 |
+
|
200 |
+
model_dropdown.change(
|
201 |
+
fn=update_model_and_clear,
|
202 |
+
inputs=[model_dropdown],
|
203 |
+
outputs=[description_html]
|
204 |
+
)
|
205 |
|
206 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from typing import Any, Literal
|
4 |
+
|
5 |
+
from gradio import ChatMessage
|
6 |
+
from gradio.components.chatbot import Message
|
7 |
+
|
8 |
+
COMMUNITY_POSTFIX_URL = "/discussions"
|
9 |
+
DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True"
|
10 |
+
|
11 |
+
models_config = {
|
12 |
+
"Apriel-Nemotron-15b-Thinker": {
|
13 |
+
"MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker",
|
14 |
+
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker",
|
15 |
+
"MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"),
|
16 |
+
"VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"),
|
17 |
+
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
|
18 |
+
"REASONING": True
|
19 |
+
},
|
20 |
+
"Apriel-5b": {
|
21 |
+
"MODEL_DISPLAY_NAME": "Apriel-5b",
|
22 |
+
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct",
|
23 |
+
"MODEL_NAME": os.environ.get("MODEL_NAME_5B"),
|
24 |
+
"VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"),
|
25 |
+
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
|
26 |
+
"REASONING": False
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def get_model_config(model_name: str) -> dict:
|
32 |
+
config = models_config.get(model_name)
|
33 |
+
if not config:
|
34 |
+
raise ValueError(f"Model {model_name} not found in models_config")
|
35 |
+
if not config.get("MODEL_NAME"):
|
36 |
+
raise ValueError(f"Model name not found in config for {model_name}")
|
37 |
+
if not config.get("VLLM_API_URL"):
|
38 |
+
raise ValueError(f"VLLM API URL not found in config for {model_name}")
|
39 |
+
|
40 |
+
return config
|
41 |
+
|
42 |
+
|
43 |
+
def log_message(message):
|
44 |
+
if DEBUG_MODE is True:
|
45 |
+
print(f"≫≫≫ {message}")
|
46 |
+
|
47 |
+
|
48 |
+
# Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not!
|
49 |
+
def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
|
50 |
+
if not DEBUG_MODE:
|
51 |
+
return
|
52 |
+
|
53 |
+
if type == "messages":
|
54 |
+
all_valid = all(
|
55 |
+
isinstance(message, dict)
|
56 |
+
and "role" in message
|
57 |
+
and "content" in message
|
58 |
+
or isinstance(message, ChatMessage | Message)
|
59 |
+
for message in messages
|
60 |
+
)
|
61 |
+
if not all_valid:
|
62 |
+
# Display which message is not valid
|
63 |
+
for i, message in enumerate(messages):
|
64 |
+
if not (isinstance(message, dict) and
|
65 |
+
"role" in message and
|
66 |
+
"content" in message) and not isinstance(message, ChatMessage | Message):
|
67 |
+
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
|
68 |
+
break
|
69 |
+
|
70 |
+
raise Exception(
|
71 |
+
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
|
72 |
+
)
|
73 |
+
# else:
|
74 |
+
# print("_check_format() --> All messages are valid.")
|
75 |
+
elif not all(
|
76 |
+
isinstance(message, (tuple, list)) and len(message) == 2
|
77 |
+
for message in messages
|
78 |
+
):
|
79 |
+
raise Exception(
|
80 |
+
"Data incompatible with tuples format. Each message should be a list of length 2."
|
81 |
+
)
|