bradnow commited on
Commit
0d67078
·
1 Parent(s): a220efd

Add selector for model and do some layout

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +143 -103
  3. utils.py +81 -0
README.md CHANGED
@@ -11,4 +11,4 @@ license: mit
11
  short_description: ServiceNow-AI model chat
12
  ---
13
 
14
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
11
  short_description: ServiceNow-AI model chat
12
  ---
13
 
14
+ A chatbot for ServiceNow-AI model chat. This is a demo of the Apriel Nemotron Chat model. The chatbot can answer questions, provide information, etc.
app.py CHANGED
@@ -1,99 +1,66 @@
1
- import os
2
- import sys
3
  import datetime
4
 
5
  from openai import OpenAI
6
  import gradio as gr
7
- from gradio.components.chatbot import ChatMessage, Message
8
- from typing import (
9
- Any,
10
- Literal,
11
- )
12
 
13
- DEBUG_LOG = False or os.environ.get("DEBUG_LOG") == "True"
14
 
15
  print(f"Gradio version: {gr.__version__}")
16
 
17
- title = None # "ServiceNow-AI Chat" # modelConfig.get('MODE_DISPLAY_NAME')
18
- description = "Please use the community section on this space to provide feedback! <a href=\"https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker/discussions\">ServiceNow-AI/Apriel-Nemotron-Chat</a>"
19
 
20
  chat_start_count = 0
 
 
21
 
22
- model_config = {
23
- "MODEL_NAME": os.environ.get("MODEL_NAME"),
24
- "MODE_DISPLAY_NAME": os.environ.get("MODE_DISPLAY_NAME"),
25
- "MODEL_HF_URL": os.environ.get("MODEL_HF_URL"),
26
- "VLLM_API_URL": os.environ.get("VLLM_API_URL"),
27
- "AUTH_TOKEN": os.environ.get("AUTH_TOKEN")
28
- }
29
-
30
- # Initialize the OpenAI client with the vLLM API URL and token
31
- client = OpenAI(
32
- api_key=model_config.get('AUTH_TOKEN'),
33
- base_url=model_config.get('VLLM_API_URL')
34
- )
35
-
36
-
37
- def log_message(message):
38
- if DEBUG_LOG is True:
39
- print(message)
40
-
41
-
42
- # Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not!
43
- def _check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
44
- if type == "messages":
45
- all_valid = all(
46
- isinstance(message, dict)
47
- and "role" in message
48
- and "content" in message
49
- or isinstance(message, ChatMessage | Message)
50
- for message in messages
51
- )
52
- if not all_valid:
53
- # Display which message is not valid
54
- for i, message in enumerate(messages):
55
- if not (isinstance(message, dict) and
56
- "role" in message and
57
- "content" in message) and not isinstance(message, ChatMessage | Message):
58
- print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
59
- break
60
-
61
- raise Exception(
62
- "Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
63
- )
64
- # else:
65
- # print("_check_format() --> All messages are valid.")
66
- elif not all(
67
- isinstance(message, (tuple, list)) and len(message) == 2
68
- for message in messages
69
- ):
70
- raise Exception(
71
- "Data incompatible with tuples format. Each message should be a list of length 2."
72
- )
73
 
74
 
75
  def chat_fn(message, history):
76
- log_message(f"{'-' * 80}\nchat_fn() --> Message: {message}")
 
 
77
 
78
  global chat_start_count
79
  chat_start_count = chat_start_count + 1
