jeffreymeetkai commited on
Commit
30e6f2c
·
verified ·
1 Parent(s): b1d8444

Create modeling_functionary.py

Browse files
Files changed (1) hide show
  1. modeling_functionary.py +110 -0
modeling_functionary.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ coding=utf-8
2
+ # Copyright (c) 2024, MeetKai Inc. All rights reserved.
3
+ """PyTorch LLaMA model."""
4
+
5
+ import json
6
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+
11
+ from transformers.generation.configuration_utils import GenerationConfig
12
+ from transformers.generation.logits_process import LogitsProcessorList
13
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
14
+ from transformers.generation.utils import (
15
+ GenerateBeamDecoderOnlyOutput,
16
+ GenerateBeamEncoderDecoderOutput,
17
+ GenerateDecoderOnlyOutput,
18
+ GenerateEncoderDecoderOutput
19
+ )
20
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
21
+ from transformers.utils import logging
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.generation.streamers import BaseStreamer
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
31
+ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
32
+ GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]
33
+
34
+
35
+ class FunctionaryForCausalLM(LlamaForCausalLM):
36
+
37
+ def generate_tool_use(
38
+ self,
39
+ inputs: Optional[torch.Tensor] = None,
40
+ generation_config: Optional[GenerationConfig] = None,
41
+ logits_processor: Optional[LogitsProcessorList] = None,
42
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
43
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
44
+ synced_gpus: Optional[bool] = None,
45
+ assistant_model: Optional["PreTrainedModel"] = None,
46
+ streamer: Optional["BaseStreamer"] = None,
47
+ negative_prompt_ids: Optional[torch.Tensor] = None,
48
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
49
+ **kwargs,
50
+ ) -> Union[GenerateOutput, torch.LongTensor]:
51
+
52
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we use it to parse raw output
53
+
54
+ results = self.generate(
55
+ inputs=inputs,
56
+ generation_config=generation_config,
57
+ logits_processor=logits_processor,
58
+ stopping_criteria=stopping_criteria,
59
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
60
+ synced_gpus=synced_gpus,
61
+ assistant_model=assistant_model,
62
+ streamer=streamer,
63
+ negative_prompt_ids=negative_prompt_ids,
64
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
65
+ **kwargs,
66
+ )
67
+
68
+ input_ids = kwargs.pop("input_ids")
69
+ function_call_token = ">>>"
70
+
71
+ correct_results = []
72
+ for input_id, result in zip(input_ids, results):
73
+ final_output_json = {"role": "assistant", "content": None, "tool_calls": None}
74
+ tool_calls = []
75
+ raw_output_str = tokenizer.decode(result[len(input_id):].cpu())
76
+ chunks = raw_output_str.split(function_call_token)
77
+ for i, chunk in enumerate(chunks):
78
+ if len(chunk) == 0:
79
+ continue
80
+
81
+ chunk = chunk.replace(tokenizer.pad_token, "")
82
+ has_text = True if chunk.startswith("all") else False
83
+ if i == 0 and has_text is not False:
84
+ final_output_json["content"] = chunk.strip[:-len("<|eot_id|>")] if chunk.endswith("<|eot_id|>") else chunk
85
+ final_output_json["content"] = final_output_json["content"][len("all\n"):]
86
+ else:
87
+ tool_calls.append(
88
+ {
89
+ "name": chunk[: chunk.index("\n{")],
90
+ "arguments": chunk[chunk.index("\n{") + 1: -len("<|eot_id|>")] if chunk.endswith("<|eot_id|>") else chunk[chunk.index("\n{") + 1:]
91
+ }
92
+ )
93
+ if len(tool_calls) > 0:
94
+ final_output_json["tool_calls"] = tool_calls
95
+ final_output_str = json.dumps(final_output_json, indent=4)
96
+ final_output_ids = tokenizer(final_output_str, add_special_tokens=False)["input_ids"]
97
+ correct_results.append(
98
+ torch.cat(
99
+ (result[:len(input_id)].cpu(), torch.tensor(final_output_ids))
100
+ )
101
+ )
102
+ max_len = max([tensor.shape[0] for tensor in correct_results])
103
+ correct_results = [
104
+ torch.nn.functional.pad(
105
+ correct_result, (0, max_len - correct_result.shape[0]), value=tokenizer.eos_token_id
106
+ ) for correct_result in correct_results
107
+ ]
108
+ correct_results = torch.stack(correct_results)
109
+
110
+ return correct_results