abhishekchohan commited on
Commit
eaa81b0
·
verified ·
1 Parent(s): e64ccba

Upload 2 files

Browse files
Files changed (2) hide show
  1. templates/gemma3.jinja +139 -0
  2. tools/gemma_tool_parser.py +291 -0
templates/gemma3.jinja ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{- bos_token }}
2
+ {%- if custom_tools is defined %}
3
+ {%- set tools = custom_tools %}
4
+ {%- endif %}
5
+ {%- if not tools_in_user_message is defined %}
6
+ {%- set tools_in_user_message = false %}
7
+ {%- endif %}
8
+ {%- if not date_string is defined %}
9
+ {%- if strftime_now is defined %}
10
+ {%- set date_string = strftime_now("%d %b %Y") %}
11
+ {%- else %}
12
+ {%- set date_string = "15 Mar 2025" %}
13
+ {%- endif %}
14
+ {%- endif %}
15
+ {%- if not tools is defined %}
16
+ {%- set tools = none %}
17
+ {%- endif %}
18
+
19
+ {#- Find out if there are any images #}
20
+ {% set image_ns = namespace(has_images=false) %}
21
+ {%- for message in messages %}
22
+ {%- if message['content'] is not string %}
23
+ {%- for content in message['content'] %}
24
+ {%- if content['type'] == 'image' %}
25
+ {%- set image_ns.has_images = true %}
26
+ {%- endif %}
27
+ {%- endfor %}
28
+ {%- endif %}
29
+ {%- endfor %}
30
+
31
+ {#- This block extracts the system message, so we can slot it into the right place. #}
32
+ {%- if messages[0]['role'] == 'system' %}
33
+ {%- if messages[0]['content'] is string %}
34
+ {%- set system_message = messages[0]['content']|trim %}
35
+ {%- else %}
36
+ {%- set system_message = messages[0]['content'][0]['text']|trim %}
37
+ {%- endif %}
38
+ {%- set messages = messages[1:] %}
39
+ {%- else %}
40
+ {%- if tools is not none %}
41
+ {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %}
42
+ {%- else %}
43
+ {%- set system_message = "" %}
44
+ {%- endif %}
45
+ {%- endif %}
46
+
47
+ {#- System message if there are no images, if the user supplied one, or if tools are used (default tool system message) #}
48
+ {%- if system_message or not image_ns.has_images %}
49
+ {{- "<start_of_turn>system\n" }}
50
+ {%- if tools is not none %}
51
+ {{- "Environment: ipython\n" }}
52
+ {%- endif %}
53
+ {{- "Cutting Knowledge Date: December 2023\n" }}
54
+ {{- "Today Date: " + date_string + "\n\n" }}
55
+ {%- if tools is not none and not tools_in_user_message %}
56
+ {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }}
57
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }}
58
+ {{- "Do not use variables.\n\n" }}
59
+ {%- for t in tools %}
60
+ {{- t | tojson(indent=4) }}
61
+ {{- "\n\n" }}
62
+ {%- endfor %}
63
+ {%- endif %}
64
+ {{- system_message }}
65
+ {{- "<end_of_turn>\n" }}
66
+ {%- endif %}
67
+
68
+ {#- Custom tools are passed in a user message with some extra guidance #}
69
+ {%- if tools_in_user_message and not tools is none %}
70
+ {#- Extract the first user message so we can plug it in here #}
71
+ {%- if messages | length != 0 %}
72
+ {%- if messages[0]['content'] is string %}
73
+ {%- set first_user_message = messages[0]['content']|trim %}
74
+ {%- else %}
75
+ {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %}
76
+ {%- endif %}
77
+ {%- set messages = messages[1:] %}
78
+ {%- else %}
79
+ {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
80
+ {%- endif %}
81
+ {{- '<start_of_turn>user\n' -}}
82
+ {{- "Given the following functions, please respond with a JSON for a function call " }}
83
+ {{- "with its proper arguments that best answers the given prompt.\n\n" }}
84
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }}
85
+ {{- "Do not use variables.\n\n" }}
86
+ {%- for t in tools %}
87
+ {{- t | tojson(indent=4) }}
88
+ {{- "\n\n" }}
89
+ {%- endfor %}
90
+ {{- first_user_message + "<end_of_turn>\n"}}
91
+ {%- endif %}
92
+
93
+ {%- for message in messages %}
94
+ {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
95
+ {%- if message.role == 'assistant' %}
96
+ {{- '<start_of_turn>model\n' }}
97
+ {%- else %}
98
+ {{- '<start_of_turn>' + message['role'] + '\n' }}
99
+ {%- endif %}
100
+ {%- if message['content'] is string %}
101
+ {{- message['content'] | trim}}
102
+ {%- else %}
103
+ {%- for content in message['content'] %}
104
+ {%- if content['type'] == 'image' %}
105
+ {{- '<start_of_image>' }}
106
+ {%- elif content['type'] == 'text' %}
107
+ {{- content['text'] | trim }}
108
+ {%- endif %}
109
+ {%- endfor %}
110
+ {%- endif %}
111
+ {{- '<end_of_turn>\n' }}
112
+ {%- elif 'tool_calls' in message %}
113
+ {%- if not message.tool_calls|length == 1 %}
114
+ {{- raise_exception("This model only supports single tool-calls at once!") }}
115
+ {%- endif %}
116
+ {%- set tool_call = message.tool_calls[0].function %}
117
+ {{- '<start_of_turn>model\n' -}}
118
+ {{- '{"name": "' + tool_call.name + '", ' }}
119
+ {{- '"parameters": ' }}
120
+ {{- tool_call.arguments | tojson }}
121
+ {{- "}" }}
122
+ {{- "<end_of_turn>\n" }}
123
+ {%- elif message.role == "tool" or message.role == "ipython" %}
124
+ {{- "<start_of_turn>ipython\n" }}
125
+ {%- if message.content is string %}
126
+ {{- { "output": message.content } | tojson }}
127
+ {%- else %}
128
+ {%- for content in message['content'] %}
129
+ {%- if content['type'] == 'text' %}
130
+ {{- { "output": content['text'] } | tojson }}
131
+ {%- endif %}
132
+ {%- endfor %}
133
+ {%- endif %}
134
+ {{- "<end_of_turn>\n" }}
135
+ {%- endif %}
136
+ {%- endfor %}
137
+ {%- if add_generation_prompt %}
138
+ {{- '<start_of_turn>model\n' }}
139
+ {%- endif %}
tools/gemma_tool_parser.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import re
5
+ from collections.abc import Sequence
6
+ from json import JSONDecoder
7
+ from typing import Union
8
+
9
+ import partial_json_parser
10
+ from partial_json_parser.core.options import Allow
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from vllm.entrypoints.openai.protocol import (
14
+ ChatCompletionRequest,
15
+ DeltaFunctionCall,
16
+ DeltaMessage,
17
+ DeltaToolCall,
18
+ ExtractedToolCallInformation,
19
+ FunctionCall,
20
+ ToolCall,
21
+ )
22
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
23
+ ToolParser,
24
+ ToolParserManager,
25
+ )
26
+ from vllm.entrypoints.openai.tool_parsers.utils import (
27
+ find_common_prefix,
28
+ is_complete_json,
29
+ partial_json_loads,
30
+ )
31
+ from vllm.logger import init_logger
32
+ from vllm.utils import random_uuid
33
+
34
+ logger = init_logger(__name__)
35
+
36
+
37
+ @ToolParserManager.register_module("gemma_json")
38
+ class GemmaJsonToolParser(ToolParser):
39
+ """
40
+ Tool call parser for Gemma 3 models intended for use with the
41
+ appropriate Gemma chat template.
42
+
43
+ Used when --enable-auto-tool-choice --tool-call-parser gemma_json
44
+ are all set
45
+ """
46
+
47
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
48
+ super().__init__(tokenizer)
49
+
50
+ # initialize properties used for state when parsing tool calls in
51
+ # streaming mode
52
+ self.prev_tool_call_arr: list[dict] = []
53
+ self.current_tool_id: int = -1
54
+ self.current_tool_name_sent: bool = False
55
+ self.streamed_args_for_tool: list[str] = []
56
+
57
+ # Gemma specific tokens
58
+ self.bos_token = "<bos>"
59
+ self.model_token = "<start_of_turn>model"
60
+ self.user_token = "<start_of_turn>user"
61
+ self.end_turn_token = "<end_of_turn>"
62
+
63
+ # For JSON detection
64
+ self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
65
+
66
+ def extract_tool_calls(
67
+ self, model_output: str, request: ChatCompletionRequest
68
+ ) -> ExtractedToolCallInformation:
69
+ """
70
+ Extract the tool calls from a complete model response.
71
+ """
72
+ # case -- if the response doesn't contain JSON, return a text response
73
+ if not model_output.startswith("{"):
74
+ return ExtractedToolCallInformation(
75
+ tools_called=False, tool_calls=[], content=model_output
76
+ )
77
+
78
+ try:
79
+ # load the JSON, and then use it to build the Function and
80
+ # Tool Call
81
+ dec = JSONDecoder()
82
+ function_call_arr = []
83
+
84
+ start_idx = 0
85
+ while start_idx < len(model_output):
86
+ try:
87
+ (obj, end_idx) = dec.raw_decode(model_output[start_idx:])
88
+ start_idx += end_idx
89
+ # Skip any separators like semicolons or commas
90
+ while start_idx < len(model_output) and model_output[start_idx] in [
91
+ ";",
92
+ ",",
93
+ " ",
94
+ ]:
95
+ start_idx += 1
96
+ function_call_arr.append(obj)
97
+ except json.JSONDecodeError:
98
+ break
99
+
100
+ tool_calls: list[ToolCall] = [
101
+ ToolCall(
102
+ type="function",
103
+ function=FunctionCall(
104
+ name=raw_function_call["name"],
105
+ # function call args are JSON but as a string
106
+ arguments=json.dumps(
107
+ raw_function_call["arguments"]
108
+ if "arguments" in raw_function_call
109
+ else raw_function_call["parameters"]
110
+ ),
111
+ ),
112
+ )
113
+ for raw_function_call in function_call_arr
114
+ ]
115
+
116
+ return ExtractedToolCallInformation(
117
+ tools_called=True, tool_calls=tool_calls, content=None
118
+ )
119
+
120
+ except Exception:
121
+ logger.exception("Error in extracting tool call from response.")
122
+ # return information to just treat the tool call as regular JSON
123
+ return ExtractedToolCallInformation(
124
+ tools_called=False, tool_calls=[], content=model_output
125
+ )
126
+
127
+ def extract_tool_calls_streaming(
128
+ self,
129
+ previous_text: str,
130
+ current_text: str,
131
+ delta_text: str,
132
+ previous_token_ids: Sequence[int],
133
+ current_token_ids: Sequence[int],
134
+ delta_token_ids: Sequence[int],
135
+ request: ChatCompletionRequest,
136
+ ) -> Union[DeltaMessage, None]:
137
+
138
+ # Skip if not JSON format
139
+ if not current_text.startswith("{"):
140
+ return DeltaMessage(content=delta_text)
141
+
142
+ # bit mask flags for partial JSON parsing
143
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
144
+ try:
145
+ tool_call_arr = []
146
+ is_complete = []
147
+ try:
148
+ start_idx = 0
149
+ while start_idx < len(current_text):
150
+ (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
151
+ is_complete.append(
152
+ is_complete_json(current_text[start_idx : start_idx + end_idx])
153
+ )
154
+ start_idx += end_idx
155
+ # Skip any separators like semicolons or commas
156
+ while start_idx < len(current_text) and current_text[start_idx] in [
157
+ ";",
158
+ ",",
159
+ " ",
160
+ ]:
161
+ start_idx += 1
162
+
163
+ # Handle parameters field as arguments if needed
164
+ if "parameters" in obj:
165
+ assert (
166
+ "arguments" not in obj
167
+ ), "model generated both parameters and arguments"
168
+ obj["arguments"] = obj["parameters"]
169
+ tool_call_arr.append(obj)
170
+ except partial_json_parser.core.exceptions.MalformedJSON:
171
+ logger.debug("not enough tokens to parse into JSON yet")
172
+ return None
173
+
174
+ # select as the current tool call the one we're on the state at
175
+ current_tool_call: dict = (
176
+ tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
177
+ )
178
+
179
+ # case -- if no tokens have been streamed for the tool, e.g.
180
+ # only the array brackets, stream nothing
181
+ if len(tool_call_arr) == 0:
182
+ return None
183
+
184
+ # case: we are starting a new tool in the array
185
+ # -> array has > 0 length AND length has moved past cursor
186
+ elif (
187
+ len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
188
+ ):
189
+ # if we're moving on to a new call, first make sure we
190
+ # haven't missed anything in the previous one that was
191
+ # auto-generated due to JSON completions, but wasn't
192
+ # streamed to the client yet.
193
+ if self.current_tool_id >= 0:
194
+ cur_arguments = current_tool_call.get("arguments")
195
+ if cur_arguments:
196
+ cur_args_json = json.dumps(cur_arguments)
197
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
198
+ argument_diff = cur_args_json[sent:]
199
+
200
+ logger.debug("got arguments diff: %s", argument_diff)
201
+ delta = DeltaMessage(
202
+ tool_calls=[
203
+ DeltaToolCall(
204
+ index=self.current_tool_id,
205
+ function=DeltaFunctionCall(
206
+ arguments=argument_diff
207
+ ).model_dump(exclude_none=True),
208
+ )
209
+ ]
210
+ )
211
+ self.streamed_args_for_tool[
212
+ self.current_tool_id
213
+ ] += argument_diff
214
+ else:
215
+ delta = None
216
+ else:
217
+ delta = None
218
+ # re-set stuff pertaining to progress in the current tool
219
+ self.current_tool_id = len(tool_call_arr) - 1
220
+ self.current_tool_name_sent = False
221
+ self.streamed_args_for_tool.append("")
222
+ logger.debug("starting on new tool %d", self.current_tool_id)
223
+ return delta
224
+
225
+ # if the current tool name hasn't been sent, send if available
226
+ # - otherwise send nothing
227
+ elif not self.current_tool_name_sent:
228
+ function_name = current_tool_call.get("name")
229
+ if function_name:
230
+ delta = DeltaMessage(
231
+ tool_calls=[
232
+ DeltaToolCall(
233
+ index=self.current_tool_id,
234
+ type="function",
235
+ id=f"chatcmpl-tool-{random_uuid()}",
236
+ function=DeltaFunctionCall(
237
+ name=function_name
238
+ ).model_dump(exclude_none=True),
239
+ )
240
+ ]
241
+ )
242
+ self.current_tool_name_sent = True
243
+ else:
244
+ delta = None
245
+
246
+ # now we know we're on the same tool call and we're streaming
247
+ # arguments
248
+ else:
249
+ cur_arguments = current_tool_call.get("arguments")
250
+ delta = None
251
+
252
+ if cur_arguments:
253
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
254
+ cur_args_json = json.dumps(cur_arguments)
255
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
256
+ "arguments"
257
+ )
258
+
259
+ argument_diff = None
260
+ if is_complete[self.current_tool_id]:
261
+ argument_diff = cur_args_json[sent:]
262
+ elif prev_arguments:
263
+ prev_args_json = json.dumps(prev_arguments)
264
+ if cur_args_json != prev_args_json:
265
+ prefix = find_common_prefix(prev_args_json, cur_args_json)
266
+ argument_diff = prefix[sent:]
267
+
268
+ if argument_diff is not None:
269
+ delta = DeltaMessage(
270
+ tool_calls=[
271
+ DeltaToolCall(
272
+ index=self.current_tool_id,
273
+ function=DeltaFunctionCall(
274
+ arguments=argument_diff
275
+ ).model_dump(exclude_none=True),
276
+ )
277
+ ]
278
+ )
279
+ self.streamed_args_for_tool[
280
+ self.current_tool_id
281
+ ] += argument_diff
282
+
283
+ self.prev_tool_call_arr = tool_call_arr
284
+ return delta
285
+
286
+ except Exception:
287
+ logger.exception("Error trying to handle streaming tool call.")
288
+ logger.debug(
289
+ "Skipping chunk as a result of tool streaming extraction error"
290
+ )
291
+ return None