80
  print(
81
  f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")
82
 
 
 
83
  # Remove any assistant messages with metadata from history for multiple turns
84
  log_message(f"Original History: {history}")
85
- _check_format(history, "messages")
86
  history = [item for item in history if
87
  not (isinstance(item, dict) and
88
  item.get("role") == "assistant" and
89
  isinstance(item.get("metadata"), dict) and
90
  item.get("metadata", {}).get("title") is not None)]
91
  log_message(f"Updated History: {history}")
92
- _check_format(history, "messages")
93
 
94
  history.append({"role": "user", "content": message})
95
  log_message(f"History with user message: {history}")
96
- _check_format(history, "messages")
97
 
98
  # Create the streaming response
99
  stream = client.chat.completions.create(
@@ -103,13 +70,14 @@ def chat_fn(message, history):
103
  stream=True
104
  )
105
 
106
- history.append(gr.ChatMessage(
107
- role="assistant",
108
- content="Thinking...",
109
- metadata={"title": "🧠 Thought"}
110
- ))
111
- log_message(f"History added thinking: {history}")
112
- _check_format(history, "messages")
 
113
 
114
  output = ""
115
  completion_started = False
@@ -118,49 +86,121 @@ def chat_fn(message, history):
118
  content = getattr(chunk.choices[0].delta, "content", "")
119
  output += content
120
 
121
- parts = output.split("[BEGIN FINAL RESPONSE]")
 
122
 
123
- if len(parts) > 1:
124
- if parts[1].endswith("[END FINAL RESPONSE]"):
125
- parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
126
- if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
127
- parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")
 
 
128
 
129
- history[-1 if not completion_started else -2] = gr.ChatMessage(
130
- role="assistant",
131
- content=parts[0],
132
- metadata={"title": "🧠 Thought"}
133
- )
134
- if completion_started:
135
- history[-1] = gr.ChatMessage(
136
  role="assistant",
137
- content=parts[1]
 
138
  )
139
- elif len(parts) > 1 and not completion_started:
140
- completion_started = True
141
- history.append(gr.ChatMessage(
 
 
 
 
 
 
 
 
 
 
 
 
142
  role="assistant",
143
- content=parts[1]
144
- ))
145
 
146
  # only yield the most recent assistant messages
147
  messages_to_yield = history[-1:] if not completion_started else history[-2:]
148
- # _check_format(messages_to_yield, "messages")
 
149
  yield messages_to_yield
150
 
151
  log_message(f"Final History: {history}")
152
- _check_format(history, "messages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
 
 
 
 
154
 
155
- # Add the model display name and Hugging Face URL to the description
156
- # description = f"### Model: [{MODE_DISPLAY_NAME}]({MODEL_HF_URL})"
 
 
 
 
157
 
158
- print(f"Running model {model_config.get('MODE_DISPLAY_NAME')} ({model_config.get('MODEL_NAME')})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- gr.ChatInterface(
161
- chat_fn,
162
- title=title,
163
- description=description,
164
- theme=gr.themes.Default(primary_hue="green"),
165
- type="messages",
166
- ).launch()
 
 
 
1
  import datetime
2
 
3
  from openai import OpenAI
4
  import gradio as gr
 
 
 
 
 
5
 
6
+ from utils import COMMUNITY_POSTFIX_URL, get_model_config, log_message, check_format, models_config
7
 
8
  print(f"Gradio version: {gr.__version__}")
9
 
10
+ DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker"
 
11
 
12
  chat_start_count = 0
13
+ model_config = None
14
+ client = None
15
 
16
+
17
+ def setup_model(model_name, intial=False):
18
+ global model_config, client
19
+ model_config = get_model_config(model_name)
20
+ log_message(f"update_model() --> Model config: {model_config}")
21
+ client = OpenAI(
22
+ api_key=model_config.get('AUTH_TOKEN'),
23
+ base_url=model_config.get('VLLM_API_URL')
24
+ )
25
+
26
+ _model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1]
27
+ _link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>"
28
+ _description = f"Please use the community section on this space to provide feedback! {_link}"
29
+
30
+ print(f"Switched to model {_model_hf_name}")
31
+
32
+ if intial:
33
+ return
34
+ else:
35
+ return _description
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def chat_fn(message, history):
39
+ log_message(f"{'-' * 80}")
40
+ log_message(f"chat_fn() --> Message: {message}")
41
+ log_message(f"chat_fn() --> History: {history}")
42
 
43
  global chat_start_count
44
  chat_start_count = chat_start_count + 1
45
  print(
46
  f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")
47
 
48
+ is_reasoning = model_config.get("REASONING")
49
+
50
  # Remove any assistant messages with metadata from history for multiple turns
51
  log_message(f"Original History: {history}")
52
+ check_format(history, "messages")
53
  history = [item for item in history if
54
  not (isinstance(item, dict) and
55
  item.get("role") == "assistant" and
56
  isinstance(item.get("metadata"), dict) and
57
  item.get("metadata", {}).get("title") is not None)]
58
  log_message(f"Updated History: {history}")
59
+ check_format(history, "messages")
60
 
61
  history.append({"role": "user", "content": message})
62
  log_message(f"History with user message: {history}")
63
+ check_format(history, "messages")
64
 
65
  # Create the streaming response
66
  stream = client.chat.completions.create(
 
70
  stream=True
71
  )
72
 
73
+ if is_reasoning:
74
+ history.append(gr.ChatMessage(
75
+ role="assistant",
76
+ content="Thinking...",
77
+ metadata={"title": "🧠 Thought"}
78
+ ))
79
+ log_message(f"History added thinking: {history}")
80
+ check_format(history, "messages")
81
 
82
  output = ""
83
  completion_started = False
 
86
  content = getattr(chunk.choices[0].delta, "content", "")
87
  output += content
88
 
89
+ if is_reasoning:
90
+ parts = output.split("[BEGIN FINAL RESPONSE]")
91
 
92
+ if len(parts) > 1:
93
+ if parts[1].endswith("[END FINAL RESPONSE]"):
94
+ parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
95
+ if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
96
+ parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")
97
+ if parts[1].endswith("<|end|>"):
98
+ parts[1] = parts[1].replace("<|end|>", "")
99
 
100
+ history[-1 if not completion_started else -2] = gr.ChatMessage(
 
 
 
 
 
 
101
  role="assistant",
102
+ content=parts[0],
103
+ metadata={"title": "🧠 Thought"}
104
  )
105
+ if completion_started:
106
+ history[-1] = gr.ChatMessage(
107
+ role="assistant",
108
+ content=parts[1]
109
+ )
110
+ elif len(parts) > 1 and not completion_started:
111
+ completion_started = True
112
+ history.append(gr.ChatMessage(
113
+ role="assistant",
114
+ content=parts[1]
115
+ ))
116
+ else:
117
+ if output.endswith("<|end|>"):
118
+ output = output.replace("<|end|>", "")
119
+ history[-1] = gr.ChatMessage(
120
  role="assistant",
121
+ content=output
122
+ )
123
 
124
  # only yield the most recent assistant messages
125
  messages_to_yield = history[-1:] if not completion_started else history[-2:]
126
+ # check_format(messages_to_yield, "messages")
127
+ # log_message(f"Yielding messages: {messages_to_yield}")
128
  yield messages_to_yield
129
 
130
  log_message(f"Final History: {history}")
131
+ check_format(history, "messages")
132
+
133
+
134
+ title = None
135
+ description = None
136
+
137
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="green")) as demo:
138
+ gr.HTML("""
139
+ <style>
140
+ .model-message {
141
+ text-align: end;
142
+ }
143
+
144
+ .model-dropdown-container {
145
+ display: flex;
146
+ align-items: center;
147
+ gap: 10px;
148
+ padding: 0;
149
+ }
150
+
151
+ @media (max-width: 800px) {
152
+ .responsive-row {
153
+ flex-direction: column;
154
+ }
155
+ .model-dropdown-container {
156
+ flex-direction: column;
157
+ align-items: flex-start;
158
+ }
159
+ }
160
+ """)
161
+
162
+ with gr.Row(variant="panel", elem_classes="responsive-row"):
163
+ with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"):
164
+ model_dropdown = gr.Dropdown(
165
+ choices=[f"Model: {model}" for model in models_config.keys()],
166
+ value=f"Model: {DEFAULT_MODEL_NAME}",
167
+ label=None,
168
+ interactive=True,
169
+ container=False,
170
+ scale=0,
171
+ min_width=400
172
+ )
173
+ with gr.Column(scale=4, min_width=0):
174
+ description_html = gr.HTML(description, elem_classes="model-message")
175
 
