jeffreymeetkai commited on
Commit
25693df
·
verified ·
1 Parent(s): 9d78f40

Delete tokenization_functionary.py

Browse files
Files changed (1) hide show
  1. tokenization_functionary.py +0 -520
tokenization_functionary.py DELETED
@@ -1,520 +0,0 @@
1
- # Copyright (c) 2024, MeetKai Inc. All rights reserved.
2
-
3
- from copy import deepcopy
4
- import json
5
- from typing import Any, Dict, List, Literal, Optional, Union
6
-
7
- import jsonref
8
- from pydantic import BaseModel, Field, model_validator
9
- from typing_extensions import Self
10
-
11
- from transformers.tokenization_utils_base import BatchEncoding
12
- from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
13
- from transformers.utils import TensorType, logging
14
-
15
-
16
- logger = logging.get_logger(__name__)
17
- SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
18
- CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."""
19
-
20
- class Function(BaseModel):
21
- name: str
22
- description: Optional[str] = Field(default="")
23
- parameters: Optional[dict] = None
24
-
25
-
26
- class Tool(BaseModel):
27
- type: Literal["function", "code_interpreter"]
28
- function: Optional[Function] = None
29
-
30
- @model_validator(mode="after")
31
- def check_type_function_matches(self) -> Self:
32
- if self.type == "function":
33
- assert self.function is not None, '"function" must contain function description when `"type": "function"`'
34
- else:
35
- assert self.function is None, '"function" must not be provided when `"type": "code_interpreter"`'
36
- return self
37
-
38
-
39
- def convert_data_type(param_type: str) -> str:
40
- """convert data_type to typescript data type
41
-
42
- Args:
43
- param_type (str): param_type
44
-
45
- Returns:
46
- str: param type in typescript
47
- """
48
- if param_type == "integer" or param_type == "float":
49
- return "number"
50
- return param_type
51
-
52
-
53
- def get_param_type(param: Dict) -> str:
54
- """get param_type of parameter
55
-
56
- Args:
57
- param (Dict): param dict in properties
58
-
59
- Returns:
60
- str: _description_
61
- """
62
- param_type = "any"
63
- if "type" in param:
64
- raw_param_type = param["type"]
65
- if type(raw_param_type) is list:
66
- param_type = " | ".join(raw_param_type)
67
- else:
68
- param_type = raw_param_type
69
-
70
- else: # in many cases, the json schema contains: oneOf instead of "type"
71
- if "oneOf" in param:
72
- one_of_types = []
73
- for item in param["oneOf"]:
74
- if "type" in item:
75
- one_of_types.append(convert_data_type(item["type"]))
76
- one_of_types = list(set(one_of_types))
77
- param_type = " | ".join(one_of_types)
78
- return convert_data_type(param_type)
79
-
80
-
81
- def get_format_param(param: Dict) -> Optional[str]:
82
- """Get "format" from param. There are cases where format is not directly in param but in oneOf
83
-
84
- Args:
85
- param (Dict): _description_
86
-
87
- Returns:
88
- Optional[str]: _description_
89
- """
90
- if "format" in param:
91
- return param["format"]
92
- if "oneOf" in param:
93
- formats = []
94
- for item in param["oneOf"]:
95
- if "format" in item:
96
- formats.append(item["format"])
97
- if len(formats) > 0:
98
- return " or ".join(formats)
99
- return None
100
-
101
-
102
- def get_param_info(param: Dict) -> Optional[str]:
103
- """get additional information about parameter such as: format, default value, min, max, ...
104
-
105
- Args:
106
- param (Dict): _description_
107
-
108
- Returns:
109
- Optional[str]: _description_
110
- """
111
- param_type = param.get("type", "any")
112
- info_list = []
113
- if "description" in param:
114
- desc = param["description"]
115
- if not desc.endswith("."):
116
- desc += "."
117
- info_list.append(desc)
118
-
119
- if "default" in param:
120
- default_value = param["default"]
121
- if param_type == "string":
122
- default_value = f'"{default_value}"' # if string --> add ""
123
- info_list.append(f"Default={default_value}.")
124
-
125
- format_param = get_format_param(param)
126
- if format_param is not None:
127
- info_list.append("Format=" + format_param)
128
-
129
- for field, field_name in [
130
- ("maximum", "Maximum"),
131
- ("minimum", "Minimum"),
132
- ("maxLength", "Maximum length"),
133
- ("minLength", "Minimum length"),
134
- ]:
135
- if field in param:
136
- info_list.append(f"{field_name}=" + str(param[field]))
137
-
138
- if len(info_list) > 0:
139
- result = "// " + " ".join(info_list)
140
- result = result.replace("\n", " ")
141
- return result
142
- return None
143
-
144
-
145
- def append_new_param_info(
146
- info_list: List[str],
147
- param_declaration: str,
148
- comment_info: Optional[str],
149
- examples_info: List,
150
- depth: int,
151
- ):
152
- """Append a new parameter with comment to the info_list
153
-
154
- Args:
155
- info_lines (List[str]): current info_list
156
- param_declaration (str): param: type
157
- comment_info (Optional[str]): information of comment
158
- examples_info (List): information of examples given
159
- depth (int): level of nested param
160
- """
161
- offset = ""
162
- if depth >= 1:
163
- offset = "".join([" " for _ in range(depth)])
164
- if comment_info is not None:
165
- # if depth == 0: # format: //comment\nparam: type
166
- info_list.append(f"{offset}{comment_info}")
167
- if len(examples_info) > 0:
168
- for example in examples_info:
169
- info_list.append(f"{offset}{example}")
170
- info_list.append(f"{offset}{param_declaration}")
171
- # else: # format: param: type // comment
172
- # info_list.append(f"{offset}{param_declaration} {comment_info}")
173
- else:
174
- info_list.append(f"{offset}{param_declaration}")
175
-
176
-
177
- def get_examples_info(param_name: str, examples: List) -> List:
178
- """get information about examples provided
179
-
180
- Args:
181
- param_name (str): _description_
182
- examples (List): _description_
183
-
184
- Returns:
185
- List: _description_
186
- """
187
- examples_list = [f"// Example {param_name}:"]
188
- for example in examples:
189
- if isinstance(example, dict) or isinstance(example, list):
190
- example_str = json.dumps(example, ensure_ascii=False).replace('\n', '\\n')
191
- else:
192
- example_str = str(example).replace('\n', '\\n')
193
- examples_list.append(f"// {example_str}")
194
-
195
- return examples_list
196
-
197
-
198
- def get_enum_option_str(enum_options: List) -> str:
199
- """get enum option separated by: "|"
200
-
201
- Args:
202
- enum_options (List): list of options
203
-
204
- Returns:
205
- _type_: concatenation of options separated by "|"
206
- """
207
- # if each option is string --> add quote
208
- return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
209
-
210
-
211
- def get_array_typescript(
212
- param_name: Optional[str], param_dic: dict, depth: int = 0
213
- ) -> str:
214
- """recursive implementation for generating type script of array
215
-
216
- Args:
217
- param_name (Optional[str]): name of param, optional
218
- param_dic (dict): param_dic
219
- depth (int, optional): nested level. Defaults to 0.
220
-
221
- Returns:
222
- _type_: typescript of array
223
- """
224
- offset = ""
225
- if depth >= 1:
226
- offset = "".join([" " for _ in range(depth)])
227
- items_info = param_dic.get("items", {})
228
-
229
- if len(items_info) == 0:
230
- if param_name is not None:
231
- return f"{offset}{param_name}: []"
232
- else:
233
- return "[]"
234
- array_type = get_param_type(items_info)
235
- if array_type == "object":
236
- info_lines = []
237
- child_lines = get_parameter_typescript(
238
- items_info.get("properties", {}), items_info.get("required", []), depth + 1
239
- )
240
- # if comment_info is not None:
241
- # info_lines.append(f"{offset}{comment_info}")
242
- if param_name is not None:
243
- info_lines.append(f"{offset}{param_name}" + ": {")
244
- else:
245
- info_lines.append(f"{offset}" + "{")
246
- info_lines.extend(child_lines)
247
- info_lines.append(f"{offset}" + "}[]")
248
- return "\n".join(info_lines)
249
-
250
- elif array_type == "array":
251
- item_info = get_array_typescript(None, items_info, depth + 1)
252
- if param_name is None:
253
- return f"{item_info}[]"
254
- return f"{offset}{param_name}: {item_info.strip()}[]"
255
-
256
- else:
257
- if "enum" in items_info:
258
- item_type = get_enum_option_str(items_info["enum"])
259
- if param_name is None:
260
- return f"({item_type})[]"
261
- else:
262
- return f"{offset}{param_name}: ({item_type})[]"
263
- else:
264
- if param_name is None:
265
- return f"{array_type}[]"
266
- else:
267
- return f"{offset}{param_name}: {array_type}[],"
268
-
269
-
270
- def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
271
- """Recursion, returning the information about parameters including data type, description and other information
272
- These kinds of information will be put into the prompt
273
-
274
- Args:
275
- properties (_type_): properties in parameters
276
- required_params (_type_): List of required parameters
277
- depth (int, optional): the depth of params (nested level). Defaults to 0.
278
-
279
- Returns:
280
- _type_: list of lines containing information about all parameters
281
- """
282
- tp_lines = []
283
- for param_name, param in properties.items():
284
- # Sometimes properties have "required" field as a list of string.
285
- # Even though its supposed to be not under properties. So we skip it
286
- if not isinstance(param, dict):
287
- continue
288
- # Param Description
289
- comment_info = get_param_info(param)
290
- # Param Examples
291
- examples_info = []
292
- if "examples" in param:
293
- examples_info = get_examples_info(param_name, param["examples"])
294
- # Param Name declaration
295
- param_declaration = f"{param_name}"
296
- if isinstance(required_params, list):
297
- if param_name not in required_params:
298
- param_declaration += "?"
299
- param_type = get_param_type(param)
300
-
301
- offset = ""
302
- if depth >= 1:
303
- offset = "".join([" " for _ in range(depth)])
304
-
305
- if param_type == "object": # param_type is object
306
- child_lines = get_parameter_typescript(
307
- param.get("properties", {}), param.get("required", []), depth + 1
308
- )
309
- if comment_info is not None:
310
- tp_lines.append(f"{offset}{comment_info}")
311
- if len(examples_info) > 0:
312
- for example in examples_info:
313
- tp_lines.append(f"{offset}{example}")
314
-
315
- param_declaration += ": {"
316
- tp_lines.append(f"{offset}{param_declaration}")
317
- tp_lines.extend(child_lines)
318
- tp_lines.append(f"{offset}" + "},")
319
-
320
- elif param_type == "array": # param_type is an array
321
- item_info = param.get("items", {})
322
- if "type" not in item_info: # don't know type of array
323
- param_declaration += ": [],"
324
- append_new_param_info(
325
- tp_lines, param_declaration, comment_info, examples_info, depth
326
- )
327
- else:
328
- array_declaration = get_array_typescript(
329
- param_declaration, param, depth
330
- )
331
- if not array_declaration.endswith(","):
332
- array_declaration += ","
333
- if comment_info is not None:
334
- tp_lines.append(f"{offset}{comment_info}")
335
- if len(examples_info) > 0:
336
- for example in examples_info:
337
- tp_lines.append(f"{offset}{example}")
338
- tp_lines.append(array_declaration)
339
- else:
340
- if "enum" in param:
341
- param_type = get_enum_option_str(param["enum"])
342
- # param_type = " | ".join([f'"{v}"' for v in param["enum"]])
343
- if "nullable" in param and param["nullable"] is True:
344
- param_type += " | null"
345
- param_declaration += f": {param_type},"
346
- append_new_param_info(
347
- tp_lines, param_declaration, comment_info, examples_info, depth
348
- )
349
-
350
- return tp_lines
351
-
352
- def generate_schema_from_functions(
353
- functions: List[Function], namespace="functions"
354
- ) -> str:
355
- """
356
- Convert functions schema to a schema that language models can understand.
357
- """
358
-
359
- schema = "// Supported function definitions that should be called when necessary.\n"
360
- schema += f"namespace {namespace} {{\n\n"
361
-
362
- for function in functions:
363
- # Convert a Function object to dict, if necessary
364
- if not isinstance(function, dict):
365
- function = function.model_dump()
366
- function_name = function.get("name", None)
367
- if function_name is None:
368
- continue
369
-
370
- description = function.get("description", "")
371
- schema += f"// {description}\n"
372
- schema += f"type {function_name}"
373
-
374
- parameters = function.get("parameters", None)
375
- if parameters is not None and parameters.get("properties") is not None:
376
- parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters))
377
- schema += " = (_: {\n"
378
- required_params = parameters.get("required", [])
379
- tp_lines = get_parameter_typescript(
380
- parameters.get("properties"),
381
- required_params,
382
- 0,
383
- )
384
- schema += "\n".join(tp_lines)
385
- schema += "\n}) => any;\n\n"
386
- else:
387
- # Doesn't have any parameters
388
- schema += " = () => any;\n\n"
389
-
390
- schema += f"}} // namespace {namespace}"
391
-
392
- return schema
393
-
394
- class FunctionaryTokenizer(PreTrainedTokenizerFast):
395
- def apply_chat_template(
396
- self,
397
- conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str],
398
- tools: Optional[List[Dict[str, Any]]],
399
- chat_template: Optional[str] = None,
400
- add_generation_prompt: bool = False,
401
- tokenize: bool = True,
402
- padding: bool = False,
403
- truncation: bool = False,
404
- max_length: Optional[int] = None,
405
- return_tensors: Optional[Union[str, TensorType]] = None,
406
- return_dict: bool = False,
407
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
408
- **kwargs,
409
- ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
410
-
411
- if return_dict and not tokenize:
412
- raise ValueError(
413
- "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
414
- "of tokenizer outputs to return."
415
- )
416
-
417
- if tokenizer_kwargs is None:
418
- tokenizer_kwargs = {}
419
-
420
- using_default_template = False
421
-
422
- # First, handle the cases when the model has a dict of multiple templates
423
- if isinstance(self.chat_template, dict) or (
424
- self.chat_template is None and isinstance(self.default_chat_template, dict)
425
- ):
426
- if self.chat_template is not None:
427
- template_dict = self.chat_template
428
- using_default_dict = False
429
- else:
430
- template_dict = self.default_chat_template
431
- using_default_dict = True
432
- if chat_template is not None and chat_template in template_dict:
433
- # The user can pass the name of a template to the chat template argument instead of an entire template
434
- chat_template = template_dict[chat_template]
435
- if using_default_dict:
436
- using_default_template = True
437
- elif chat_template is None and "default" in template_dict:
438
- chat_template = template_dict["default"]
439
- if using_default_dict:
440
- using_default_template = True
441
- elif chat_template is None:
442
- raise ValueError(
443
- "This model has multiple chat templates with no default specified! Please either pass a chat "
444
- "template or the name of the template you wish to use to the `chat_template` argument. Available "
445
- f"template names are {sorted(template_dict.keys())}."
446
- )
447
- elif chat_template is None:
448
- # These are the cases when the model has a single template
449
- # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
450
- if self.chat_template is not None:
451
- chat_template = self.chat_template
452
- else:
453
- chat_template = self.default_chat_template
454
- using_default_template = True
455
-
456
- if using_default_template:
457
- logger.warning_once(
458
- "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
459
- "very error-prone, because models are often trained with templates different from the class default! "
460
- "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
461
- "point any code depending on them will stop working. We recommend setting a valid chat template before "
462
- "then to ensure that this model continues working without issues."
463
- )
464
-
465
- # Prepare tools/functions into schema
466
- functions_pydantic_to_render = []
467
- has_code_interpreter = False
468
- for i in range(len(tools)):
469
- tool_pydantic = Tool.model_validate(tools[i])
470
- if tool_pydantic.type == "function":
471
- functions_pydantic_to_render.append(tool_pydantic.function)
472
- else:
473
- has_code_interpreter = True
474
- conversation.insert(0, {"role": "system", "content": generate_schema_from_functions(functions_pydantic_to_render)})
475
- # Insert system prompt
476
- system_prompt_to_use = SYSTEM_PROMPT if not has_code_interpreter else CODE_INTERPRETER_SYSTEM_PROMPT
477
- conversation.insert(1, {"role": "system", "content": system_prompt_to_use})
478
-
479
- # Compilation function uses a cache to avoid recompiling the same template
480
- compiled_template = self._compile_jinja_template(chat_template)
481
-
482
- if isinstance(conversation, (list, tuple)) and (
483
- isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
484
- ):
485
- conversations = conversation
486
- is_batched = True
487
- else:
488
- conversations = [conversation]
489
- is_batched = False
490
-
491
- rendered = []
492
- template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
493
- for chat in conversations:
494
- if hasattr(chat, "messages"):
495
- # Indicates it's a Conversation object
496
- chat = chat.messages
497
- rendered_chat = compiled_template.render(
498
- messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
499
- )
500
- rendered.append(rendered_chat)
501
-
502
- if not is_batched:
503
- rendered = rendered[0]
504
-
505
- if tokenize:
506
- out = self(
507
- rendered,
508
- padding=padding,
509
- truncation=truncation,
510
- max_length=max_length,
511
- add_special_tokens=False,
512
- return_tensors=return_tensors,
513
- **tokenizer_kwargs,
514
- )
515
- if return_dict:
516
- return out
517
- else:
518
- return out["input_ids"]
519
- else:
520
- return rendered