LHC88 commited on
Commit
1d427be
·
1 Parent(s): dfe2122

added tokenization_functionary.yp

Browse files
Files changed (1) hide show
  1. tokenization_functionary.py +524 -0
tokenization_functionary.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, MeetKai Inc. All rights reserved.
2
+
3
+ from copy import deepcopy
4
+ import json
5
+ from typing import Any, Dict, List, Literal, Optional, Union
6
+
7
+ import jsonref
8
+ from pydantic import BaseModel, Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from transformers.tokenization_utils_base import BatchEncoding
12
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
13
+ from transformers.utils import TensorType, logging
14
+
15
+
16
+ logger = logging.get_logger(__name__)
17
+ SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
18
+ CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."""
19
+
20
+ class Function(BaseModel):
21
+ name: str
22
+ description: Optional[str] = Field(default="")
23
+ parameters: Optional[dict] = None
24
+
25
+
26
+ class Tool(BaseModel):
27
+ type: Literal["function", "code_interpreter"]
28
+ function: Optional[Function] = None
29
+
30
+ @model_validator(mode="after")
31
+ def check_type_function_matches(self) -> Self:
32
+ if self.type == "function":
33
+ assert self.function is not None, '"function" must contain function description when `"type": "function"`'
34
+ else:
35
+ assert self.function is None, '"function" must not be provided when `"type": "code_interpreter"`'
36
+ return self
37
+
38
+
39
+ def convert_data_type(param_type: str) -> str:
40
+ """convert data_type to typescript data type
41
+ Args:
42
+ param_type (str): param_type
43
+ Returns:
44
+ str: param type in typescript
45
+ """
46
+ if param_type == "integer" or param_type == "float":
47
+ return "number"
48
+ return param_type
49
+
50
+
51
+ def get_param_type(param: Dict) -> str:
52
+ """get param_type of parameter
53
+ Args:
54
+ param (Dict): param dict in properties
55
+ Returns:
56
+ str: _description_
57
+ """
58
+ param_type = "any"
59
+ if "type" in param:
60
+ raw_param_type = param["type"]
61
+ if type(raw_param_type) is list:
62
+ param_type = " | ".join(raw_param_type)
63
+ else:
64
+ param_type = raw_param_type
65
+
66
+ else: # in many cases, the json schema contains: oneOf instead of "type"
67
+ if "oneOf" in param:
68
+ one_of_types = []
69
+ for item in param["oneOf"]:
70
+ if "type" in item:
71
+ one_of_types.append(convert_data_type(item["type"]))
72
+ one_of_types = list(set(one_of_types))
73
+ param_type = " | ".join(one_of_types)
74
+ return convert_data_type(param_type)
75
+
76
+
77
+ def get_format_param(param: Dict) -> Optional[str]:
78
+ """Get "format" from param. There are cases where format is not directly in param but in oneOf
79
+ Args:
80
+ param (Dict): _description_
81
+ Returns:
82
+ Optional[str]: _description_
83
+ """
84
+ if "format" in param:
85
+ return param["format"]
86
+ if "oneOf" in param:
87
+ formats = []
88
+ for item in param["oneOf"]:
89
+ if "format" in item:
90
+ formats.append(item["format"])
91
+ if len(formats) > 0:
92
+ return " or ".join(formats)
93
+ return None
94
+
95
+
96
+ def get_param_info(param: Dict) -> Optional[str]:
97
+ """get additional information about parameter such as: format, default value, min, max, ...
98
+ Args:
99
+ param (Dict): _description_
100
+ Returns:
101
+ Optional[str]: _description_
102
+ """
103
+ param_type = param.get("type", "any")
104
+ info_list = []
105
+ if "description" in param:
106
+ desc = param["description"]
107
+ if not desc.endswith("."):
108
+ desc += "."
109
+ info_list.append(desc)
110
+
111
+ if "default" in param:
112
+ default_value = param["default"]
113
+ if param_type == "string":
114
+ default_value = f'"{default_value}"' # if string --> add ""
115
+ info_list.append(f"Default={default_value}.")
116
+
117
+ format_param = get_format_param(param)
118
+ if format_param is not None:
119
+ info_list.append("Format=" + format_param)
120
+
121
+ for field, field_name in [
122
+ ("maximum", "Maximum"),
123
+ ("minimum", "Minimum"),
124
+ ("maxLength", "Maximum length"),
125
+ ("minLength", "Minimum length"),
126
+ ]:
127
+ if field in param:
128
+ info_list.append(f"{field_name}=" + str(param[field]))
129
+
130
+ if len(info_list) > 0:
131
+ result = "// " + " ".join(info_list)
132
+ result = result.replace("\n", " ")
133
+ return result
134
+ return None
135
+
136
+
137
+ def append_new_param_info(
138
+ info_list: List[str],
139
+ param_declaration: str,
140
+ comment_info: Optional[str],
141
+ examples_info: List,
142
+ depth: int,
143
+ ):
144
+ """Append a new parameter with comment to the info_list
145
+ Args:
146
+ info_lines (List[str]): current info_list
147
+ param_declaration (str): param: type
148
+ comment_info (Optional[str]): information of comment
149
+ examples_info (List): information of examples given
150
+ depth (int): level of nested param
151
+ """
152
+ offset = ""
153
+ if depth >= 1:
154
+ offset = "".join([" " for _ in range(depth)])
155
+ if comment_info is not None:
156
+ # if depth == 0: # format: //comment\nparam: type
157
+ info_list.append(f"{offset}{comment_info}")
158
+ if len(examples_info) > 0:
159
+ for example in examples_info:
160
+ info_list.append(f"{offset}{example}")
161
+ info_list.append(f"{offset}{param_declaration}")
162
+ # else: # format: param: type // comment
163
+ # info_list.append(f"{offset}{param_declaration} {comment_info}")
164
+ else:
165
+ info_list.append(f"{offset}{param_declaration}")
166
+
167
+
168
+ def get_examples_info(param_name: str, examples: List) -> List:
169
+ """get information about examples provided
170
+ Args:
171
+ param_name (str): _description_
172
+ examples (List): _description_
173
+ Returns:
174
+ List: _description_
175
+ """
176
+ examples_list = [f"// Example {param_name}:"]
177
+ for example in examples:
178
+ if isinstance(example, dict) or isinstance(example, list):
179
+ example_str = json.dumps(example, ensure_ascii=False).replace('\n', '\\n')
180
+ else:
181
+ example_str = str(example).replace('\n', '\\n')
182
+ examples_list.append(f"// {example_str}")
183
+
184
+ return examples_list
185
+
186
+
187
+ def get_enum_option_str(enum_options: List) -> str:
188
+ """get enum option separated by: "|"
189
+ Args:
190
+ enum_options (List): list of options
191
+ Returns:
192
+ _type_: concatenation of options separated by "|"
193
+ """
194
+ # if each option is string --> add quote
195
+ return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
196
+
197
+
198
+ def get_array_typescript(
199
+ param_name: Optional[str], param_dic: dict, depth: int = 0
200
+ ) -> str:
201
+ """recursive implementation for generating type script of array
202
+ Args:
203
+ param_name (Optional[str]): name of param, optional
204
+ param_dic (dict): param_dic
205
+ depth (int, optional): nested level. Defaults to 0.
206
+ Returns:
207
+ _type_: typescript of array
208
+ """
209
+ offset = ""
210
+ if depth >= 1:
211
+ offset = "".join([" " for _ in range(depth)])
212
+ items_info = param_dic.get("items", {})
213
+
214
+ if len(items_info) == 0:
215
+ if param_name is not None:
216
+ return f"{offset}{param_name}: []"
217
+ else:
218
+ return "[]"
219
+ array_type = get_param_type(items_info)
220
+ if array_type == "object":
221
+ info_lines = []
222
+ child_lines = get_parameter_typescript(
223
+ items_info.get("properties", {}), items_info.get("required", []), depth + 1
224
+ )
225
+ # if comment_info is not None:
226
+ # info_lines.append(f"{offset}{comment_info}")
227
+ if param_name is not None:
228
+ info_lines.append(f"{offset}{param_name}" + ": {")
229
+ else:
230
+ info_lines.append(f"{offset}" + "{")
231
+ info_lines.extend(child_lines)
232
+ info_lines.append(f"{offset}" + "}[]")
233
+ return "\n".join(info_lines)
234
+
235
+ elif array_type == "array":
236
+ item_info = get_array_typescript(None, items_info, depth + 1)
237
+ if param_name is None:
238
+ return f"{item_info}[]"
239
+ return f"{offset}{param_name}: {item_info.strip()}[]"
240
+
241
+ else:
242
+ if "enum" in items_info:
243
+ item_type = get_enum_option_str(items_info["enum"])
244
+ if param_name is None:
245
+ return f"({item_type})[]"
246
+ else:
247
+ return f"{offset}{param_name}: ({item_type})[]"
248
+ else:
249
+ if param_name is None:
250
+ return f"{array_type}[]"
251
+ else:
252
+ return f"{offset}{param_name}: {array_type}[],"
253
+
254
+
255
+ def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
256
+ """Recursion, returning the information about parameters including data type, description and other information
257
+ These kinds of information will be put into the prompt
258
+ Args:
259
+ properties (_type_): properties in parameters
260
+ required_params (_type_): List of required parameters
261
+ depth (int, optional): the depth of params (nested level). Defaults to 0.
262
+ Returns:
263
+ _type_: list of lines containing information about all parameters
264
+ """
265
+ tp_lines = []
266
+ for param_name, param in properties.items():
267
+ # Sometimes properties have "required" field as a list of string.
268
+ # Even though its supposed to be not under properties. So we skip it
269
+ if not isinstance(param, dict):
270
+ continue
271
+ # Param Description
272
+ comment_info = get_param_info(param)
273
+ # Param Examples
274
+ examples_info = []
275
+ if "examples" in param:
276
+ examples_info = get_examples_info(param_name, param["examples"])
277
+ # Param Name declaration
278
+ param_declaration = f"{param_name}"
279
+ if isinstance(required_params, list):
280
+ if param_name not in required_params:
281
+ param_declaration += "?"
282
+ param_type = get_param_type(param)
283
+
284
+ offset = ""
285
+ if depth >= 1:
286
+ offset = "".join([" " for _ in range(depth)])
287
+
288
+ if param_type == "object": # param_type is object
289
+ child_lines = get_parameter_typescript(
290
+ param.get("properties", {}), param.get("required", []), depth + 1
291
+ )
292
+ if comment_info is not None:
293
+ tp_lines.append(f"{offset}{comment_info}")
294
+ if len(examples_info) > 0:
295
+ for example in examples_info:
296
+ tp_lines.append(f"{offset}{example}")
297
+
298
+ param_declaration += ": {"
299
+ tp_lines.append(f"{offset}{param_declaration}")
300
+ tp_lines.extend(child_lines)
301
+ tp_lines.append(f"{offset}" + "},")
302
+
303
+ elif param_type == "array": # param_type is an array
304
+ item_info = param.get("items", {})
305
+ if "type" not in item_info: # don't know type of array
306
+ param_declaration += ": [],"
307
+ append_new_param_info(
308
+ tp_lines, param_declaration, comment_info, examples_info, depth
309
+ )
310
+ else:
311
+ array_declaration = get_array_typescript(
312
+ param_declaration, param, depth
313
+ )
314
+ if not array_declaration.endswith(","):
315
+ array_declaration += ","
316
+ if comment_info is not None:
317
+ tp_lines.append(f"{offset}{comment_info}")
318
+ if len(examples_info) > 0:
319
+ for example in examples_info:
320
+ tp_lines.append(f"{offset}{example}")
321
+ tp_lines.append(array_declaration)
322
+ else:
323
+ if "enum" in param:
324
+ param_type = get_enum_option_str(param["enum"])
325
+ # param_type = " | ".join([f'"{v}"' for v in param["enum"]])
326
+ if "nullable" in param and param["nullable"] is True:
327
+ param_type += " | null"
328
+ param_declaration += f": {param_type},"
329
+ append_new_param_info(
330
+ tp_lines, param_declaration, comment_info, examples_info, depth
331
+ )
332
+
333
+ return tp_lines
334
+
335
+ def generate_schema_from_functions(
336
+ functions: List[Function], namespace="functions"
337
+ ) -> str:
338
+ """
339
+ Convert functions schema to a schema that language models can understand.
340
+ """
341
+
342
+ schema = "// Supported function definitions that should be called when necessary.\n"
343
+ schema += f"namespace {namespace} {{\n\n"
344
+
345
+ for function in functions:
346
+ # Convert a Function object to dict, if necessary
347
+ if not isinstance(function, dict):
348
+ function = function.model_dump()
349
+ function_name = function.get("name", None)
350
+ if function_name is None:
351
+ continue
352
+
353
+ description = function.get("description", "")
354
+ schema += f"// {description}\n"
355
+ schema += f"type {function_name}"
356
+
357
+ parameters = function.get("parameters", None)
358
+ if parameters is not None and parameters.get("properties") is not None:
359
+ parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters))
360
+ schema += " = (_: {\n"
361
+ required_params = parameters.get("required", [])
362
+ tp_lines = get_parameter_typescript(
363
+ parameters.get("properties"),
364
+ required_params,
365
+ 0,
366
+ )
367
+ schema += "\n".join(tp_lines)
368
+ schema += "\n}) => any;\n\n"
369
+ else:
370
+ # Doesn't have any parameters
371
+ schema += " = () => any;\n\n"
372
+
373
+ schema += f"}} // namespace {namespace}"
374
+
375
+ return schema
376
+
377
+ class FunctionaryTokenizer(PreTrainedTokenizerFast):
378
+ def apply_chat_template(
379
+ self,
380
+ conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str],
381
+ tools: Optional[List[Dict[str, Any]]],
382
+ chat_template: Optional[str] = None,
383
+ add_generation_prompt: bool = False,
384
+ tokenize: bool = True,
385
+ padding: bool = False,
386
+ truncation: bool = False,
387
+ max_length: Optional[int] = None,
388
+ return_tensors: Optional[Union[str, TensorType]] = None,
389
+ return_dict: bool = False,
390
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
391
+ **kwargs,
392
+ ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
393
+
394
+ if return_dict and not tokenize:
395
+ raise ValueError(
396
+ "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
397
+ "of tokenizer outputs to return."
398
+ )
399
+
400
+ if tokenizer_kwargs is None:
401
+ tokenizer_kwargs = {}
402
+
403
+ using_default_template = False
404
+
405
+ # First, handle the cases when the model has a dict of multiple templates
406
+ if isinstance(self.chat_template, dict) or (
407
+ self.chat_template is None and isinstance(self.default_chat_template, dict)
408
+ ):
409
+ if self.chat_template is not None:
410
+ template_dict = self.chat_template
411
+ using_default_dict = False
412
+ else:
413
+ template_dict = self.default_chat_template
414
+ using_default_dict = True
415
+ if chat_template is not None and chat_template in template_dict:
416
+ # The user can pass the name of a template to the chat template argument instead of an entire template
417
+ chat_template = template_dict[chat_template]
418
+ if using_default_dict:
419
+ using_default_template = True
420
+ elif chat_template is None and "default" in template_dict:
421
+ chat_template = template_dict["default"]
422
+ if using_default_dict:
423
+ using_default_template = True
424
+ elif chat_template is None:
425
+ raise ValueError(
426
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
427
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
428
+ f"template names are {sorted(template_dict.keys())}."
429
+ )
430
+ elif chat_template is None:
431
+ # These are the cases when the model has a single template
432
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
433
+ if self.chat_template is not None:
434
+ chat_template = self.chat_template
435
+ else:
436
+ chat_template = self.default_chat_template
437
+ using_default_template = True
438
+
439
+ if using_default_template:
440
+ logger.warning_once(
441
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
442
+ "very error-prone, because models are often trained with templates different from the class default! "
443
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
444
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
445
+ "then to ensure that this model continues working without issues."
446
+ )
447
+
448
+ PYTHON_RUN_SYS_MSG = "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."
449
+ SYSTEM_CONTENT = """You are capable of executing available function(s) if required.
450
+ Only execute function(s) when absolutely necessary.
451
+ Ask for the required input to:recipient==all
452
+ Use JSON for function arguments.
453
+ Respond in this format:
454
+ >>>${recipient}
455
+ ${content}
456
+ Available functions:
457
+ """
458
+
459
+ # Prepare tools/functions into schema
460
+ functions_pydantic_to_render = []
461
+ has_code_interpreter = False
462
+ if tools is not None:
463
+ for item in tools:
464
+ if (
465
+ "function" in item and item["function"] is not None
466
+ ): # new data format: tools: [{"type": xx, "function": xxx}]
467
+ functions_pydantic_to_render.append(item["function"])
468
+ elif "type" in item and item["type"] == "code_interpreter":
469
+ has_code_interpreter = True
470
+ else:
471
+ functions_pydantic_to_render.append(item) # old format
472
+
473
+ conversation.insert(
474
+ 0,
475
+ {
476
+ "role": "system",
477
+ "content": SYSTEM_CONTENT + generate_schema_from_functions(functions_pydantic_to_render),
478
+ },
479
+ )
480
+ if has_code_interpreter:
481
+ conversation.insert(1, {"role": "system", "content": PYTHON_RUN_SYS_MSG})
482
+
483
+ # Compilation function uses a cache to avoid recompiling the same template
484
+ compiled_template = self._compile_jinja_template(chat_template)
485
+
486
+ if isinstance(conversation, (list, tuple)) and (
487
+ isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
488
+ ):
489
+ conversations = conversation
490
+ is_batched = True
491
+ else:
492
+ conversations = [conversation]
493
+ is_batched = False
494
+
495
+ rendered = []
496
+ template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
497
+ for chat in conversations:
498
+ if hasattr(chat, "messages"):
499
+ # Indicates it's a Conversation object
500
+ chat = chat.messages
501
+ rendered_chat = compiled_template.render(
502
+ messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
503
+ )
504
+ rendered.append(rendered_chat)
505
+
506
+ if not is_batched:
507
+ rendered = rendered[0]
508
+
509
+ if tokenize:
510
+ out = self(
511
+ rendered,
512
+ padding=padding,
513
+ truncation=truncation,
514
+ max_length=max_length,
515
+ add_special_tokens=False,
516
+ return_tensors=return_tensors,
517
+ **tokenizer_kwargs,
518
+ )
519
+ if return_dict:
520
+ return out
521
+ else:
522
+ return out["input_ids"]
523
+ else:
524
+ return rendered