176
+ chat_bot = gr.Chatbot(
177
+ type="messages",
178
+ height="calc(100vh - 320px)",
179
+ )
180
 
181
+ chat_interface = gr.ChatInterface(
182
+ chat_fn,
183
+ description="",
184
+ type="messages",
185
+ chatbot=chat_bot
186
+ )
187
 
188
+ # Add this line to ensure the model is reset to default on page reload
189
+ demo.load(lambda: setup_model(DEFAULT_MODEL_NAME, intial=False), [], [description_html])
190
+
191
+
192
+ def update_model_and_clear(model_name):
193
+ # Remove the "Model: " prefix to get the actual model name
194
+ actual_model_name = model_name.replace("Model: ", "")
195
+ desc = setup_model(actual_model_name)
196
+ chat_bot.clear() # Critical line
197
+ return desc
198
+
199
+
200
+ model_dropdown.change(
201
+ fn=update_model_and_clear,
202
+ inputs=[model_dropdown],
203
+ outputs=[description_html]
204
+ )
205
 
206
+ demo.launch()
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Any, Literal
4
+
5
+ from gradio import ChatMessage
6
+ from gradio.components.chatbot import Message
7
+
8
+ COMMUNITY_POSTFIX_URL = "/discussions"
9
+ DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True"
10
+
11
+ models_config = {
12
+ "Apriel-Nemotron-15b-Thinker": {
13
+ "MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker",
14
+ "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker",
15
+ "MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"),
16
+ "VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"),
17
+ "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
18
+ "REASONING": True
19
+ },
20
+ "Apriel-5b": {
21
+ "MODEL_DISPLAY_NAME": "Apriel-5b",
22
+ "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct",
23
+ "MODEL_NAME": os.environ.get("MODEL_NAME_5B"),
24
+ "VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"),
25
+ "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
26
+ "REASONING": False
27
+ }
28
+ }
29
+
30
+
31
+ def get_model_config(model_name: str) -> dict:
32
+ config = models_config.get(model_name)
33
+ if not config:
34
+ raise ValueError(f"Model {model_name} not found in models_config")
35
+ if not config.get("MODEL_NAME"):
36
+ raise ValueError(f"Model name not found in config for {model_name}")
37
+ if not config.get("VLLM_API_URL"):
38
+ raise ValueError(f"VLLM API URL not found in config for {model_name}")
39
+
40
+ return config
41
+
42
+
43
+ def log_message(message):
44
+ if DEBUG_MODE is True:
45
+ print(f"≫≫≫ {message}")
46
+
47
+
48
+ # Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not!
49
+ def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
50
+ if not DEBUG_MODE:
51
+ return
52
+
53
+ if type == "messages":
54
+ all_valid = all(
55
+ isinstance(message, dict)
56
+ and "role" in message
57
+ and "content" in message
58
+ or isinstance(message, ChatMessage | Message)
59
+ for message in messages
60
+ )
61
+ if not all_valid:
62
+ # Display which message is not valid
63
+ for i, message in enumerate(messages):
64
+ if not (isinstance(message, dict) and
65
+ "role" in message and
66
+ "content" in message) and not isinstance(message, ChatMessage | Message):
67
+ print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
68
+ break
69
+
70
+ raise Exception(
71
+ "Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
72
+ )
73
+ # else:
74
+ # print("_check_format() --> All messages are valid.")
75
+ elif not all(
76
+ isinstance(message, (tuple, list)) and len(message) == 2
77
+ for message in messages
78
+ ):
79
+ raise Exception(
80
+ "Data incompatible with tuples format. Each message should be a list of length 2."
81
+ )