Isaachh commited on
Commit
71e86f7
·
1 Parent(s): 4e2a680

temporarily switch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -1
  2. app.py +130 -73
  3. {bunny/serve/examples → assets}/example_1.png +0 -0
  4. {bunny/serve/examples → assets}/example_2.png +0 -0
  5. {bunny/serve/examples → assets}/icon.jpg +0 -0
  6. {bunny/serve/examples → assets}/user.png +0 -0
  7. bunny/constants.py +0 -7
  8. bunny/conversation.py +0 -239
  9. bunny/eval/m4c_evaluator.py +0 -334
  10. bunny/eval/model_vqa.py +0 -111
  11. bunny/eval/model_vqa_cmmmu.py +0 -234
  12. bunny/eval/model_vqa_loader.py +0 -143
  13. bunny/eval/model_vqa_mmbench.py +0 -167
  14. bunny/eval/model_vqa_mmmu.py +0 -326
  15. bunny/eval/model_vqa_science.py +0 -119
  16. bunny/model/__init__.py +0 -6
  17. bunny/model/builder.py +0 -197
  18. bunny/model/bunny_arch.py +0 -230
  19. bunny/model/language_model/bunny_llama.py +0 -102
  20. bunny/model/language_model/bunny_minicpm.py +0 -103
  21. bunny/model/language_model/bunny_phi.py +0 -100
  22. bunny/model/language_model/bunny_phi3.py +0 -100
  23. bunny/model/language_model/bunny_qwen.py +0 -100
  24. bunny/model/language_model/bunny_stablelm.py +0 -100
  25. bunny/model/language_model/llama/__init__.py +0 -114
  26. bunny/model/language_model/llama/configuration_llama.py +0 -191
  27. bunny/model/language_model/llama/modeling_llama.py +0 -1844
  28. bunny/model/language_model/llama/tokenization_llama.py +0 -471
  29. bunny/model/language_model/llama/tokenization_llama_fast.py +0 -281
  30. bunny/model/language_model/minicpm/configuration_minicpm.py +0 -202
  31. bunny/model/language_model/minicpm/modeling_minicpm.py +0 -1456
  32. bunny/model/language_model/phi/__init__.py +0 -69
  33. bunny/model/language_model/phi/configuration_phi.py +0 -195
  34. bunny/model/language_model/phi/modeling_phi.py +0 -1374
  35. bunny/model/language_model/phi3/__init__.py +0 -69
  36. bunny/model/language_model/phi3/configuration_phi3.py +0 -213
  37. bunny/model/language_model/phi3/modeling_phi3.py +0 -1597
  38. bunny/model/language_model/qwen2/__init__.py +0 -80
  39. bunny/model/language_model/qwen2/configuration_qwen2.py +0 -144
  40. bunny/model/language_model/qwen2/modeling_qwen2.py +0 -1403
  41. bunny/model/language_model/qwen2/tokenization_qwen2.py +0 -345
  42. bunny/model/language_model/qwen2/tokenization_qwen2_fast.py +0 -143
  43. bunny/model/language_model/stable_lm/configuration_stablelm_epoch.py +0 -113
  44. bunny/model/language_model/stable_lm/modeling_stablelm_epoch.py +0 -917
  45. bunny/model/multimodal_encoder/builder.py +0 -29
  46. bunny/model/multimodal_encoder/clip/clip_encoder.py +0 -76
  47. bunny/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +0 -63
  48. bunny/model/multimodal_encoder/eva_clip/eva_clip_processors.py +0 -68
  49. bunny/model/multimodal_encoder/eva_clip/eva_vit.py +0 -851
  50. bunny/model/multimodal_encoder/siglip/siglip_encoder.py +0 -129
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🚀
4
  colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.7.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,74 +1,131 @@
1
- import sys
2
- import os
 
3
  import time
4
- import argparse
5
- import subprocess
6
-
7
- import bunny.serve.gradio_web_server as gws
8
-
9
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-e', '.'])
10
-
11
-
12
- def start_controller():
13
- controller_command = [
14
- sys.executable, '-m', 'bunny.serve.controller',
15
- '--host', '0.0.0.0',
16
- '--port', '10000'
17
- ]
18
- return subprocess.Popen(controller_command)
19
-
20
-
21
- def start_worker(port: int, model_path: str, model_type: str):
22
- worker_command = [
23
- sys.executable, '-m', 'bunny.serve.model_worker',
24
- '--host', '0.0.0.0',
25
- '--controller', 'http://localhost:10000',
26
- '--port', f'{port}',
27
- '--worker', f'http://localhost:{port}',
28
- '--model-path', model_path,
29
- '--model-type', model_type
30
- ]
31
- return subprocess.Popen(worker_command)
32
-
33
-
34
- if __name__ == '__main__':
35
- parser = argparse.ArgumentParser()
36
- parser.add_argument("--host", type=str, default="0.0.0.0")
37
- parser.add_argument("--port", type=int)
38
- parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
39
- parser.add_argument("--concurrency-count", type=int, default=5)
40
- parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
41
- parser.add_argument("--share", action="store_true")
42
- parser.add_argument("--moderate", action="store_true")
43
- parser.add_argument("--embed", action="store_true")
44
- gws.args = parser.parse_args()
45
- gws.models = []
46
-
47
- controller_proc = start_controller()
48
-
49
- worker_procs = []
50
-
51
- worker_procs.append(start_worker(port=40000, model_path='BAAI/Bunny-v1_1-Llama-3-8B-V', model_type='llama3-8b'))
52
- worker_procs.append(start_worker(port=40001, model_path='BAAI/Bunny-v1_1-4B', model_type='phi-3'))
53
- worker_procs.append(start_worker(port=40002, model_path='BAAI/Bunny-v1_0-3B', model_type='phi-2'))
54
-
55
- time.sleep(60)
56
-
57
- exit_status = 0
58
- try:
59
- demo = gws.build_demo(embed_mode=gws.args.embed)
60
- demo.launch(
61
- server_name=gws.args.host,
62
- server_port=gws.args.port,
63
- share=gws.args.share,
64
- debug=True,
65
- max_threads=10
66
- )
67
- except Exception as e:
68
- print(e)
69
- exit_status = 1
70
- finally:
71
- for worker_proc in worker_procs:
72
- worker_proc.kill()
73
- controller_proc.kill()
74
- sys.exit(exit_status)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ import warnings
4
  import time
5
+ import spaces
6
+ import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
+ from PIL import Image
9
+ from threading import Thread
10
+
11
+
12
+ transformers.logging.set_verbosity_error()
13
+ transformers.logging.disable_progress_bar()
14
+ warnings.filterwarnings("ignore")
15
+
16
+
17
+ device = "cuda" # or cpu
18
+ torch.set_default_device(device)
19
+
20
+ model_name = "BAAI/Bunny-v1_1-Llama-3-8B-V"
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_name,
23
+ torch_dtype=torch.float16, # float32 for cpu
24
+ device_map="auto",
25
+ trust_remote_code=True)
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True)
29
+
30
+
31
+ @spaces.GPU
32
+ def bot_streaming(message, history):
33
+ print(message)
34
+ if message["files"]:
35
+ # message["files"][-1] is a Dict or just a string
36
+ if type(message["files"][-1]) == dict:
37
+ image_file = message["files"][-1]["path"]
38
+ else:
39
+ image_file = message["files"][-1]
40
+ else:
41
+ image_file = None
42
+ # if there's no image uploaded for this turn, look for images in the past turns
43
+ # kept inside tuples, take the last one
44
+ for hist in history:
45
+ if type(hist[0]) == tuple:
46
+ image_file = hist[0][0]
47
+
48
+
49
+ prompt = message["text"]
50
+ if image_file is None:
51
+ text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:"
52
+ input_ids = torch.tensor(tokenizer(text).input_ids, dtype=torch.long).unsqueeze(0).to(device)
53
+ else:
54
+ text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
55
+ text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")]
56
+ input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
57
+
58
+ if image_file is not None:
59
+ image = Image.open(image_file)
60
+ image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device)
61
+ else:
62
+ image_tensor = None
63
+
64
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
65
+
66
+ thread = Thread(target=model.generate, kwargs=dict(
67
+ inputs=input_ids,
68
+ images=image_tensor,
69
+ do_sample=True,
70
+ temperature=0.2,
71
+ top_p=0.7,
72
+ max_new_tokens=512,
73
+ streamer=streamer,
74
+ use_cache=True,
75
+ repetition_penalty=1.08
76
+ ))
77
+ thread.start()
78
+
79
+ buffer = ""
80
+ time.sleep(0.5)
81
+ for new_text in streamer:
82
+ if "<|end_of_text|>" in new_text:
83
+ new_text = new_text.split("<|end_of_text|>")[0]
84
+ buffer += new_text
85
+
86
+ # generated_text_without_prompt = buffer[len(text_prompt):]
87
+ generated_text_without_prompt = buffer
88
+ # print(generated_text_without_prompt)
89
+ time.sleep(0.06)
90
+ # print(f"new_text: {generated_text_without_prompt}")
91
+ yield generated_text_without_prompt
92
+
93
+
94
+ title_markdown = ("""
95
+ # 🐰 Bunny: A family of lightweight multimodal models
96
+
97
+ [📖 [Technical report](https://arxiv.org/abs/2402.11530)] | [🏠 [Code](https://github.com/BAAI-DCAI/Bunny)] | [🤗 [Bunny-v1.1-Llama-3-8B-V](https://huggingface.co/BAAI/Bunny-v1_1-Llama-3-8B-V)] | [🤗 [Bunny-v1.1-4B](https://huggingface.co/BAAI/Bunny-v1_1-4B)] | [🤗 [Bunny-v1.0-3B](https://huggingface.co/BAAI/Bunny-v1_0-3B)]
98
+
99
+ """)
100
+
101
+ chatbot = gr.Chatbot(
102
+ elem_id="chatbot",
103
+ label="Bunny-v1.1-Llama-3-8B-V",
104
+ avatar_images=[f"./assets/user.png", f"./assets/icon.jpg"],
105
+ height=550
106
+ )
107
+
108
+ chat_input = gr.MultimodalTextbox(
109
+ interactive=True,
110
+ file_types=["image"],
111
+ placeholder="Enter message or upload file...",
112
+ show_label=False
113
+ )
114
+
115
+ with gr.Blocks(fill_height=True) as demo:
116
+ gr.Markdown(title_markdown)
117
+
118
+ gr.ChatInterface(
119
+ fn=bot_streaming,
120
+ stop_btn="Stop Generation",
121
+ multimodal=True,
122
+ textbox=chat_input,
123
+ chatbot=chatbot
124
+ )
125
+
126
+ gr.Examples(examples=[{"text": "What is the astronaut holding in his hand?", "files": ["./assets/example_1.png"]},
127
+ {"text": "Why is the image funny?", "files": ["./assets/example_2.png"]}], inputs=chat_input)
128
+
129
+
130
+ demo.queue(api_open=False)
131
+ demo.launch(show_api=False, share=False)
{bunny/serve/examples → assets}/example_1.png RENAMED
File without changes
{bunny/serve/examples → assets}/example_2.png RENAMED
File without changes
{bunny/serve/examples → assets}/icon.jpg RENAMED
File without changes
{bunny/serve/examples → assets}/user.png RENAMED
File without changes
bunny/constants.py DELETED
@@ -1,7 +0,0 @@
1
- # Model Constants
2
- IGNORE_INDEX = -100
3
- IMAGE_TOKEN_INDEX = -200
4
- DEFAULT_IMAGE_TOKEN = "<image>"
5
- CONTROLLER_HEART_BEAT_EXPIRATION = 30
6
- LOGDIR = "gradio-logs"
7
- WORKER_HEART_BEAT_INTERVAL = 15
 
 
 
 
 
 
 
 
bunny/conversation.py DELETED
@@ -1,239 +0,0 @@
1
- import dataclasses
2
- from enum import auto, Enum
3
- from typing import List
4
-
5
-
6
- class SeparatorStyle(Enum):
7
- """Different separator style."""
8
- TWO = auto()
9
- PLAIN = auto()
10
-
11
-
12
- @dataclasses.dataclass
13
- class Conversation:
14
- """A class that keeps all conversation history."""
15
- system: str
16
- roles: List[str]
17
- messages: List[List[str]]
18
- offset: int
19
- sep_style: SeparatorStyle
20
- sep: str = "###"
21
- sep2: str = None
22
- version: str = "Unknown"
23
-
24
- skip_next: bool = False
25
-
26
- def get_prompt(self):
27
- messages = self.messages
28
- if len(messages) > 0 and type(messages[0][1]) is tuple:
29
- messages = self.messages.copy()
30
- init_role, init_msg = messages[0].copy()
31
- init_msg = init_msg[0].replace("<image>", "").strip()
32
- if 'mmtag' in self.version:
33
- messages[0] = (init_role, init_msg)
34
- messages.insert(0, (self.roles[0], "<Image><image></Image>"))
35
- messages.insert(1, (self.roles[1], "Received."))
36
- else:
37
- messages[0] = (init_role, "<image>\n" + init_msg)
38
-
39
- if self.sep_style == SeparatorStyle.TWO:
40
- seps = [self.sep, self.sep2]
41
- ret = self.system + seps[0]
42
- for i, (role, message) in enumerate(messages):
43
- if message:
44
- if type(message) is tuple:
45
- message, _, _ = message
46
- ret += role + ": " + message + seps[i % 2]
47
- else:
48
- ret += role + ":"
49
-
50
- elif self.sep_style == SeparatorStyle.PLAIN:
51
- seps = [self.sep, self.sep2]
52
- ret = self.system
53
- for i, (role, message) in enumerate(messages):
54
- if message:
55
- if type(message) is tuple:
56
- message, _, _ = message
57
- ret += message + seps[i % 2]
58
- else:
59
- ret += ""
60
- else:
61
- raise ValueError(f"Invalid style: {self.sep_style}")
62
-
63
- return ret
64
-
65
- def append_message(self, role, message):
66
- self.messages.append([role, message])
67
-
68
- def get_images(self, return_pil=False):
69
- images = []
70
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
71
- if i % 2 == 0:
72
- if type(msg) is tuple:
73
- import base64
74
- from io import BytesIO
75
- from PIL import Image
76
- msg, image, image_process_mode = msg
77
- if image_process_mode == "Pad":
78
- def expand2square(pil_img, background_color=(122, 116, 104)):
79
- width, height = pil_img.size
80
- if width == height:
81
- return pil_img
82
- elif width > height:
83
- result = Image.new(pil_img.mode, (width, width), background_color)
84
- result.paste(pil_img, (0, (width - height) // 2))
85
- return result
86
- else:
87
- result = Image.new(pil_img.mode, (height, height), background_color)
88
- result.paste(pil_img, ((height - width) // 2, 0))
89
- return result
90
-
91
- image = expand2square(image)
92
- elif image_process_mode in ["Default", "Crop"]:
93
- pass
94
- elif image_process_mode == "Resize":
95
- image = image.resize((336, 336))
96
- else:
97
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
98
-
99
- if return_pil:
100
- images.append(image)
101
- else:
102
- buffered = BytesIO()
103
- image.save(buffered, format="PNG")
104
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
105
- images.append(img_b64_str)
106
- return images
107
-
108
- def to_gradio_chatbot(self):
109
- ret = []
110
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
111
- if i % 2 == 0:
112
- if type(msg) is tuple:
113
- import base64
114
- from io import BytesIO
115
- msg, image, image_process_mode = msg
116
- max_hw, min_hw = max(image.size), min(image.size)
117
- aspect_ratio = max_hw / min_hw
118
- max_len, min_len = 800, 400
119
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
120
- longest_edge = int(shortest_edge * aspect_ratio)
121
- W, H = image.size
122
- if H > W:
123
- H, W = longest_edge, shortest_edge
124
- else:
125
- H, W = shortest_edge, longest_edge
126
- image = image.resize((W, H))
127
- buffered = BytesIO()
128
- image.save(buffered, format="JPEG")
129
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
130
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
131
- msg = img_str + msg.replace('<image>', '').strip()
132
- ret.append([msg, None])
133
- else:
134
- ret.append([msg, None])
135
- else:
136
- ret[-1][-1] = msg
137
- return ret
138
-
139
- def copy(self):
140
- return Conversation(
141
- system=self.system,
142
- roles=self.roles,
143
- messages=[[x, y] for x, y in self.messages],
144
- offset=self.offset,
145
- sep_style=self.sep_style,
146
- sep=self.sep,
147
- sep2=self.sep2,
148
- version=self.version)
149
-
150
- def dict(self):
151
- if len(self.get_images()) > 0:
152
- return {
153
- "system": self.system,
154
- "roles": self.roles,
155
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
156
- "offset": self.offset,
157
- "sep": self.sep,
158
- "sep2": self.sep2,
159
- }
160
- return {
161
- "system": self.system,
162
- "roles": self.roles,
163
- "messages": self.messages,
164
- "offset": self.offset,
165
- "sep": self.sep,
166
- "sep2": self.sep2,
167
- }
168
-
169
-
170
- conv_bunny = Conversation(
171
- system="A chat between a curious user and an artificial intelligence assistant. "
172
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
173
- roles=("USER", "ASSISTANT"),
174
- version="bunny",
175
- messages=(),
176
- offset=0,
177
- sep_style=SeparatorStyle.TWO,
178
- sep=" ",
179
- sep2="<|endoftext|>",
180
- )
181
-
182
- conv_phi3 = Conversation(
183
- system="A chat between a curious user and an artificial intelligence assistant. "
184
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
185
- roles=("USER", "ASSISTANT"),
186
- version="phi3",
187
- messages=(),
188
- offset=0,
189
- sep_style=SeparatorStyle.TWO,
190
- sep=" ",
191
- sep2="<|endoftext|>",
192
- )
193
-
194
- conv_minicpm = Conversation(
195
- system="A chat between a curious user and an artificial intelligence assistant. "
196
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
197
- roles=("USER", "ASSISTANT"),
198
- version="minicpm",
199
- messages=(),
200
- offset=0,
201
- sep_style=SeparatorStyle.TWO,
202
- sep=" ",
203
- sep2="</s>",
204
- )
205
-
206
- conv_llama = Conversation(
207
- system="A chat between a curious user and an artificial intelligence assistant. "
208
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
209
- roles=("USER", "ASSISTANT"),
210
- version="llama",
211
- messages=(),
212
- offset=0,
213
- sep_style=SeparatorStyle.TWO,
214
- sep=" ",
215
- sep2="<|end_of_text|>",
216
- )
217
-
218
- conv_plain = Conversation(
219
- system="",
220
- roles=("", ""),
221
- messages=(
222
- ),
223
- offset=0,
224
- sep_style=SeparatorStyle.PLAIN,
225
- sep="\n",
226
- )
227
-
228
- default_conversation = conv_bunny
229
- conv_templates = {
230
- "default": conv_bunny,
231
- "bunny": conv_bunny,
232
- "phi3": conv_phi3,
233
- "plain": conv_plain,
234
- 'minicpm': conv_minicpm,
235
- 'llama': conv_llama
236
- }
237
-
238
- if __name__ == "__main__":
239
- print(default_conversation.get_prompt())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/eval/m4c_evaluator.py DELETED
@@ -1,334 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- import re
3
-
4
- from tqdm import tqdm
5
-
6
-
7
- class EvalAIAnswerProcessor:
8
- """
9
- Processes an answer similar to Eval AI
10
- copied from
11
- https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
12
- """
13
-
14
- CONTRACTIONS = {
15
- "aint": "ain't",
16
- "arent": "aren't",
17
- "cant": "can't",
18
- "couldve": "could've",
19
- "couldnt": "couldn't",
20
- "couldn'tve": "couldn't've",
21
- "couldnt've": "couldn't've",
22
- "didnt": "didn't",
23
- "doesnt": "doesn't",
24
- "dont": "don't",
25
- "hadnt": "hadn't",
26
- "hadnt've": "hadn't've",
27
- "hadn'tve": "hadn't've",
28
- "hasnt": "hasn't",
29
- "havent": "haven't",
30
- "hed": "he'd",
31
- "hed've": "he'd've",
32
- "he'dve": "he'd've",
33
- "hes": "he's",
34
- "howd": "how'd",
35
- "howll": "how'll",
36
- "hows": "how's",
37
- "Id've": "I'd've",
38
- "I'dve": "I'd've",
39
- "Im": "I'm",
40
- "Ive": "I've",
41
- "isnt": "isn't",
42
- "itd": "it'd",
43
- "itd've": "it'd've",
44
- "it'dve": "it'd've",
45
- "itll": "it'll",
46
- "let's": "let's",
47
- "maam": "ma'am",
48
- "mightnt": "mightn't",
49
- "mightnt've": "mightn't've",
50
- "mightn'tve": "mightn't've",
51
- "mightve": "might've",
52
- "mustnt": "mustn't",
53
- "mustve": "must've",
54
- "neednt": "needn't",
55
- "notve": "not've",
56
- "oclock": "o'clock",
57
- "oughtnt": "oughtn't",
58
- "ow's'at": "'ow's'at",
59
- "'ows'at": "'ow's'at",
60
- "'ow'sat": "'ow's'at",
61
- "shant": "shan't",
62
- "shed've": "she'd've",
63
- "she'dve": "she'd've",
64
- "she's": "she's",
65
- "shouldve": "should've",
66
- "shouldnt": "shouldn't",
67
- "shouldnt've": "shouldn't've",
68
- "shouldn'tve": "shouldn't've",
69
- "somebody'd": "somebodyd",
70
- "somebodyd've": "somebody'd've",
71
- "somebody'dve": "somebody'd've",
72
- "somebodyll": "somebody'll",
73
- "somebodys": "somebody's",
74
- "someoned": "someone'd",
75
- "someoned've": "someone'd've",
76
- "someone'dve": "someone'd've",
77
- "someonell": "someone'll",
78
- "someones": "someone's",
79
- "somethingd": "something'd",
80
- "somethingd've": "something'd've",
81
- "something'dve": "something'd've",
82
- "somethingll": "something'll",
83
- "thats": "that's",
84
- "thered": "there'd",
85
- "thered've": "there'd've",
86
- "there'dve": "there'd've",
87
- "therere": "there're",
88
- "theres": "there's",
89
- "theyd": "they'd",
90
- "theyd've": "they'd've",
91
- "they'dve": "they'd've",
92
- "theyll": "they'll",
93
- "theyre": "they're",
94
- "theyve": "they've",
95
- "twas": "'twas",
96
- "wasnt": "wasn't",
97
- "wed've": "we'd've",
98
- "we'dve": "we'd've",
99
- "weve": "we've",
100
- "werent": "weren't",
101
- "whatll": "what'll",
102
- "whatre": "what're",
103
- "whats": "what's",
104
- "whatve": "what've",
105
- "whens": "when's",
106
- "whered": "where'd",
107
- "wheres": "where's",
108
- "whereve": "where've",
109
- "whod": "who'd",
110
- "whod've": "who'd've",
111
- "who'dve": "who'd've",
112
- "wholl": "who'll",
113
- "whos": "who's",
114
- "whove": "who've",
115
- "whyll": "why'll",
116
- "whyre": "why're",
117
- "whys": "why's",
118
- "wont": "won't",
119
- "wouldve": "would've",
120
- "wouldnt": "wouldn't",
121
- "wouldnt've": "wouldn't've",
122
- "wouldn'tve": "wouldn't've",
123
- "yall": "y'all",
124
- "yall'll": "y'all'll",
125
- "y'allll": "y'all'll",
126
- "yall'd've": "y'all'd've",
127
- "y'alld've": "y'all'd've",
128
- "y'all'dve": "y'all'd've",
129
- "youd": "you'd",
130
- "youd've": "you'd've",
131
- "you'dve": "you'd've",
132
- "youll": "you'll",
133
- "youre": "you're",
134
- "youve": "you've",
135
- }
136
-
137
- NUMBER_MAP = {
138
- "none": "0",
139
- "zero": "0",
140
- "one": "1",
141
- "two": "2",
142
- "three": "3",
143
- "four": "4",
144
- "five": "5",
145
- "six": "6",
146
- "seven": "7",
147
- "eight": "8",
148
- "nine": "9",
149
- "ten": "10",
150
- }
151
- ARTICLES = ["a", "an", "the"]
152
- PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
153
- COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
154
- PUNCTUATIONS = [
155
- ";",
156
- r"/",
157
- "[",
158
- "]",
159
- '"',
160
- "{",
161
- "}",
162
- "(",
163
- ")",
164
- "=",
165
- "+",
166
- "\\",
167
- "_",
168
- "-",
169
- ">",
170
- "<",
171
- "@",
172
- "`",
173
- ",",
174
- "?",
175
- "!",
176
- ]
177
-
178
- def __init__(self, *args, **kwargs):
179
- pass
180
-
181
- def word_tokenize(self, word):
182
- word = word.lower()
183
- word = word.replace(",", "").replace("?", "").replace("'s", " 's")
184
- return word.strip()
185
-
186
- def process_punctuation(self, in_text):
187
- out_text = in_text
188
- for p in self.PUNCTUATIONS:
189
- if (p + " " in in_text or " " + p in in_text) or (
190
- re.search(self.COMMA_STRIP, in_text) is not None
191
- ):
192
- out_text = out_text.replace(p, "")
193
- else:
194
- out_text = out_text.replace(p, " ")
195
- out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
196
- return out_text
197
-
198
- def process_digit_article(self, in_text):
199
- out_text = []
200
- temp_text = in_text.lower().split()
201
- for word in temp_text:
202
- word = self.NUMBER_MAP.setdefault(word, word)
203
- if word not in self.ARTICLES:
204
- out_text.append(word)
205
- else:
206
- pass
207
- for word_id, word in enumerate(out_text):
208
- if word in self.CONTRACTIONS:
209
- out_text[word_id] = self.CONTRACTIONS[word]
210
- out_text = " ".join(out_text)
211
- return out_text
212
-
213
- def __call__(self, item):
214
- item = self.word_tokenize(item)
215
- item = item.replace("\n", " ").replace("\t", " ").strip()
216
- item = self.process_punctuation(item)
217
- item = self.process_digit_article(item)
218
- return item
219
-
220
-
221
- class TextVQAAccuracyEvaluator:
222
- def __init__(self):
223
- self.answer_processor = EvalAIAnswerProcessor()
224
-
225
- def _compute_answer_scores(self, raw_answers):
226
- """
227
- compute the accuracy (soft score) of human answers
228
- """
229
- answers = [self.answer_processor(a) for a in raw_answers]
230
- assert len(answers) == 10
231
- gt_answers = list(enumerate(answers))
232
- unique_answers = set(answers)
233
- unique_answer_scores = {}
234
-
235
- for unique_answer in unique_answers:
236
- accs = []
237
- for gt_answer in gt_answers:
238
- other_answers = [item for item in gt_answers if item != gt_answer]
239
- matching_answers = [
240
- item for item in other_answers if item[1] == unique_answer
241
- ]
242
- acc = min(1, float(len(matching_answers)) / 3)
243
- accs.append(acc)
244
- unique_answer_scores[unique_answer] = sum(accs) / len(accs)
245
-
246
- return unique_answer_scores
247
-
248
- def eval_pred_list(self, pred_list):
249
- pred_scores = []
250
- for entry in tqdm(pred_list):
251
- pred_answer = self.answer_processor(entry["pred_answer"])
252
- unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
253
- score = unique_answer_scores.get(pred_answer, 0.0)
254
- pred_scores.append(score)
255
-
256
- accuracy = sum(pred_scores) / len(pred_scores)
257
- return accuracy
258
-
259
-
260
- class STVQAAccuracyEvaluator:
261
- def __init__(self):
262
- self.answer_processor = EvalAIAnswerProcessor()
263
-
264
- def eval_pred_list(self, pred_list):
265
- pred_scores = []
266
- for entry in pred_list:
267
- pred_answer = self.answer_processor(entry["pred_answer"])
268
- gts = [self.answer_processor(a) for a in entry["gt_answers"]]
269
- score = 1.0 if pred_answer in gts else 0.0
270
- pred_scores.append(score)
271
-
272
- accuracy = sum(pred_scores) / len(pred_scores)
273
- return accuracy
274
-
275
-
276
- class STVQAANLSEvaluator:
277
- def __init__(self):
278
- import editdistance # install with `pip install editdistance`
279
-
280
- self.get_edit_distance = editdistance.eval
281
-
282
- def get_anls(self, s1, s2):
283
- s1 = s1.lower().strip()
284
- s2 = s2.lower().strip()
285
- iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
286
- anls = iou if iou >= 0.5 else 0.0
287
- return anls
288
-
289
- def eval_pred_list(self, pred_list):
290
- pred_scores = []
291
- for entry in pred_list:
292
- anls = max(
293
- self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
294
- )
295
- pred_scores.append(anls)
296
-
297
- accuracy = sum(pred_scores) / len(pred_scores)
298
- return accuracy
299
-
300
-
301
- class TextCapsBleu4Evaluator:
302
- def __init__(self):
303
- # The following script requires Java 1.8.0 and pycocotools installed.
304
- # The pycocoevalcap can be installed with pip as
305
- # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
306
- # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
307
- # but has no python3 support yet.
308
- try:
309
- from pycocoevalcap.bleu.bleu import Bleu
310
- from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
311
- except ModuleNotFoundError:
312
- print(
313
- "Please install pycocoevalcap module using "
314
- "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
315
- )
316
- raise
317
-
318
- self.tokenizer = PTBTokenizer()
319
- self.scorer = Bleu(4)
320
-
321
- def eval_pred_list(self, pred_list):
322
- # Create reference and hypotheses captions.
323
- gts = {}
324
- res = {}
325
- for idx, entry in enumerate(pred_list):
326
- gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
327
- res[idx] = [{"caption": entry["pred_answer"]}]
328
-
329
- gts = self.tokenizer.tokenize(gts)
330
- res = self.tokenizer.tokenize(res)
331
- score, _ = self.scorer.compute_score(gts, res)
332
-
333
- bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
334
- return bleu4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/eval/model_vqa.py DELETED
@@ -1,111 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- from tqdm import tqdm
6
- import shortuuid
7
-
8
- from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
- from bunny.conversation import conv_templates, SeparatorStyle
10
- from bunny.model.builder import load_pretrained_model
11
- from bunny.util.utils import disable_torch_init
12
- from bunny.util.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
13
-
14
- from PIL import Image
15
- import math
16
-
17
-
18
- def split_list(lst, n):
19
- """Split a list into n (roughly) equal-sized chunks"""
20
- chunk_size = math.ceil(len(lst) / n) # integer division
21
- return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
22
-
23
-
24
- def get_chunk(lst, n, k):
25
- chunks = split_list(lst, n)
26
- return chunks[k]
27
-
28
-
29
- def eval_model(args):
30
- # Model
31
- disable_torch_init()
32
- model_path = os.path.expanduser(args.model_path)
33
- model_name = get_model_name_from_path(model_path)
34
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
35
- args.model_type)
36
-
37
- questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
38
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
- answers_file = os.path.expanduser(args.answers_file)
40
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
- ans_file = open(answers_file, "w")
42
- for line in tqdm(questions):
43
- idx = line["question_id"]
44
- image_file = line["image"]
45
- qs = line["text"]
46
- cur_prompt = qs
47
-
48
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
49
-
50
- conv = conv_templates[args.conv_mode].copy()
51
- conv.append_message(conv.roles[0], qs)
52
- conv.append_message(conv.roles[1], None)
53
- prompt = conv.get_prompt()
54
-
55
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
56
-
57
- image = Image.open(os.path.join(args.image_folder, image_file))
58
- image_tensor = process_images([image], image_processor, model.config)[0]
59
-
60
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
61
-
62
- with torch.inference_mode():
63
- output_ids = model.generate(
64
- input_ids,
65
- images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
66
- do_sample=True if args.temperature > 0 else False,
67
- temperature=args.temperature,
68
- top_p=args.top_p,
69
- num_beams=args.num_beams,
70
- # no_repeat_ngram_size=3,
71
- max_new_tokens=1024,
72
- use_cache=True)
73
-
74
- input_token_len = input_ids.shape[1]
75
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
76
- if n_diff_input_output > 0:
77
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
78
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
79
- outputs = outputs.strip()
80
- if outputs.endswith(stop_str):
81
- outputs = outputs[:-len(stop_str)]
82
- outputs = outputs.strip()
83
-
84
- ans_id = shortuuid.uuid()
85
- ans_file.write(json.dumps({"question_id": idx,
86
- "prompt": cur_prompt,
87
- "text": outputs,
88
- "answer_id": ans_id,
89
- "model_id": model_name,
90
- "metadata": {}}) + "\n")
91
- ans_file.flush()
92
- ans_file.close()
93
-
94
-
95
- if __name__ == "__main__":
96
- parser = argparse.ArgumentParser()
97
- parser.add_argument("--model-path", type=str, default=None)
98
- parser.add_argument("--model-base", type=str, default=None)
99
- parser.add_argument("--model-type", type=str, default=None)
100
- parser.add_argument("--image-folder", type=str, default=None)
101
- parser.add_argument("--question-file", type=str, default=None)
102
- parser.add_argument("--answers-file", type=str, default=None)
103
- parser.add_argument("--conv-mode", type=str, default=None)
104
- parser.add_argument("--num-chunks", type=int, default=1)
105
- parser.add_argument("--chunk-idx", type=int, default=0)
106
- parser.add_argument("--temperature", type=float, default=0.2)
107
- parser.add_argument("--top_p", type=float, default=None)
108
- parser.add_argument("--num_beams", type=int, default=1)
109
- args = parser.parse_args()
110
-
111
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/eval/model_vqa_cmmmu.py DELETED
@@ -1,234 +0,0 @@
1
- import random
2
- import numpy as np
3
- import os
4
- import json
5
- import yaml
6
- import torch
7
-
8
- from tqdm import tqdm
9
- from datasets import load_dataset, concatenate_datasets
10
- from argparse import ArgumentParser
11
-
12
- from bunny.model.builder import load_pretrained_model
13
- from bunny.util.mm_utils import get_model_name_from_path, tokenizer_image_token, process_images
14
- from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
15
- from bunny.conversation import conv_templates
16
-
17
- CAT_CN2EN = {'艺术与设计': 'art_and_design',
18
- '商业': 'business',
19
- '健康与医学': 'health_and_medicine',
20
- '人文社会科学': 'humanities_and_social_sciences',
21
- '科学': 'science',
22
- '技术与工程': 'technology_and_engineering'}
23
-
24
-
25
- def call_bunny_engine_df(args, sample, model, tokenizer=None, processor=None):
26
- def deal_with_prompt(input_text):
27
- qs = input_text
28
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
29
- return qs
30
-
31
- prompt = sample['final_input_prompt']
32
- prompt = deal_with_prompt(prompt)
33
-
34
- conv = conv_templates[args.conv_mode].copy()
35
- conv.append_message(conv.roles[0], prompt)
36
- conv.append_message(conv.roles[1], None)
37
- prompt = conv.get_prompt()
38
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
39
-
40
- image = sample['image_1']
41
- if sample['image_2'] is not None: # multiple images actually
42
- if sample['type'] == '选择':
43
- all_choices = sample['all_choices']
44
- response = random.choice(all_choices)
45
- else:
46
- response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
47
- elif image is not None:
48
- output_ids = model.generate(
49
- input_ids,
50
- images=image.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
51
- do_sample=False,
52
- temperature=0,
53
- top_p=None,
54
- # num_beams=5,
55
- max_new_tokens=128,
56
- use_cache=True)
57
-
58
- input_token_len = input_ids.shape[1]
59
- # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
60
- # if n_diff_input_output > 0:
61
- # print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
62
- response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
63
-
64
- return response
65
-
66
-
67
- def load_yaml(file_path):
68
- with open(file_path, 'r') as stream:
69
- try:
70
- yaml_dict = yaml.safe_load(stream)
71
- except yaml.YAMLError as exc:
72
- print(exc)
73
-
74
- return yaml_dict
75
-
76
-
77
- # DATA PROCESSING
78
- def construct_prompt(sample, config):
79
- question = sample['question']
80
- options = []
81
- for i in range(1, 5):
82
- if sample[f'option{i}'] is None:
83
- break
84
- options.append(sample[f'option{i}'])
85
-
86
- example = ""
87
- if sample['type'] == '选择':
88
- start_chr = 'A'
89
- prediction_range = []
90
- for option in options:
91
- prediction_range.append(start_chr)
92
- example += f"({start_chr}) {option}\n"
93
- start_chr = chr(ord(start_chr) + 1)
94
- empty_prompt_sample_structure = config['multi_choice_example_format']
95
- empty_prompt = empty_prompt_sample_structure.format(question, example)
96
- res_dict = {}
97
- res_dict['correct_choice'] = sample['answer']
98
- res_dict['all_choices'] = prediction_range
99
- res_dict['empty_prompt'] = empty_prompt
100
- if config['task_instructions']:
101
- res_dict['final_input_prompt'] = config['task_instructions'][0].strip() + '\n\n' + empty_prompt
102
- else:
103
- res_dict['final_input_prompt'] = empty_prompt
104
-
105
- res_dict['gt_content'] = sample['answer']
106
- elif sample['type'] == '判断':
107
- empty_prompt_sample_structure = config['T/F_example_format']
108
- empty_prompt = empty_prompt_sample_structure.format(question, example)
109
- res_dict = {}
110
- res_dict['empty_prompt'] = empty_prompt
111
- if config['task_instructions']:
112
- res_dict['final_input_prompt'] = config['task_instructions'][1].strip() + '\n\n' + empty_prompt
113
- else:
114
- res_dict['final_input_prompt'] = empty_prompt
115
- res_dict['gt_content'] = sample['answer']
116
- else:
117
- empty_prompt_sample_structure = config['short_ans_example_format']
118
- empty_prompt = empty_prompt_sample_structure.format(question)
119
- res_dict = {}
120
- res_dict['empty_prompt'] = empty_prompt
121
- if config['task_instructions']:
122
- res_dict['final_input_prompt'] = config['task_instructions'][2].strip() + '\n\n' + empty_prompt
123
- else:
124
- res_dict['final_input_prompt'] = empty_prompt
125
- res_dict['gt_content'] = sample['answer']
126
-
127
- res_dict.update(sample)
128
- return res_dict
129
-
130
-
131
- def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
132
- out_samples = []
133
- with torch.no_grad():
134
- for sample in tqdm(samples):
135
- if args.small_gpu_usage:
136
- sample['image_1'] = sample['image_1'].cuda()
137
- response = call_model_engine_fn(args, sample, model, tokenizer, processor)
138
- if args.small_gpu_usage:
139
- sample['image_1'] = sample['image_1'].cpu()
140
-
141
- out_sample = dict()
142
- out_sample['id'] = sample['id']
143
- out_sample['type'] = sample['type']
144
- out_sample['response'] = response
145
- out_samples.append(out_sample)
146
- return out_samples
147
-
148
-
149
- def set_seed(seed_value):
150
- """
151
- Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
152
-
153
- :param seed_value: An integer value to be used as the seed.
154
- """
155
- torch.manual_seed(seed_value)
156
- if torch.cuda.is_available():
157
- torch.cuda.manual_seed(seed_value)
158
- torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
159
- random.seed(seed_value)
160
- np.random.seed(seed_value)
161
- torch.backends.cudnn.deterministic = True
162
- torch.backends.cudnn.benchmark = False
163
-
164
-
165
- def main():
166
- parser = ArgumentParser()
167
- parser.add_argument('--model-path', type=str, default=None)
168
- parser.add_argument('--model-base', type=str, default=None)
169
- parser.add_argument("--model-type", type=str, default=None)
170
- parser.add_argument("--conv-mode", type=str, default=None)
171
- parser.add_argument('--data-path', type=str, default=None)
172
- parser.add_argument('--config-path', type=str, default=None)
173
- parser.add_argument('--output-path', type=str, default=None)
174
- parser.add_argument('--split', type=str, default='validation')
175
- parser.add_argument('--seed', type=int, default=42)
176
- parser.add_argument("--small-gpu-usage", action="store_true")
177
-
178
- args = parser.parse_args()
179
- device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
180
- set_seed(args.seed)
181
-
182
- print('bunny_initializing...')
183
- processor = None
184
- call_model_engine = call_bunny_engine_df
185
-
186
- # load config and process to one value
187
- args.config = load_yaml(args.config_path)
188
- for key, value in args.config.items():
189
- if key == 'task_instructions':
190
- args.config[key] = value
191
- elif key != 'eval_params' and type(value) == list:
192
- assert len(value) == 1, 'key {} has more than one value'.format(key)
193
- args.config[key] = value[0]
194
-
195
- # run for each subject
196
- sub_dataset_list = []
197
- for subject in CAT_CN2EN.values():
198
- sub_dataset = load_dataset(args.data_path, subject, split=args.split)
199
- sub_dataset_list.append(sub_dataset)
200
-
201
- # merge all dataset
202
- dataset = concatenate_datasets(sub_dataset_list)
203
-
204
- # load model
205
- model_path = os.path.expanduser(args.model_path)
206
- model_name = get_model_name_from_path(model_path)
207
- tokenizer, model, vis_processors, context_len = load_pretrained_model(model_path, args.model_base, model_name,
208
- args.model_type)
209
-
210
- samples = []
211
- print('Processing CMMMU dataset...')
212
- for sample in tqdm(dataset):
213
-
214
- sample = construct_prompt(sample, args.config)
215
- if sample['image_1']:
216
- sample['image_1'] = process_images([sample['image_1'].convert('RGB')], vis_processors, model.config)[0]
217
- if not args.small_gpu_usage:
218
- sample['image_1'] = sample['image_1'].to(device)
219
-
220
- samples.append(sample)
221
-
222
- print('Start to evaluate...')
223
- # run ex
224
- out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
225
-
226
- os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
227
-
228
- with open(args.output_path, 'w') as f:
229
- for out_sample in out_samples:
230
- f.write(json.dumps(out_sample) + '\n')
231
-
232
-
233
- if __name__ == '__main__':
234
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/eval/model_vqa_loader.py DELETED
@@ -1,143 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- from tqdm import tqdm
6
- import shortuuid
7
-
8
- from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
- from bunny.conversation import conv_templates
10
- from bunny.model.builder import load_pretrained_model
11
- from bunny.util.utils import disable_torch_init
12
- from bunny.util.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13
- from torch.utils.data import Dataset, DataLoader
14
-
15
- from PIL import Image
16
- import math
17
-
18
-
19
- def split_list(lst, n):
20
- """Split a list into n (roughly) equal-sized chunks"""
21
- chunk_size = math.ceil(len(lst) / n) # integer division
22
- return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
23
-
24
-
25
- def get_chunk(lst, n, k):
26
- chunks = split_list(lst, n)
27
- return chunks[k]
28
-
29
-
30
- # Custom dataset class
31
- class CustomDataset(Dataset):
32
- def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
33
- self.questions = questions
34
- self.image_folder = image_folder
35
- self.tokenizer = tokenizer
36
- self.image_processor = image_processor
37
- self.model_config = model_config
38
-
39
- def __getitem__(self, index):
40
- line = self.questions[index]
41
- image_file = line["image"]
42
- qs = line["text"]
43
-
44
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
45
-
46
- conv = conv_templates[args.conv_mode].copy()
47
- conv.append_message(conv.roles[0], qs)
48
- conv.append_message(conv.roles[1], None)
49
- prompt = conv.get_prompt()
50
-
51
- image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
52
- image_tensor = process_images([image], self.image_processor, self.model_config)[0]
53
-
54
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
55
-
56
- return input_ids, image_tensor
57
-
58
- def __len__(self):
59
- return len(self.questions)
60
-
61
-
62
- # DataLoader
63
- def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
64
- assert batch_size == 1, "batch_size must be 1"
65
- dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
66
- data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
67
- return data_loader
68
-
69
-
70
- def eval_model(args):
71
- # Model
72
- disable_torch_init()
73
- model_path = os.path.expanduser(args.model_path)
74
- model_name = get_model_name_from_path(model_path)
75
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
76
- args.model_type)
77
-
78
- questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
79
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
80
- answers_file = os.path.expanduser(args.answers_file)
81
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
82
- ans_file = open(answers_file, "w")
83
-
84
- if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
85
- args.conv_mode = args.conv_mode + '_mmtag'
86
- print(
87
- f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
88
-
89
- data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
90
-
91
- for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
92
- idx = line["question_id"]
93
- cur_prompt = line["text"]
94
-
95
- input_ids = input_ids.to(device='cuda', non_blocking=True)
96
-
97
- with torch.inference_mode():
98
- output_ids = model.generate(
99
- input_ids,
100
- images=image_tensor.to(dtype=model.dtype, device='cuda', non_blocking=True),
101
- do_sample=True if args.temperature > 0 else False,
102
- temperature=args.temperature,
103
- top_p=args.top_p,
104
- num_beams=args.num_beams,
105
- max_new_tokens=args.max_new_tokens,
106
- use_cache=True)
107
-
108
- input_token_len = input_ids.shape[1]
109
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
110
- if n_diff_input_output > 0:
111
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
112
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
113
- outputs = outputs.strip()
114
-
115
- ans_id = shortuuid.uuid()
116
- ans_file.write(json.dumps({"question_id": idx,
117
- "prompt": cur_prompt,
118
- "text": outputs,
119
- "answer_id": ans_id,
120
- "model_id": model_name,
121
- "metadata": {}}) + "\n")
122
- # ans_file.flush()
123
- ans_file.close()
124
-
125
-
126
- if __name__ == "__main__":
127
- parser = argparse.ArgumentParser()
128
- parser.add_argument("--model-path", type=str, default=None)
129
- parser.add_argument("--model-base", type=str, default=None)
130
- parser.add_argument("--model-type", type=str, default=None)
131
- parser.add_argument("--image-folder", type=str, default=None)
132
- parser.add_argument("--question-file", type=str, default=None)
133
- parser.add_argument("--answers-file", type=str, default=None)
134
- parser.add_argument("--conv-mode", type=str, default=None)
135
- parser.add_argument("--num-chunks", type=int, default=1)
136
- parser.add_argument("--chunk-idx", type=int, default=0)
137
- parser.add_argument("--temperature", type=float, default=0.2)
138
- parser.add_argument("--top_p", type=float, default=None)
139
- parser.add_argument("--num_beams", type=int, default=1)
140
- parser.add_argument("--max_new_tokens", type=int, default=128)
141
- args = parser.parse_args()
142
-
143
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/eval/model_vqa_mmbench.py DELETED
@@ -1,167 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- import pandas as pd
6
- from tqdm import tqdm
7
- import shortuuid
8
-
9
- from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
10
- from bunny.conversation import conv_templates, SeparatorStyle
11
- from bunny.model.builder import load_pretrained_model
12
- from bunny.util.utils import disable_torch_init
13
- from bunny.util.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, \
14
- get_model_name_from_path
15
-
16
- import math
17
-
18
- all_options = ['A', 'B', 'C', 'D']
19
-
20
-
21
- def split_list(lst, n):
22
- """Split a list into n (roughly) equal-sized chunks"""
23
- chunk_size = math.ceil(len(lst) / n) # integer division
24
- return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
25
-
26
-
27
- def get_chunk(lst, n, k):
28
- chunks = split_list(lst, n)
29
- return chunks[k]
30
-
31
-
32
- def is_none(value):
33
- if value is None:
34
- return True
35
- if type(value) is float and math.isnan(value):
36
- return True
37
- if type(value) is str and value.lower() == 'nan':
38
- return True
39
- if type(value) is str and value.lower() == 'none':
40
- return True
41
- return False
42
-
43
-
44
- def get_options(row, options):
45
- parsed_options = []
46
- for option in options:
47
- option_value = row[option]
48
- if is_none(option_value):
49
- break
50
- parsed_options.append(option_value)
51
- return parsed_options
52
-
53
-
54
- def eval_model(args):
55
- # Model
56
- disable_torch_init()
57
- model_path = os.path.expanduser(args.model_path)
58
- model_name = get_model_name_from_path(model_path)
59
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
60
- args.model_type)
61
-
62
- questions = pd.read_table(os.path.expanduser(args.question_file))
63
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
64
- answers_file = os.path.expanduser(args.answers_file)
65
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
66
- ans_file = open(answers_file, "w")
67
-
68
- for index, row in tqdm(questions.iterrows(), total=len(questions)):
69
- options = get_options(row, all_options)
70
- cur_option_char = all_options[:len(options)]
71
-
72
- if args.all_rounds:
73
- num_rounds = len(options)
74
- else:
75
- num_rounds = 1
76
-
77
- for round_idx in range(num_rounds):
78
- idx = row['index']
79
- question = row['question']
80
- hint = row['hint']
81
- image = load_image_from_base64(row['image'])
82
- if not is_none(hint):
83
- question = hint + '\n' + question
84
- for option_char, option in zip(all_options[:len(options)], options):
85
- question = question + '\n' + option_char + '. ' + option
86
- qs = cur_prompt = question
87
-
88
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
89
-
90
- if args.single_pred_prompt:
91
- if args.lang == 'cn':
92
- qs = qs + '\n' + "请直接回答选项字母。"
93
- else:
94
- qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
95
-
96
- conv = conv_templates[args.conv_mode].copy()
97
- conv.append_message(conv.roles[0], qs)
98
- conv.append_message(conv.roles[1], None)
99
- prompt = conv.get_prompt()
100
-
101
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
102
- 0).cuda()
103
-
104
- image_tensor = process_images([image], image_processor, model.config)[0]
105
-
106
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
107
-
108
- with torch.inference_mode():
109
- output_ids = model.generate(
110
- input_ids,
111
- images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
112
- do_sample=True if args.temperature > 0 else False,
113
- temperature=args.temperature,
114
- top_p=args.top_p,
115
- num_beams=args.num_beams,
116
- # no_repeat_ngram_size=3,
117
- max_new_tokens=128,
118
- use_cache=True)
119
-
120
- input_token_len = input_ids.shape[1]
121
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
122
- if n_diff_input_output > 0:
123
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
124
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
125
- outputs = outputs.strip()
126
- if outputs.endswith(stop_str):
127
- outputs = outputs[:-len(stop_str)]
128
- outputs = outputs.strip()
129
-
130
- ans_id = shortuuid.uuid()
131
- ans_file.write(json.dumps({"question_id": idx,
132
- "round_id": round_idx,
133
- "prompt": cur_prompt,
134
- "text": outputs,
135
- "options": options,
136
- "option_char": cur_option_char,
137
- "answer_id": ans_id,
138
- "model_id": model_name,
139
- "metadata": {}}) + "\n")
140
- ans_file.flush()
141
-
142
- # rotate options
143
- options = options[1:] + options[:1]
144
- cur_option_char = cur_option_char[1:] + cur_option_char[:1]
145
- ans_file.close()
146
-
147
-
148
- if __name__ == "__main__":
149
- parser = argparse.ArgumentParser()
150
- parser.add_argument("--model-path", type=str, default=None)
151
- parser.add_argument("--model-base", type=str, default=None)
152
- parser.add_argument("--model-type", type=str, default=None)
153
- parser.add_argument("--image-folder", type=str, default=None)
154
- parser.add_argument("--question-file", type=str, default=None)
155
- parser.add_argument("--answers-file", type=str, default=None)
156
- parser.add_argument("--conv-mode", type=str, default=None)
157
- parser.add_argument("--num-chunks", type=int, default=1)
158
- parser.add_argument("--chunk-idx", type=int, default=0)
159
- parser.add_argument("--temperature", type=float, default=0.2)
160
- parser.add_argument("--top_p", type=float, default=None)
161
- parser.add_argument("--num_beams", type=int, default=1)
162
- parser.add_argument("--all-rounds", action="store_true")
163
- parser.add_argument("--single-pred-prompt", action="store_true")
164
- parser.add_argument("--lang", type=str, default="en")
165
- args = parser.parse_args()
166
-
167
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/eval/model_vqa_mmmu.py DELETED
@@ -1,326 +0,0 @@
1
- import re
2
- import random
3
- import numpy as np
4
- import os
5
- import json
6
- import yaml
7
- import torch
8
-
9
- from tqdm import tqdm
10
- from datasets import load_dataset, concatenate_datasets
11
- from argparse import ArgumentParser
12
-
13
- from bunny.model.builder import load_pretrained_model
14
- from bunny.util.mm_utils import get_model_name_from_path, tokenizer_image_token, process_images
15
- from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
16
- from bunny.conversation import conv_templates
17
-
18
- CAT_SHORT2LONG = {
19
- 'acc': 'Accounting',
20
- 'agri': 'Agriculture',
21
- 'arch': 'Architecture_and_Engineering',
22
- 'art': 'Art',
23
- 'art_theory': 'Art_Theory',
24
- 'bas_med': 'Basic_Medical_Science',
25
- 'bio': 'Biology',
26
- 'chem': 'Chemistry',
27
- 'cli_med': 'Clinical_Medicine',
28
- 'cs': 'Computer_Science',
29
- 'design': 'Design',
30
- 'diag_med': 'Diagnostics_and_Laboratory_Medicine',
31
- 'econ': 'Economics',
32
- 'elec': 'Electronics',
33
- 'ep': 'Energy_and_Power',
34
- 'fin': 'Finance',
35
- 'geo': 'Geography',
36
- 'his': 'History',
37
- 'liter': 'Literature',
38
- 'manage': 'Manage',
39
- 'mark': 'Marketing',
40
- 'mate': 'Materials',
41
- 'math': 'Math',
42
- 'mech': 'Mechanical_Engineering',
43
- 'music': 'Music',
44
- 'phar': 'Pharmacy',
45
- 'phys': 'Physics',
46
- 'psy': 'Psychology',
47
- 'pub_health': 'Public_Health',
48
- 'socio': 'Sociology'
49
- }
50
-
51
-
52
- # ----------- Process Multi-choice -------------
53
- def parse_multi_choice_response(response, all_choices, index2ans):
54
- """
55
- Parse the prediction from the generated response.
56
- Return the predicted index e.g., A, B, C, D.
57
- """
58
- for char in [',', '.', '!', '?', ';', ':', "'"]:
59
- response = response.strip(char)
60
- response = " " + response + " " # add space to avoid partial match
61
-
62
- index_ans = True
63
- ans_with_brack = False
64
- candidates = []
65
- for choice in all_choices: # e.g., (A) (B) (C) (D)
66
- if f'({choice})' in response:
67
- candidates.append(choice)
68
- ans_with_brack = True
69
-
70
- if len(candidates) == 0:
71
- for choice in all_choices: # e.g., A B C D
72
- if f' {choice} ' in response:
73
- candidates.append(choice)
74
-
75
- # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
76
- if len(candidates) == 0 and len(response.split()) > 5:
77
- for index, ans in index2ans.items():
78
- if ans.lower() in response.lower():
79
- candidates.append(index)
80
- index_ans = False # it's content ans.
81
-
82
- if len(candidates) == 0: # still not get answer, randomly choose one.
83
- pred_index = random.choice(all_choices)
84
- elif len(candidates) > 1:
85
- start_indexes = []
86
- if index_ans:
87
- if ans_with_brack:
88
- for can in candidates:
89
- index = response.rfind(f'({can})')
90
- start_indexes.append(index) # -1 will be ignored anyway
91
- # start_indexes = [generated_response.index(f'({can})') for can in candidates]
92
- else:
93
- for can in candidates:
94
- index = response.rfind(f" {can} ")
95
- start_indexes.append(index)
96
- else:
97
- for can in candidates:
98
- index = response.lower().rfind(index2ans[can].lower())
99
- start_indexes.append(index)
100
- # get the last one
101
- pred_index = candidates[np.argmax(start_indexes)]
102
- else: # if only one candidate, use it.
103
- pred_index = candidates[0]
104
-
105
- return pred_index
106
-
107
-
108
- def call_bunny_engine_df(args, sample, model, tokenizer=None, processor=None):
109
- def deal_with_prompt(input_text):
110
- qs = input_text
111
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
112
- return qs
113
-
114
- prompt = sample['final_input_prompt']
115
- prompt = deal_with_prompt(prompt)
116
-
117
- conv = conv_templates[args.conv_mode].copy()
118
- conv.append_message(conv.roles[0], prompt)
119
- conv.append_message(conv.roles[1], None)
120
- prompt = conv.get_prompt()
121
-
122
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
123
-
124
- image = sample['image']
125
- if image is not None:
126
- output_ids = model.generate(
127
- input_ids,
128
- images=image.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
129
- do_sample=False,
130
- temperature=0,
131
- top_p=None,
132
- # num_beams=5,
133
- max_new_tokens=128,
134
- use_cache=True)
135
-
136
- input_token_len = input_ids.shape[1]
137
- # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
138
- # if n_diff_input_output > 0:
139
- # print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
140
- response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
141
- else: # multiple images actually
142
- if sample['question_type'] == 'multiple-choice':
143
- all_choices = sample['all_choices']
144
- response = random.choice(all_choices)
145
- else:
146
- response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
147
-
148
- return response
149
-
150
-
151
- def load_yaml(file_path):
152
- with open(file_path, 'r') as stream:
153
- try:
154
- yaml_dict = yaml.safe_load(stream)
155
- except yaml.YAMLError as exc:
156
- print(exc)
157
-
158
- return yaml_dict
159
-
160
-
161
- def parse_img_path(text):
162
- matches = re.findall("<img='(.*?)'>", text)
163
- return matches
164
-
165
-
166
- def process_single_sample(data):
167
- question = data['question']
168
- o_imgs_paths = []
169
- for option in data['options']:
170
- current_o_imgs_paths = parse_img_path(option)
171
- for img_path in current_o_imgs_paths:
172
- o_imgs_paths.append(img_path)
173
-
174
- if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
175
- return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
176
- 'image': None, 'question_type': data['question_type']}
177
- else:
178
- return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
179
- 'image': data['image_1'], 'question_type': data['question_type']}
180
-
181
-
182
- # DATA PROCESSING
183
- def construct_prompt(sample, config):
184
- question = sample['question']
185
- options = eval(sample['options'])
186
- example = ""
187
- if sample['question_type'] == 'multiple-choice':
188
- start_chr = 'A'
189
- prediction_range = []
190
- index2ans = {}
191
- for option in options:
192
- prediction_range.append(start_chr)
193
- example += f"({start_chr}) {option}\n"
194
- index2ans[start_chr] = option
195
- start_chr = chr(ord(start_chr) + 1)
196
- empty_prompt_sample_structure = config['multi_choice_example_format']
197
- empty_prompt = empty_prompt_sample_structure.format(question, example)
198
- res_dict = {}
199
- res_dict['index2ans'] = index2ans
200
- res_dict['correct_choice'] = sample['answer']
201
- res_dict['all_choices'] = prediction_range
202
- res_dict['empty_prompt'] = empty_prompt
203
- if config['task_instructions']:
204
- res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
205
- else:
206
- res_dict['final_input_prompt'] = empty_prompt
207
-
208
- res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
209
- else:
210
- empty_prompt_sample_structure = config['short_ans_example_format']
211
- empty_prompt = empty_prompt_sample_structure.format(question)
212
- res_dict = {}
213
- res_dict['empty_prompt'] = empty_prompt
214
- if config['task_instructions']:
215
- res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
216
- else:
217
- res_dict['final_input_prompt'] = empty_prompt
218
- res_dict['gt_content'] = sample['answer']
219
-
220
- res_dict.update(sample)
221
- return res_dict
222
-
223
-
224
- def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
225
- out_samples = dict()
226
- with torch.no_grad():
227
- for sample in tqdm(samples):
228
- if args.small_gpu_usage:
229
- sample['image'] = sample['image'].cuda()
230
- response = call_model_engine_fn(args, sample, model, tokenizer, processor)
231
- if args.small_gpu_usage:
232
- sample['image'] = sample['image'].cpu()
233
-
234
- if sample['question_type'] == 'multiple-choice':
235
- pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans'])
236
- else: # open question
237
- pred_ans = response
238
- out_samples[sample['id']] = pred_ans
239
- return out_samples
240
-
241
-
242
- def set_seed(seed_value):
243
- """
244
- Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
245
-
246
- :param seed_value: An integer value to be used as the seed.
247
- """
248
- torch.manual_seed(seed_value)
249
- if torch.cuda.is_available():
250
- torch.cuda.manual_seed(seed_value)
251
- torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
252
- random.seed(seed_value)
253
- np.random.seed(seed_value)
254
- torch.backends.cudnn.deterministic = True
255
- torch.backends.cudnn.benchmark = False
256
-
257
-
258
- def main():
259
- parser = ArgumentParser()
260
- parser.add_argument('--model-path', type=str, default=None)
261
- parser.add_argument('--model-base', type=str, default=None)
262
- parser.add_argument("--model-type", type=str, default=None)
263
- parser.add_argument("--conv-mode", type=str, default=None)
264
- parser.add_argument('--data-path', type=str, default=None)
265
- parser.add_argument('--config-path', type=str, default=None)
266
- parser.add_argument('--output-path', type=str, default=None)
267
- parser.add_argument('--split', type=str, default='validation')
268
- parser.add_argument('--seed', type=int, default=42)
269
- parser.add_argument("--small-gpu-usage", action="store_true")
270
-
271
- args = parser.parse_args()
272
- device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
273
- set_seed(args.seed)
274
-
275
- print('bunny_initializing...')
276
- processor = None
277
- call_model_engine = call_bunny_engine_df
278
-
279
- # load config and process to one value
280
- args.config = load_yaml(args.config_path)
281
- for key, value in args.config.items():
282
- if key != 'eval_params' and type(value) == list:
283
- assert len(value) == 1, 'key {} has more than one value'.format(key)
284
- args.config[key] = value[0]
285
-
286
- # run for each subject
287
- sub_dataset_list = []
288
- for subject in CAT_SHORT2LONG.values():
289
- sub_dataset = load_dataset(args.data_path, subject, split=args.split)
290
- sub_dataset_list.append(sub_dataset)
291
-
292
- # merge all dataset
293
- dataset = concatenate_datasets(sub_dataset_list)
294
-
295
- # load model
296
- model_path = os.path.expanduser(args.model_path)
297
- model_name = get_model_name_from_path(model_path)
298
- tokenizer, model, vis_processors, context_len = load_pretrained_model(model_path, args.model_base, model_name,
299
- args.model_type)
300
-
301
- samples = []
302
- print('Processing MMMU dataset...')
303
- for sample in tqdm(dataset):
304
- sample = process_single_sample(sample)
305
-
306
- sample = construct_prompt(sample, args.config)
307
- if sample['image']:
308
- sample['image'] = process_images([sample['image'].convert('RGB')], vis_processors, model.config)[0]
309
-
310
- if not args.small_gpu_usage:
311
- sample['image'] = sample['image'].to(device)
312
-
313
- samples.append(sample)
314
-
315
- print('Start to evaluate...')
316
- # run ex
317
- out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
318
-
319
- os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
320
-
321
- with open(args.output_path, 'w') as f:
322
- json.dump(out_samples, f, indent=4)
323
-
324
-
325
- if __name__ == '__main__':
326
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/eval/model_vqa_science.py DELETED
@@ -1,119 +0,0 @@
1
- import argparse
2
- import torch
3
- import os
4
- import json
5
- from tqdm import tqdm
6
- import shortuuid
7
-
8
- from bunny.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9
- from bunny.conversation import conv_templates, SeparatorStyle
10
- from bunny.model.builder import load_pretrained_model
11
- from bunny.util.utils import disable_torch_init
12
- from bunny.util.mm_utils import tokenizer_image_token, get_model_name_from_path
13
-
14
- from PIL import Image
15
- import math
16
-
17
-
18
- def split_list(lst, n):
19
- """Split a list into n (roughly) equal-sized chunks"""
20
- chunk_size = math.ceil(len(lst) / n) # integer division
21
- return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
22
-
23
-
24
- def get_chunk(lst, n, k):
25
- chunks = split_list(lst, n)
26
- return chunks[k]
27
-
28
-
29
- def eval_model(args):
30
- # Model
31
- disable_torch_init()
32
- model_path = os.path.expanduser(args.model_path)
33
- model_name = get_model_name_from_path(model_path)
34
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name,
35
- args.model_type)
36
-
37
- questions = json.load(open(os.path.expanduser(args.question_file), "r"))
38
- questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
39
- answers_file = os.path.expanduser(args.answers_file)
40
- os.makedirs(os.path.dirname(answers_file), exist_ok=True)
41
- ans_file = open(answers_file, "w")
42
- for i, line in enumerate(tqdm(questions)):
43
- idx = line["id"]
44
- question = line['conversations'][0]
45
- qs = question['value'].replace('<image>', '').strip()
46
- cur_prompt = qs
47
-
48
- if 'image' in line:
49
- image_file = line["image"]
50
- image = Image.open(os.path.join(args.image_folder, image_file))
51
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
52
- images = image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True)
53
-
54
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
55
- cur_prompt = '<image>' + '\n' + cur_prompt
56
- else:
57
- images = None
58
-
59
- if args.single_pred_prompt:
60
- qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
61
- cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
62
-
63
- conv = conv_templates[args.conv_mode].copy()
64
- conv.append_message(conv.roles[0], qs)
65
- conv.append_message(conv.roles[1], None)
66
- prompt = conv.get_prompt()
67
-
68
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
69
-
70
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
71
-
72
- with torch.inference_mode():
73
- output_ids = model.generate(
74
- input_ids,
75
- images=images,
76
- do_sample=True if args.temperature > 0 else False,
77
- temperature=args.temperature,
78
- max_new_tokens=1024,
79
- use_cache=True
80
- )
81
-
82
- input_token_len = input_ids.shape[1]
83
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
84
- if n_diff_input_output > 0:
85
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
86
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
87
- outputs = outputs.strip()
88
- if outputs.endswith(stop_str):
89
- outputs = outputs[:-len(stop_str)]
90
- outputs = outputs.strip()
91
-
92
- ans_id = shortuuid.uuid()
93
- ans_file.write(json.dumps({"question_id": idx,
94
- "prompt": cur_prompt,
95
- "text": outputs,
96
- "answer_id": ans_id,
97
- "model_id": model_name,
98
- "metadata": {}}) + "\n")
99
- ans_file.flush()
100
- ans_file.close()
101
-
102
-
103
- if __name__ == "__main__":
104
- parser = argparse.ArgumentParser()
105
- parser.add_argument("--model-path", type=str, default=None)
106
- parser.add_argument("--model-base", type=str, default=None)
107
- parser.add_argument("--model-type", type=str, default=None)
108
- parser.add_argument("--image-folder", type=str, default=None)
109
- parser.add_argument("--question-file", type=str, default=None)
110
- parser.add_argument("--answers-file", type=str, default=None)
111
- parser.add_argument("--conv-mode", type=str, default=None)
112
- parser.add_argument("--num-chunks", type=int, default=1)
113
- parser.add_argument("--chunk-idx", type=int, default=0)
114
- parser.add_argument("--temperature", type=float, default=0.2)
115
- parser.add_argument("--single-pred-prompt", action="store_true")
116
-
117
- args = parser.parse_args()
118
-
119
- eval_model(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .language_model.bunny_phi import BunnyPhiForCausalLM, BunnyPhiConfig
2
- from .language_model.bunny_stablelm import BunnyStableLMForCausalLM, BunnyStableLMConfig
3
- from .language_model.bunny_qwen import BunnyQwen2ForCausalLM, BunnyQwen2Config
4
- from .language_model.bunny_minicpm import BunnyMiniCPMForCausalLM, BunnyMiniCPMConfig
5
- from .language_model.bunny_llama import BunnyLlamaForCausalLM, BunnyLlamaConfig
6
- from .language_model.bunny_phi3 import BunnyPhi3ForCausalLM, BunnyPhi3Config
 
 
 
 
 
 
 
bunny/model/builder.py DELETED
@@ -1,197 +0,0 @@
1
- import os
2
- import warnings
3
- import torch
4
-
5
- from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig, logging
6
-
7
- logging.set_verbosity_error()
8
- warnings.filterwarnings('ignore')
9
-
10
- from bunny.model import *
11
-
12
-
13
- def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False,
14
- device_map="auto", device="cuda", **kwargs):
15
- if model_type not in {'phi-1.5', 'phi-2', 'phi-3', 'stablelm-2', 'qwen1.5-1.8b', 'minicpm', 'llama3-8b'}:
16
- raise ValueError(f"Unknown Model Type {model_type}")
17
-
18
- kwargs = {"device_map": device_map, **kwargs}
19
-
20
- if device != "cuda":
21
- kwargs['device_map'] = {"": device}
22
-
23
- if load_8bit:
24
- kwargs['load_in_8bit'] = True
25
- elif load_4bit:
26
- kwargs['load_in_4bit'] = True
27
- kwargs['quantization_config'] = BitsAndBytesConfig(
28
- load_in_4bit=True,
29
- bnb_4bit_compute_dtype=torch.float16,
30
- bnb_4bit_use_double_quant=True,
31
- bnb_4bit_quant_type='nf4'
32
- )
33
- else:
34
- kwargs['torch_dtype'] = torch.float16
35
-
36
- # Load Bunny model
37
- if 'lora' in model_name.lower() and model_base is None:
38
- warnings.warn(
39
- 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
40
- if 'lora' in model_name.lower() and model_base is not None:
41
- lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
42
-
43
- print('Loading Bunny from base model...')
44
- if model_type == 'phi-1.5' or model_type == 'phi-2':
45
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
46
- model = BunnyPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
47
- config=lora_cfg_pretrained, **kwargs)
48
- elif model_type == 'phi-3':
49
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
50
- model = BunnyPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
51
- config=lora_cfg_pretrained, **kwargs)
52
- elif model_type == 'stablelm-2':
53
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True, trust_remote_code=True)
54
- model = BunnyStableLMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
55
- config=lora_cfg_pretrained, **kwargs)
56
- elif model_type == 'qwen1.5-1.8b':
57
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
58
- model = BunnyQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained,
59
- **kwargs)
60
- elif model_type == 'minicpm':
61
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
62
- model = BunnyMiniCPMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
63
- config=lora_cfg_pretrained,
64
- **kwargs)
65
- elif model_type == 'llama3-8b':
66
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
67
- model = BunnyLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
68
- config=lora_cfg_pretrained,
69
- **kwargs)
70
-
71
- token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
72
- if model.lm_head.weight.shape[0] != token_num:
73
- model.lm_head.weight = torch.nn.Parameter(
74
- torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
75
- model.model.embed_tokens.weight = torch.nn.Parameter(
76
- torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
77
-
78
- print('Loading additional Bunny weights...')
79
- if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
80
- non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
81
- else:
82
- # this is probably from HF Hub
83
- from huggingface_hub import hf_hub_download
84
- def load_from_hf(repo_id, filename, subfolder=None):
85
- cache_file = hf_hub_download(
86
- repo_id=repo_id,
87
- filename=filename,
88
- subfolder=subfolder)
89
- return torch.load(cache_file, map_location='cpu')
90
-
91
- non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
92
-
93
- non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
94
- non_lora_trainables.items()}
95
- if any(k.startswith('model.model.') for k in non_lora_trainables):
96
- non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in
97
- non_lora_trainables.items()}
98
- model.load_state_dict(non_lora_trainables, strict=False)
99
-
100
- from peft import PeftModel
101
- print('Loading LoRA weights...')
102
- model = PeftModel.from_pretrained(model, model_path)
103
- print('Merging LoRA weights...')
104
- model = model.merge_and_unload()
105
- print('Model is loaded...')
106
- elif model_base is not None:
107
- # this may be mm projector only
108
- print('Loading Bunny from base model...')
109
-
110
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
111
- if model_type == 'phi-1.5' or model_type == 'phi-2':
112
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
113
- model = BunnyPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
114
- config=cfg_pretrained, **kwargs)
115
- elif model_type == 'phi-3':
116
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
117
- model = BunnyPhi3ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
118
- config=cfg_pretrained, **kwargs)
119
- elif model_type == 'stablelm-2':
120
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True, trust_remote_code=True)
121
- model = BunnyStableLMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
122
- config=cfg_pretrained, **kwargs)
123
- elif model_type == 'qwen1.5-1.8b':
124
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
125
- model = BunnyQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
126
- **kwargs)
127
- elif model_type == 'minicpm':
128
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
129
- model = BunnyMiniCPMForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
130
- **kwargs)
131
- elif model_type == 'llama3-8b':
132
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
133
- model = BunnyLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
134
- **kwargs)
135
-
136
- mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
137
- mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
138
- model.load_state_dict(mm_projector_weights, strict=False)
139
- else:
140
- if model_type == 'phi-1.5' or model_type == 'phi-2':
141
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
142
- model = BunnyPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
143
- elif model_type == 'phi-3':
144
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
145
- model = BunnyPhi3ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
146
- elif model_type == 'stablelm-2':
147
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
148
- model = BunnyStableLMForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
149
- elif model_type == 'qwen1.5-1.8b':
150
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
151
- model = BunnyQwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
152
- elif model_type == 'minicpm':
153
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
154
- model = BunnyMiniCPMForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
155
- elif model_type == 'llama3-8b':
156
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
157
- model = BunnyLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
158
-
159
- model.resize_token_embeddings(len(tokenizer))
160
-
161
- vision_tower = model.get_vision_tower()
162
- if not vision_tower.is_loaded:
163
- vision_tower.load_model()
164
-
165
- # if getattr(model.config, "unfreeze_vision_tower", False):
166
- # if 'lora' in model_name.lower():
167
- # assert model_base is not None
168
- # vision_non_lora_trainables = {k[19:]: v for k, v in non_lora_trainables.items() if
169
- # k.startswith('model.vision_tower.')}
170
- # vision_tower.load_state_dict(vision_non_lora_trainables, strict=False)
171
- # else:
172
- # assert model_base is None
173
- # from safetensors.torch import load_file
174
- # vision_weights = {}
175
- # for file_name in os.listdir(model_path):
176
- # if file_name.endswith('safetensors'):
177
- # vision_weights.update(
178
- # {k[19:]: v for k, v in load_file(os.path.join(model_path, file_name)).items() if
179
- # k.startswith('model.vision_tower.')})
180
- # vision_tower.load_state_dict(vision_weights, strict=True)
181
-
182
- vision_tower.to(device=device, dtype=torch.float16)
183
- image_processor = vision_tower.image_processor
184
-
185
- if hasattr(model.config, "max_sequence_length"):
186
- context_len = model.config.max_sequence_length
187
- else:
188
- context_len = 2048
189
-
190
- if model_type == 'llama3-8b':
191
- tokenizer.eos_token_id = 128001
192
- model.generation_config.pad_token_id = tokenizer.eos_token_id
193
-
194
- if model.generation_config.pad_token_id is None:
195
- model.generation_config.pad_token_id = model.generation_config.eos_token_id
196
-
197
- return tokenizer, model, image_processor, context_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/bunny_arch.py DELETED
@@ -1,230 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import torch
4
-
5
- from .multimodal_encoder.builder import build_vision_tower
6
- from .multimodal_projector.builder import build_vision_projector
7
-
8
- from bunny.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
9
-
10
-
11
- class BunnyMetaModel:
12
-
13
- def __init__(self, config):
14
- super(BunnyMetaModel, self).__init__(config)
15
-
16
- if hasattr(config, "mm_vision_tower"):
17
- self.vision_tower = build_vision_tower(config, delay_load=False)
18
- # self.vision_tower = build_vision_tower(config, delay_load=not getattr(config, 'continuous_training', False))
19
- if getattr(config, 'continuous_training', False):
20
- config.continuous_training = False
21
- self.mm_projector = build_vision_projector(config)
22
-
23
- def get_vision_tower(self):
24
- vision_tower = getattr(self, 'vision_tower', None)
25
- if type(vision_tower) is list:
26
- vision_tower = vision_tower[0]
27
- return vision_tower
28
-
29
- def initialize_vision_modules(self, model_args):
30
- vision_tower = model_args.vision_tower
31
-
32
- pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
33
-
34
- self.config.mm_vision_tower = vision_tower
35
-
36
- if self.get_vision_tower() is None:
37
- vision_tower = build_vision_tower(model_args)
38
- self.vision_tower = vision_tower
39
- else:
40
- vision_tower = self.vision_tower
41
- vision_tower.load_model()
42
-
43
- self.config.use_mm_proj = True
44
- self.config.mm_projector_type = getattr(model_args, 'mm_projector_type')
45
- self.config.mm_hidden_size = vision_tower.hidden_size
46
-
47
- if getattr(self, 'mm_projector', None) is None:
48
- self.mm_projector = build_vision_projector(self.config)
49
- else:
50
- # In case it is frozen by LoRA
51
- for p in self.mm_projector.parameters():
52
- p.requires_grad = True
53
-
54
- if pretrain_mm_mlp_adapter is not None:
55
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
56
-
57
- def get_w(weights, keyword):
58
- return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
59
-
60
- self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
61
-
62
-
63
- class BunnyMetaForCausalLM(ABC):
64
-
65
- @abstractmethod
66
- def get_model(self):
67
- pass
68
-
69
- def get_vision_tower(self):
70
- return self.get_model().get_vision_tower()
71
-
72
- def encode_images(self, images):
73
- image_features = self.get_model().get_vision_tower()(images)
74
- image_features = self.get_model().mm_projector(image_features)
75
- return image_features
76
-
77
- def prepare_inputs_labels_for_multimodal(
78
- self, input_ids, position_ids, attention_mask, past_key_values, labels, images
79
- ):
80
- vision_tower = self.get_vision_tower()
81
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
82
- if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
83
- 1] == 1:
84
- target_shape = past_key_values[-1][-1].shape[-2] + 1
85
- attention_mask = torch.cat((attention_mask, torch.ones(
86
- (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
87
- dtype=attention_mask.dtype,
88
- device=attention_mask.device
89
- )), dim=1)
90
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
91
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
92
-
93
- if type(images) is list or images.ndim == 5:
94
- concat_images = torch.cat([image for image in images], dim=0)
95
- image_features = self.encode_images(concat_images)
96
- split_sizes = [image.shape[0] for image in images]
97
- image_features = torch.split(image_features, split_sizes, dim=0)
98
- image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
99
- else:
100
- image_features = self.encode_images(images).to(self.device)
101
-
102
- # Let's just add dummy tensors if they do not exist,
103
- # it is a headache to deal with None all the time.
104
- # But it is not ideal, and if you have a better idea,
105
- # please open an issue / submit a PR, thanks.
106
- _labels = labels
107
- _position_ids = position_ids
108
- _attention_mask = attention_mask
109
- if attention_mask is None:
110
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
111
- else:
112
- attention_mask = attention_mask.bool()
113
- if position_ids is None:
114
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
115
- if labels is None:
116
- labels = torch.full_like(input_ids, IGNORE_INDEX)
117
-
118
- input_ids_temp = input_ids # points to the actual input_ids tensor
119
-
120
- # remove the padding using attention_mask -- TODO: double check
121
- input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
122
- zip(input_ids, attention_mask)]
123
- labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
124
-
125
- # -- TODO: better implementation?
126
- # replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
127
- input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
128
-
129
- new_input_embeds = []
130
- new_labels = []
131
- cur_image_idx = 0
132
- for batch_idx, cur_input_ids in enumerate(input_ids):
133
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
134
- if num_images == 0:
135
- cur_image_features = image_features[cur_image_idx]
136
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
137
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
138
- new_input_embeds.append(cur_input_embeds)
139
- new_labels.append(labels[batch_idx])
140
- cur_image_idx += 1
141
- continue
142
-
143
- image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
144
- cur_input_ids.shape[0]]
145
- cur_input_ids_noim = []
146
- cur_labels = labels[batch_idx]
147
- cur_labels_noim = []
148
- for i in range(len(image_token_indices) - 1):
149
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
150
- cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
151
- split_sizes = [x.shape[0] for x in cur_labels_noim]
152
- cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
153
- cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
154
- cur_new_input_embeds = []
155
- cur_new_labels = []
156
-
157
- for i in range(num_images + 1):
158
- cur_new_input_embeds.append(cur_input_embeds_no_im[i])
159
- cur_new_labels.append(cur_labels_noim[i])
160
- if i < num_images:
161
- cur_image_features = image_features[cur_image_idx]
162
- cur_image_idx += 1
163
- cur_new_input_embeds.append(cur_image_features)
164
- cur_new_labels.append(
165
- torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device,
166
- dtype=cur_labels.dtype))
167
-
168
- cur_new_input_embeds = torch.cat(cur_new_input_embeds)
169
- cur_new_labels = torch.cat(cur_new_labels)
170
-
171
- new_input_embeds.append(cur_new_input_embeds)
172
- new_labels.append(cur_new_labels)
173
-
174
- # Truncate sequences to max length as image embeddings can make the sequence longer
175
- tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
176
- if tokenizer_model_max_length is not None:
177
- new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
178
- new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
179
-
180
- # Combine them
181
- max_len = max(x.shape[0] for x in new_input_embeds)
182
- batch_size = len(new_input_embeds)
183
-
184
- new_input_embeds_padded = []
185
- new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype,
186
- device=new_labels[0].device)
187
- attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
188
- position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
189
-
190
- for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
191
- cur_len = cur_new_embed.shape[0]
192
- if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
193
- new_input_embeds_padded.append(torch.cat((
194
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
195
- device=cur_new_embed.device),
196
- cur_new_embed
197
- ), dim=0))
198
- if cur_len > 0:
199
- new_labels_padded[i, -cur_len:] = cur_new_labels
200
- attention_mask[i, -cur_len:] = True
201
- position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype,
202
- device=position_ids.device)
203
- else:
204
- new_input_embeds_padded.append(torch.cat((
205
- cur_new_embed,
206
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
207
- device=cur_new_embed.device)
208
- ), dim=0))
209
- if cur_len > 0:
210
- new_labels_padded[i, :cur_len] = cur_new_labels
211
- attention_mask[i, :cur_len] = True
212
- position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype,
213
- device=position_ids.device)
214
-
215
- new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
216
-
217
- if _labels is None:
218
- new_labels = None
219
- else:
220
- new_labels = new_labels_padded
221
-
222
- if _attention_mask is None:
223
- attention_mask = None
224
- else:
225
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
226
-
227
- if _position_ids is None:
228
- position_ids = None
229
-
230
- return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/bunny_llama.py DELETED
@@ -1,102 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoConfig, AutoModelForCausalLM
6
-
7
- from .llama import LlamaModel, LlamaConfig, LlamaForCausalLM
8
-
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
-
11
- from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
-
13
-
14
- class BunnyLlamaConfig(LlamaConfig):
15
- model_type = "bunny-llama"
16
-
17
-
18
- class BunnyLlamaModel(BunnyMetaModel, LlamaModel):
19
- config_class = BunnyLlamaConfig
20
-
21
- def __init__(self, config: LlamaConfig):
22
- super(BunnyLlamaModel, self).__init__(config)
23
-
24
-
25
- class BunnyLlamaForCausalLM(LlamaForCausalLM, BunnyMetaForCausalLM):
26
- config_class = BunnyLlamaConfig
27
-
28
- def __init__(self, config):
29
- super(LlamaForCausalLM, self).__init__(config)
30
- self.model = BunnyLlamaModel(config)
31
- self.vocab_size = config.vocab_size
32
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
-
34
- # Initialize weights and apply final processing
35
- self.post_init()
36
-
37
- def get_model(self):
38
- return self.model
39
-
40
- def forward(
41
- self,
42
- input_ids: torch.LongTensor = None,
43
- attention_mask: Optional[torch.Tensor] = None,
44
- position_ids: Optional[torch.LongTensor] = None,
45
- past_key_values: Optional[List[torch.FloatTensor]] = None,
46
- inputs_embeds: Optional[torch.FloatTensor] = None,
47
- labels: Optional[torch.LongTensor] = None,
48
- use_cache: Optional[bool] = None,
49
- output_attentions: Optional[bool] = None,
50
- output_hidden_states: Optional[bool] = None,
51
- images: Optional[torch.FloatTensor] = None,
52
- return_dict: Optional[bool] = None,
53
- cache_position: Optional[torch.LongTensor] = None,
54
- ) -> Union[Tuple, CausalLMOutputWithPast]:
55
- if inputs_embeds is None:
56
- (
57
- input_ids,
58
- position_ids,
59
- attention_mask,
60
- past_key_values,
61
- inputs_embeds,
62
- labels
63
- ) = self.prepare_inputs_labels_for_multimodal(
64
- input_ids,
65
- position_ids,
66
- attention_mask,
67
- past_key_values,
68
- labels,
69
- images
70
- )
71
-
72
- return super().forward(
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=position_ids,
76
- past_key_values=past_key_values,
77
- inputs_embeds=inputs_embeds,
78
- labels=labels,
79
- use_cache=use_cache,
80
- output_attentions=output_attentions,
81
- output_hidden_states=output_hidden_states,
82
- return_dict=return_dict,
83
- cache_position=None
84
- )
85
-
86
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
87
- **kwargs):
88
- images = kwargs.pop("images", None)
89
-
90
- _inputs = super().prepare_inputs_for_generation(
91
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
92
- **kwargs
93
- )
94
-
95
- if images is not None:
96
- _inputs['images'] = images
97
-
98
- return _inputs
99
-
100
-
101
- AutoConfig.register("bunny-llama", BunnyLlamaConfig)
102
- AutoModelForCausalLM.register(BunnyLlamaConfig, BunnyLlamaForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/bunny_minicpm.py DELETED
@@ -1,103 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoConfig, AutoModelForCausalLM
6
-
7
- from bunny.model.language_model.minicpm.modeling_minicpm import MiniCPMModel, MiniCPMForCausalLM
8
- from bunny.model.language_model.minicpm.configuration_minicpm import MiniCPMConfig
9
-
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
-
12
- from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
13
-
14
-
15
- class BunnyMiniCPMConfig(MiniCPMConfig):
16
- model_type = "bunny-minicpm"
17
-
18
-
19
- class BunnyMiniCPMModel(BunnyMetaModel, MiniCPMModel):
20
- config_class = BunnyMiniCPMConfig
21
-
22
- def __init__(self, config: MiniCPMConfig):
23
- super(BunnyMiniCPMModel, self).__init__(config)
24
-
25
-
26
- class BunnyMiniCPMForCausalLM(MiniCPMForCausalLM, BunnyMetaForCausalLM):
27
- config_class = BunnyMiniCPMConfig
28
-
29
- def __init__(self, config):
30
- super(MiniCPMForCausalLM, self).__init__(config)
31
- self.model = BunnyMiniCPMModel(config)
32
- self.vocab_size = config.vocab_size
33
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
-
35
- # Initialize weights and apply final processing
36
- self.post_init()
37
-
38
- def get_model(self):
39
- return self.model
40
-
41
- def forward(
42
- self,
43
- input_ids: torch.LongTensor = None,
44
- attention_mask: Optional[torch.Tensor] = None,
45
- position_ids: Optional[torch.LongTensor] = None,
46
- past_key_values: Optional[List[torch.FloatTensor]] = None,
47
- inputs_embeds: Optional[torch.FloatTensor] = None,
48
- labels: Optional[torch.LongTensor] = None,
49
- use_cache: Optional[bool] = None,
50
- output_attentions: Optional[bool] = None,
51
- output_hidden_states: Optional[bool] = None,
52
- images: Optional[torch.FloatTensor] = None,
53
- return_dict: Optional[bool] = None,
54
- ) -> Union[Tuple, CausalLMOutputWithPast]:
55
-
56
- if inputs_embeds is None:
57
- (
58
- input_ids,
59
- position_ids,
60
- attention_mask,
61
- past_key_values,
62
- inputs_embeds,
63
- labels
64
- ) = self.prepare_inputs_labels_for_multimodal(
65
- input_ids,
66
- position_ids,
67
- attention_mask,
68
- past_key_values,
69
- labels,
70
- images
71
- )
72
- if inputs_embeds is not None:
73
- inputs_embeds *= self.get_model().config.scale_emb
74
-
75
- return super().forward(
76
- input_ids=input_ids,
77
- attention_mask=attention_mask,
78
- position_ids=position_ids,
79
- past_key_values=past_key_values,
80
- inputs_embeds=inputs_embeds,
81
- labels=labels,
82
- use_cache=use_cache,
83
- output_attentions=output_attentions,
84
- output_hidden_states=output_hidden_states,
85
- return_dict=return_dict
86
- )
87
-
88
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
89
- **kwargs):
90
- images = kwargs.pop("images", None)
91
-
92
- _inputs = super().prepare_inputs_for_generation(
93
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
94
- **kwargs
95
- )
96
-
97
- if images is not None:
98
- _inputs['images'] = images
99
- return _inputs
100
-
101
-
102
- AutoConfig.register("bunny-minicpm", BunnyMiniCPMConfig)
103
- AutoModelForCausalLM.register(BunnyMiniCPMConfig, BunnyMiniCPMForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/bunny_phi.py DELETED
@@ -1,100 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoConfig, AutoModelForCausalLM
6
-
7
- from .phi import PhiModel, PhiConfig, PhiForCausalLM
8
-
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
-
11
- from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
-
13
-
14
- class BunnyPhiConfig(PhiConfig):
15
- model_type = "bunny-phi"
16
-
17
-
18
- class BunnyPhiModel(BunnyMetaModel, PhiModel):
19
- config_class = BunnyPhiConfig
20
-
21
- def __init__(self, config: PhiConfig):
22
- super(BunnyPhiModel, self).__init__(config)
23
-
24
-
25
- class BunnyPhiForCausalLM(PhiForCausalLM, BunnyMetaForCausalLM):
26
- config_class = BunnyPhiConfig
27
-
28
- def __init__(self, config):
29
- super(PhiForCausalLM, self).__init__(config)
30
- self.model = BunnyPhiModel(config)
31
- self.vocab_size = config.vocab_size
32
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
-
34
- # Initialize weights and apply final processing
35
- self.post_init()
36
-
37
- def get_model(self):
38
- return self.model
39
-
40
- def forward(
41
- self,
42
- input_ids: torch.LongTensor = None,
43
- attention_mask: Optional[torch.Tensor] = None,
44
- position_ids: Optional[torch.LongTensor] = None,
45
- past_key_values: Optional[List[torch.FloatTensor]] = None,
46
- inputs_embeds: Optional[torch.FloatTensor] = None,
47
- labels: Optional[torch.LongTensor] = None,
48
- use_cache: Optional[bool] = None,
49
- output_attentions: Optional[bool] = None,
50
- output_hidden_states: Optional[bool] = None,
51
- images: Optional[torch.FloatTensor] = None,
52
- return_dict: Optional[bool] = None,
53
- ) -> Union[Tuple, CausalLMOutputWithPast]:
54
-
55
- if inputs_embeds is None:
56
- (
57
- input_ids,
58
- position_ids,
59
- attention_mask,
60
- past_key_values,
61
- inputs_embeds,
62
- labels
63
- ) = self.prepare_inputs_labels_for_multimodal(
64
- input_ids,
65
- position_ids,
66
- attention_mask,
67
- past_key_values,
68
- labels,
69
- images
70
- )
71
-
72
- return super().forward(
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=position_ids,
76
- past_key_values=past_key_values,
77
- inputs_embeds=inputs_embeds,
78
- labels=labels,
79
- use_cache=use_cache,
80
- output_attentions=output_attentions,
81
- output_hidden_states=output_hidden_states,
82
- return_dict=return_dict
83
- )
84
-
85
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
- **kwargs):
87
- images = kwargs.pop("images", None)
88
-
89
- _inputs = super().prepare_inputs_for_generation(
90
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
- **kwargs
92
- )
93
-
94
- if images is not None:
95
- _inputs['images'] = images
96
- return _inputs
97
-
98
-
99
- AutoConfig.register("bunny-phi", BunnyPhiConfig)
100
- AutoModelForCausalLM.register(BunnyPhiConfig, BunnyPhiForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/bunny_phi3.py DELETED
@@ -1,100 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoConfig, AutoModelForCausalLM
6
-
7
- from .phi3 import Phi3Model, Phi3Config, Phi3ForCausalLM
8
-
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
-
11
- from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
-
13
-
14
- class BunnyPhi3Config(Phi3Config):
15
- model_type = "bunny-phi3"
16
-
17
-
18
- class BunnyPhi3Model(BunnyMetaModel, Phi3Model):
19
- config_class = BunnyPhi3Config
20
-
21
- def __init__(self, config: Phi3Config):
22
- super(BunnyPhi3Model, self).__init__(config)
23
-
24
-
25
- class BunnyPhi3ForCausalLM(Phi3ForCausalLM, BunnyMetaForCausalLM):
26
- config_class = BunnyPhi3Config
27
-
28
- def __init__(self, config):
29
- super(Phi3ForCausalLM, self).__init__(config)
30
- self.model = BunnyPhi3Model(config)
31
- self.vocab_size = config.vocab_size
32
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
-
34
- # Initialize weights and apply final processing
35
- self.post_init()
36
-
37
- def get_model(self):
38
- return self.model
39
-
40
- def forward(
41
- self,
42
- input_ids: torch.LongTensor = None,
43
- attention_mask: Optional[torch.Tensor] = None,
44
- position_ids: Optional[torch.LongTensor] = None,
45
- past_key_values: Optional[List[torch.FloatTensor]] = None,
46
- inputs_embeds: Optional[torch.FloatTensor] = None,
47
- labels: Optional[torch.LongTensor] = None,
48
- use_cache: Optional[bool] = None,
49
- output_attentions: Optional[bool] = None,
50
- output_hidden_states: Optional[bool] = None,
51
- images: Optional[torch.FloatTensor] = None,
52
- return_dict: Optional[bool] = None,
53
- ) -> Union[Tuple, CausalLMOutputWithPast]:
54
-
55
- if inputs_embeds is None:
56
- (
57
- input_ids,
58
- position_ids,
59
- attention_mask,
60
- past_key_values,
61
- inputs_embeds,
62
- labels
63
- ) = self.prepare_inputs_labels_for_multimodal(
64
- input_ids,
65
- position_ids,
66
- attention_mask,
67
- past_key_values,
68
- labels,
69
- images
70
- )
71
-
72
- return super().forward(
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=position_ids,
76
- past_key_values=past_key_values,
77
- inputs_embeds=inputs_embeds,
78
- labels=labels,
79
- use_cache=use_cache,
80
- output_attentions=output_attentions,
81
- output_hidden_states=output_hidden_states,
82
- return_dict=return_dict
83
- )
84
-
85
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
- **kwargs):
87
- images = kwargs.pop("images", None)
88
-
89
- _inputs = super().prepare_inputs_for_generation(
90
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
- **kwargs
92
- )
93
-
94
- if images is not None:
95
- _inputs['images'] = images
96
- return _inputs
97
-
98
-
99
- AutoConfig.register("bunny-phi3", BunnyPhi3Config)
100
- AutoModelForCausalLM.register(BunnyPhi3Config, BunnyPhi3ForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/bunny_qwen.py DELETED
@@ -1,100 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoConfig, AutoModelForCausalLM
6
-
7
- from .qwen2 import Qwen2Model, Qwen2Config, Qwen2ForCausalLM
8
-
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
-
11
- from ..bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
12
-
13
-
14
- class BunnyQwen2Config(Qwen2Config):
15
- model_type = "bunny-qwen2"
16
-
17
-
18
- class BunnyQwen2Model(BunnyMetaModel, Qwen2Model):
19
- config_class = BunnyQwen2Config
20
-
21
- def __init__(self, config: Qwen2Config):
22
- super(BunnyQwen2Model, self).__init__(config)
23
-
24
-
25
- class BunnyQwen2ForCausalLM(Qwen2ForCausalLM, BunnyMetaForCausalLM):
26
- config_class = BunnyQwen2Config
27
-
28
- def __init__(self, config):
29
- super(Qwen2ForCausalLM, self).__init__(config)
30
- self.model = BunnyQwen2Model(config)
31
- self.vocab_size = config.vocab_size
32
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
33
-
34
- # Initialize weights and apply final processing
35
- self.post_init()
36
-
37
- def get_model(self):
38
- return self.model
39
-
40
- def forward(
41
- self,
42
- input_ids: torch.LongTensor = None,
43
- attention_mask: Optional[torch.Tensor] = None,
44
- position_ids: Optional[torch.LongTensor] = None,
45
- past_key_values: Optional[List[torch.FloatTensor]] = None,
46
- inputs_embeds: Optional[torch.FloatTensor] = None,
47
- labels: Optional[torch.LongTensor] = None,
48
- use_cache: Optional[bool] = None,
49
- output_attentions: Optional[bool] = None,
50
- output_hidden_states: Optional[bool] = None,
51
- images: Optional[torch.FloatTensor] = None,
52
- return_dict: Optional[bool] = None,
53
- ) -> Union[Tuple, CausalLMOutputWithPast]:
54
-
55
- if inputs_embeds is None:
56
- (
57
- input_ids,
58
- position_ids,
59
- attention_mask,
60
- past_key_values,
61
- inputs_embeds,
62
- labels
63
- ) = self.prepare_inputs_labels_for_multimodal(
64
- input_ids,
65
- position_ids,
66
- attention_mask,
67
- past_key_values,
68
- labels,
69
- images
70
- )
71
-
72
- return super().forward(
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=position_ids,
76
- past_key_values=past_key_values,
77
- inputs_embeds=inputs_embeds,
78
- labels=labels,
79
- use_cache=use_cache,
80
- output_attentions=output_attentions,
81
- output_hidden_states=output_hidden_states,
82
- return_dict=return_dict
83
- )
84
-
85
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
- **kwargs):
87
- images = kwargs.pop("images", None)
88
-
89
- _inputs = super().prepare_inputs_for_generation(
90
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
- **kwargs
92
- )
93
-
94
- if images is not None:
95
- _inputs['images'] = images
96
- return _inputs
97
-
98
-
99
- AutoConfig.register("bunny-qwen2", BunnyQwen2Config)
100
- AutoModelForCausalLM.register(BunnyQwen2Config, BunnyQwen2ForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/bunny_stablelm.py DELETED
@@ -1,100 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from transformers import AutoConfig, AutoModelForCausalLM
6
-
7
- from bunny.model.language_model.stable_lm.modeling_stablelm_epoch import StableLMEpochModel, StableLMEpochConfig, \
8
- StableLMEpochForCausalLM
9
-
10
- from transformers.modeling_outputs import CausalLMOutputWithPast
11
-
12
- from bunny.model.bunny_arch import BunnyMetaModel, BunnyMetaForCausalLM
13
-
14
-
15
- class BunnyStableLMConfig(StableLMEpochConfig):
16
- model_type = "bunny-stablelm"
17
-
18
-
19
- class BunnyStableLMModel(BunnyMetaModel, StableLMEpochModel):
20
- config_class = BunnyStableLMConfig
21
-
22
- def __init__(self, config: StableLMEpochConfig):
23
- super(BunnyStableLMModel, self).__init__(config)
24
-
25
-
26
- class BunnyStableLMForCausalLM(StableLMEpochForCausalLM, BunnyMetaForCausalLM):
27
- config_class = BunnyStableLMConfig
28
-
29
- def __init__(self, config):
30
- super(StableLMEpochForCausalLM, self).__init__(config)
31
- self.model = BunnyStableLMModel(config)
32
- self.vocab_size = config.vocab_size
33
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
34
-
35
- # Initialize weights and apply final processing
36
- self.post_init()
37
-
38
- def get_model(self):
39
- return self.model
40
-
41
- def forward(
42
- self,
43
- input_ids: torch.LongTensor = None,
44
- attention_mask: Optional[torch.Tensor] = None,
45
- position_ids: Optional[torch.LongTensor] = None,
46
- past_key_values: Optional[List[torch.FloatTensor]] = None,
47
- inputs_embeds: Optional[torch.FloatTensor] = None,
48
- labels: Optional[torch.LongTensor] = None,
49
- use_cache: Optional[bool] = None,
50
- output_attentions: Optional[bool] = None,
51
- output_hidden_states: Optional[bool] = None,
52
- images: Optional[torch.FloatTensor] = None,
53
- return_dict: Optional[bool] = None,
54
- ) -> Union[Tuple, CausalLMOutputWithPast]:
55
- if inputs_embeds is None:
56
- (
57
- input_ids,
58
- position_ids,
59
- attention_mask,
60
- past_key_values,
61
- inputs_embeds,
62
- labels
63
- ) = self.prepare_inputs_labels_for_multimodal(
64
- input_ids,
65
- position_ids,
66
- attention_mask,
67
- past_key_values,
68
- labels,
69
- images
70
- )
71
-
72
- return super().forward(
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=position_ids,
76
- past_key_values=past_key_values,
77
- inputs_embeds=inputs_embeds,
78
- labels=labels,
79
- use_cache=use_cache,
80
- output_attentions=output_attentions,
81
- output_hidden_states=output_hidden_states,
82
- return_dict=return_dict
83
- )
84
-
85
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
86
- **kwargs):
87
- images = kwargs.pop("images", None)
88
-
89
- _inputs = super().prepare_inputs_for_generation(
90
- input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
91
- **kwargs
92
- )
93
-
94
- if images is not None:
95
- _inputs['images'] = images
96
- return _inputs
97
-
98
-
99
- AutoConfig.register("bunny-stablelm", BunnyStableLMConfig)
100
- AutoModelForCausalLM.register(BunnyStableLMConfig, BunnyStableLMForCausalLM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/llama/__init__.py DELETED
@@ -1,114 +0,0 @@
1
- # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import TYPE_CHECKING
15
-
16
- from transformers.utils import (
17
- OptionalDependencyNotAvailable,
18
- _LazyModule,
19
- is_flax_available,
20
- is_sentencepiece_available,
21
- is_tokenizers_available,
22
- is_torch_available,
23
- )
24
-
25
-
26
- _import_structure = {
27
- "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"],
28
- }
29
-
30
- try:
31
- if not is_sentencepiece_available():
32
- raise OptionalDependencyNotAvailable()
33
- except OptionalDependencyNotAvailable:
34
- pass
35
- else:
36
- _import_structure["tokenization_llama"] = ["LlamaTokenizer"]
37
-
38
- try:
39
- if not is_tokenizers_available():
40
- raise OptionalDependencyNotAvailable()
41
- except OptionalDependencyNotAvailable:
42
- pass
43
- else:
44
- _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
45
-
46
- try:
47
- if not is_torch_available():
48
- raise OptionalDependencyNotAvailable()
49
- except OptionalDependencyNotAvailable:
50
- pass
51
- else:
52
- _import_structure["modeling_llama"] = [
53
- "LlamaForCausalLM",
54
- "LlamaModel",
55
- "LlamaPreTrainedModel",
56
- "LlamaForSequenceClassification",
57
- "LlamaForQuestionAnswering",
58
- ]
59
-
60
- try:
61
- if not is_flax_available():
62
- raise OptionalDependencyNotAvailable()
63
- except OptionalDependencyNotAvailable:
64
- pass
65
- else:
66
- _import_structure["modeling_flax_llama"] = ["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]
67
-
68
-
69
- if TYPE_CHECKING:
70
- from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
71
-
72
- try:
73
- if not is_sentencepiece_available():
74
- raise OptionalDependencyNotAvailable()
75
- except OptionalDependencyNotAvailable:
76
- pass
77
- else:
78
- from .tokenization_llama import LlamaTokenizer
79
-
80
- try:
81
- if not is_tokenizers_available():
82
- raise OptionalDependencyNotAvailable()
83
- except OptionalDependencyNotAvailable:
84
- pass
85
- else:
86
- from .tokenization_llama_fast import LlamaTokenizerFast
87
-
88
- try:
89
- if not is_torch_available():
90
- raise OptionalDependencyNotAvailable()
91
- except OptionalDependencyNotAvailable:
92
- pass
93
- else:
94
- from .modeling_llama import (
95
- LlamaForCausalLM,
96
- LlamaForQuestionAnswering,
97
- LlamaForSequenceClassification,
98
- LlamaModel,
99
- LlamaPreTrainedModel,
100
- )
101
-
102
- try:
103
- if not is_flax_available():
104
- raise OptionalDependencyNotAvailable()
105
- except OptionalDependencyNotAvailable:
106
- pass
107
- else:
108
- from .modeling_flax_llama import FlaxLlamaForCausalLM, FlaxLlamaModel, FlaxLlamaPreTrainedModel
109
-
110
-
111
- else:
112
- import sys
113
-
114
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/llama/configuration_llama.py DELETED
@@ -1,191 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ LLaMA model configuration"""
21
-
22
- from transformers.configuration_utils import PretrainedConfig
23
- from transformers.utils import logging
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
-
29
- # from ..deprecated._archive_maps import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
30
-
31
-
32
- class LlamaConfig(PretrainedConfig):
33
- r"""
34
- This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
35
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
- defaults will yield a similar configuration to that of the LLaMA-7B.
37
-
38
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
- documentation from [`PretrainedConfig`] for more information.
40
-
41
-
42
- Args:
43
- vocab_size (`int`, *optional*, defaults to 32000):
44
- Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
45
- `inputs_ids` passed when calling [`LlamaModel`]
46
- hidden_size (`int`, *optional*, defaults to 4096):
47
- Dimension of the hidden representations.
48
- intermediate_size (`int`, *optional*, defaults to 11008):
49
- Dimension of the MLP representations.
50
- num_hidden_layers (`int`, *optional*, defaults to 32):
51
- Number of hidden layers in the Transformer decoder.
52
- num_attention_heads (`int`, *optional*, defaults to 32):
53
- Number of attention heads for each attention layer in the Transformer decoder.
54
- num_key_value_heads (`int`, *optional*):
55
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
- by meanpooling all the original heads within that group. For more details checkout [this
60
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
- `num_attention_heads`.
62
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
63
- The non-linear activation function (function or string) in the decoder.
64
- max_position_embeddings (`int`, *optional*, defaults to 2048):
65
- The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
66
- Llama 2 up to 4096, CodeLlama up to 16384.
67
- initializer_range (`float`, *optional*, defaults to 0.02):
68
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
70
- The epsilon used by the rms normalization layers.
71
- use_cache (`bool`, *optional*, defaults to `True`):
72
- Whether or not the model should return the last key/values attentions (not used by all models). Only
73
- relevant if `config.is_decoder=True`.
74
- pad_token_id (`int`, *optional*):
75
- Padding token id.
76
- bos_token_id (`int`, *optional*, defaults to 1):
77
- Beginning of stream token id.
78
- eos_token_id (`int`, *optional*, defaults to 2):
79
- End of stream token id.
80
- pretraining_tp (`int`, *optional*, defaults to 1):
81
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
82
- document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
83
- necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
84
- issue](https://github.com/pytorch/pytorch/issues/76232).
85
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
86
- Whether to tie weight embeddings
87
- rope_theta (`float`, *optional*, defaults to 10000.0):
88
- The base period of the RoPE embeddings.
89
- rope_scaling (`Dict`, *optional*):
90
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
91
- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
92
- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
93
- `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
94
- these scaling strategies behave:
95
- https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
96
- experimental feature, subject to breaking API changes in future versions.
97
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
98
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
99
- attention_dropout (`float`, *optional*, defaults to 0.0):
100
- The dropout ratio for the attention probabilities.
101
-
102
- ```python
103
- >>> from transformers import LlamaModel, LlamaConfig
104
-
105
- >>> # Initializing a LLaMA llama-7b style configuration
106
- >>> configuration = LlamaConfig()
107
-
108
- >>> # Initializing a model from the llama-7b style configuration
109
- >>> model = LlamaModel(configuration)
110
-
111
- >>> # Accessing the model configuration
112
- >>> configuration = model.config
113
- ```"""
114
-
115
- model_type = "llama"
116
- keys_to_ignore_at_inference = ["past_key_values"]
117
-
118
- def __init__(
119
- self,
120
- vocab_size=32000,
121
- hidden_size=4096,
122
- intermediate_size=11008,
123
- num_hidden_layers=32,
124
- num_attention_heads=32,
125
- num_key_value_heads=None,
126
- hidden_act="silu",
127
- max_position_embeddings=2048,
128
- initializer_range=0.02,
129
- rms_norm_eps=1e-6,
130
- use_cache=True,
131
- pad_token_id=None,
132
- bos_token_id=1,
133
- eos_token_id=2,
134
- pretraining_tp=1,
135
- tie_word_embeddings=False,
136
- rope_theta=10000.0,
137
- rope_scaling=None,
138
- attention_bias=False,
139
- attention_dropout=0.0,
140
- **kwargs,
141
- ):
142
- self.vocab_size = vocab_size
143
- self.max_position_embeddings = max_position_embeddings
144
- self.hidden_size = hidden_size
145
- self.intermediate_size = intermediate_size
146
- self.num_hidden_layers = num_hidden_layers
147
- self.num_attention_heads = num_attention_heads
148
-
149
- # for backward compatibility
150
- if num_key_value_heads is None:
151
- num_key_value_heads = num_attention_heads
152
-
153
- self.num_key_value_heads = num_key_value_heads
154
- self.hidden_act = hidden_act
155
- self.initializer_range = initializer_range
156
- self.rms_norm_eps = rms_norm_eps
157
- self.pretraining_tp = pretraining_tp
158
- self.use_cache = use_cache
159
- self.rope_theta = rope_theta
160
- self.rope_scaling = rope_scaling
161
- self._rope_scaling_validation()
162
- self.attention_bias = attention_bias
163
- self.attention_dropout = attention_dropout
164
-
165
- super().__init__(
166
- pad_token_id=pad_token_id,
167
- bos_token_id=bos_token_id,
168
- eos_token_id=eos_token_id,
169
- tie_word_embeddings=tie_word_embeddings,
170
- **kwargs,
171
- )
172
-
173
- def _rope_scaling_validation(self):
174
- """
175
- Validate the `rope_scaling` configuration.
176
- """
177
- if self.rope_scaling is None:
178
- return
179
-
180
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
181
- raise ValueError(
182
- "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
183
- )
184
- rope_scaling_type = self.rope_scaling.get("type", None)
185
- rope_scaling_factor = self.rope_scaling.get("factor", None)
186
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
187
- raise ValueError(
188
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
189
- )
190
- if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
191
- raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/llama/modeling_llama.py DELETED
@@ -1,1844 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """PyTorch LLaMA model."""
21
-
22
- import math
23
- import warnings
24
- from typing import List, Optional, Tuple, Union
25
-
26
- import torch
27
- import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
- from torch import nn
30
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
-
32
- from transformers.activations import ACT2FN
33
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
- # from transformers.modeling_attn_mask_utils import AttentionMaskConverter
35
- from dataclasses import dataclass
36
- @dataclass
37
- class AttentionMaskConverter:
38
- """
39
- A utility attention mask class that allows one to:
40
- - Create a causal 4d mask
41
- - Create a causal 4d mask with slided window
42
- - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
43
- key_value_length) that can be multiplied with attention scores
44
-
45
- Examples:
46
-
47
- ```python
48
- >>> import torch
49
- >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
50
-
51
- >>> converter = AttentionMaskConverter(True)
52
- >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
53
- tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
54
- [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
55
- [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
56
- [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
57
- [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
58
- ```
59
-
60
- Parameters:
61
- is_causal (`bool`):
62
- Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
63
-
64
- sliding_window (`int`, *optional*):
65
- Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
66
- """
67
-
68
- is_causal: bool
69
- sliding_window: int
70
-
71
- def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
72
- self.is_causal = is_causal
73
- self.sliding_window = sliding_window
74
-
75
- if self.sliding_window is not None and self.sliding_window <= 0:
76
- raise ValueError(
77
- f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
78
- )
79
-
80
- def to_causal_4d(
81
- self,
82
- batch_size: int,
83
- query_length: int,
84
- key_value_length: int,
85
- dtype: torch.dtype,
86
- device: Union[torch.device, "str"] = "cpu",
87
- ) -> Optional[torch.Tensor]:
88
- """
89
- Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
90
- bias to upper right hand triangular matrix (causal mask).
91
- """
92
- if not self.is_causal:
93
- raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
94
-
95
- # If shape is not cached, create a new causal mask and cache it
96
- input_shape = (batch_size, query_length)
97
- past_key_values_length = key_value_length - query_length
98
-
99
- # create causal mask
100
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
101
- causal_4d_mask = None
102
- if input_shape[-1] > 1 or self.sliding_window is not None:
103
- causal_4d_mask = self._make_causal_mask(
104
- input_shape,
105
- dtype,
106
- device=device,
107
- past_key_values_length=past_key_values_length,
108
- sliding_window=self.sliding_window,
109
- )
110
-
111
- return causal_4d_mask
112
-
113
- def to_4d(
114
- self,
115
- attention_mask_2d: torch.Tensor,
116
- query_length: int,
117
- dtype: torch.dtype,
118
- key_value_length: Optional[int] = None,
119
- ) -> torch.Tensor:
120
- """
121
- Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
122
- key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
123
- causal, a causal mask will be added.
124
- """
125
- input_shape = (attention_mask_2d.shape[0], query_length)
126
-
127
- # create causal mask
128
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
129
- causal_4d_mask = None
130
- if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
131
- if key_value_length is None:
132
- raise ValueError(
133
- "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
134
- )
135
-
136
- past_key_values_length = key_value_length - query_length
137
- causal_4d_mask = self._make_causal_mask(
138
- input_shape,
139
- dtype,
140
- device=attention_mask_2d.device,
141
- past_key_values_length=past_key_values_length,
142
- sliding_window=self.sliding_window,
143
- )
144
- elif self.sliding_window is not None:
145
- raise NotImplementedError("Sliding window is currently only implemented for causal masking")
146
-
147
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
148
- expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
149
- attention_mask_2d.device
150
- )
151
-
152
- if causal_4d_mask is not None:
153
- expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
154
-
155
- # expanded_attn_mask + causal_4d_mask can cause some overflow
156
- expanded_4d_mask = expanded_attn_mask
157
-
158
- return expanded_4d_mask
159
-
160
- @staticmethod
161
- def _make_causal_mask(
162
- input_ids_shape: torch.Size,
163
- dtype: torch.dtype,
164
- device: torch.device,
165
- past_key_values_length: int = 0,
166
- sliding_window: Optional[int] = None,
167
- ):
168
- """
169
- Make causal mask used for bi-directional self-attention.
170
- """
171
- bsz, tgt_len = input_ids_shape
172
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
173
- mask_cond = torch.arange(mask.size(-1), device=device)
174
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
175
-
176
- mask = mask.to(dtype)
177
-
178
- if past_key_values_length > 0:
179
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
180
-
181
- # add lower triangular sliding window mask if necessary
182
- if sliding_window is not None:
183
- diagonal = past_key_values_length - sliding_window - 1
184
-
185
- context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
186
- mask.masked_fill_(context_mask, torch.finfo(dtype).min)
187
-
188
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
189
-
190
- @staticmethod
191
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
192
- """
193
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
194
- """
195
- bsz, src_len = mask.size()
196
- tgt_len = tgt_len if tgt_len is not None else src_len
197
-
198
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
199
-
200
- inverted_mask = 1.0 - expanded_mask
201
-
202
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
203
-
204
- @staticmethod
205
- def _unmask_unattended(
206
- expanded_mask: torch.FloatTensor,
207
- min_dtype: float,
208
- ):
209
- # fmt: off
210
- """
211
- Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
212
- using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
213
- Details: https://github.com/pytorch/pytorch/issues/110213
214
-
215
- `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
216
- `attention_mask` is [bsz, src_seq_len].
217
-
218
- The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
219
-
220
- For example, if `expanded_mask` is (e.g. here left-padding case)
221
- ```
222
- [[[[0, 0, 0],
223
- [0, 0, 0],
224
- [0, 0, 1]]],
225
- [[[1, 0, 0],
226
- [1, 1, 0],
227
- [1, 1, 1]]],
228
- [[[0, 0, 0],
229
- [0, 1, 0],
230
- [0, 1, 1]]]]
231
- ```
232
- then the modified `expanded_mask` will be
233
- ```
234
- [[[[1, 1, 1], <-- modified
235
- [1, 1, 1], <-- modified
236
- [0, 0, 1]]],
237
- [[[1, 0, 0],
238
- [1, 1, 0],
239
- [1, 1, 1]]],
240
- [[[1, 1, 1], <-- modified
241
- [0, 1, 0],
242
- [0, 1, 1]]]]
243
- ```
244
- """
245
- # fmt: on
246
- if expanded_mask.dtype == torch.bool:
247
- raise ValueError(
248
- "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
249
- )
250
-
251
- return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
252
-
253
- @staticmethod
254
- def _ignore_causal_mask_sdpa(
255
- attention_mask: Optional[torch.Tensor],
256
- inputs_embeds: torch.Tensor,
257
- past_key_values_length: int,
258
- sliding_window: Optional[int] = None,
259
- ) -> bool:
260
- """
261
- Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
262
-
263
- In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
264
- `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
265
- allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
266
- """
267
-
268
- batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
269
- key_value_length = query_length + past_key_values_length
270
-
271
- is_tracing = (
272
- torch.jit.is_tracing()
273
- or isinstance(inputs_embeds, torch.fx.Proxy)
274
- or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
275
- )
276
-
277
- ignore_causal_mask = False
278
-
279
- if attention_mask is None:
280
- # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
281
- # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
282
- # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag.
283
- #
284
- # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`).
285
- if (
286
- not is_tracing
287
- and (query_length == 1 or key_value_length == query_length)
288
- and (sliding_window is None or key_value_length < sliding_window)
289
- ):
290
- ignore_causal_mask = True
291
- elif sliding_window is None or key_value_length < sliding_window:
292
- if len(attention_mask.shape) == 4:
293
- expected_shape = (batch_size, 1, query_length, key_value_length)
294
- if tuple(attention_mask.shape) != expected_shape:
295
- raise ValueError(
296
- f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
297
- )
298
- elif not is_tracing and torch.all(attention_mask == 1):
299
- if query_length == 1 or key_value_length == query_length:
300
- # For query_length == 1, causal attention and bi-directional attention are the same.
301
- ignore_causal_mask = True
302
-
303
- # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
304
- # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
305
- # Reference: https://github.com/pytorch/pytorch/issues/108108
306
- # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
307
-
308
- return ignore_causal_mask
309
-
310
-
311
- from transformers.modeling_outputs import (
312
- BaseModelOutputWithPast,
313
- CausalLMOutputWithPast,
314
- QuestionAnsweringModelOutput,
315
- SequenceClassifierOutputWithPast,
316
- )
317
- from transformers.modeling_utils import PreTrainedModel
318
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
319
- from transformers.utils import (
320
- add_start_docstrings,
321
- add_start_docstrings_to_model_forward,
322
- is_flash_attn_2_available,
323
- is_flash_attn_greater_or_equal_2_10,
324
- logging,
325
- replace_return_docstrings,
326
- )
327
- from .configuration_llama import LlamaConfig
328
-
329
-
330
- if is_flash_attn_2_available():
331
- from flash_attn import flash_attn_func, flash_attn_varlen_func
332
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
333
-
334
-
335
- logger = logging.get_logger(__name__)
336
-
337
- _CONFIG_FOR_DOC = "LlamaConfig"
338
-
339
-
340
- def _get_unpad_data(attention_mask):
341
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
342
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
343
- max_seqlen_in_batch = seqlens_in_batch.max().item()
344
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
345
- return (
346
- indices,
347
- cu_seqlens,
348
- max_seqlen_in_batch,
349
- )
350
-
351
-
352
- class LlamaRMSNorm(nn.Module):
353
- def __init__(self, hidden_size, eps=1e-6):
354
- """
355
- LlamaRMSNorm is equivalent to T5LayerNorm
356
- """
357
- super().__init__()
358
- self.weight = nn.Parameter(torch.ones(hidden_size))
359
- self.variance_epsilon = eps
360
-
361
- def forward(self, hidden_states):
362
- input_dtype = hidden_states.dtype
363
- hidden_states = hidden_states.to(torch.float32)
364
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
365
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
366
- return self.weight * hidden_states.to(input_dtype)
367
-
368
-
369
- ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
370
-
371
-
372
- class LlamaRotaryEmbedding(nn.Module):
373
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
374
- super().__init__()
375
- self.scaling_factor = scaling_factor
376
- self.dim = dim
377
- self.max_position_embeddings = max_position_embeddings
378
- self.base = base
379
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
380
- self.register_buffer("inv_freq", inv_freq, persistent=False)
381
- # For BC we register cos and sin cached
382
- self.max_seq_len_cached = max_position_embeddings
383
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
384
- t = t / self.scaling_factor
385
- freqs = torch.outer(t, self.inv_freq)
386
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
387
- emb = torch.cat((freqs, freqs), dim=-1)
388
- self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
389
- self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
390
-
391
- @property
392
- def sin_cached(self):
393
- logger.warning_once(
394
- "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
395
- "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
396
- )
397
- return self._sin_cached
398
-
399
- @property
400
- def cos_cached(self):
401
- logger.warning_once(
402
- "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
403
- "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
404
- )
405
- return self._cos_cached
406
-
407
- @torch.no_grad()
408
- def forward(self, x, position_ids):
409
- # x: [bs, num_attention_heads, seq_len, head_size]
410
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
411
- position_ids_expanded = position_ids[:, None, :].float()
412
- # Force float32 since bfloat16 loses precision on long contexts
413
- # See https://github.com/huggingface/transformers/pull/29285
414
- device_type = x.device.type
415
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
416
- with torch.autocast(device_type=device_type, enabled=False):
417
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
418
- emb = torch.cat((freqs, freqs), dim=-1)
419
- cos = emb.cos()
420
- sin = emb.sin()
421
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
422
-
423
-
424
- class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
425
- """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
426
-
427
- def forward(self, x, position_ids):
428
- # difference to the original RoPE: a scaling factor is aplied to the position ids
429
- position_ids = position_ids.float() / self.scaling_factor
430
- cos, sin = super().forward(x, position_ids)
431
- return cos, sin
432
-
433
-
434
- class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
435
- """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
436
-
437
- def forward(self, x, position_ids):
438
- # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
439
- seq_len = torch.max(position_ids) + 1
440
- if seq_len > self.max_position_embeddings:
441
- base = self.base * (
442
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
443
- ) ** (self.dim / (self.dim - 2))
444
- inv_freq = 1.0 / (
445
- base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
446
- )
447
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
448
-
449
- cos, sin = super().forward(x, position_ids)
450
- return cos, sin
451
-
452
-
453
- def rotate_half(x):
454
- """Rotates half the hidden dims of the input."""
455
- x1 = x[..., : x.shape[-1] // 2]
456
- x2 = x[..., x.shape[-1] // 2 :]
457
- return torch.cat((-x2, x1), dim=-1)
458
-
459
-
460
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
461
- """Applies Rotary Position Embedding to the query and key tensors.
462
-
463
- Args:
464
- q (`torch.Tensor`): The query tensor.
465
- k (`torch.Tensor`): The key tensor.
466
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
467
- sin (`torch.Tensor`): The sine part of the rotary embedding.
468
- position_ids (`torch.Tensor`, *optional*):
469
- Deprecated and unused.
470
- unsqueeze_dim (`int`, *optional*, defaults to 1):
471
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
472
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
473
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
474
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
475
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
476
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
477
- Returns:
478
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
479
- """
480
- cos = cos.unsqueeze(unsqueeze_dim)
481
- sin = sin.unsqueeze(unsqueeze_dim)
482
- q_embed = (q * cos) + (rotate_half(q) * sin)
483
- k_embed = (k * cos) + (rotate_half(k) * sin)
484
- return q_embed, k_embed
485
-
486
-
487
- class LlamaMLP(nn.Module):
488
- def __init__(self, config):
489
- super().__init__()
490
- self.config = config
491
- self.hidden_size = config.hidden_size
492
- self.intermediate_size = config.intermediate_size
493
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
494
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
495
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
496
- self.act_fn = ACT2FN[config.hidden_act]
497
-
498
- def forward(self, x):
499
- if self.config.pretraining_tp > 1:
500
- slice = self.intermediate_size // self.config.pretraining_tp
501
- gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
502
- up_proj_slices = self.up_proj.weight.split(slice, dim=0)
503
- down_proj_slices = self.down_proj.weight.split(slice, dim=1)
504
-
505
- gate_proj = torch.cat(
506
- [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
507
- )
508
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
509
-
510
- intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
511
- down_proj = [
512
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
513
- ]
514
- down_proj = sum(down_proj)
515
- else:
516
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
517
-
518
- return down_proj
519
-
520
-
521
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
522
- """
523
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
524
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
525
- """
526
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
527
- if n_rep == 1:
528
- return hidden_states
529
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
530
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
531
-
532
-
533
- class LlamaAttention(nn.Module):
534
- """Multi-headed attention from 'Attention Is All You Need' paper"""
535
-
536
- def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
537
- super().__init__()
538
- self.config = config
539
- self.layer_idx = layer_idx
540
- if layer_idx is None:
541
- logger.warning_once(
542
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
543
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
544
- "when creating this class."
545
- )
546
-
547
- self.attention_dropout = config.attention_dropout
548
- self.hidden_size = config.hidden_size
549
- self.num_heads = config.num_attention_heads
550
- self.head_dim = self.hidden_size // self.num_heads
551
- self.num_key_value_heads = config.num_key_value_heads
552
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
553
- self.max_position_embeddings = config.max_position_embeddings
554
- self.rope_theta = config.rope_theta
555
- self.is_causal = True
556
-
557
- if (self.head_dim * self.num_heads) != self.hidden_size:
558
- raise ValueError(
559
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
560
- f" and `num_heads`: {self.num_heads})."
561
- )
562
-
563
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
564
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
565
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
566
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
567
- self._init_rope()
568
-
569
- def _init_rope(self):
570
- if self.config.rope_scaling is None:
571
- self.rotary_emb = LlamaRotaryEmbedding(
572
- self.head_dim,
573
- max_position_embeddings=self.max_position_embeddings,
574
- base=self.rope_theta,
575
- )
576
- else:
577
- scaling_type = self.config.rope_scaling["type"]
578
- scaling_factor = self.config.rope_scaling["factor"]
579
- if scaling_type == "linear":
580
- self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
581
- self.head_dim,
582
- max_position_embeddings=self.max_position_embeddings,
583
- scaling_factor=scaling_factor,
584
- base=self.rope_theta,
585
- )
586
- elif scaling_type == "dynamic":
587
- self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
588
- self.head_dim,
589
- max_position_embeddings=self.max_position_embeddings,
590
- scaling_factor=scaling_factor,
591
- base=self.rope_theta,
592
- )
593
- else:
594
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
595
-
596
- def forward(
597
- self,
598
- hidden_states: torch.Tensor,
599
- attention_mask: Optional[torch.Tensor] = None,
600
- position_ids: Optional[torch.LongTensor] = None,
601
- past_key_value: Optional[Cache] = None,
602
- output_attentions: bool = False,
603
- use_cache: bool = False,
604
- cache_position: Optional[torch.LongTensor] = None,
605
- **kwargs,
606
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
607
- bsz, q_len, _ = hidden_states.size()
608
-
609
- if self.config.pretraining_tp > 1:
610
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
611
- query_slices = self.q_proj.weight.split(
612
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
613
- )
614
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
615
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
616
-
617
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
618
- query_states = torch.cat(query_states, dim=-1)
619
-
620
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
621
- key_states = torch.cat(key_states, dim=-1)
622
-
623
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
624
- value_states = torch.cat(value_states, dim=-1)
625
-
626
- else:
627
- query_states = self.q_proj(hidden_states)
628
- key_states = self.k_proj(hidden_states)
629
- value_states = self.v_proj(hidden_states)
630
-
631
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
632
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
633
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
634
-
635
- past_key_value = getattr(self, "past_key_value", past_key_value)
636
- cos, sin = self.rotary_emb(value_states, position_ids)
637
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
638
-
639
- if past_key_value is not None:
640
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
641
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
642
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
643
-
644
- key_states = repeat_kv(key_states, self.num_key_value_groups)
645
- value_states = repeat_kv(value_states, self.num_key_value_groups)
646
-
647
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
648
-
649
- if attention_mask is not None: # no matter the length, we just slice it
650
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
651
- attn_weights = attn_weights + causal_mask
652
-
653
- # upcast attention to fp32
654
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
655
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
656
- attn_output = torch.matmul(attn_weights, value_states)
657
-
658
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
659
- raise ValueError(
660
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
661
- f" {attn_output.size()}"
662
- )
663
-
664
- attn_output = attn_output.transpose(1, 2).contiguous()
665
-
666
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
667
-
668
- if self.config.pretraining_tp > 1:
669
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
670
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
671
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
672
- else:
673
- attn_output = self.o_proj(attn_output)
674
-
675
- if not output_attentions:
676
- attn_weights = None
677
-
678
- return attn_output, attn_weights, past_key_value
679
-
680
-
681
- class LlamaFlashAttention2(LlamaAttention):
682
- """
683
- Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
684
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
685
- flash attention and deal with padding tokens in case the input contains any of them.
686
- """
687
-
688
- def __init__(self, *args, **kwargs):
689
- super().__init__(*args, **kwargs)
690
-
691
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
692
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
693
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
694
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
695
-
696
- def forward(
697
- self,
698
- hidden_states: torch.Tensor,
699
- attention_mask: Optional[torch.LongTensor] = None,
700
- position_ids: Optional[torch.LongTensor] = None,
701
- past_key_value: Optional[Cache] = None,
702
- output_attentions: bool = False,
703
- use_cache: bool = False,
704
- cache_position: Optional[torch.LongTensor] = None,
705
- **kwargs,
706
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
707
- output_attentions = False
708
-
709
- bsz, q_len, _ = hidden_states.size()
710
-
711
- query_states = self.q_proj(hidden_states)
712
- key_states = self.k_proj(hidden_states)
713
- value_states = self.v_proj(hidden_states)
714
-
715
- # Flash attention requires the input to have the shape
716
- # batch_size x seq_length x head_dim x hidden_dim
717
- # therefore we just need to keep the original shape
718
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
719
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
720
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
721
-
722
- cos, sin = self.rotary_emb(value_states, position_ids)
723
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
724
-
725
- past_key_value = getattr(self, "past_key_value", past_key_value)
726
-
727
- if past_key_value is not None:
728
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
729
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
730
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
731
-
732
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
733
- # to be able to avoid many of these transpose/reshape/view.
734
- query_states = query_states.transpose(1, 2)
735
- key_states = key_states.transpose(1, 2)
736
- value_states = value_states.transpose(1, 2)
737
-
738
- dropout_rate = self.attention_dropout if self.training else 0.0
739
-
740
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
741
- # therefore the input hidden states gets silently casted in float32. Hence, we need
742
- # cast them back in the correct dtype just to be sure everything works as expected.
743
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
744
- # in fp32. (LlamaRMSNorm handles it correctly)
745
-
746
- input_dtype = query_states.dtype
747
- if input_dtype == torch.float32:
748
- if torch.is_autocast_enabled():
749
- target_dtype = torch.get_autocast_gpu_dtype()
750
- # Handle the case where the model is quantized
751
- elif hasattr(self.config, "_pre_quantization_dtype"):
752
- target_dtype = self.config._pre_quantization_dtype
753
- else:
754
- target_dtype = self.q_proj.weight.dtype
755
-
756
- logger.warning_once(
757
- f"The input hidden states seems to be silently casted in float32, this might be related to"
758
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
759
- f" {target_dtype}."
760
- )
761
-
762
- query_states = query_states.to(target_dtype)
763
- key_states = key_states.to(target_dtype)
764
- value_states = value_states.to(target_dtype)
765
-
766
- attn_output = self._flash_attention_forward(
767
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
768
- )
769
-
770
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
771
- attn_output = self.o_proj(attn_output)
772
-
773
- if not output_attentions:
774
- attn_weights = None
775
-
776
- return attn_output, attn_weights, past_key_value
777
-
778
- def _flash_attention_forward(
779
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
780
- ):
781
- """
782
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
783
- first unpad the input, then computes the attention scores and pad the final attention scores.
784
-
785
- Args:
786
- query_states (`torch.Tensor`):
787
- Input query states to be passed to Flash Attention API
788
- key_states (`torch.Tensor`):
789
- Input key states to be passed to Flash Attention API
790
- value_states (`torch.Tensor`):
791
- Input value states to be passed to Flash Attention API
792
- attention_mask (`torch.Tensor`):
793
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
794
- position of padding tokens and 1 for the position of non-padding tokens.
795
- dropout (`float`):
796
- Attention dropout
797
- softmax_scale (`float`, *optional*):
798
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
799
- """
800
- if not self._flash_attn_uses_top_left_mask:
801
- causal = self.is_causal
802
- else:
803
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
804
- causal = self.is_causal and query_length != 1
805
-
806
- # Contains at least one padding token in the sequence
807
- if attention_mask is not None:
808
- batch_size = query_states.shape[0]
809
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
810
- query_states, key_states, value_states, attention_mask, query_length
811
- )
812
-
813
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
814
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
815
-
816
- attn_output_unpad = flash_attn_varlen_func(
817
- query_states,
818
- key_states,
819
- value_states,
820
- cu_seqlens_q=cu_seqlens_q,
821
- cu_seqlens_k=cu_seqlens_k,
822
- max_seqlen_q=max_seqlen_in_batch_q,
823
- max_seqlen_k=max_seqlen_in_batch_k,
824
- dropout_p=dropout,
825
- softmax_scale=softmax_scale,
826
- causal=causal,
827
- )
828
-
829
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
830
- else:
831
- attn_output = flash_attn_func(
832
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
833
- )
834
-
835
- return attn_output
836
-
837
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
838
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
839
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
840
-
841
- key_layer = index_first_axis(
842
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
843
- )
844
- value_layer = index_first_axis(
845
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
846
- )
847
- if query_length == kv_seq_len:
848
- query_layer = index_first_axis(
849
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
850
- )
851
- cu_seqlens_q = cu_seqlens_k
852
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
853
- indices_q = indices_k
854
- elif query_length == 1:
855
- max_seqlen_in_batch_q = 1
856
- cu_seqlens_q = torch.arange(
857
- batch_size + 1, dtype=torch.int32, device=query_layer.device
858
- ) # There is a memcpy here, that is very bad.
859
- indices_q = cu_seqlens_q[:-1]
860
- query_layer = query_layer.squeeze(1)
861
- else:
862
- # The -q_len: slice assumes left padding.
863
- attention_mask = attention_mask[:, -query_length:]
864
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
865
-
866
- return (
867
- query_layer,
868
- key_layer,
869
- value_layer,
870
- indices_q,
871
- (cu_seqlens_q, cu_seqlens_k),
872
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
873
- )
874
-
875
-
876
- class LlamaSdpaAttention(LlamaAttention):
877
- """
878
- Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
879
- `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
880
- SDPA API.
881
- """
882
-
883
- # Adapted from LlamaAttention.forward
884
- def forward(
885
- self,
886
- hidden_states: torch.Tensor,
887
- attention_mask: Optional[torch.Tensor] = None,
888
- position_ids: Optional[torch.LongTensor] = None,
889
- past_key_value: Optional[Cache] = None,
890
- output_attentions: bool = False,
891
- use_cache: bool = False,
892
- cache_position: Optional[torch.LongTensor] = None,
893
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
894
- if output_attentions:
895
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
896
- logger.warning_once(
897
- "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
898
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
899
- )
900
- return super().forward(
901
- hidden_states=hidden_states,
902
- attention_mask=attention_mask,
903
- position_ids=position_ids,
904
- past_key_value=past_key_value,
905
- output_attentions=output_attentions,
906
- use_cache=use_cache,
907
- cache_position=cache_position,
908
- )
909
-
910
- bsz, q_len, _ = hidden_states.size()
911
-
912
- query_states = self.q_proj(hidden_states)
913
- key_states = self.k_proj(hidden_states)
914
- value_states = self.v_proj(hidden_states)
915
-
916
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
917
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
918
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
919
-
920
- cos, sin = self.rotary_emb(value_states, position_ids)
921
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
922
-
923
- # In case static cache is used, it is an instance attribute.
924
- past_key_value = getattr(self, "past_key_value", past_key_value)
925
-
926
- if past_key_value is not None:
927
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
928
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
929
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
930
-
931
- key_states = repeat_kv(key_states, self.num_key_value_groups)
932
- value_states = repeat_kv(value_states, self.num_key_value_groups)
933
-
934
- causal_mask = attention_mask
935
- if attention_mask is not None:
936
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
937
-
938
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
939
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
940
- if query_states.device.type == "cuda" and causal_mask is not None:
941
- query_states = query_states.contiguous()
942
- key_states = key_states.contiguous()
943
- value_states = value_states.contiguous()
944
-
945
- # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
946
- # relying on the `is_causal` argument.
947
- attn_output = torch.nn.functional.scaled_dot_product_attention(
948
- query_states,
949
- key_states,
950
- value_states,
951
- attn_mask=causal_mask,
952
- dropout_p=self.attention_dropout if self.training else 0.0,
953
- is_causal=causal_mask is None and q_len > 1,
954
- )
955
-
956
- attn_output = attn_output.transpose(1, 2).contiguous()
957
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
958
-
959
- attn_output = self.o_proj(attn_output)
960
-
961
- return attn_output, None, past_key_value
962
-
963
-
964
- LLAMA_ATTENTION_CLASSES = {
965
- "eager": LlamaAttention,
966
- "flash_attention_2": LlamaFlashAttention2,
967
- "sdpa": LlamaSdpaAttention,
968
- }
969
-
970
-
971
- class LlamaDecoderLayer(nn.Module):
972
- def __init__(self, config: LlamaConfig, layer_idx: int):
973
- super().__init__()
974
- self.hidden_size = config.hidden_size
975
-
976
- self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
977
-
978
- self.mlp = LlamaMLP(config)
979
- self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
980
- self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
981
-
982
- def forward(
983
- self,
984
- hidden_states: torch.Tensor,
985
- attention_mask: Optional[torch.Tensor] = None,
986
- position_ids: Optional[torch.LongTensor] = None,
987
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
988
- output_attentions: Optional[bool] = False,
989
- use_cache: Optional[bool] = False,
990
- cache_position: Optional[torch.LongTensor] = None,
991
- **kwargs,
992
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
993
- """
994
- Args:
995
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
996
- attention_mask (`torch.FloatTensor`, *optional*):
997
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
998
- query_sequence_length, key_sequence_length)` if default attention is used.
999
- output_attentions (`bool`, *optional*):
1000
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1001
- returned tensors for more detail.
1002
- use_cache (`bool`, *optional*):
1003
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1004
- (see `past_key_values`).
1005
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1006
- """
1007
- if "padding_mask" in kwargs:
1008
- warnings.warn(
1009
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1010
- )
1011
-
1012
- residual = hidden_states
1013
-
1014
- hidden_states = self.input_layernorm(hidden_states)
1015
-
1016
- # Self Attention
1017
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
1018
- hidden_states=hidden_states,
1019
- attention_mask=attention_mask,
1020
- position_ids=position_ids,
1021
- past_key_value=past_key_value,
1022
- output_attentions=output_attentions,
1023
- use_cache=use_cache,
1024
- cache_position=cache_position,
1025
- **kwargs,
1026
- )
1027
- hidden_states = residual + hidden_states
1028
-
1029
- # Fully Connected
1030
- residual = hidden_states
1031
- hidden_states = self.post_attention_layernorm(hidden_states)
1032
- hidden_states = self.mlp(hidden_states)
1033
- hidden_states = residual + hidden_states
1034
-
1035
- outputs = (hidden_states,)
1036
-
1037
- if output_attentions:
1038
- outputs += (self_attn_weights,)
1039
-
1040
- if use_cache:
1041
- outputs += (present_key_value,)
1042
-
1043
- return outputs
1044
-
1045
-
1046
- LLAMA_START_DOCSTRING = r"""
1047
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1048
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1049
- etc.)
1050
-
1051
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1052
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1053
- and behavior.
1054
-
1055
- Parameters:
1056
- config ([`LlamaConfig`]):
1057
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1058
- load the weights associated with the model, only the configuration. Check out the
1059
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1060
- """
1061
-
1062
-
1063
- @add_start_docstrings(
1064
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1065
- LLAMA_START_DOCSTRING,
1066
- )
1067
- class LlamaPreTrainedModel(PreTrainedModel):
1068
- config_class = LlamaConfig
1069
- base_model_prefix = "model"
1070
- supports_gradient_checkpointing = True
1071
- _no_split_modules = ["LlamaDecoderLayer"]
1072
- _skip_keys_device_placement = ["past_key_values"]
1073
- _supports_flash_attn_2 = True
1074
- _supports_sdpa = True
1075
- _supports_cache_class = True
1076
-
1077
- def _init_weights(self, module):
1078
- std = self.config.initializer_range
1079
- if isinstance(module, nn.Linear):
1080
- module.weight.data.normal_(mean=0.0, std=std)
1081
- if module.bias is not None:
1082
- module.bias.data.zero_()
1083
- elif isinstance(module, nn.Embedding):
1084
- module.weight.data.normal_(mean=0.0, std=std)
1085
- if module.padding_idx is not None:
1086
- module.weight.data[module.padding_idx].zero_()
1087
-
1088
- def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
1089
- if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
1090
- raise ValueError(
1091
- "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
1092
- "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
1093
- )
1094
-
1095
- for layer in self.model.layers:
1096
- device = layer.input_layernorm.weight.device
1097
- if hasattr(self.config, "_pre_quantization_dtype"):
1098
- dtype = self.config._pre_quantization_dtype
1099
- else:
1100
- dtype = layer.self_attn.o_proj.weight.dtype
1101
- layer.self_attn.past_key_value = cache_cls(
1102
- self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
1103
- )
1104
-
1105
- def _reset_cache(self):
1106
- for layer in self.model.layers:
1107
- layer.self_attn.past_key_value = None
1108
-
1109
-
1110
- LLAMA_INPUTS_DOCSTRING = r"""
1111
- Args:
1112
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1113
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1114
- it.
1115
-
1116
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1117
- [`PreTrainedTokenizer.__call__`] for details.
1118
-
1119
- [What are input IDs?](../glossary#input-ids)
1120
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1121
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1122
-
1123
- - 1 for tokens that are **not masked**,
1124
- - 0 for tokens that are **masked**.
1125
-
1126
- [What are attention masks?](../glossary#attention-mask)
1127
-
1128
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1129
- [`PreTrainedTokenizer.__call__`] for details.
1130
-
1131
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1132
- `past_key_values`).
1133
-
1134
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1135
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1136
- information on the default strategy.
1137
-
1138
- - 1 indicates the head is **not masked**,
1139
- - 0 indicates the head is **masked**.
1140
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1141
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1142
- config.n_positions - 1]`.
1143
-
1144
- [What are position IDs?](../glossary#position-ids)
1145
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1146
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1147
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1148
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1149
-
1150
- Two formats are allowed:
1151
- - a [`~cache_utils.Cache`] instance;
1152
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1153
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1154
- cache format.
1155
-
1156
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1157
- legacy cache format will be returned.
1158
-
1159
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1160
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1161
- of shape `(batch_size, sequence_length)`.
1162
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1163
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1164
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1165
- model's internal embedding lookup matrix.
1166
- use_cache (`bool`, *optional*):
1167
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1168
- `past_key_values`).
1169
- output_attentions (`bool`, *optional*):
1170
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1171
- tensors for more detail.
1172
- output_hidden_states (`bool`, *optional*):
1173
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1174
- more detail.
1175
- return_dict (`bool`, *optional*):
1176
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1177
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1178
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1179
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1180
- the complete sequence length.
1181
- """
1182
-
1183
-
1184
- @add_start_docstrings(
1185
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
1186
- LLAMA_START_DOCSTRING,
1187
- )
1188
- class LlamaModel(LlamaPreTrainedModel):
1189
- """
1190
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
1191
-
1192
- Args:
1193
- config: LlamaConfig
1194
- """
1195
-
1196
- def __init__(self, config: LlamaConfig):
1197
- super().__init__(config)
1198
- self.padding_idx = config.pad_token_id
1199
- self.vocab_size = config.vocab_size
1200
-
1201
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1202
- self.layers = nn.ModuleList(
1203
- [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1204
- )
1205
- self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1206
- self.gradient_checkpointing = False
1207
-
1208
- # Initialize weights and apply final processing
1209
- self.post_init()
1210
-
1211
- def get_input_embeddings(self):
1212
- return self.embed_tokens
1213
-
1214
- def set_input_embeddings(self, value):
1215
- self.embed_tokens = value
1216
-
1217
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1218
- def forward(
1219
- self,
1220
- input_ids: torch.LongTensor = None,
1221
- attention_mask: Optional[torch.Tensor] = None,
1222
- position_ids: Optional[torch.LongTensor] = None,
1223
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1224
- inputs_embeds: Optional[torch.FloatTensor] = None,
1225
- use_cache: Optional[bool] = None,
1226
- output_attentions: Optional[bool] = None,
1227
- output_hidden_states: Optional[bool] = None,
1228
- return_dict: Optional[bool] = None,
1229
- cache_position: Optional[torch.LongTensor] = None,
1230
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1231
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1232
- output_hidden_states = (
1233
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1234
- )
1235
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1236
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
-
1238
- if (input_ids is None) ^ (inputs_embeds is not None):
1239
- raise ValueError(
1240
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1241
- )
1242
-
1243
- if self.gradient_checkpointing and self.training and use_cache:
1244
- logger.warning_once(
1245
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1246
- )
1247
- use_cache = False
1248
-
1249
- if inputs_embeds is None:
1250
- inputs_embeds = self.embed_tokens(input_ids)
1251
-
1252
- past_seen_tokens = 0
1253
- if use_cache: # kept for BC (cache positions)
1254
- if not isinstance(past_key_values, StaticCache):
1255
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1256
- past_seen_tokens = past_key_values.get_seq_length()
1257
-
1258
- if cache_position is None:
1259
- if isinstance(past_key_values, StaticCache):
1260
- raise ValueError("cache_position is a required argument when using StaticCache.")
1261
- cache_position = torch.arange(
1262
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1263
- )
1264
-
1265
- if position_ids is None:
1266
- position_ids = cache_position.unsqueeze(0)
1267
-
1268
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
1269
-
1270
- # embed positions
1271
- hidden_states = inputs_embeds
1272
-
1273
- # decoder layers
1274
- all_hidden_states = () if output_hidden_states else None
1275
- all_self_attns = () if output_attentions else None
1276
- next_decoder_cache = None
1277
-
1278
- for decoder_layer in self.layers:
1279
- if output_hidden_states:
1280
- all_hidden_states += (hidden_states,)
1281
-
1282
- if self.gradient_checkpointing and self.training:
1283
- layer_outputs = self._gradient_checkpointing_func(
1284
- decoder_layer.__call__,
1285
- hidden_states,
1286
- causal_mask,
1287
- position_ids,
1288
- past_key_values,
1289
- output_attentions,
1290
- use_cache,
1291
- cache_position,
1292
- )
1293
- else:
1294
- layer_outputs = decoder_layer(
1295
- hidden_states,
1296
- attention_mask=causal_mask,
1297
- position_ids=position_ids,
1298
- past_key_value=past_key_values,
1299
- output_attentions=output_attentions,
1300
- use_cache=use_cache,
1301
- cache_position=cache_position,
1302
- )
1303
-
1304
- hidden_states = layer_outputs[0]
1305
-
1306
- if use_cache:
1307
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1308
-
1309
- if output_attentions:
1310
- all_self_attns += (layer_outputs[1],)
1311
-
1312
- hidden_states = self.norm(hidden_states)
1313
-
1314
- # add hidden states from the last decoder layer
1315
- if output_hidden_states:
1316
- all_hidden_states += (hidden_states,)
1317
-
1318
- next_cache = None
1319
- if use_cache:
1320
- next_cache = (
1321
- next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1322
- )
1323
- if not return_dict:
1324
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1325
- return BaseModelOutputWithPast(
1326
- last_hidden_state=hidden_states,
1327
- past_key_values=next_cache,
1328
- hidden_states=all_hidden_states,
1329
- attentions=all_self_attns,
1330
- )
1331
-
1332
- def _update_causal_mask(
1333
- self,
1334
- attention_mask: torch.Tensor,
1335
- input_tensor: torch.Tensor,
1336
- cache_position: torch.Tensor,
1337
- past_seen_tokens: int,
1338
- ):
1339
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1340
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1341
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1342
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1343
-
1344
- if self.config._attn_implementation == "flash_attention_2":
1345
- if attention_mask is not None and 0.0 in attention_mask:
1346
- return attention_mask
1347
- return None
1348
-
1349
- if self.config._attn_implementation == "sdpa":
1350
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
1351
- # in order to dispatch on Flash Attention 2.
1352
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1353
- attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
1354
- ):
1355
- return None
1356
-
1357
- dtype, device = input_tensor.dtype, input_tensor.device
1358
- min_dtype = torch.finfo(dtype).min
1359
- sequence_length = input_tensor.shape[1]
1360
- if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
1361
- target_length = self.config.max_position_embeddings
1362
- else: # dynamic cache
1363
- target_length = (
1364
- attention_mask.shape[-1]
1365
- if isinstance(attention_mask, torch.Tensor)
1366
- else past_seen_tokens + sequence_length + 1
1367
- )
1368
-
1369
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1370
- if sequence_length != 1:
1371
- causal_mask = torch.triu(causal_mask, diagonal=1)
1372
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1373
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1374
- if attention_mask is not None:
1375
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1376
- if attention_mask.dim() == 2:
1377
- mask_length = attention_mask.shape[-1]
1378
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1379
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1380
- elif attention_mask.dim() == 4:
1381
- # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1382
- # cache. In that case, the 4D attention mask attends to the newest tokens only.
1383
- if attention_mask.shape[-2] < cache_position[0] + sequence_length:
1384
- offset = cache_position[0]
1385
- else:
1386
- offset = 0
1387
- mask_shape = attention_mask.shape
1388
- mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1389
- causal_mask[
1390
- : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1391
- ] = mask_slice
1392
-
1393
- if (
1394
- self.config._attn_implementation == "sdpa"
1395
- and attention_mask is not None
1396
- and attention_mask.device.type == "cuda"
1397
- ):
1398
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1399
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1400
- # Details: https://github.com/pytorch/pytorch/issues/110213
1401
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1402
-
1403
- return causal_mask
1404
-
1405
-
1406
- class LlamaForCausalLM(LlamaPreTrainedModel):
1407
- _tied_weights_keys = ["lm_head.weight"]
1408
-
1409
- def __init__(self, config):
1410
- super().__init__(config)
1411
- self.model = LlamaModel(config)
1412
- self.vocab_size = config.vocab_size
1413
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1414
-
1415
- # Initialize weights and apply final processing
1416
- self.post_init()
1417
-
1418
- def get_input_embeddings(self):
1419
- return self.model.embed_tokens
1420
-
1421
- def set_input_embeddings(self, value):
1422
- self.model.embed_tokens = value
1423
-
1424
- def get_output_embeddings(self):
1425
- return self.lm_head
1426
-
1427
- def set_output_embeddings(self, new_embeddings):
1428
- self.lm_head = new_embeddings
1429
-
1430
- def set_decoder(self, decoder):
1431
- self.model = decoder
1432
-
1433
- def get_decoder(self):
1434
- return self.model
1435
-
1436
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1437
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1438
- def forward(
1439
- self,
1440
- input_ids: torch.LongTensor = None,
1441
- attention_mask: Optional[torch.Tensor] = None,
1442
- position_ids: Optional[torch.LongTensor] = None,
1443
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1444
- inputs_embeds: Optional[torch.FloatTensor] = None,
1445
- labels: Optional[torch.LongTensor] = None,
1446
- use_cache: Optional[bool] = None,
1447
- output_attentions: Optional[bool] = None,
1448
- output_hidden_states: Optional[bool] = None,
1449
- return_dict: Optional[bool] = None,
1450
- cache_position: Optional[torch.LongTensor] = None,
1451
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1452
- r"""
1453
- Args:
1454
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1455
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1456
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1457
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1458
-
1459
- Returns:
1460
-
1461
- Example:
1462
-
1463
- ```python
1464
- >>> from transformers import AutoTokenizer, LlamaForCausalLM
1465
-
1466
- >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1467
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1468
-
1469
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1470
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1471
-
1472
- >>> # Generate
1473
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1474
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1475
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1476
- ```"""
1477
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1478
- output_hidden_states = (
1479
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1480
- )
1481
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1482
-
1483
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1484
- outputs = self.model(
1485
- input_ids=input_ids,
1486
- attention_mask=attention_mask,
1487
- position_ids=position_ids,
1488
- past_key_values=past_key_values,
1489
- inputs_embeds=inputs_embeds,
1490
- use_cache=use_cache,
1491
- output_attentions=output_attentions,
1492
- output_hidden_states=output_hidden_states,
1493
- return_dict=return_dict,
1494
- cache_position=cache_position,
1495
- )
1496
-
1497
- hidden_states = outputs[0]
1498
- if self.config.pretraining_tp > 1:
1499
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1500
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1501
- logits = torch.cat(logits, dim=-1)
1502
- else:
1503
- logits = self.lm_head(hidden_states)
1504
- logits = logits.float()
1505
-
1506
- loss = None
1507
- if labels is not None:
1508
- # Shift so that tokens < n predict n
1509
- shift_logits = logits[..., :-1, :].contiguous()
1510
- shift_labels = labels[..., 1:].contiguous()
1511
- # Flatten the tokens
1512
- loss_fct = CrossEntropyLoss()
1513
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1514
- shift_labels = shift_labels.view(-1)
1515
- # Enable model parallelism
1516
- shift_labels = shift_labels.to(shift_logits.device)
1517
- loss = loss_fct(shift_logits, shift_labels)
1518
-
1519
- if not return_dict:
1520
- output = (logits,) + outputs[1:]
1521
- return (loss,) + output if loss is not None else output
1522
-
1523
- return CausalLMOutputWithPast(
1524
- loss=loss,
1525
- logits=logits,
1526
- past_key_values=outputs.past_key_values,
1527
- hidden_states=outputs.hidden_states,
1528
- attentions=outputs.attentions,
1529
- )
1530
-
1531
- def prepare_inputs_for_generation(
1532
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
1533
- ):
1534
- # With static cache, the `past_key_values` is None
1535
- # TODO joao: standardize interface for the different Cache classes and remove of this if
1536
- has_static_cache = False
1537
- if past_key_values is None:
1538
- past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
1539
- has_static_cache = past_key_values is not None
1540
-
1541
- past_length = 0
1542
- if past_key_values is not None:
1543
- if isinstance(past_key_values, Cache):
1544
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1545
- max_cache_length = (
1546
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1547
- if past_key_values.get_max_length() is not None
1548
- else None
1549
- )
1550
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1551
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1552
- else:
1553
- cache_length = past_length = past_key_values[0][0].shape[2]
1554
- max_cache_length = None
1555
-
1556
- # Keep only the unprocessed tokens:
1557
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1558
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1559
- # input)
1560
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1561
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1562
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1563
- # input_ids based on the past_length.
1564
- elif past_length < input_ids.shape[1]:
1565
- input_ids = input_ids[:, past_length:]
1566
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1567
- else:
1568
- remove_prefix_length = input_ids.shape[1] - 1
1569
- input_ids = input_ids[:, remove_prefix_length:]
1570
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1571
- if (
1572
- max_cache_length is not None
1573
- and attention_mask is not None
1574
- and cache_length + input_ids.shape[1] > max_cache_length
1575
- ):
1576
- attention_mask = attention_mask[:, -max_cache_length:]
1577
-
1578
- position_ids = kwargs.get("position_ids", None)
1579
- if attention_mask is not None and position_ids is None:
1580
- # create position_ids on the fly for batch generation
1581
- position_ids = attention_mask.long().cumsum(-1) - 1
1582
- position_ids.masked_fill_(attention_mask == 0, 1)
1583
- if past_key_values:
1584
- position_ids = position_ids[:, -input_ids.shape[1] :]
1585
-
1586
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1587
- if inputs_embeds is not None and past_key_values is None:
1588
- model_inputs = {"inputs_embeds": inputs_embeds}
1589
- else:
1590
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1591
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1592
- # TODO: use `next_tokens` directly instead.
1593
- model_inputs = {"input_ids": input_ids.contiguous()}
1594
-
1595
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1596
- if cache_position is None:
1597
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1598
- else:
1599
- cache_position = cache_position[-input_length:]
1600
-
1601
- if has_static_cache:
1602
- past_key_values = None
1603
-
1604
- model_inputs.update(
1605
- {
1606
- "position_ids": position_ids,
1607
- "cache_position": cache_position,
1608
- "past_key_values": past_key_values,
1609
- "use_cache": kwargs.get("use_cache"),
1610
- "attention_mask": attention_mask,
1611
- }
1612
- )
1613
- return model_inputs
1614
-
1615
- @staticmethod
1616
- def _reorder_cache(past_key_values, beam_idx):
1617
- reordered_past = ()
1618
- for layer_past in past_key_values:
1619
- reordered_past += (
1620
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1621
- )
1622
- return reordered_past
1623
-
1624
-
1625
- @add_start_docstrings(
1626
- """
1627
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
1628
-
1629
- [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1630
- (e.g. GPT-2) do.
1631
-
1632
- Since it does classification on the last token, it requires to know the position of the last token. If a
1633
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1634
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1635
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1636
- each row of the batch).
1637
- """,
1638
- LLAMA_START_DOCSTRING,
1639
- )
1640
- class LlamaForSequenceClassification(LlamaPreTrainedModel):
1641
- def __init__(self, config):
1642
- super().__init__(config)
1643
- self.num_labels = config.num_labels
1644
- self.model = LlamaModel(config)
1645
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1646
-
1647
- # Initialize weights and apply final processing
1648
- self.post_init()
1649
-
1650
- def get_input_embeddings(self):
1651
- return self.model.embed_tokens
1652
-
1653
- def set_input_embeddings(self, value):
1654
- self.model.embed_tokens = value
1655
-
1656
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1657
- def forward(
1658
- self,
1659
- input_ids: torch.LongTensor = None,
1660
- attention_mask: Optional[torch.Tensor] = None,
1661
- position_ids: Optional[torch.LongTensor] = None,
1662
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1663
- inputs_embeds: Optional[torch.FloatTensor] = None,
1664
- labels: Optional[torch.LongTensor] = None,
1665
- use_cache: Optional[bool] = None,
1666
- output_attentions: Optional[bool] = None,
1667
- output_hidden_states: Optional[bool] = None,
1668
- return_dict: Optional[bool] = None,
1669
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1670
- r"""
1671
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1672
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1673
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1674
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1675
- """
1676
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1677
-
1678
- transformer_outputs = self.model(
1679
- input_ids,
1680
- attention_mask=attention_mask,
1681
- position_ids=position_ids,
1682
- past_key_values=past_key_values,
1683
- inputs_embeds=inputs_embeds,
1684
- use_cache=use_cache,
1685
- output_attentions=output_attentions,
1686
- output_hidden_states=output_hidden_states,
1687
- return_dict=return_dict,
1688
- )
1689
- hidden_states = transformer_outputs[0]
1690
- logits = self.score(hidden_states)
1691
-
1692
- if input_ids is not None:
1693
- batch_size = input_ids.shape[0]
1694
- else:
1695
- batch_size = inputs_embeds.shape[0]
1696
-
1697
- if self.config.pad_token_id is None and batch_size != 1:
1698
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1699
- if self.config.pad_token_id is None:
1700
- sequence_lengths = -1
1701
- else:
1702
- if input_ids is not None:
1703
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1704
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1705
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1706
- sequence_lengths = sequence_lengths.to(logits.device)
1707
- else:
1708
- sequence_lengths = -1
1709
-
1710
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1711
-
1712
- loss = None
1713
- if labels is not None:
1714
- labels = labels.to(logits.device)
1715
- if self.config.problem_type is None:
1716
- if self.num_labels == 1:
1717
- self.config.problem_type = "regression"
1718
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1719
- self.config.problem_type = "single_label_classification"
1720
- else:
1721
- self.config.problem_type = "multi_label_classification"
1722
-
1723
- if self.config.problem_type == "regression":
1724
- loss_fct = MSELoss()
1725
- if self.num_labels == 1:
1726
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1727
- else:
1728
- loss = loss_fct(pooled_logits, labels)
1729
- elif self.config.problem_type == "single_label_classification":
1730
- loss_fct = CrossEntropyLoss()
1731
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1732
- elif self.config.problem_type == "multi_label_classification":
1733
- loss_fct = BCEWithLogitsLoss()
1734
- loss = loss_fct(pooled_logits, labels)
1735
- if not return_dict:
1736
- output = (pooled_logits,) + transformer_outputs[1:]
1737
- return ((loss,) + output) if loss is not None else output
1738
-
1739
- return SequenceClassifierOutputWithPast(
1740
- loss=loss,
1741
- logits=pooled_logits,
1742
- past_key_values=transformer_outputs.past_key_values,
1743
- hidden_states=transformer_outputs.hidden_states,
1744
- attentions=transformer_outputs.attentions,
1745
- )
1746
-
1747
-
1748
- @add_start_docstrings(
1749
- """
1750
- The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1751
- SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1752
- """,
1753
- LLAMA_START_DOCSTRING,
1754
- )
1755
- class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1756
- base_model_prefix = "transformer"
1757
-
1758
- # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
1759
- def __init__(self, config):
1760
- super().__init__(config)
1761
- self.transformer = LlamaModel(config)
1762
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
1763
-
1764
- # Initialize weights and apply final processing
1765
- self.post_init()
1766
-
1767
- def get_input_embeddings(self):
1768
- return self.transformer.embed_tokens
1769
-
1770
- def set_input_embeddings(self, value):
1771
- self.transformer.embed_tokens = value
1772
-
1773
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1774
- def forward(
1775
- self,
1776
- input_ids: Optional[torch.LongTensor] = None,
1777
- attention_mask: Optional[torch.FloatTensor] = None,
1778
- position_ids: Optional[torch.LongTensor] = None,
1779
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1780
- inputs_embeds: Optional[torch.FloatTensor] = None,
1781
- start_positions: Optional[torch.LongTensor] = None,
1782
- end_positions: Optional[torch.LongTensor] = None,
1783
- output_attentions: Optional[bool] = None,
1784
- output_hidden_states: Optional[bool] = None,
1785
- return_dict: Optional[bool] = None,
1786
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1787
- r"""
1788
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1789
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1790
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1791
- are not taken into account for computing the loss.
1792
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1793
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1794
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1795
- are not taken into account for computing the loss.
1796
- """
1797
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1798
-
1799
- outputs = self.transformer(
1800
- input_ids,
1801
- attention_mask=attention_mask,
1802
- position_ids=position_ids,
1803
- past_key_values=past_key_values,
1804
- inputs_embeds=inputs_embeds,
1805
- output_attentions=output_attentions,
1806
- output_hidden_states=output_hidden_states,
1807
- return_dict=return_dict,
1808
- )
1809
-
1810
- sequence_output = outputs[0]
1811
-
1812
- logits = self.qa_outputs(sequence_output)
1813
- start_logits, end_logits = logits.split(1, dim=-1)
1814
- start_logits = start_logits.squeeze(-1).contiguous()
1815
- end_logits = end_logits.squeeze(-1).contiguous()
1816
-
1817
- total_loss = None
1818
- if start_positions is not None and end_positions is not None:
1819
- # If we are on multi-GPU, split add a dimension
1820
- if len(start_positions.size()) > 1:
1821
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
1822
- if len(end_positions.size()) > 1:
1823
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
1824
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1825
- ignored_index = start_logits.size(1)
1826
- start_positions = start_positions.clamp(0, ignored_index)
1827
- end_positions = end_positions.clamp(0, ignored_index)
1828
-
1829
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1830
- start_loss = loss_fct(start_logits, start_positions)
1831
- end_loss = loss_fct(end_logits, end_positions)
1832
- total_loss = (start_loss + end_loss) / 2
1833
-
1834
- if not return_dict:
1835
- output = (start_logits, end_logits) + outputs[2:]
1836
- return ((total_loss,) + output) if total_loss is not None else output
1837
-
1838
- return QuestionAnsweringModelOutput(
1839
- loss=total_loss,
1840
- start_logits=start_logits,
1841
- end_logits=end_logits,
1842
- hidden_states=outputs.hidden_states,
1843
- attentions=outputs.attentions,
1844
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/llama/tokenization_llama.py DELETED
@@ -1,471 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
-
21
- """Tokenization classes for LLaMA."""
22
- import os
23
- from shutil import copyfile
24
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
25
-
26
- import sentencepiece as spm
27
-
28
- from transformers.convert_slow_tokenizer import import_protobuf
29
- from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
30
- from transformers.utils import logging
31
-
32
-
33
- if TYPE_CHECKING:
34
- from transformers.tokenization_utils_base import TextInput
35
-
36
- logger = logging.get_logger(__name__)
37
-
38
- VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
39
-
40
- SPIECE_UNDERLINE = "▁"
41
-
42
- B_INST, E_INST = "[INST]", "[/INST]"
43
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
44
-
45
- # fmt: off
46
- DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
47
- answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
48
- that your responses are socially unbiased and positive in nature.
49
-
50
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
51
- correct. If you don't know the answer to a question, please don't share false information."""
52
- # fmt: on
53
-
54
-
55
- class LlamaTokenizer(PreTrainedTokenizer):
56
- """
57
- Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
58
- no padding token in the original model.
59
-
60
- Args:
61
- vocab_file (`str`):
62
- Path to the vocabulary file.
63
- unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
64
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
65
- token instead.
66
- bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
67
- The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
68
- eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
69
- The end of sequence token.
70
- pad_token (`str` or `tokenizers.AddedToken`, *optional*):
71
- A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
72
- attention mechanisms or loss computation.
73
- sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
74
- Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
75
- SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
76
- to set:
77
-
78
- - `enable_sampling`: Enable subword regularization.
79
- - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
80
-
81
- - `nbest_size = {0,1}`: No sampling is performed.
82
- - `nbest_size > 1`: samples from the nbest_size results.
83
- - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
84
- using forward-filtering-and-backward-sampling algorithm.
85
-
86
- - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
87
- BPE-dropout.
88
-
89
- add_bos_token (`bool`, *optional*, defaults to `True`):
90
- Whether or not to add an `bos_token` at the start of sequences.
91
- add_eos_token (`bool`, *optional*, defaults to `False`):
92
- Whether or not to add an `eos_token` at the end of sequences.
93
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
94
- Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
95
- extra spaces.
96
- use_default_system_prompt (`bool`, *optional*, defaults to `False`):
97
- Whether or not the default system prompt for Llama should be used.
98
- spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
99
- Whether or not to add spaces between special tokens.
100
- legacy (`bool`, *optional*):
101
- Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
102
- and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
103
- example:
104
-
105
- - `legacy=True`:
106
- ```python
107
- >>> from transformers import T5Tokenizer
108
-
109
- >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True)
110
- >>> tokenizer.encode("Hello <extra_id_0>.")
111
- [8774, 32099, 3, 5, 1]
112
- ```
113
- - `legacy=False`:
114
- ```python
115
- >>> from transformers import T5Tokenizer
116
-
117
- >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
118
- >>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
119
- [8774, 32099, 5, 1]
120
- ```
121
- Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
122
- add_prefix_space (`bool`, *optional*, defaults to `True`):
123
- Whether or not to add an initial space to the input. This allows to treat the leading word just as any
124
- other word.
125
-
126
- """
127
-
128
- vocab_files_names = VOCAB_FILES_NAMES
129
- model_input_names = ["input_ids", "attention_mask"]
130
-
131
- def __init__(
132
- self,
133
- vocab_file,
134
- unk_token="<unk>",
135
- bos_token="<s>",
136
- eos_token="</s>",
137
- pad_token=None,
138
- sp_model_kwargs: Optional[Dict[str, Any]] = None,
139
- add_bos_token=True,
140
- add_eos_token=False,
141
- clean_up_tokenization_spaces=False,
142
- use_default_system_prompt=False,
143
- spaces_between_special_tokens=False,
144
- legacy=None,
145
- add_prefix_space=True,
146
- **kwargs,
147
- ):
148
- self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
149
- bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
150
- eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
151
- unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
152
- pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
153
-
154
- if legacy is None:
155
- logger.warning_once(
156
- f"You are using the default legacy behaviour of the {self.__class__}. This is"
157
- " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
158
- " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
159
- " means, and thoroughly read the reason why this was added as explained in"
160
- " https://github.com/huggingface/transformers/pull/24565"
161
- )
162
- legacy = True
163
-
164
- self.legacy = legacy
165
- self.vocab_file = vocab_file
166
- self.add_bos_token = add_bos_token
167
- self.add_eos_token = add_eos_token
168
- self.use_default_system_prompt = use_default_system_prompt
169
- self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
170
- self.add_prefix_space = add_prefix_space
171
-
172
- super().__init__(
173
- bos_token=bos_token,
174
- eos_token=eos_token,
175
- unk_token=unk_token,
176
- pad_token=pad_token,
177
- add_bos_token=add_bos_token,
178
- add_eos_token=add_eos_token,
179
- sp_model_kwargs=self.sp_model_kwargs,
180
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
181
- use_default_system_prompt=use_default_system_prompt,
182
- spaces_between_special_tokens=spaces_between_special_tokens,
183
- legacy=legacy,
184
- add_prefix_space=add_prefix_space,
185
- **kwargs,
186
- )
187
-
188
- @property
189
- def unk_token_length(self):
190
- return len(self.sp_model.encode(str(self.unk_token)))
191
-
192
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
193
- def get_spm_processor(self, from_slow=False):
194
- tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
195
- if self.legacy or from_slow: # no dependency on protobuf
196
- tokenizer.Load(self.vocab_file)
197
- return tokenizer
198
-
199
- with open(self.vocab_file, "rb") as f:
200
- sp_model = f.read()
201
- model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
202
- model = model_pb2.ModelProto.FromString(sp_model)
203
- normalizer_spec = model_pb2.NormalizerSpec()
204
- normalizer_spec.add_dummy_prefix = False
205
- model.normalizer_spec.MergeFrom(normalizer_spec)
206
- sp_model = model.SerializeToString()
207
- tokenizer.LoadFromSerializedProto(sp_model)
208
- return tokenizer
209
-
210
- def __getstate__(self):
211
- state = self.__dict__.copy()
212
- state["sp_model"] = None
213
- state["sp_model_proto"] = self.sp_model.serialized_model_proto()
214
- return state
215
-
216
- def __setstate__(self, d):
217
- self.__dict__ = d
218
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
219
- self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
220
-
221
- @property
222
- def vocab_size(self):
223
- """Returns vocab size"""
224
- return self.sp_model.get_piece_size()
225
-
226
- def get_vocab(self):
227
- """Returns vocab as a dict"""
228
- vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
229
- vocab.update(self.added_tokens_encoder)
230
- return vocab
231
-
232
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
233
- def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
234
- """
235
- Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
236
- first token is special.
237
- """
238
- if self.legacy or len(text) == 0:
239
- return super().tokenize(text, **kwargs)
240
-
241
- text = text.replace(SPIECE_UNDERLINE, " ")
242
- if self.add_prefix_space:
243
- text = SPIECE_UNDERLINE + text
244
-
245
- tokens = super().tokenize(text, **kwargs)
246
-
247
- if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
248
- tokens = tokens[1:]
249
- return tokens
250
-
251
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
252
- def _tokenize(self, text, **kwargs):
253
- """
254
- Returns a tokenized string.
255
-
256
- We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
257
- SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
258
- `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
259
- `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
260
- `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
261
- """
262
- tokens = self.sp_model.encode(text, out_type=str)
263
- if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
264
- return tokens
265
-
266
- # 1. Encode string + prefix ex: "<unk> Hey"
267
- tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
268
- # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
269
- return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
270
-
271
- def _convert_token_to_id(self, token):
272
- """Converts a token (str) in an id using the vocab."""
273
- return self.sp_model.piece_to_id(token)
274
-
275
- def _convert_id_to_token(self, index):
276
- """Converts an index (integer) in a token (str) using the vocab."""
277
- token = self.sp_model.IdToPiece(index)
278
- return token
279
-
280
- def convert_tokens_to_string(self, tokens):
281
- """Converts a sequence of tokens (string) in a single string."""
282
- # since we manually add the prefix space, we have to remove it when decoding
283
- if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
284
- tokens[0] = tokens[0][1:]
285
-
286
- current_sub_tokens = []
287
- out_string = ""
288
- prev_is_special = False
289
- for i, token in enumerate(tokens):
290
- # make sure that special tokens are not decoded using sentencepiece model
291
- if token in self.all_special_tokens:
292
- if not prev_is_special and i != 0 and self.legacy:
293
- out_string += " "
294
- out_string += self.sp_model.decode(current_sub_tokens) + token
295
- prev_is_special = True
296
- current_sub_tokens = []
297
- else:
298
- if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
299
- out_string += " "
300
- current_sub_tokens.append(token)
301
- prev_is_special = False
302
- out_string += self.sp_model.decode(current_sub_tokens)
303
- return out_string
304
-
305
- def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
306
- """
307
- Save the vocabulary and special tokens file to a directory.
308
-
309
- Args:
310
- save_directory (`str`):
311
- The directory in which to save the vocabulary.
312
-
313
- Returns:
314
- `Tuple(str)`: Paths to the files saved.
315
- """
316
- if not os.path.isdir(save_directory):
317
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
318
- return
319
- out_vocab_file = os.path.join(
320
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
321
- )
322
-
323
- if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
324
- copyfile(self.vocab_file, out_vocab_file)
325
- elif not os.path.isfile(self.vocab_file):
326
- with open(out_vocab_file, "wb") as fi:
327
- content_spiece_model = self.sp_model.serialized_model_proto()
328
- fi.write(content_spiece_model)
329
-
330
- return (out_vocab_file,)
331
-
332
- def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
333
- bos_token_id = [self.bos_token_id] if self.add_bos_token else []
334
- eos_token_id = [self.eos_token_id] if self.add_eos_token else []
335
-
336
- output = bos_token_id + token_ids_0 + eos_token_id
337
-
338
- if token_ids_1 is not None:
339
- output = output + bos_token_id + token_ids_1 + eos_token_id
340
-
341
- return output
342
-
343
- def get_special_tokens_mask(
344
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
345
- ) -> List[int]:
346
- """
347
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
348
- special tokens using the tokenizer `prepare_for_model` method.
349
-
350
- Args:
351
- token_ids_0 (`List[int]`):
352
- List of IDs.
353
- token_ids_1 (`List[int]`, *optional*):
354
- Optional second list of IDs for sequence pairs.
355
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
356
- Whether or not the token list is already formatted with special tokens for the model.
357
-
358
- Returns:
359
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
360
- """
361
- if already_has_special_tokens:
362
- return super().get_special_tokens_mask(
363
- token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
364
- )
365
-
366
- bos_token_id = [1] if self.add_bos_token else []
367
- eos_token_id = [1] if self.add_eos_token else []
368
-
369
- if token_ids_1 is None:
370
- return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
371
- return (
372
- bos_token_id
373
- + ([0] * len(token_ids_0))
374
- + eos_token_id
375
- + bos_token_id
376
- + ([0] * len(token_ids_1))
377
- + eos_token_id
378
- )
379
-
380
- def create_token_type_ids_from_sequences(
381
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
382
- ) -> List[int]:
383
- """
384
- Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
385
- sequence pair mask has the following format:
386
-
387
- ```
388
- 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
389
- | first sequence | second sequence |
390
- ```
391
-
392
- if token_ids_1 is None, only returns the first portion of the mask (0s).
393
-
394
- Args:
395
- token_ids_0 (`List[int]`):
396
- List of ids.
397
- token_ids_1 (`List[int]`, *optional*):
398
- Optional second list of IDs for sequence pairs.
399
-
400
- Returns:
401
- `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
402
- """
403
- bos_token_id = [self.bos_token_id] if self.add_bos_token else []
404
- eos_token_id = [self.eos_token_id] if self.add_eos_token else []
405
-
406
- output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
407
-
408
- if token_ids_1 is not None:
409
- output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
410
-
411
- return output
412
-
413
- @property
414
- def default_chat_template(self):
415
- """
416
- LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
417
- Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
418
- user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
419
- rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
420
- results in an unusual token ordering when it is present. This template should definitely be changed if you wish
421
- to fine-tune a model with more flexible role ordering!
422
-
423
- The output should look something like:
424
-
425
- <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
426
- <bos>[INST] Prompt [/INST]
427
-
428
- The reference for this chat template is [this code
429
- snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
430
- in the original repository.
431
- """
432
- logger.warning_once(
433
- "\nNo chat template is defined for this tokenizer - using the default template "
434
- f"for the {self.__class__.__name__} class. If the default is not appropriate for "
435
- "your model, please set `tokenizer.chat_template` to an appropriate template. "
436
- "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
437
- )
438
- template = (
439
- "{% if messages[0]['role'] == 'system' %}"
440
- "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
441
- "{% set system_message = messages[0]['content'] %}"
442
- "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
443
- "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
444
- "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
445
- "{% else %}"
446
- "{% set loop_messages = messages %}"
447
- "{% set system_message = false %}"
448
- "{% endif %}"
449
- "{% for message in loop_messages %}" # Loop over all non-system messages
450
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
451
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
452
- "{% endif %}"
453
- "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
454
- "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
455
- "{% else %}"
456
- "{% set content = message['content'] %}"
457
- "{% endif %}"
458
- "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
459
- "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
460
- "{% elif message['role'] == 'system' %}"
461
- "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
462
- "{% elif message['role'] == 'assistant' %}"
463
- "{{ ' ' + content.strip() + ' ' + eos_token }}"
464
- "{% endif %}"
465
- "{% endfor %}"
466
- )
467
- template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
468
- default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
469
- template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
470
-
471
- return template
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/llama/tokenization_llama_fast.py DELETED
@@ -1,281 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2020 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import os
16
- from shutil import copyfile
17
- from typing import Optional, Tuple
18
-
19
- from tokenizers import processors
20
-
21
- from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
- from transformers.utils import is_sentencepiece_available, logging
23
- from transformers.utils.versions import require_version
24
-
25
-
26
- require_version("tokenizers>=0.13.3")
27
-
28
- if is_sentencepiece_available():
29
- from .tokenization_llama import LlamaTokenizer
30
- else:
31
- LlamaTokenizer = None
32
-
33
- logger = logging.get_logger(__name__)
34
- VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
35
-
36
- B_INST, E_INST = "[INST]", "[/INST]"
37
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
38
-
39
- # fmt: off
40
- DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
41
- answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
42
- that your responses are socially unbiased and positive in nature.
43
-
44
- If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
45
- correct. If you don't know the answer to a question, please don't share false information."""
46
- # fmt: on
47
-
48
-
49
- class LlamaTokenizerFast(PreTrainedTokenizerFast):
50
- """
51
- Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
52
-
53
- This uses notably ByteFallback and no normalization.
54
-
55
- ```python
56
- >>> from transformers import LlamaTokenizerFast
57
-
58
- >>> tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
59
- >>> tokenizer.encode("Hello this is a test")
60
- [1, 15043, 445, 338, 263, 1243]
61
- ```
62
-
63
- If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
64
- call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
65
- values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
66
- [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
67
-
68
-
69
- This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
70
- refer to this superclass for more information regarding those methods.
71
-
72
- Args:
73
- vocab_file (`str`, *optional*):
74
- [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
75
- contains the vocabulary necessary to instantiate a tokenizer.
76
- tokenizer_file (`str`, *optional*):
77
- [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
78
- contains everything needed to load the tokenizer.
79
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
80
- Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
81
- extra spaces.
82
- unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
83
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
84
- token instead.
85
- bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
86
- The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
87
- eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
88
- The end of sequence token.
89
- add_bos_token (`bool`, *optional*, defaults to `True`):
90
- Whether or not to add an `bos_token` at the start of sequences.
91
- add_eos_token (`bool`, *optional*, defaults to `False`):
92
- Whether or not to add an `eos_token` at the end of sequences.
93
- use_default_system_prompt (`bool`, *optional*, defaults to `False`):
94
- Whether or not the default system prompt for Llama should be used.
95
- add_prefix_space (`bool`, *optional*):
96
- Whether or not the tokenizer should automatically add a prefix space
97
- """
98
-
99
- vocab_files_names = VOCAB_FILES_NAMES
100
- slow_tokenizer_class = LlamaTokenizer
101
- padding_side = "left"
102
- model_input_names = ["input_ids", "attention_mask"]
103
-
104
- def __init__(
105
- self,
106
- vocab_file=None,
107
- tokenizer_file=None,
108
- clean_up_tokenization_spaces=False,
109
- unk_token="<unk>",
110
- bos_token="<s>",
111
- eos_token="</s>",
112
- add_bos_token=True,
113
- add_eos_token=False,
114
- use_default_system_prompt=False,
115
- add_prefix_space=None,
116
- **kwargs,
117
- ):
118
- if add_prefix_space is not None:
119
- logger.warning_once(
120
- "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
121
- )
122
- kwargs["from_slow"] = True
123
-
124
- super().__init__(
125
- vocab_file=vocab_file,
126
- tokenizer_file=tokenizer_file,
127
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
128
- unk_token=unk_token,
129
- bos_token=bos_token,
130
- eos_token=eos_token,
131
- add_bos_token=add_bos_token,
132
- add_eos_token=add_eos_token,
133
- use_default_system_prompt=use_default_system_prompt,
134
- **kwargs,
135
- )
136
- self._add_bos_token = add_bos_token
137
- self._add_eos_token = add_eos_token
138
- self.update_post_processor()
139
- self.use_default_system_prompt = use_default_system_prompt
140
- self.vocab_file = vocab_file
141
-
142
- @property
143
- def can_save_slow_tokenizer(self) -> bool:
144
- return os.path.isfile(self.vocab_file) if self.vocab_file else False
145
-
146
- def update_post_processor(self):
147
- """
148
- Updates the underlying post processor with the current `bos_token` and `eos_token`.
149
- """
150
- bos = self.bos_token
151
- bos_token_id = self.bos_token_id
152
- if bos is None and self.add_bos_token:
153
- raise ValueError("add_bos_token = True but bos_token = None")
154
-
155
- eos = self.eos_token
156
- eos_token_id = self.eos_token_id
157
- if eos is None and self.add_eos_token:
158
- raise ValueError("add_eos_token = True but eos_token = None")
159
-
160
- single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
161
- pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
162
-
163
- special_tokens = []
164
- if self.add_bos_token:
165
- special_tokens.append((bos, bos_token_id))
166
- if self.add_eos_token:
167
- special_tokens.append((eos, eos_token_id))
168
- self._tokenizer.post_processor = processors.TemplateProcessing(
169
- single=single, pair=pair, special_tokens=special_tokens
170
- )
171
-
172
- @property
173
- def add_eos_token(self):
174
- return self._add_eos_token
175
-
176
- @property
177
- def add_bos_token(self):
178
- return self._add_bos_token
179
-
180
- @add_eos_token.setter
181
- def add_eos_token(self, value):
182
- self._add_eos_token = value
183
- self.update_post_processor()
184
-
185
- @add_bos_token.setter
186
- def add_bos_token(self, value):
187
- self._add_bos_token = value
188
- self.update_post_processor()
189
-
190
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
191
- if not self.can_save_slow_tokenizer:
192
- raise ValueError(
193
- "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
194
- "tokenizer."
195
- )
196
-
197
- if not os.path.isdir(save_directory):
198
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
199
- return
200
- out_vocab_file = os.path.join(
201
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
202
- )
203
-
204
- if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
205
- copyfile(self.vocab_file, out_vocab_file)
206
-
207
- return (out_vocab_file,)
208
-
209
- @property
210
- # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
211
- def default_chat_template(self):
212
- """
213
- LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
214
- Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
215
- user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
216
- rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
217
- results in an unusual token ordering when it is present. This template should definitely be changed if you wish
218
- to fine-tune a model with more flexible role ordering!
219
-
220
- The output should look something like:
221
-
222
- <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
223
- <bos>[INST] Prompt [/INST]
224
-
225
- The reference for this chat template is [this code
226
- snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
227
- in the original repository.
228
- """
229
- logger.warning_once(
230
- "\nNo chat template is defined for this tokenizer - using the default template "
231
- f"for the {self.__class__.__name__} class. If the default is not appropriate for "
232
- "your model, please set `tokenizer.chat_template` to an appropriate template. "
233
- "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
234
- )
235
- template = (
236
- "{% if messages[0]['role'] == 'system' %}"
237
- "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
238
- "{% set system_message = messages[0]['content'] %}"
239
- "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
240
- "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
241
- "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
242
- "{% else %}"
243
- "{% set loop_messages = messages %}"
244
- "{% set system_message = false %}"
245
- "{% endif %}"
246
- "{% for message in loop_messages %}" # Loop over all non-system messages
247
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
248
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
249
- "{% endif %}"
250
- "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
251
- "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
252
- "{% else %}"
253
- "{% set content = message['content'] %}"
254
- "{% endif %}"
255
- "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
256
- "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
257
- "{% elif message['role'] == 'system' %}"
258
- "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
259
- "{% elif message['role'] == 'assistant' %}"
260
- "{{ ' ' + content.strip() + ' ' + eos_token }}"
261
- "{% endif %}"
262
- "{% endfor %}"
263
- )
264
- template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
265
- default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
266
- template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
267
-
268
- return template
269
-
270
- # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
271
- # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
272
- def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
273
- bos_token_id = [self.bos_token_id] if self.add_bos_token else []
274
- eos_token_id = [self.eos_token_id] if self.add_eos_token else []
275
-
276
- output = bos_token_id + token_ids_0 + eos_token_id
277
-
278
- if token_ids_1 is not None:
279
- output = output + bos_token_id + token_ids_1 + eos_token_id
280
-
281
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/minicpm/configuration_minicpm.py DELETED
@@ -1,202 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ MiniCPM model configuration"""
21
-
22
- from transformers.configuration_utils import PretrainedConfig
23
- from transformers.utils import logging
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
- MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29
-
30
-
31
- class MiniCPMConfig(PretrainedConfig):
32
- r"""
33
- This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
34
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
- defaults will yield a similar configuration to that of the MiniCPM-7B.
36
-
37
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
- documentation from [`PretrainedConfig`] for more information.
39
-
40
-
41
- Args:
42
- vocab_size (`int`, *optional*, defaults to 32000):
43
- Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
44
- `inputs_ids` passed when calling [`MiniCPMModel`]
45
- hidden_size (`int`, *optional*, defaults to 4096):
46
- Dimension of the hidden representations.
47
- intermediate_size (`int`, *optional*, defaults to 11008):
48
- Dimension of the MLP representations.
49
- num_hidden_layers (`int`, *optional*, defaults to 32):
50
- Number of hidden layers in the Transformer decoder.
51
- num_attention_heads (`int`, *optional*, defaults to 32):
52
- Number of attention heads for each attention layer in the Transformer decoder.
53
- num_key_value_heads (`int`, *optional*):
54
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
- by meanpooling all the original heads within that group. For more details checkout [this
59
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
- `num_attention_heads`.
61
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
- The non-linear activation function (function or string) in the decoder.
63
- max_position_embeddings (`int`, *optional*, defaults to 2048):
64
- The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
65
- MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
66
- initializer_range (`float`, *optional*, defaults to 0.02):
67
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69
- The epsilon used by the rms normalization layers.
70
- use_cache (`bool`, *optional*, defaults to `True`):
71
- Whether or not the model should return the last key/values attentions (not used by all models). Only
72
- relevant if `config.is_decoder=True`.
73
- pad_token_id (`int`, *optional*):
74
- Padding token id.
75
- bos_token_id (`int`, *optional*, defaults to 1):
76
- Beginning of stream token id.
77
- eos_token_id (`int`, *optional*, defaults to 2):
78
- End of stream token id.
79
- pretraining_tp (`int`, *optional*, defaults to 1):
80
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
81
- document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
82
- necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
83
- issue](https://github.com/pytorch/pytorch/issues/76232).
84
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
85
- Whether to tie weight embeddings
86
- rope_theta (`float`, *optional*, defaults to 10000.0):
87
- The base period of the RoPE embeddings.
88
- rope_scaling (`Dict`, *optional*):
89
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
90
- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
91
- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
92
- `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
93
- these scaling strategies behave:
94
- https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
95
- experimental feature, subject to breaking API changes in future versions.
96
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
97
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
98
- attention_dropout (`float`, *optional*, defaults to 0.0):
99
- The dropout ratio for the attention probabilities.
100
-
101
- ```python
102
- >>> from transformers import MiniCPMModel, MiniCPMConfig
103
-
104
- >>> # Initializing a MiniCPM minicpm-7b style configuration
105
- >>> configuration = MiniCPMConfig()
106
-
107
- >>> # Initializing a model from the minicpm-7b style configuration
108
- >>> model = MiniCPMModel(configuration)
109
-
110
- >>> # Accessing the model configuration
111
- >>> configuration = model.config
112
- ```"""
113
-
114
- model_type = "minicpm"
115
- keys_to_ignore_at_inference = ["past_key_values"]
116
-
117
- def __init__(
118
- self,
119
- vocab_size=32000,
120
- hidden_size=4096,
121
- intermediate_size=11008,
122
- num_hidden_layers=32,
123
- num_attention_heads=32,
124
- num_key_value_heads=None,
125
- hidden_act="silu",
126
- max_position_embeddings=2048,
127
- initializer_range=0.02,
128
- rms_norm_eps=1e-6,
129
- use_cache=True,
130
- pad_token_id=None,
131
- bos_token_id=1,
132
- eos_token_id=2,
133
- pretraining_tp=1,
134
- tie_word_embeddings=True,
135
- rope_theta=10000.0,
136
- rope_scaling=None,
137
- attention_bias=False,
138
- attention_dropout=0.0,
139
- scale_emb=1,
140
- dim_model_base=1,
141
- scale_depth=1,
142
- **kwargs,
143
- ):
144
- self.vocab_size = vocab_size
145
- self.max_position_embeddings = max_position_embeddings
146
- self.hidden_size = hidden_size
147
- self.intermediate_size = intermediate_size
148
- self.num_hidden_layers = num_hidden_layers
149
- self.num_attention_heads = num_attention_heads
150
-
151
- # for backward compatibility
152
- if num_key_value_heads is None:
153
- num_key_value_heads = num_attention_heads
154
-
155
- self.num_key_value_heads = num_key_value_heads
156
- self.hidden_act = hidden_act
157
- self.initializer_range = initializer_range
158
- self.rms_norm_eps = rms_norm_eps
159
- self.pretraining_tp = pretraining_tp
160
- self.use_cache = use_cache
161
- self.rope_theta = rope_theta
162
- self.rope_scaling = rope_scaling
163
- self._rope_scaling_validation()
164
- self.attention_bias = attention_bias
165
- self.attention_dropout = attention_dropout
166
- self.scale_emb = scale_emb
167
- self.dim_model_base = dim_model_base
168
- self.scale_depth = scale_depth
169
-
170
- super().__init__(
171
- pad_token_id=pad_token_id,
172
- bos_token_id=bos_token_id,
173
- eos_token_id=eos_token_id,
174
- tie_word_embeddings=tie_word_embeddings,
175
- **kwargs,
176
- )
177
- try:
178
- import flash_attn
179
- self._attn_implementation = "flash_attention_2"
180
- except:
181
- pass
182
-
183
- def _rope_scaling_validation(self):
184
- """
185
- Validate the `rope_scaling` configuration.
186
- """
187
- if self.rope_scaling is None:
188
- return
189
-
190
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
191
- raise ValueError(
192
- "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
193
- f"got {self.rope_scaling}"
194
- )
195
- rope_scaling_type = self.rope_scaling.get("type", None)
196
- rope_scaling_factor = self.rope_scaling.get("factor", None)
197
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
198
- raise ValueError(
199
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
200
- )
201
- if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
202
- raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/minicpm/modeling_minicpm.py DELETED
@@ -1,1456 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ PyTorch MiniCPM model."""
21
- import math
22
- import warnings
23
- from typing import List, Optional, Tuple, Union, Dict
24
-
25
- import torch
26
- import torch.nn.functional as F
27
- import torch.utils.checkpoint
28
- from torch import nn
29
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
-
31
- from transformers.activations import ACT2FN
32
- from transformers.cache_utils import Cache, DynamicCache
33
- from transformers.modeling_attn_mask_utils import (
34
- AttentionMaskConverter,
35
- _prepare_4d_attention_mask,
36
- _prepare_4d_causal_attention_mask,
37
- _prepare_4d_causal_attention_mask_for_sdpa,
38
- )
39
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
40
- from transformers.modeling_utils import PreTrainedModel
41
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
42
- from transformers.utils import (
43
- add_start_docstrings,
44
- add_start_docstrings_to_model_forward,
45
- is_flash_attn_2_available,
46
- is_flash_attn_greater_or_equal_2_10,
47
- logging,
48
- replace_return_docstrings,
49
- )
50
- from transformers.utils.import_utils import is_torch_fx_available
51
- from .configuration_minicpm import MiniCPMConfig
52
- import re
53
-
54
- try:
55
- from flash_attn import flash_attn_func, flash_attn_varlen_func
56
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
- except:
58
- pass
59
-
60
-
61
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
- # It means that the function will not be traced through and simply appear as a node in the graph.
63
- if is_torch_fx_available():
64
- if not is_torch_greater_or_equal_than_1_13:
65
- import torch.fx
66
-
67
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
68
-
69
-
70
- logger = logging.get_logger(__name__)
71
-
72
- _CONFIG_FOR_DOC = "MiniCPMConfig"
73
-
74
-
75
- def _get_unpad_data(attention_mask):
76
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
77
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
78
- max_seqlen_in_batch = seqlens_in_batch.max().item()
79
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
80
- return (
81
- indices,
82
- cu_seqlens,
83
- max_seqlen_in_batch,
84
- )
85
-
86
-
87
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
88
- warnings.warn(
89
- "Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
90
- )
91
- return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
92
-
93
-
94
- def _make_causal_mask(
95
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
96
- ):
97
- warnings.warn(
98
- "Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask"
99
- )
100
- return AttentionMaskConverter._make_causal_mask(
101
- input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
102
- )
103
-
104
- # @torch.jit.script # type: ignore
105
- def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
106
- old_dtype = hidden.dtype
107
- variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
108
- hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
109
- return hidden * weight
110
-
111
-
112
- class MiniCPMRMSNorm(nn.Module):
113
- def __init__(self, hidden_size, eps=1e-6):
114
- """
115
- MiniCPMRMSNorm is equivalent to T5LayerNorm
116
- """
117
- super().__init__()
118
- self.weight = nn.Parameter(torch.ones(hidden_size))
119
- self.variance_epsilon = eps
120
-
121
- def forward(self, hidden_states):
122
- return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
123
-
124
-
125
- ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm)
126
-
127
-
128
- class MiniCPMRotaryEmbedding(nn.Module):
129
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
130
- super().__init__()
131
-
132
- self.dim = dim
133
- self.max_position_embeddings = max_position_embeddings
134
- self.base = base
135
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
136
- self.register_buffer("inv_freq", inv_freq, persistent=False)
137
-
138
- # Build here to make `torch.jit.trace` work.
139
- self._set_cos_sin_cache(
140
- # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
141
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
142
- )
143
-
144
- def _set_cos_sin_cache(self, seq_len, device, dtype):
145
- self.max_seq_len_cached = seq_len
146
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
147
- freqs = torch.outer(t, self.inv_freq)
148
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
149
- emb = torch.cat((freqs, freqs), dim=-1)
150
-
151
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
152
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
153
-
154
- def forward(self, x, seq_len=None):
155
- # x: [bs, num_attention_heads, seq_len, head_size]
156
- if seq_len > self.max_seq_len_cached:
157
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
158
-
159
- return (
160
- self.cos_cached[:seq_len].to(dtype=x.dtype),
161
- self.sin_cached[:seq_len].to(dtype=x.dtype),
162
- )
163
-
164
-
165
- class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
166
- """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
167
-
168
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
169
- self.scaling_factor = scaling_factor
170
- super().__init__(dim, max_position_embeddings, base, device)
171
-
172
- def _set_cos_sin_cache(self, seq_len, device, dtype):
173
- self.max_seq_len_cached = seq_len
174
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
175
- t = t / self.scaling_factor
176
-
177
- freqs = torch.outer(t, self.inv_freq)
178
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
179
- emb = torch.cat((freqs, freqs), dim=-1)
180
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
181
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
182
-
183
-
184
- class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding):
185
- """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
186
-
187
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
188
- self.scaling_factor = scaling_factor
189
- super().__init__(dim, max_position_embeddings, base, device)
190
-
191
- def _set_cos_sin_cache(self, seq_len, device, dtype):
192
- self.max_seq_len_cached = seq_len
193
-
194
- if seq_len > self.max_position_embeddings:
195
- base = self.base * (
196
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
197
- ) ** (self.dim / (self.dim - 2))
198
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
199
- self.register_buffer("inv_freq", inv_freq, persistent=False)
200
-
201
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
202
-
203
- freqs = torch.outer(t, self.inv_freq)
204
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
205
- emb = torch.cat((freqs, freqs), dim=-1)
206
-
207
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
208
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
209
-
210
-
211
- def rotate_half(x):
212
- """Rotates half the hidden dims of the input."""
213
- x1 = x[..., : x.shape[-1] // 2]
214
- x2 = x[..., x.shape[-1] // 2 :]
215
- return torch.cat((-x2, x1), dim=-1)
216
-
217
-
218
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
219
- """Applies Rotary Position Embedding to the query and key tensors.
220
-
221
- Args:
222
- q (`torch.Tensor`): The query tensor.
223
- k (`torch.Tensor`): The key tensor.
224
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
225
- sin (`torch.Tensor`): The sine part of the rotary embedding.
226
- position_ids (`torch.Tensor`):
227
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
228
- used to pass offsetted position ids when working with a KV-cache.
229
- unsqueeze_dim (`int`, *optional*, defaults to 1):
230
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
231
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
232
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
233
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
234
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
235
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
236
- Returns:
237
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
238
- """
239
- # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
240
- # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
241
- # q_embed = (q * cos) + (rotate_half(q) * sin)
242
- # k_embed = (k * cos) + (rotate_half(k) * sin)
243
- orig_dtype = k.dtype
244
- cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
245
- sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
246
- q_fp32 = q.to(dtype=torch.float32, device=q.device)
247
- k_fp32 = k.to(dtype=torch.float32, device=k.device)
248
- q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
249
- k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
250
- return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
251
-
252
- class MiniCPMMLP(nn.Module):
253
- def __init__(self, config):
254
- super().__init__()
255
- self.config = config
256
- self.hidden_size = config.hidden_size
257
- self.intermediate_size = config.intermediate_size
258
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
259
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
260
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
261
- self.act_fn = ACT2FN[config.hidden_act]
262
-
263
- def forward(self, x):
264
- if self.config.pretraining_tp > 1:
265
- slice = self.intermediate_size // self.config.pretraining_tp
266
- gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
267
- up_proj_slices = self.up_proj.weight.split(slice, dim=0)
268
- down_proj_slices = self.down_proj.weight.split(slice, dim=1)
269
-
270
- gate_proj = torch.cat(
271
- [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
272
- )
273
- up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
274
-
275
- intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
276
- down_proj = [
277
- F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
278
- ]
279
- down_proj = sum(down_proj)
280
- else:
281
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
282
-
283
- return down_proj
284
-
285
-
286
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
287
- """
288
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
289
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
290
- """
291
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
292
- if n_rep == 1:
293
- return hidden_states
294
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
295
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
296
-
297
-
298
-
299
- class MiniCPMAttention(nn.Module):
300
- """Multi-headed attention from 'Attention Is All You Need' paper"""
301
-
302
- def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):
303
- super().__init__()
304
- self.config = config
305
- self.layer_idx = layer_idx
306
- if layer_idx is None:
307
- logger.warning_once(
308
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
309
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
310
- "when creating this class."
311
- )
312
-
313
- self.attention_dropout = config.attention_dropout
314
- self.hidden_size = config.hidden_size
315
- self.num_heads = config.num_attention_heads
316
- self.head_dim = self.hidden_size // self.num_heads
317
- self.num_key_value_heads = config.num_key_value_heads
318
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
319
- self.max_position_embeddings = config.max_position_embeddings
320
- self.rope_theta = config.rope_theta
321
- self.is_causal = True
322
-
323
- if (self.head_dim * self.num_heads) != self.hidden_size:
324
- raise ValueError(
325
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
326
- f" and `num_heads`: {self.num_heads})."
327
- )
328
-
329
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
330
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
331
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
332
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
333
- self._init_rope()
334
-
335
- def _init_rope(self):
336
- if self.config.rope_scaling is None:
337
- self.rotary_emb = MiniCPMRotaryEmbedding(
338
- self.head_dim,
339
- max_position_embeddings=self.max_position_embeddings,
340
- base=self.rope_theta,
341
- )
342
- else:
343
- scaling_type = self.config.rope_scaling["type"]
344
- scaling_factor = self.config.rope_scaling["factor"]
345
- if scaling_type == "linear":
346
- self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding(
347
- self.head_dim,
348
- max_position_embeddings=self.max_position_embeddings,
349
- scaling_factor=scaling_factor,
350
- base=self.rope_theta,
351
- )
352
- elif scaling_type == "dynamic":
353
- self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding(
354
- self.head_dim,
355
- max_position_embeddings=self.max_position_embeddings,
356
- scaling_factor=scaling_factor,
357
- base=self.rope_theta,
358
- )
359
- else:
360
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
361
-
362
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
363
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
364
-
365
- def forward(
366
- self,
367
- hidden_states: torch.Tensor,
368
- attention_mask: Optional[torch.Tensor] = None,
369
- position_ids: Optional[torch.LongTensor] = None,
370
- past_key_value: Optional[Cache] = None,
371
- output_attentions: bool = False,
372
- use_cache: bool = False,
373
- **kwargs,
374
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
375
- if "padding_mask" in kwargs:
376
- warnings.warn(
377
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
378
- )
379
-
380
- bsz, q_len, _ = hidden_states.size()
381
-
382
- if self.config.pretraining_tp > 1:
383
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
384
- query_slices = self.q_proj.weight.split(
385
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
386
- )
387
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
388
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
389
-
390
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
391
- query_states = torch.cat(query_states, dim=-1)
392
-
393
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
394
- key_states = torch.cat(key_states, dim=-1)
395
-
396
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
397
- value_states = torch.cat(value_states, dim=-1)
398
-
399
- else:
400
- query_states = self.q_proj(hidden_states)
401
- key_states = self.k_proj(hidden_states)
402
- value_states = self.v_proj(hidden_states)
403
-
404
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
405
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
406
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
407
-
408
- kv_seq_len = key_states.shape[-2]
409
- if past_key_value is not None:
410
- if self.layer_idx is None:
411
- raise ValueError(
412
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
413
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
414
- "with a layer index."
415
- )
416
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
417
- cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
418
-
419
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
420
-
421
- if past_key_value is not None:
422
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
423
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
424
-
425
- key_states = repeat_kv(key_states, self.num_key_value_groups)
426
- value_states = repeat_kv(value_states, self.num_key_value_groups)
427
-
428
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
429
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
430
- raise ValueError(
431
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
432
- f" {attn_weights.size()}"
433
- )
434
-
435
- if attention_mask is not None:
436
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
437
- raise ValueError(
438
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
439
- )
440
- attn_weights = attn_weights + attention_mask
441
-
442
- # upcast attention to fp32
443
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
444
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
445
- attn_output = torch.matmul(attn_weights, value_states)
446
-
447
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
448
- raise ValueError(
449
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
450
- f" {attn_output.size()}"
451
- )
452
-
453
- attn_output = attn_output.transpose(1, 2).contiguous()
454
-
455
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
456
-
457
- if self.config.pretraining_tp > 1:
458
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
459
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
460
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
461
- else:
462
- attn_output = self.o_proj(attn_output)
463
-
464
- if not output_attentions:
465
- attn_weights = None
466
-
467
- return attn_output, attn_weights, past_key_value
468
-
469
-
470
- class MiniCPMFlashAttention2(MiniCPMAttention):
471
- """
472
- MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays
473
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
474
- flash attention and deal with padding tokens in case the input contains any of them.
475
- """
476
-
477
- def __init__(self, *args, **kwargs):
478
- super().__init__(*args, **kwargs)
479
-
480
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
481
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
482
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
483
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
484
-
485
- def forward(
486
- self,
487
- hidden_states: torch.Tensor,
488
- attention_mask: Optional[torch.LongTensor] = None,
489
- position_ids: Optional[torch.LongTensor] = None,
490
- past_key_value: Optional[Cache] = None,
491
- output_attentions: bool = False,
492
- use_cache: bool = False,
493
- **kwargs,
494
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
495
- # MiniCPMFlashAttention2 attention does not support output_attentions
496
- if "padding_mask" in kwargs:
497
- warnings.warn(
498
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
499
- )
500
-
501
- # overwrite attention_mask with padding_mask
502
- attention_mask = kwargs.pop("padding_mask")
503
-
504
- output_attentions = False
505
-
506
- bsz, q_len, _ = hidden_states.size()
507
-
508
- query_states = self.q_proj(hidden_states)
509
- key_states = self.k_proj(hidden_states)
510
- value_states = self.v_proj(hidden_states)
511
-
512
- # Flash attention requires the input to have the shape
513
- # batch_size x seq_length x head_dim x hidden_dim
514
- # therefore we just need to keep the original shape
515
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
516
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
517
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
518
-
519
- kv_seq_len = key_states.shape[-2]
520
- if past_key_value is not None:
521
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
522
- cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
523
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
524
-
525
- if past_key_value is not None:
526
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
527
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
528
-
529
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
530
- # to be able to avoid many of these transpose/reshape/view.
531
- query_states = query_states.transpose(1, 2)
532
- key_states = key_states.transpose(1, 2)
533
- value_states = value_states.transpose(1, 2)
534
-
535
- dropout_rate = self.attention_dropout if self.training else 0.0
536
-
537
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
538
- # therefore the input hidden states gets silently casted in float32. Hence, we need
539
- # cast them back in the correct dtype just to be sure everything works as expected.
540
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
541
- # in fp32. (MiniCPMRMSNorm handles it correctly)
542
-
543
- input_dtype = query_states.dtype
544
- if input_dtype == torch.float32:
545
- # Handle the case where the model is quantized
546
- if hasattr(self.config, "_pre_quantization_dtype"):
547
- target_dtype = self.config._pre_quantization_dtype
548
- else:
549
- target_dtype = self.q_proj.weight.dtype
550
-
551
- logger.warning_once(
552
- f"The input hidden states seems to be silently casted in float32, this might be related to"
553
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
554
- f" {target_dtype}."
555
- )
556
-
557
- query_states = query_states.to(target_dtype)
558
- key_states = key_states.to(target_dtype)
559
- value_states = value_states.to(target_dtype)
560
-
561
- attn_output = self._flash_attention_forward(
562
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
563
- )
564
-
565
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
566
- attn_output = self.o_proj(attn_output)
567
-
568
- if not output_attentions:
569
- attn_weights = None
570
-
571
- return attn_output, attn_weights, past_key_value
572
-
573
- def _flash_attention_forward(
574
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
575
- ):
576
- """
577
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
578
- first unpad the input, then computes the attention scores and pad the final attention scores.
579
-
580
- Args:
581
- query_states (`torch.Tensor`):
582
- Input query states to be passed to Flash Attention API
583
- key_states (`torch.Tensor`):
584
- Input key states to be passed to Flash Attention API
585
- value_states (`torch.Tensor`):
586
- Input value states to be passed to Flash Attention API
587
- attention_mask (`torch.Tensor`):
588
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
589
- position of padding tokens and 1 for the position of non-padding tokens.
590
- dropout (`int`, *optional*):
591
- Attention dropout
592
- softmax_scale (`float`, *optional*):
593
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
594
- """
595
- if not self._flash_attn_uses_top_left_mask:
596
- causal = self.is_causal
597
- else:
598
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
599
- causal = self.is_causal and query_length != 1
600
- # Contains at least one padding token in the sequence
601
- if attention_mask is not None:
602
- batch_size = query_states.shape[0]
603
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
604
- query_states, key_states, value_states, attention_mask, query_length
605
- )
606
-
607
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
608
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
609
- attn_output_unpad = flash_attn_varlen_func(
610
- query_states,
611
- key_states,
612
- value_states,
613
- cu_seqlens_q=cu_seqlens_q,
614
- cu_seqlens_k=cu_seqlens_k,
615
- max_seqlen_q=max_seqlen_in_batch_q,
616
- max_seqlen_k=max_seqlen_in_batch_k,
617
- dropout_p=dropout,
618
- softmax_scale=softmax_scale,
619
- causal=causal,
620
- )
621
-
622
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
623
- else:
624
- attn_output = flash_attn_func(
625
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
626
- )
627
-
628
- return attn_output
629
-
630
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
631
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
632
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
633
-
634
- key_layer = index_first_axis(
635
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
636
- )
637
- value_layer = index_first_axis(
638
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
639
- )
640
- if query_length == kv_seq_len:
641
- query_layer = index_first_axis(
642
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
643
- )
644
- cu_seqlens_q = cu_seqlens_k
645
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
646
- indices_q = indices_k
647
- elif query_length == 1:
648
- max_seqlen_in_batch_q = 1
649
- cu_seqlens_q = torch.arange(
650
- batch_size + 1, dtype=torch.int32, device=query_layer.device
651
- ) # There is a memcpy here, that is very bad.
652
- indices_q = cu_seqlens_q[:-1]
653
- query_layer = query_layer.squeeze(1)
654
- else:
655
- # The -q_len: slice assumes left padding.
656
- attention_mask = attention_mask[:, -query_length:]
657
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
658
-
659
- return (
660
- query_layer,
661
- key_layer,
662
- value_layer,
663
- indices_q,
664
- (cu_seqlens_q, cu_seqlens_k),
665
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
666
- )
667
-
668
-
669
- class MiniCPMSdpaAttention(MiniCPMAttention):
670
- """
671
- MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
672
- `MiniCPMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
673
- SDPA API.
674
- """
675
-
676
- # Adapted from MiniCPMAttention.forward
677
- def forward(
678
- self,
679
- hidden_states: torch.Tensor,
680
- attention_mask: Optional[torch.Tensor] = None,
681
- position_ids: Optional[torch.LongTensor] = None,
682
- past_key_value: Optional[Cache] = None,
683
- output_attentions: bool = False,
684
- use_cache: bool = False,
685
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
686
- if output_attentions:
687
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
688
- logger.warning_once(
689
- "MiniCPMModel is using MiniCPMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
690
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
691
- )
692
- return super().forward(
693
- hidden_states=hidden_states,
694
- attention_mask=attention_mask,
695
- position_ids=position_ids,
696
- past_key_value=past_key_value,
697
- output_attentions=output_attentions,
698
- use_cache=use_cache,
699
- )
700
-
701
- bsz, q_len, _ = hidden_states.size()
702
-
703
- query_states = self.q_proj(hidden_states)
704
- key_states = self.k_proj(hidden_states)
705
- value_states = self.v_proj(hidden_states)
706
-
707
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
708
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
709
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
710
-
711
- kv_seq_len = key_states.shape[-2]
712
- if past_key_value is not None:
713
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
714
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
715
-
716
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
717
-
718
- if past_key_value is not None:
719
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
720
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
721
-
722
- key_states = repeat_kv(key_states, self.num_key_value_groups)
723
- value_states = repeat_kv(value_states, self.num_key_value_groups)
724
-
725
- if attention_mask is not None:
726
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
727
- raise ValueError(
728
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
729
- )
730
-
731
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
732
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
733
- if query_states.device.type == "cuda" and attention_mask is not None:
734
- query_states = query_states.contiguous()
735
- key_states = key_states.contiguous()
736
- value_states = value_states.contiguous()
737
-
738
- attn_output = torch.nn.functional.scaled_dot_product_attention(
739
- query_states,
740
- key_states,
741
- value_states,
742
- attn_mask=attention_mask,
743
- dropout_p=self.attention_dropout if self.training else 0.0,
744
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
745
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
746
- )
747
-
748
- attn_output = attn_output.transpose(1, 2).contiguous()
749
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
750
-
751
- attn_output = self.o_proj(attn_output)
752
-
753
- return attn_output, None, past_key_value
754
-
755
-
756
- MINICPM_ATTENTION_CLASSES = {
757
- "eager": MiniCPMAttention,
758
- "flash_attention_2": MiniCPMFlashAttention2,
759
- "sdpa": MiniCPMSdpaAttention,
760
- }
761
-
762
-
763
- class MiniCPMDecoderLayer(nn.Module):
764
- def __init__(self, config: MiniCPMConfig, layer_idx: int):
765
- super().__init__()
766
- self.hidden_size = config.hidden_size
767
- self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
768
-
769
- self.mlp = MiniCPMMLP(config)
770
- self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
771
- self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
772
-
773
- self.scale_depth = config.scale_depth
774
- self.num_hidden_layers = config.num_hidden_layers
775
-
776
- def forward(
777
- self,
778
- hidden_states: torch.Tensor,
779
- attention_mask: Optional[torch.Tensor] = None,
780
- position_ids: Optional[torch.LongTensor] = None,
781
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
782
- output_attentions: Optional[bool] = False,
783
- use_cache: Optional[bool] = False,
784
- **kwargs,
785
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
786
- """
787
- Args:
788
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
789
- attention_mask (`torch.FloatTensor`, *optional*):
790
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
791
- query_sequence_length, key_sequence_length)` if default attention is used.
792
- output_attentions (`bool`, *optional*):
793
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
794
- returned tensors for more detail.
795
- use_cache (`bool`, *optional*):
796
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
797
- (see `past_key_values`).
798
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
799
- """
800
- if "padding_mask" in kwargs:
801
- warnings.warn(
802
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
803
- )
804
-
805
- residual = hidden_states
806
- hidden_states = self.input_layernorm(hidden_states)
807
- # Self Attention
808
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
809
- hidden_states=hidden_states,
810
- attention_mask=attention_mask,
811
- position_ids=position_ids,
812
- past_key_value=past_key_value,
813
- output_attentions=output_attentions,
814
- use_cache=use_cache,
815
- **kwargs,
816
- )
817
-
818
- hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
819
-
820
- # Fully Connected
821
- residual = hidden_states
822
- hidden_states = self.post_attention_layernorm(hidden_states)
823
-
824
- hidden_states = self.mlp(hidden_states)
825
- hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
826
-
827
- outputs = (hidden_states,)
828
-
829
- if output_attentions:
830
- outputs += (self_attn_weights,)
831
-
832
- if use_cache:
833
- outputs += (present_key_value,)
834
-
835
- return outputs
836
-
837
-
838
- MINICPM_START_DOCSTRING = r"""
839
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
840
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
841
- etc.)
842
-
843
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
844
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
845
- and behavior.
846
-
847
- Parameters:
848
- config ([`MiniCPMConfig`]):
849
- Model configuration class with all the parameters of the model. Initializing with a config file does not
850
- load the weights associated with the model, only the configuration. Check out the
851
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
852
- """
853
-
854
-
855
- @add_start_docstrings(
856
- "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
857
- MINICPM_START_DOCSTRING,
858
- )
859
- class MiniCPMPreTrainedModel(PreTrainedModel):
860
- config_class = MiniCPMConfig
861
- base_model_prefix = "model"
862
- supports_gradient_checkpointing = True
863
- _no_split_modules = ["MiniCPMDecoderLayer"]
864
- _skip_keys_device_placement = "past_key_values"
865
- _supports_flash_attn_2 = True
866
- _supports_sdpa = True
867
- _supports_cache_class = True
868
-
869
- def _init_weights(self, module):
870
- std = self.config.initializer_range
871
- if isinstance(module, nn.Linear):
872
- module.weight.data.normal_(mean=0.0, std=std)
873
- if module.bias is not None:
874
- module.bias.data.zero_()
875
- elif isinstance(module, nn.Embedding):
876
- module.weight.data.normal_(mean=0.0, std=std)
877
- if module.padding_idx is not None:
878
- module.weight.data[module.padding_idx].zero_()
879
-
880
-
881
- MINICPM_INPUTS_DOCSTRING = r"""
882
- Args:
883
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
884
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
885
- it.
886
-
887
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
888
- [`PreTrainedTokenizer.__call__`] for details.
889
-
890
- [What are input IDs?](../glossary#input-ids)
891
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
892
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
893
-
894
- - 1 for tokens that are **not masked**,
895
- - 0 for tokens that are **masked**.
896
-
897
- [What are attention masks?](../glossary#attention-mask)
898
-
899
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
900
- [`PreTrainedTokenizer.__call__`] for details.
901
-
902
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
903
- `past_key_values`).
904
-
905
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
906
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
907
- information on the default strategy.
908
-
909
- - 1 indicates the head is **not masked**,
910
- - 0 indicates the head is **masked**.
911
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
912
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
913
- config.n_positions - 1]`.
914
-
915
- [What are position IDs?](../glossary#position-ids)
916
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
917
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
918
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
919
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
920
-
921
- Two formats are allowed:
922
- - a [`~cache_utils.Cache`] instance;
923
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
924
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
925
- cache format.
926
-
927
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
928
- legacy cache format will be returned.
929
-
930
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
931
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
932
- of shape `(batch_size, sequence_length)`.
933
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
934
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
935
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
936
- model's internal embedding lookup matrix.
937
- use_cache (`bool`, *optional*):
938
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
939
- `past_key_values`).
940
- output_attentions (`bool`, *optional*):
941
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
942
- tensors for more detail.
943
- output_hidden_states (`bool`, *optional*):
944
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
945
- more detail.
946
- return_dict (`bool`, *optional*):
947
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
948
- """
949
-
950
-
951
- @add_start_docstrings(
952
- "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.",
953
- MINICPM_START_DOCSTRING,
954
- )
955
- class MiniCPMModel(MiniCPMPreTrainedModel):
956
- """
957
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
958
-
959
- Args:
960
- config: MiniCPMConfig
961
- """
962
-
963
- def __init__(self, config: MiniCPMConfig):
964
- super().__init__(config)
965
- self.padding_idx = config.pad_token_id
966
- self.vocab_size = config.vocab_size
967
-
968
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
969
- self.layers = nn.ModuleList(
970
- [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
971
- )
972
- self._use_sdpa = config._attn_implementation == "sdpa"
973
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
974
-
975
- self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
976
-
977
- self.gradient_checkpointing = False
978
- # Initialize weights and apply final processing
979
- self.post_init()
980
-
981
- def get_input_embeddings(self):
982
- return self.embed_tokens
983
-
984
- def set_input_embeddings(self, value):
985
- self.embed_tokens = value
986
-
987
- @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
988
- def forward(
989
- self,
990
- input_ids: torch.LongTensor = None,
991
- attention_mask: Optional[torch.Tensor] = None,
992
- position_ids: Optional[torch.LongTensor] = None,
993
- past_key_values: Optional[List[torch.FloatTensor]] = None,
994
- inputs_embeds: Optional[torch.FloatTensor] = None,
995
- use_cache: Optional[bool] = None,
996
- output_attentions: Optional[bool] = None,
997
- output_hidden_states: Optional[bool] = None,
998
- return_dict: Optional[bool] = None,
999
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1000
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1001
- output_hidden_states = (
1002
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1003
- )
1004
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1005
-
1006
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
-
1008
- # retrieve input_ids and inputs_embeds
1009
- if input_ids is not None and inputs_embeds is not None:
1010
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1011
- elif input_ids is not None:
1012
- batch_size, seq_length = input_ids.shape[:2]
1013
- elif inputs_embeds is not None:
1014
- batch_size, seq_length = inputs_embeds.shape[:2]
1015
- else:
1016
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1017
-
1018
- if self.gradient_checkpointing and self.training:
1019
- if use_cache:
1020
- logger.warning_once(
1021
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1022
- )
1023
- use_cache = False
1024
-
1025
- past_key_values_length = 0
1026
- if use_cache:
1027
- use_legacy_cache = not isinstance(past_key_values, Cache)
1028
- if use_legacy_cache:
1029
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1030
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1031
-
1032
- if position_ids is None:
1033
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1034
- position_ids = torch.arange(
1035
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1036
- )
1037
- position_ids = position_ids.unsqueeze(0)
1038
-
1039
- if inputs_embeds is None:
1040
- inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb
1041
-
1042
-
1043
- if self._use_flash_attention_2:
1044
- # 2d mask is passed through the layers
1045
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1046
- elif self._use_sdpa and not output_attentions:
1047
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1048
- # the manual implementation that requires a 4D causal mask in all cases.
1049
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1050
- attention_mask,
1051
- (batch_size, seq_length),
1052
- inputs_embeds,
1053
- past_key_values_length,
1054
- )
1055
- else:
1056
- # 4d mask is passed through the layers
1057
- attention_mask = _prepare_4d_causal_attention_mask(
1058
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1059
- )
1060
-
1061
- # embed positions
1062
- hidden_states = inputs_embeds
1063
-
1064
- # decoder layers
1065
- all_hidden_states = () if output_hidden_states else None
1066
- all_self_attns = () if output_attentions else None
1067
- next_decoder_cache = None
1068
-
1069
- for decoder_layer in self.layers:
1070
- if output_hidden_states:
1071
- all_hidden_states += (hidden_states,)
1072
-
1073
- if self.gradient_checkpointing and self.training:
1074
- layer_outputs = self._gradient_checkpointing_func(
1075
- decoder_layer.__call__,
1076
- hidden_states,
1077
- attention_mask,
1078
- position_ids,
1079
- past_key_values,
1080
- output_attentions,
1081
- use_cache,
1082
- )
1083
- else:
1084
- layer_outputs = decoder_layer(
1085
- hidden_states,
1086
- attention_mask=attention_mask,
1087
- position_ids=position_ids,
1088
- past_key_value=past_key_values,
1089
- output_attentions=output_attentions,
1090
- use_cache=use_cache,
1091
- )
1092
-
1093
- hidden_states = layer_outputs[0]
1094
-
1095
- if use_cache:
1096
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1097
-
1098
- if output_attentions:
1099
- all_self_attns += (layer_outputs[1],)
1100
-
1101
- hidden_states = self.norm(hidden_states)
1102
-
1103
- # add hidden states from the last decoder layer
1104
- if output_hidden_states:
1105
- all_hidden_states += (hidden_states,)
1106
-
1107
- next_cache = None
1108
- if use_cache:
1109
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1110
- if not return_dict:
1111
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1112
- return BaseModelOutputWithPast(
1113
- last_hidden_state=hidden_states,
1114
- past_key_values=next_cache,
1115
- hidden_states=all_hidden_states,
1116
- attentions=all_self_attns,
1117
- )
1118
-
1119
-
1120
- class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1121
- _tied_weights_keys = ["lm_head.weight"]
1122
-
1123
- def __init__(self, config):
1124
- super().__init__(config)
1125
- self.model = MiniCPMModel(config)
1126
- self.vocab_size = config.vocab_size
1127
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1128
-
1129
- # Initialize weights and apply final processing
1130
- self.post_init()
1131
-
1132
- def get_input_embeddings(self):
1133
- return self.model.embed_tokens
1134
-
1135
- def set_input_embeddings(self, value):
1136
- self.model.embed_tokens = value
1137
-
1138
- def get_output_embeddings(self):
1139
- return self.lm_head
1140
-
1141
- def set_output_embeddings(self, new_embeddings):
1142
- self.lm_head = new_embeddings
1143
-
1144
- def set_decoder(self, decoder):
1145
- self.model = decoder
1146
-
1147
- def get_decoder(self):
1148
- return self.model
1149
-
1150
- @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1151
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1152
- def forward(
1153
- self,
1154
- input_ids: torch.LongTensor = None,
1155
- attention_mask: Optional[torch.Tensor] = None,
1156
- position_ids: Optional[torch.LongTensor] = None,
1157
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1158
- inputs_embeds: Optional[torch.FloatTensor] = None,
1159
- labels: Optional[torch.LongTensor] = None,
1160
- use_cache: Optional[bool] = None,
1161
- output_attentions: Optional[bool] = None,
1162
- output_hidden_states: Optional[bool] = None,
1163
- return_dict: Optional[bool] = None,
1164
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1165
- r"""
1166
- Args:
1167
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1168
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1169
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1170
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1171
-
1172
- Returns:
1173
-
1174
- Example:
1175
-
1176
- ```python
1177
- >>> from transformers import AutoTokenizer, MiniCPMForCausalLM
1178
-
1179
- >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1180
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1181
-
1182
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1183
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1184
-
1185
- >>> # Generate
1186
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1187
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1188
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1189
- ```"""
1190
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1191
- output_hidden_states = (
1192
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1193
- )
1194
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1195
-
1196
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1197
- outputs = self.model(
1198
- input_ids=input_ids,
1199
- attention_mask=attention_mask,
1200
- position_ids=position_ids,
1201
- past_key_values=past_key_values,
1202
- inputs_embeds=inputs_embeds,
1203
- use_cache=use_cache,
1204
- output_attentions=output_attentions,
1205
- output_hidden_states=output_hidden_states,
1206
- return_dict=return_dict,
1207
- )
1208
-
1209
- hidden_states = outputs[0]
1210
- if self.config.pretraining_tp > 1:
1211
- lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1212
- logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1213
- logits = torch.cat(logits, dim=-1)
1214
- else:
1215
- logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
1216
- logits = logits.float()
1217
-
1218
- loss = None
1219
- if labels is not None:
1220
- # Shift so that tokens < n predict n
1221
- shift_logits = logits[..., :-1, :].contiguous()
1222
- shift_labels = labels[..., 1:].contiguous()
1223
- # Flatten the tokens
1224
- loss_fct = CrossEntropyLoss()
1225
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1226
- shift_labels = shift_labels.view(-1)
1227
- # Enable model parallelism
1228
- shift_labels = shift_labels.to(shift_logits.device)
1229
- loss = loss_fct(shift_logits, shift_labels)
1230
-
1231
- if not return_dict:
1232
- output = (logits,) + outputs[1:]
1233
- return (loss,) + output if loss is not None else output
1234
-
1235
- return CausalLMOutputWithPast(
1236
- loss=loss,
1237
- logits=logits,
1238
- past_key_values=outputs.past_key_values,
1239
- hidden_states=outputs.hidden_states,
1240
- attentions=outputs.attentions,
1241
- )
1242
-
1243
- def prepare_inputs_for_generation(
1244
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1245
- ):
1246
- if past_key_values is not None:
1247
- if isinstance(past_key_values, Cache):
1248
- cache_length = past_key_values.get_seq_length()
1249
- past_length = past_key_values.seen_tokens
1250
- max_cache_length = past_key_values.get_max_length()
1251
- else:
1252
- cache_length = past_length = past_key_values[0][0].shape[2]
1253
- max_cache_length = None
1254
-
1255
- # Keep only the unprocessed tokens:
1256
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1257
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1258
- # input)
1259
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1260
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1261
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1262
- # input_ids based on the past_length.
1263
- elif past_length < input_ids.shape[1]:
1264
- input_ids = input_ids[:, past_length:]
1265
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1266
- else:
1267
- remove_prefix_length = input_ids.shape[1] - 1
1268
- input_ids = input_ids[:, remove_prefix_length:]
1269
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1270
- if (
1271
- max_cache_length is not None
1272
- and attention_mask is not None
1273
- and cache_length + input_ids.shape[1] > max_cache_length
1274
- ):
1275
- attention_mask = attention_mask[:, -max_cache_length:]
1276
-
1277
- position_ids = kwargs.get("position_ids", None)
1278
- if attention_mask is not None and position_ids is None:
1279
- # create position_ids on the fly for batch generation
1280
- position_ids = attention_mask.long().cumsum(-1) - 1
1281
- position_ids.masked_fill_(attention_mask == 0, 1)
1282
- if past_key_values:
1283
- position_ids = position_ids[:, -input_ids.shape[1] :]
1284
-
1285
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1286
- if inputs_embeds is not None and past_key_values is None:
1287
- model_inputs = {"inputs_embeds": inputs_embeds}
1288
- else:
1289
- model_inputs = {"input_ids": input_ids}
1290
-
1291
- model_inputs.update(
1292
- {
1293
- "position_ids": position_ids,
1294
- "past_key_values": past_key_values,
1295
- "use_cache": kwargs.get("use_cache"),
1296
- "attention_mask": attention_mask,
1297
- }
1298
- )
1299
- return model_inputs
1300
-
1301
- @staticmethod
1302
- def _reorder_cache(past_key_values, beam_idx):
1303
- reordered_past = ()
1304
- for layer_past in past_key_values:
1305
- reordered_past += (
1306
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1307
- )
1308
- return reordered_past
1309
-
1310
- @torch.inference_mode()
1311
- def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1312
- max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1313
- **kwargs):
1314
- if history is None:
1315
- history = []
1316
- if logits_processor:
1317
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1318
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1319
- else:
1320
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1321
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1322
-
1323
- history.append({"role": role, "content": query})
1324
- history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1325
- inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1326
- outputs = self.generate(**inputs, **gen_kwargs)
1327
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1328
- response = tokenizer.decode(outputs)
1329
- pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
1330
- matches = pattern.findall(response)
1331
- if len(matches) > 0:
1332
- response = matches[0]
1333
- history.append({"role": "assistant", "content": response})
1334
- return response, history
1335
-
1336
-
1337
- @add_start_docstrings(
1338
- """
1339
- The MiniCPM Model transformer with a sequence classification head on top (linear layer).
1340
-
1341
- [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1342
- (e.g. GPT-2) do.
1343
-
1344
- Since it does classification on the last token, it requires to know the position of the last token. If a
1345
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1346
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1347
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1348
- each row of the batch).
1349
- """,
1350
- MINICPM_START_DOCSTRING,
1351
- )
1352
- class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
1353
- def __init__(self, config):
1354
- super().__init__(config)
1355
- self.num_labels = config.num_labels
1356
- self.model = MiniCPMModel(config)
1357
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1358
-
1359
- # Initialize weights and apply final processing
1360
- self.post_init()
1361
-
1362
- def get_input_embeddings(self):
1363
- return self.model.embed_tokens
1364
-
1365
- def set_input_embeddings(self, value):
1366
- self.model.embed_tokens = value
1367
-
1368
- @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING)
1369
- def forward(
1370
- self,
1371
- input_ids: torch.LongTensor = None,
1372
- attention_mask: Optional[torch.Tensor] = None,
1373
- position_ids: Optional[torch.LongTensor] = None,
1374
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1375
- inputs_embeds: Optional[torch.FloatTensor] = None,
1376
- labels: Optional[torch.LongTensor] = None,
1377
- use_cache: Optional[bool] = None,
1378
- output_attentions: Optional[bool] = None,
1379
- output_hidden_states: Optional[bool] = None,
1380
- return_dict: Optional[bool] = None,
1381
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1382
- r"""
1383
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1384
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1385
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1386
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1387
- """
1388
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1389
-
1390
- transformer_outputs = self.model(
1391
- input_ids,
1392
- attention_mask=attention_mask,
1393
- position_ids=position_ids,
1394
- past_key_values=past_key_values,
1395
- inputs_embeds=inputs_embeds,
1396
- use_cache=use_cache,
1397
- output_attentions=output_attentions,
1398
- output_hidden_states=output_hidden_states,
1399
- return_dict=return_dict,
1400
- )
1401
- hidden_states = transformer_outputs[0]
1402
- logits = self.score(hidden_states)
1403
-
1404
- if input_ids is not None:
1405
- batch_size = input_ids.shape[0]
1406
- else:
1407
- batch_size = inputs_embeds.shape[0]
1408
-
1409
- if self.config.pad_token_id is None and batch_size != 1:
1410
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1411
- if self.config.pad_token_id is None:
1412
- sequence_lengths = -1
1413
- else:
1414
- if input_ids is not None:
1415
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1416
- logits.device
1417
- )
1418
- else:
1419
- sequence_lengths = -1
1420
-
1421
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1422
-
1423
- loss = None
1424
- if labels is not None:
1425
- labels = labels.to(logits.device)
1426
- if self.config.problem_type is None:
1427
- if self.num_labels == 1:
1428
- self.config.problem_type = "regression"
1429
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1430
- self.config.problem_type = "single_label_classification"
1431
- else:
1432
- self.config.problem_type = "multi_label_classification"
1433
-
1434
- if self.config.problem_type == "regression":
1435
- loss_fct = MSELoss()
1436
- if self.num_labels == 1:
1437
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1438
- else:
1439
- loss = loss_fct(pooled_logits, labels)
1440
- elif self.config.problem_type == "single_label_classification":
1441
- loss_fct = CrossEntropyLoss()
1442
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1443
- elif self.config.problem_type == "multi_label_classification":
1444
- loss_fct = BCEWithLogitsLoss()
1445
- loss = loss_fct(pooled_logits, labels)
1446
- if not return_dict:
1447
- output = (pooled_logits,) + transformer_outputs[1:]
1448
- return ((loss,) + output) if loss is not None else output
1449
-
1450
- return SequenceClassifierOutputWithPast(
1451
- loss=loss,
1452
- logits=pooled_logits,
1453
- past_key_values=transformer_outputs.past_key_values,
1454
- hidden_states=transformer_outputs.hidden_states,
1455
- attentions=transformer_outputs.attentions,
1456
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/phi/__init__.py DELETED
@@ -1,69 +0,0 @@
1
- # Copyright 2023 Microsoft and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import TYPE_CHECKING
17
-
18
- from transformers.utils import (
19
- OptionalDependencyNotAvailable,
20
- _LazyModule,
21
- is_sentencepiece_available,
22
- is_tokenizers_available,
23
- is_torch_available,
24
- )
25
-
26
-
27
- _import_structure = {
28
- "configuration_phi": ["PHI_PRETRAINED_CONFIG_ARCHIVE_MAP", "PhiConfig"],
29
- }
30
-
31
- try:
32
- if not is_torch_available():
33
- raise OptionalDependencyNotAvailable()
34
- except OptionalDependencyNotAvailable:
35
- pass
36
- else:
37
- _import_structure["modeling_phi"] = [
38
- "PHI_PRETRAINED_MODEL_ARCHIVE_LIST",
39
- "PhiPreTrainedModel",
40
- "PhiModel",
41
- "PhiForCausalLM",
42
- "PhiForSequenceClassification",
43
- "PhiForTokenClassification",
44
- ]
45
-
46
-
47
- if TYPE_CHECKING:
48
- from .configuration_phi import PHI_PRETRAINED_CONFIG_ARCHIVE_MAP, PhiConfig
49
-
50
- try:
51
- if not is_torch_available():
52
- raise OptionalDependencyNotAvailable()
53
- except OptionalDependencyNotAvailable:
54
- pass
55
- else:
56
- from .modeling_phi import (
57
- PHI_PRETRAINED_MODEL_ARCHIVE_LIST,
58
- PhiForCausalLM,
59
- PhiForSequenceClassification,
60
- PhiForTokenClassification,
61
- PhiModel,
62
- PhiPreTrainedModel,
63
- )
64
-
65
-
66
- else:
67
- import sys
68
-
69
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/phi/configuration_phi.py DELETED
@@ -1,195 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """ Phi model configuration"""
17
-
18
-
19
- from transformers.configuration_utils import PretrainedConfig
20
- from transformers.utils import logging
21
-
22
-
23
- logger = logging.get_logger(__name__)
24
-
25
- PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
- "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json",
27
- "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json",
28
- "microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
29
- }
30
-
31
-
32
- class PhiConfig(PretrainedConfig):
33
- r"""
34
- This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
35
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
36
- defaults will yield a similar configuration to that of the Phi
37
- [microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
38
-
39
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
- documentation from [`PretrainedConfig`] for more information.
41
-
42
- Args:
43
- vocab_size (`int`, *optional*, defaults to 51200):
44
- Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
45
- `inputs_ids` passed when calling [`PhiModel`].
46
- hidden_size (`int`, *optional*, defaults to 2048):
47
- Dimension of the hidden representations.
48
- intermediate_size (`int`, *optional*, defaults to 8192):
49
- Dimension of the MLP representations.
50
- num_hidden_layers (`int`, *optional*, defaults to 24):
51
- Number of hidden layers in the Transformer decoder.
52
- num_attention_heads (`int`, *optional*, defaults to 32):
53
- Number of attention heads for each attention layer in the Transformer decoder.
54
- num_key_value_heads (`int`, *optional*):
55
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
56
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
57
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
58
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
59
- by meanpooling all the original heads within that group. For more details checkout [this
60
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
61
- `num_attention_heads`.
62
- resid_pdrop (`float`, *optional*, defaults to 0.0):
63
- Dropout probability for mlp outputs.
64
- embd_pdrop (`int`, *optional*, defaults to 0.0):
65
- The dropout ratio for the embeddings.
66
- attention_dropout (`float`, *optional*, defaults to 0.0):
67
- The dropout ratio after computing the attention scores.
68
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
69
- The non-linear activation function (function or string) in the decoder.
70
- max_position_embeddings (`int`, *optional*, defaults to 2048):
71
- The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
72
- tokens.
73
- initializer_range (`float`, *optional*, defaults to 0.02):
74
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
- layer_norm_eps (`float`, *optional*, defaults to 1e-05):
76
- The epsilon used by the rms normalization layers.
77
- use_cache (`bool`, *optional*, defaults to `True`):
78
- Whether or not the model should return the last key/values attentions (not used by all models). Only
79
- relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
80
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
81
- Whether to tie weight embeddings
82
- rope_theta (`float`, *optional*, defaults to 10000.0):
83
- The base period of the RoPE embeddings.
84
- rope_scaling (`Dict`, *optional*):
85
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
86
- strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
87
- is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
88
- `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
89
- these scaling strategies behave:
90
- https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
91
- is an experimental feature, subject to breaking API changes in future versions.
92
- partial_rotary_factor (`float`, *optional*, defaults to 0.5):
93
- Percentage of the query and keys which will have rotary embedding.
94
- qk_layernorm (`bool`, *optional*, defaults to `False`):
95
- Whether or not to normalize the Queries and Keys after projecting the hidden states.
96
- bos_token_id (`int`, *optional*, defaults to 1):
97
- Denotes beginning of sequences token id.
98
- eos_token_id (`int`, *optional*, defaults to 2):
99
- Denotes end of sequences token id.
100
-
101
- Example:
102
-
103
- ```python
104
- >>> from transformers import PhiModel, PhiConfig
105
-
106
- >>> # Initializing a Phi-1 style configuration
107
- >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
108
-
109
- >>> # Initializing a model from the configuration
110
- >>> model = PhiModel(configuration)
111
-
112
- >>> # Accessing the model configuration
113
- >>> configuration = model.config
114
- ```"""
115
-
116
- model_type = "phi"
117
- keys_to_ignore_at_inference = ["past_key_values"]
118
-
119
- def __init__(
120
- self,
121
- vocab_size=51200,
122
- hidden_size=2048,
123
- intermediate_size=8192,
124
- num_hidden_layers=24,
125
- num_attention_heads=32,
126
- num_key_value_heads=None,
127
- resid_pdrop=0.0,
128
- embd_pdrop=0.0,
129
- attention_dropout=0.0,
130
- hidden_act="gelu_new",
131
- max_position_embeddings=2048,
132
- initializer_range=0.02,
133
- layer_norm_eps=1e-5,
134
- use_cache=True,
135
- tie_word_embeddings=False,
136
- rope_theta=10000.0,
137
- rope_scaling=None,
138
- partial_rotary_factor=0.5,
139
- qk_layernorm=False,
140
- bos_token_id=1,
141
- eos_token_id=2,
142
- **kwargs,
143
- ):
144
- self.vocab_size = vocab_size
145
- self.hidden_size = hidden_size
146
- self.intermediate_size = intermediate_size
147
- self.num_hidden_layers = num_hidden_layers
148
- self.num_attention_heads = num_attention_heads
149
-
150
- if num_key_value_heads is None:
151
- num_key_value_heads = num_attention_heads
152
-
153
- self.num_key_value_heads = num_key_value_heads
154
- self.resid_pdrop = resid_pdrop
155
- self.embd_pdrop = embd_pdrop
156
- self.attention_dropout = attention_dropout
157
- self.hidden_act = hidden_act
158
- self.max_position_embeddings = max_position_embeddings
159
- self.initializer_range = initializer_range
160
- self.layer_norm_eps = layer_norm_eps
161
- self.use_cache = use_cache
162
- self.rope_theta = rope_theta
163
- self.rope_scaling = rope_scaling
164
- self.partial_rotary_factor = partial_rotary_factor
165
- self.qk_layernorm = qk_layernorm
166
- self._rope_scaling_validation()
167
-
168
- super().__init__(
169
- bos_token_id=bos_token_id,
170
- eos_token_id=eos_token_id,
171
- tie_word_embeddings=tie_word_embeddings,
172
- **kwargs,
173
- )
174
-
175
- # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
176
- def _rope_scaling_validation(self):
177
- """
178
- Validate the `rope_scaling` configuration.
179
- """
180
- if self.rope_scaling is None:
181
- return
182
-
183
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
184
- raise ValueError(
185
- "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
186
- f"got {self.rope_scaling}"
187
- )
188
- rope_scaling_type = self.rope_scaling.get("type", None)
189
- rope_scaling_factor = self.rope_scaling.get("factor", None)
190
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
191
- raise ValueError(
192
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
193
- )
194
- if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
195
- raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/phi/modeling_phi.py DELETED
@@ -1,1374 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """ PyTorch Phi model."""
17
-
18
-
19
- import math
20
- from typing import List, Optional, Tuple, Union
21
-
22
- import torch
23
- import torch.nn.functional as F
24
- import torch.utils.checkpoint
25
- from torch import nn
26
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
-
28
- from transformers.activations import ACT2FN
29
- from transformers.cache_utils import Cache, DynamicCache
30
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
31
- from transformers.modeling_outputs import (
32
- BaseModelOutputWithPast,
33
- CausalLMOutputWithPast,
34
- SequenceClassifierOutputWithPast,
35
- TokenClassifierOutput,
36
- )
37
- from transformers.modeling_utils import PreTrainedModel
38
- from transformers.utils import (
39
- add_code_sample_docstrings,
40
- add_start_docstrings,
41
- add_start_docstrings_to_model_forward,
42
- is_flash_attn_2_available,
43
- is_flash_attn_greater_or_equal_2_10,
44
- logging,
45
- replace_return_docstrings,
46
- )
47
- from .configuration_phi import PhiConfig
48
-
49
-
50
- if is_flash_attn_2_available():
51
- from flash_attn import flash_attn_func, flash_attn_varlen_func
52
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
53
-
54
-
55
- logger = logging.get_logger(__name__)
56
-
57
- _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
58
- _CONFIG_FOR_DOC = "PhiConfig"
59
-
60
- PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
61
- "microsoft/phi-1",
62
- "microsoft/phi-1_5",
63
- "microsoft/phi-2",
64
- # See all Phi models at https://huggingface.co/models?filter=phi
65
- ]
66
-
67
-
68
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
69
- def _get_unpad_data(attention_mask):
70
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
71
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
72
- max_seqlen_in_batch = seqlens_in_batch.max().item()
73
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
74
- return (
75
- indices,
76
- cu_seqlens,
77
- max_seqlen_in_batch,
78
- )
79
-
80
-
81
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi
82
- class PhiRotaryEmbedding(nn.Module):
83
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
84
- super().__init__()
85
-
86
- self.dim = dim
87
- self.max_position_embeddings = max_position_embeddings
88
- self.base = base
89
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
90
- self.register_buffer("inv_freq", inv_freq, persistent=False)
91
-
92
- # Build here to make `torch.jit.trace` work.
93
- self._set_cos_sin_cache(
94
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
95
- )
96
-
97
- def _set_cos_sin_cache(self, seq_len, device, dtype):
98
- self.max_seq_len_cached = seq_len
99
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
100
-
101
- freqs = torch.outer(t, self.inv_freq)
102
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
103
- emb = torch.cat((freqs, freqs), dim=-1)
104
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
105
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
106
-
107
- def forward(self, x, seq_len=None):
108
- # x: [bs, num_attention_heads, seq_len, head_size]
109
- if seq_len > self.max_seq_len_cached:
110
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
111
-
112
- return (
113
- self.cos_cached[:seq_len].to(dtype=x.dtype),
114
- self.sin_cached[:seq_len].to(dtype=x.dtype),
115
- )
116
-
117
-
118
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
119
- class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
120
- """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
121
-
122
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
123
- self.scaling_factor = scaling_factor
124
- super().__init__(dim, max_position_embeddings, base, device)
125
-
126
- def _set_cos_sin_cache(self, seq_len, device, dtype):
127
- self.max_seq_len_cached = seq_len
128
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
129
- t = t / self.scaling_factor
130
-
131
- freqs = torch.outer(t, self.inv_freq)
132
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
133
- emb = torch.cat((freqs, freqs), dim=-1)
134
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
135
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
136
-
137
-
138
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
139
- class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
140
- """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141
-
142
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
143
- self.scaling_factor = scaling_factor
144
- super().__init__(dim, max_position_embeddings, base, device)
145
-
146
- def _set_cos_sin_cache(self, seq_len, device, dtype):
147
- self.max_seq_len_cached = seq_len
148
-
149
- if seq_len > self.max_position_embeddings:
150
- base = self.base * (
151
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
152
- ) ** (self.dim / (self.dim - 2))
153
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
154
- self.register_buffer("inv_freq", inv_freq, persistent=False)
155
-
156
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
157
-
158
- freqs = torch.outer(t, self.inv_freq)
159
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
160
- emb = torch.cat((freqs, freqs), dim=-1)
161
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
162
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
163
-
164
-
165
- # Copied from transformers.models.llama.modeling_llama.rotate_half
166
- def rotate_half(x):
167
- """Rotates half the hidden dims of the input."""
168
- x1 = x[..., : x.shape[-1] // 2]
169
- x2 = x[..., x.shape[-1] // 2 :]
170
- return torch.cat((-x2, x1), dim=-1)
171
-
172
-
173
- # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
174
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
175
- """Applies Rotary Position Embedding to the query and key tensors.
176
-
177
- Args:
178
- q (`torch.Tensor`): The query tensor.
179
- k (`torch.Tensor`): The key tensor.
180
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
181
- sin (`torch.Tensor`): The sine part of the rotary embedding.
182
- position_ids (`torch.Tensor`):
183
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
184
- used to pass offsetted position ids when working with a KV-cache.
185
- unsqueeze_dim (`int`, *optional*, defaults to 1):
186
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
187
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
188
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
189
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
190
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
191
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
192
- Returns:
193
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
194
- """
195
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
196
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
197
- q_embed = (q * cos) + (rotate_half(q) * sin)
198
- k_embed = (k * cos) + (rotate_half(k) * sin)
199
- return q_embed, k_embed
200
-
201
-
202
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
203
- class PhiMLP(nn.Module):
204
- def __init__(self, config):
205
- super().__init__()
206
- self.config = config
207
- self.activation_fn = ACT2FN[config.hidden_act]
208
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
209
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
210
-
211
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
212
- hidden_states = self.fc1(hidden_states)
213
- hidden_states = self.activation_fn(hidden_states)
214
- hidden_states = self.fc2(hidden_states)
215
- return hidden_states
216
-
217
-
218
- # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
219
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
220
- """
221
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
222
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
223
- """
224
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
225
- if n_rep == 1:
226
- return hidden_states
227
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
228
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
229
-
230
-
231
- class PhiAttention(nn.Module):
232
- """Multi-headed attention from 'Attention Is All You Need' paper"""
233
-
234
- def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
235
- super().__init__()
236
- self.config = config
237
- self.layer_idx = layer_idx
238
- if layer_idx is None:
239
- logger.warning_once(
240
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
241
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
242
- "when creating this class."
243
- )
244
-
245
- self.attention_dropout = config.attention_dropout
246
- self.hidden_size = config.hidden_size
247
- self.num_heads = config.num_attention_heads
248
- self.head_dim = self.hidden_size // self.num_heads
249
- self.num_key_value_heads = config.num_key_value_heads
250
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
251
- self.max_position_embeddings = config.max_position_embeddings
252
- self.rope_theta = config.rope_theta
253
- self.partial_rotary_factor = config.partial_rotary_factor
254
- self.is_causal = True
255
-
256
- if (self.head_dim * self.num_heads) != self.hidden_size:
257
- raise ValueError(
258
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
259
- f" and `num_heads`: {self.num_heads})."
260
- )
261
-
262
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
263
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
264
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
265
- self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
266
-
267
- self.qk_layernorm = config.qk_layernorm
268
- if self.qk_layernorm:
269
- self.q_layernorm = nn.LayerNorm(
270
- config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
271
- )
272
- self.k_layernorm = nn.LayerNorm(
273
- config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
274
- )
275
-
276
- self._init_rope()
277
-
278
- def _init_rope(self):
279
- if self.config.rope_scaling is None:
280
- self.rotary_emb = PhiRotaryEmbedding(
281
- int(self.partial_rotary_factor * self.head_dim),
282
- max_position_embeddings=self.max_position_embeddings,
283
- base=self.rope_theta,
284
- )
285
- else:
286
- scaling_type = self.config.rope_scaling["type"]
287
- scaling_factor = self.config.rope_scaling["factor"]
288
- if scaling_type == "linear":
289
- self.rotary_emb = PhiLinearScalingRotaryEmbedding(
290
- int(self.partial_rotary_factor * self.head_dim),
291
- max_position_embeddings=self.max_position_embeddings,
292
- scaling_factor=scaling_factor,
293
- base=self.rope_theta,
294
- )
295
- elif scaling_type == "dynamic":
296
- self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
297
- int(self.partial_rotary_factor * self.head_dim),
298
- max_position_embeddings=self.max_position_embeddings,
299
- scaling_factor=scaling_factor,
300
- base=self.rope_theta,
301
- )
302
- else:
303
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
304
-
305
- def forward(
306
- self,
307
- hidden_states: torch.Tensor,
308
- attention_mask: Optional[torch.Tensor] = None,
309
- position_ids: Optional[torch.LongTensor] = None,
310
- past_key_value: Optional[Cache] = None,
311
- output_attentions: bool = False,
312
- use_cache: bool = False,
313
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
314
- bsz, q_len, _ = hidden_states.size()
315
-
316
- query_states = self.q_proj(hidden_states)
317
- key_states = self.k_proj(hidden_states)
318
- value_states = self.v_proj(hidden_states)
319
-
320
- if self.qk_layernorm:
321
- query_states = self.q_layernorm(query_states)
322
- key_states = self.k_layernorm(key_states)
323
-
324
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
325
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
326
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
327
-
328
- kv_seq_len = key_states.shape[-2]
329
- if past_key_value is not None:
330
- if self.layer_idx is None:
331
- raise ValueError(
332
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
333
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
334
- "with a layer index."
335
- )
336
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
337
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
338
-
339
- # Partial rotary embedding
340
- query_rot, query_pass = (
341
- query_states[..., : self.rotary_emb.dim],
342
- query_states[..., self.rotary_emb.dim :],
343
- )
344
- key_rot, key_pass = (
345
- key_states[..., : self.rotary_emb.dim],
346
- key_states[..., self.rotary_emb.dim :],
347
- )
348
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
349
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
350
-
351
- # [batch_size, seq_length, num_heads, head_dim]
352
- query_states = torch.cat((query_rot, query_pass), dim=-1)
353
- key_states = torch.cat((key_rot, key_pass), dim=-1)
354
-
355
- if past_key_value is not None:
356
- cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
357
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
358
-
359
- key_states = repeat_kv(key_states, self.num_key_value_groups)
360
- value_states = repeat_kv(value_states, self.num_key_value_groups)
361
-
362
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
363
- attn_weights = torch.matmul(
364
- query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
365
- ) / math.sqrt(self.head_dim)
366
-
367
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
368
- raise ValueError(
369
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
370
- f" {attn_weights.size()}"
371
- )
372
-
373
- if attention_mask is not None:
374
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
375
- raise ValueError(
376
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
377
- )
378
- attn_weights = attn_weights + attention_mask
379
-
380
- # upcast attention to fp32
381
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
382
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
383
-
384
- attn_output = torch.matmul(attn_weights, value_states)
385
-
386
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
387
- raise ValueError(
388
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
389
- f" {attn_output.size()}"
390
- )
391
-
392
- attn_output = attn_output.transpose(1, 2).contiguous()
393
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
394
-
395
- attn_output = self.dense(attn_output)
396
-
397
- if not output_attentions:
398
- attn_weights = None
399
-
400
- return attn_output, attn_weights, past_key_value
401
-
402
-
403
- class PhiFlashAttention2(PhiAttention):
404
- """
405
- Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
406
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
407
- flash attention and deal with padding tokens in case the input contains any of them.
408
- """
409
-
410
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
411
- def __init__(self, *args, **kwargs):
412
- super().__init__(*args, **kwargs)
413
-
414
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
415
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
416
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
417
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
418
-
419
- def forward(
420
- self,
421
- hidden_states: torch.Tensor,
422
- attention_mask: Optional[torch.LongTensor] = None,
423
- position_ids: Optional[torch.LongTensor] = None,
424
- past_key_value: Optional[Cache] = None,
425
- output_attentions: bool = False,
426
- use_cache: bool = False,
427
- **kwargs,
428
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
429
- # PhiFlashAttention2 attention does not support output_attentions
430
-
431
- output_attentions = False
432
-
433
- bsz, q_len, _ = hidden_states.size()
434
-
435
- query_states = self.q_proj(hidden_states)
436
- key_states = self.k_proj(hidden_states)
437
- value_states = self.v_proj(hidden_states)
438
-
439
- if self.qk_layernorm:
440
- query_states = self.q_layernorm(query_states)
441
- key_states = self.k_layernorm(key_states)
442
-
443
- # Flash attention requires the input to have the shape
444
- # batch_size x seq_length x head_dim x hidden_dim
445
- # therefore we just need to keep the original shape
446
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
447
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
449
-
450
- kv_seq_len = key_states.shape[-2]
451
- if past_key_value is not None:
452
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
453
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
454
-
455
- # Partial rotary embedding
456
- query_rot, query_pass = (
457
- query_states[..., : self.rotary_emb.dim],
458
- query_states[..., self.rotary_emb.dim :],
459
- )
460
- key_rot, key_pass = (
461
- key_states[..., : self.rotary_emb.dim],
462
- key_states[..., self.rotary_emb.dim :],
463
- )
464
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
465
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
466
-
467
- # [batch_size, seq_length, num_heads, head_dim]
468
- query_states = torch.cat((query_rot, query_pass), dim=-1)
469
- key_states = torch.cat((key_rot, key_pass), dim=-1)
470
-
471
- if past_key_value is not None:
472
- cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
473
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
474
-
475
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
476
- # to be able to avoid many of these transpose/reshape/view.
477
- query_states = query_states.transpose(1, 2)
478
- key_states = key_states.transpose(1, 2)
479
- value_states = value_states.transpose(1, 2)
480
-
481
- attn_dropout = self.attention_dropout if self.training else 0.0
482
-
483
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
484
- # therefore the input hidden states gets silently casted in float32. Hence, we need
485
- # cast them back in the correct dtype just to be sure everything works as expected.
486
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
487
- # in fp32.
488
-
489
- if query_states.dtype == torch.float32:
490
- if torch.is_autocast_enabled():
491
- target_dtype = torch.get_autocast_gpu_dtype()
492
- # Handle the case where the model is quantized
493
- elif hasattr(self.config, "_pre_quantization_dtype"):
494
- target_dtype = self.config._pre_quantization_dtype
495
- else:
496
- target_dtype = self.q_proj.weight.dtype
497
-
498
- logger.warning_once(
499
- f"The input hidden states seems to be silently casted in float32, this might be related to"
500
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
501
- f" {target_dtype}."
502
- )
503
-
504
- query_states = query_states.to(target_dtype)
505
- key_states = key_states.to(target_dtype)
506
- value_states = value_states.to(target_dtype)
507
-
508
- attn_output = self._flash_attention_forward(
509
- query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
510
- )
511
-
512
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
513
- attn_output = self.dense(attn_output)
514
-
515
- if not output_attentions:
516
- attn_weights = None
517
-
518
- return attn_output, attn_weights, past_key_value
519
-
520
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
521
- def _flash_attention_forward(
522
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
523
- ):
524
- """
525
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
526
- first unpad the input, then computes the attention scores and pad the final attention scores.
527
-
528
- Args:
529
- query_states (`torch.Tensor`):
530
- Input query states to be passed to Flash Attention API
531
- key_states (`torch.Tensor`):
532
- Input key states to be passed to Flash Attention API
533
- value_states (`torch.Tensor`):
534
- Input value states to be passed to Flash Attention API
535
- attention_mask (`torch.Tensor`):
536
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
537
- position of padding tokens and 1 for the position of non-padding tokens.
538
- dropout (`int`, *optional*):
539
- Attention dropout
540
- softmax_scale (`float`, *optional*):
541
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
542
- """
543
- if not self._flash_attn_uses_top_left_mask:
544
- causal = self.is_causal
545
- else:
546
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
547
- causal = self.is_causal and query_length != 1
548
-
549
- # Contains at least one padding token in the sequence
550
- if attention_mask is not None:
551
- batch_size = query_states.shape[0]
552
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
553
- query_states, key_states, value_states, attention_mask, query_length
554
- )
555
-
556
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
557
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
558
-
559
- attn_output_unpad = flash_attn_varlen_func(
560
- query_states,
561
- key_states,
562
- value_states,
563
- cu_seqlens_q=cu_seqlens_q,
564
- cu_seqlens_k=cu_seqlens_k,
565
- max_seqlen_q=max_seqlen_in_batch_q,
566
- max_seqlen_k=max_seqlen_in_batch_k,
567
- dropout_p=dropout,
568
- softmax_scale=softmax_scale,
569
- causal=causal,
570
- )
571
-
572
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
573
- else:
574
- attn_output = flash_attn_func(
575
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
576
- )
577
-
578
- return attn_output
579
-
580
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
581
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
582
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
583
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
584
-
585
- key_layer = index_first_axis(
586
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
587
- )
588
- value_layer = index_first_axis(
589
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
590
- )
591
- if query_length == kv_seq_len:
592
- query_layer = index_first_axis(
593
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
594
- )
595
- cu_seqlens_q = cu_seqlens_k
596
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
597
- indices_q = indices_k
598
- elif query_length == 1:
599
- max_seqlen_in_batch_q = 1
600
- cu_seqlens_q = torch.arange(
601
- batch_size + 1, dtype=torch.int32, device=query_layer.device
602
- ) # There is a memcpy here, that is very bad.
603
- indices_q = cu_seqlens_q[:-1]
604
- query_layer = query_layer.squeeze(1)
605
- else:
606
- # The -q_len: slice assumes left padding.
607
- attention_mask = attention_mask[:, -query_length:]
608
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
609
-
610
- return (
611
- query_layer,
612
- key_layer,
613
- value_layer,
614
- indices_q,
615
- (cu_seqlens_q, cu_seqlens_k),
616
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
617
- )
618
-
619
-
620
- PHI_ATTENTION_CLASSES = {
621
- "eager": PhiAttention,
622
- "flash_attention_2": PhiFlashAttention2,
623
- }
624
-
625
-
626
- class PhiDecoderLayer(nn.Module):
627
- def __init__(self, config: PhiConfig, layer_idx: int):
628
- super().__init__()
629
- self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
630
- self.mlp = PhiMLP(config)
631
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
632
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
633
-
634
- def forward(
635
- self,
636
- hidden_states: torch.Tensor,
637
- attention_mask: Optional[torch.Tensor] = None,
638
- position_ids: Optional[torch.LongTensor] = None,
639
- output_attentions: Optional[bool] = False,
640
- use_cache: Optional[bool] = False,
641
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
642
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
643
- """
644
- Args:
645
- hidden_states (`torch.FloatTensor`):
646
- input to the layer of shape `(batch, seq_len, embed_dim)`
647
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
648
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
649
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
650
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
651
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
652
- output_attentions (`bool`, *optional*):
653
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
654
- returned tensors for more detail.
655
- use_cache (`bool`, *optional*):
656
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
657
- (see `past_key_values`).
658
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
659
- """
660
-
661
- residual = hidden_states
662
-
663
- hidden_states = self.input_layernorm(hidden_states)
664
-
665
- # Self Attention
666
- attn_outputs, self_attn_weights, present_key_value = self.self_attn(
667
- hidden_states=hidden_states,
668
- attention_mask=attention_mask,
669
- position_ids=position_ids,
670
- past_key_value=past_key_value,
671
- output_attentions=output_attentions,
672
- use_cache=use_cache,
673
- )
674
- attn_outputs = self.resid_dropout(attn_outputs)
675
-
676
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
677
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
678
- outputs = (hidden_states,)
679
-
680
- if output_attentions:
681
- outputs += (self_attn_weights,)
682
-
683
- if use_cache:
684
- outputs += (present_key_value,)
685
-
686
- return outputs
687
-
688
-
689
- PHI_START_DOCSTRING = r"""
690
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
691
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
692
- etc.)
693
-
694
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
695
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
696
- and behavior.
697
-
698
- Parameters:
699
- config ([`PhiConfig`]):
700
- Model configuration class with all the parameters of the model. Initializing with a config file does not
701
- load the weights associated with the model, only the configuration. Check out the
702
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
703
- """
704
-
705
-
706
- @add_start_docstrings(
707
- "The bare Phi Model outputting raw hidden-states without any specific head on top.",
708
- PHI_START_DOCSTRING,
709
- )
710
- class PhiPreTrainedModel(PreTrainedModel):
711
- config_class = PhiConfig
712
- base_model_prefix = "model"
713
- supports_gradient_checkpointing = True
714
- _no_split_modules = ["PhiDecoderLayer"]
715
- _skip_keys_device_placement = "past_key_values"
716
- _supports_flash_attn_2 = True
717
- _supports_cache_class = True
718
-
719
- def _init_weights(self, module):
720
- std = self.config.initializer_range
721
- if isinstance(module, nn.Linear):
722
- module.weight.data.normal_(mean=0.0, std=std)
723
- if module.bias is not None:
724
- module.bias.data.zero_()
725
- elif isinstance(module, nn.Embedding):
726
- module.weight.data.normal_(mean=0.0, std=std)
727
- if module.padding_idx is not None:
728
- module.weight.data[module.padding_idx].zero_()
729
-
730
-
731
- PHI_INPUTS_DOCSTRING = r"""
732
- Args:
733
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
734
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
735
- it.
736
-
737
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
738
- [`PreTrainedTokenizer.__call__`] for details.
739
-
740
- [What are input IDs?](../glossary#input-ids)
741
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
742
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
743
-
744
- - 1 for tokens that are **not masked**,
745
- - 0 for tokens that are **masked**.
746
-
747
- [What are attention masks?](../glossary#attention-mask)
748
-
749
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
750
- [`PreTrainedTokenizer.__call__`] for details.
751
-
752
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
753
- `past_key_values`).
754
-
755
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
756
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
757
- information on the default strategy.
758
-
759
- - 1 indicates the head is **not masked**,
760
- - 0 indicates the head is **masked**.
761
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
762
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
763
- config.n_positions - 1]`.
764
-
765
- [What are position IDs?](../glossary#position-ids)
766
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
767
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
768
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
769
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
770
-
771
- Two formats are allowed:
772
- - a [`~cache_utils.Cache`] instance;
773
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
774
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
775
- cache format.
776
-
777
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
778
- legacy cache format will be returned.
779
-
780
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
781
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
782
- of shape `(batch_size, sequence_length)`.
783
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
784
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
785
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
786
- model's internal embedding lookup matrix.
787
- use_cache (`bool`, *optional*):
788
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
789
- `past_key_values`).
790
- output_attentions (`bool`, *optional*):
791
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
792
- tensors for more detail.
793
- output_hidden_states (`bool`, *optional*):
794
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
795
- more detail.
796
- return_dict (`bool`, *optional*):
797
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
798
- """
799
-
800
-
801
- @add_start_docstrings(
802
- "The bare Phi Model outputting raw hidden-states without any specific head on top.",
803
- PHI_START_DOCSTRING,
804
- )
805
- class PhiModel(PhiPreTrainedModel):
806
- """
807
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
808
-
809
- Args:
810
- config: PhiConfig
811
- """
812
-
813
- def __init__(self, config: PhiConfig):
814
- super().__init__(config)
815
- self.padding_idx = config.pad_token_id
816
- self.vocab_size = config.vocab_size
817
-
818
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
819
- self.embed_dropout = nn.Dropout(config.embd_pdrop)
820
- self.layers = nn.ModuleList(
821
- [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
822
- )
823
- self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
824
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
825
-
826
- self.gradient_checkpointing = False
827
- # Initialize weights and apply final processing
828
- self.post_init()
829
-
830
- def get_input_embeddings(self):
831
- return self.embed_tokens
832
-
833
- def set_input_embeddings(self, value):
834
- self.embed_tokens = value
835
-
836
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
837
- def forward(
838
- self,
839
- input_ids: torch.LongTensor = None,
840
- attention_mask: Optional[torch.Tensor] = None,
841
- position_ids: Optional[torch.LongTensor] = None,
842
- past_key_values: Optional[List[torch.FloatTensor]] = None,
843
- inputs_embeds: Optional[torch.FloatTensor] = None,
844
- use_cache: Optional[bool] = None,
845
- output_attentions: Optional[bool] = None,
846
- output_hidden_states: Optional[bool] = None,
847
- return_dict: Optional[bool] = None,
848
- ) -> Union[Tuple, BaseModelOutputWithPast]:
849
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
850
- output_hidden_states = (
851
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
852
- )
853
- use_cache = use_cache if use_cache is not None else self.config.use_cache
854
-
855
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
856
-
857
- # retrieve input_ids and inputs_embeds
858
- if input_ids is not None and inputs_embeds is not None:
859
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
860
- elif input_ids is not None:
861
- batch_size, seq_length = input_ids.shape[:2]
862
- elif inputs_embeds is not None:
863
- batch_size, seq_length = inputs_embeds.shape[:2]
864
- else:
865
- raise ValueError("You have to specify either input_ids or inputs_embeds")
866
-
867
- past_key_values_length = 0
868
-
869
- if self.gradient_checkpointing and self.training:
870
- if use_cache:
871
- logger.warning_once(
872
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
873
- )
874
- use_cache = False
875
-
876
- if use_cache:
877
- use_legacy_cache = not isinstance(past_key_values, Cache)
878
- if use_legacy_cache:
879
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
880
- past_key_values_length = past_key_values.get_usable_length(seq_length)
881
-
882
- if position_ids is None:
883
- device = input_ids.device if input_ids is not None else inputs_embeds.device
884
- position_ids = torch.arange(
885
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
886
- )
887
- position_ids = position_ids.unsqueeze(0)
888
-
889
- if inputs_embeds is None:
890
- inputs_embeds = self.embed_tokens(input_ids)
891
-
892
- inputs_embeds = self.embed_dropout(inputs_embeds)
893
-
894
- # Attention mask.
895
- if self._use_flash_attention_2:
896
- # 2d mask is passed through the layers
897
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
898
- else:
899
- # 4d mask is passed through the layers
900
- attention_mask = _prepare_4d_causal_attention_mask(
901
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
902
- )
903
-
904
- hidden_states = inputs_embeds
905
-
906
- # decoder layers
907
- all_hidden_states = () if output_hidden_states else None
908
- all_self_attns = () if output_attentions else None
909
- next_decoder_cache = None
910
-
911
- for decoder_layer in self.layers:
912
- if output_hidden_states:
913
- all_hidden_states += (hidden_states,)
914
-
915
- if self.gradient_checkpointing and self.training:
916
- layer_outputs = self._gradient_checkpointing_func(
917
- decoder_layer.__call__,
918
- hidden_states,
919
- attention_mask,
920
- position_ids,
921
- past_key_values,
922
- output_attentions,
923
- )
924
- else:
925
- layer_outputs = decoder_layer(
926
- hidden_states,
927
- attention_mask=attention_mask,
928
- position_ids=position_ids,
929
- past_key_value=past_key_values,
930
- output_attentions=output_attentions,
931
- use_cache=use_cache,
932
- )
933
-
934
- hidden_states = layer_outputs[0]
935
-
936
- if use_cache:
937
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
938
-
939
- if output_attentions:
940
- all_self_attns += (layer_outputs[1],)
941
-
942
- hidden_states = self.final_layernorm(hidden_states)
943
-
944
- # add hidden states from the last decoder layer
945
- if output_hidden_states:
946
- all_hidden_states += (hidden_states,)
947
-
948
- next_cache = None
949
- if use_cache:
950
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
951
- if not return_dict:
952
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
953
- return BaseModelOutputWithPast(
954
- last_hidden_state=hidden_states,
955
- past_key_values=next_cache,
956
- hidden_states=all_hidden_states,
957
- attentions=all_self_attns,
958
- )
959
-
960
-
961
- class PhiForCausalLM(PhiPreTrainedModel):
962
- _tied_weights_keys = ["lm_head.weight"]
963
-
964
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
965
- def __init__(self, config):
966
- super().__init__(config)
967
- self.model = PhiModel(config)
968
- self.vocab_size = config.vocab_size
969
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
970
-
971
- # Initialize weights and apply final processing
972
- self.post_init()
973
-
974
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
975
- def get_input_embeddings(self):
976
- return self.model.embed_tokens
977
-
978
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
979
- def set_input_embeddings(self, value):
980
- self.model.embed_tokens = value
981
-
982
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
983
- def get_output_embeddings(self):
984
- return self.lm_head
985
-
986
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
987
- def set_output_embeddings(self, new_embeddings):
988
- self.lm_head = new_embeddings
989
-
990
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
991
- def set_decoder(self, decoder):
992
- self.model = decoder
993
-
994
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
995
- def get_decoder(self):
996
- return self.model
997
-
998
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
999
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1000
- def forward(
1001
- self,
1002
- input_ids: torch.LongTensor = None,
1003
- attention_mask: Optional[torch.Tensor] = None,
1004
- position_ids: Optional[torch.LongTensor] = None,
1005
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1006
- inputs_embeds: Optional[torch.FloatTensor] = None,
1007
- labels: Optional[torch.LongTensor] = None,
1008
- use_cache: Optional[bool] = None,
1009
- output_attentions: Optional[bool] = None,
1010
- output_hidden_states: Optional[bool] = None,
1011
- return_dict: Optional[bool] = None,
1012
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1013
- r"""
1014
- Args:
1015
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1016
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1017
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1018
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1019
-
1020
- Returns:
1021
-
1022
- Example:
1023
-
1024
- ```python
1025
- >>> from transformers import AutoTokenizer, PhiForCausalLM
1026
-
1027
- >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1028
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1029
-
1030
- >>> prompt = "This is an example script ."
1031
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1032
-
1033
- >>> # Generate
1034
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1035
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1036
- 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1037
- ```"""
1038
-
1039
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1040
- output_hidden_states = (
1041
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1042
- )
1043
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1044
-
1045
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1046
- outputs = self.model(
1047
- input_ids=input_ids,
1048
- attention_mask=attention_mask,
1049
- position_ids=position_ids,
1050
- past_key_values=past_key_values,
1051
- inputs_embeds=inputs_embeds,
1052
- use_cache=use_cache,
1053
- output_attentions=output_attentions,
1054
- output_hidden_states=output_hidden_states,
1055
- return_dict=return_dict,
1056
- )
1057
-
1058
- hidden_states = outputs[0]
1059
- logits = self.lm_head(hidden_states)
1060
- logits = logits.float()
1061
-
1062
- loss = None
1063
- if labels is not None:
1064
- # Shift so that tokens < n predict n
1065
- shift_logits = logits[..., :-1, :].contiguous()
1066
- shift_labels = labels[..., 1:].contiguous()
1067
- # Flatten the tokens
1068
- loss_fct = CrossEntropyLoss()
1069
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1070
- shift_labels = shift_labels.view(-1)
1071
- # Enable model parallelism
1072
- shift_labels = shift_labels.to(shift_logits.device)
1073
- loss = loss_fct(shift_logits, shift_labels)
1074
-
1075
- if not return_dict:
1076
- output = (logits,) + outputs[1:]
1077
- return (loss,) + output if loss is not None else output
1078
-
1079
- return CausalLMOutputWithPast(
1080
- loss=loss,
1081
- logits=logits,
1082
- past_key_values=outputs.past_key_values,
1083
- hidden_states=outputs.hidden_states,
1084
- attentions=outputs.attentions,
1085
- )
1086
-
1087
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1088
- def prepare_inputs_for_generation(
1089
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1090
- ):
1091
- if past_key_values is not None:
1092
- if isinstance(past_key_values, Cache):
1093
- cache_length = past_key_values.get_seq_length()
1094
- past_length = past_key_values.seen_tokens
1095
- max_cache_length = past_key_values.get_max_length()
1096
- else:
1097
- cache_length = past_length = past_key_values[0][0].shape[2]
1098
- max_cache_length = None
1099
-
1100
- # Keep only the unprocessed tokens:
1101
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1102
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1103
- # input)
1104
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1105
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1106
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1107
- # input_ids based on the past_length.
1108
- elif past_length < input_ids.shape[1]:
1109
- input_ids = input_ids[:, past_length:]
1110
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1111
- else:
1112
- remove_prefix_length = input_ids.shape[1] - 1
1113
- input_ids = input_ids[:, remove_prefix_length:]
1114
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1115
- if (
1116
- max_cache_length is not None
1117
- and attention_mask is not None
1118
- and cache_length + input_ids.shape[1] > max_cache_length
1119
- ):
1120
- attention_mask = attention_mask[:, -max_cache_length:]
1121
-
1122
- position_ids = kwargs.get("position_ids", None)
1123
- if attention_mask is not None and position_ids is None:
1124
- # create position_ids on the fly for batch generation
1125
- position_ids = attention_mask.long().cumsum(-1) - 1
1126
- position_ids.masked_fill_(attention_mask == 0, 1)
1127
- if past_key_values:
1128
- position_ids = position_ids[:, -input_ids.shape[1] :]
1129
-
1130
- if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
1131
- # generation with static cache
1132
- seen_tokens = past_key_value.get_seq_length()
1133
- input_ids = input_ids[:, seen_tokens:]
1134
- position_ids = position_ids[:, seen_tokens:]
1135
-
1136
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1137
- if inputs_embeds is not None and past_key_values is None:
1138
- model_inputs = {"inputs_embeds": inputs_embeds}
1139
- else:
1140
- model_inputs = {"input_ids": input_ids}
1141
-
1142
- model_inputs.update(
1143
- {
1144
- "position_ids": position_ids,
1145
- "past_key_values": past_key_values,
1146
- "use_cache": kwargs.get("use_cache"),
1147
- "attention_mask": attention_mask,
1148
- }
1149
- )
1150
- return model_inputs
1151
-
1152
- @staticmethod
1153
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1154
- def _reorder_cache(past_key_values, beam_idx):
1155
- reordered_past = ()
1156
- for layer_past in past_key_values:
1157
- reordered_past += (
1158
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1159
- )
1160
- return reordered_past
1161
-
1162
-
1163
- @add_start_docstrings(
1164
- """
1165
- The PhiModel with a sequence classification head on top (linear layer).
1166
-
1167
- [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1168
- (e.g. GPT-2) do.
1169
-
1170
- Since it does classification on the last token, it requires to know the position of the last token. If a
1171
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1172
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1173
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1174
- each row of the batch).
1175
- """,
1176
- PHI_START_DOCSTRING,
1177
- )
1178
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs
1179
- class PhiForSequenceClassification(PhiPreTrainedModel):
1180
- def __init__(self, config):
1181
- super().__init__(config)
1182
- self.num_labels = config.num_labels
1183
- self.model = PhiModel(config)
1184
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1185
-
1186
- # Initialize weights and apply final processing
1187
- self.post_init()
1188
-
1189
- def get_input_embeddings(self):
1190
- return self.model.embed_tokens
1191
-
1192
- def set_input_embeddings(self, value):
1193
- self.model.embed_tokens = value
1194
-
1195
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1196
- def forward(
1197
- self,
1198
- input_ids: torch.LongTensor = None,
1199
- attention_mask: Optional[torch.Tensor] = None,
1200
- position_ids: Optional[torch.LongTensor] = None,
1201
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1202
- inputs_embeds: Optional[torch.FloatTensor] = None,
1203
- labels: Optional[torch.LongTensor] = None,
1204
- use_cache: Optional[bool] = None,
1205
- output_attentions: Optional[bool] = None,
1206
- output_hidden_states: Optional[bool] = None,
1207
- return_dict: Optional[bool] = None,
1208
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1209
- r"""
1210
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1211
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1212
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1213
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1214
- """
1215
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1216
-
1217
- model_outputs = self.model(
1218
- input_ids,
1219
- attention_mask=attention_mask,
1220
- position_ids=position_ids,
1221
- past_key_values=past_key_values,
1222
- inputs_embeds=inputs_embeds,
1223
- use_cache=use_cache,
1224
- output_attentions=output_attentions,
1225
- output_hidden_states=output_hidden_states,
1226
- return_dict=return_dict,
1227
- )
1228
- hidden_states = model_outputs[0]
1229
- logits = self.score(hidden_states)
1230
-
1231
- if input_ids is not None:
1232
- batch_size = input_ids.shape[0]
1233
- else:
1234
- batch_size = inputs_embeds.shape[0]
1235
-
1236
- if self.config.pad_token_id is None and batch_size != 1:
1237
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1238
- if self.config.pad_token_id is None:
1239
- sequence_lengths = -1
1240
- else:
1241
- if input_ids is not None:
1242
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1243
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1244
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1245
- sequence_lengths = sequence_lengths.to(logits.device)
1246
- else:
1247
- sequence_lengths = -1
1248
-
1249
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1250
-
1251
- loss = None
1252
- if labels is not None:
1253
- labels = labels.to(logits.device)
1254
- if self.config.problem_type is None:
1255
- if self.num_labels == 1:
1256
- self.config.problem_type = "regression"
1257
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1258
- self.config.problem_type = "single_label_classification"
1259
- else:
1260
- self.config.problem_type = "multi_label_classification"
1261
-
1262
- if self.config.problem_type == "regression":
1263
- loss_fct = MSELoss()
1264
- if self.num_labels == 1:
1265
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1266
- else:
1267
- loss = loss_fct(pooled_logits, labels)
1268
- elif self.config.problem_type == "single_label_classification":
1269
- loss_fct = CrossEntropyLoss()
1270
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1271
- elif self.config.problem_type == "multi_label_classification":
1272
- loss_fct = BCEWithLogitsLoss()
1273
- loss = loss_fct(pooled_logits, labels)
1274
- if not return_dict:
1275
- output = (pooled_logits,) + model_outputs[1:]
1276
- return ((loss,) + output) if loss is not None else output
1277
-
1278
- return SequenceClassifierOutputWithPast(
1279
- loss=loss,
1280
- logits=pooled_logits,
1281
- past_key_values=model_outputs.past_key_values,
1282
- hidden_states=model_outputs.hidden_states,
1283
- attentions=model_outputs.attentions,
1284
- )
1285
-
1286
-
1287
- @add_start_docstrings(
1288
- """
1289
- PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1290
- Named-Entity-Recognition (NER) tasks.
1291
- """,
1292
- PHI_START_DOCSTRING,
1293
- )
1294
- # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs
1295
- class PhiForTokenClassification(PhiPreTrainedModel):
1296
- def __init__(self, config: PhiConfig):
1297
- super().__init__(config)
1298
- self.num_labels = config.num_labels
1299
-
1300
- self.model = PhiModel(config)
1301
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1302
- classifier_dropout = config.classifier_dropout
1303
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1304
- classifier_dropout = config.hidden_dropout
1305
- else:
1306
- classifier_dropout = 0.1
1307
- self.dropout = nn.Dropout(classifier_dropout)
1308
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1309
-
1310
- # Initialize weights and apply final processing
1311
- self.post_init()
1312
-
1313
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1314
- @add_code_sample_docstrings(
1315
- checkpoint=_CHECKPOINT_FOR_DOC,
1316
- output_type=TokenClassifierOutput,
1317
- config_class=_CONFIG_FOR_DOC,
1318
- )
1319
- def forward(
1320
- self,
1321
- input_ids: Optional[torch.LongTensor] = None,
1322
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1323
- attention_mask: Optional[torch.Tensor] = None,
1324
- inputs_embeds: Optional[torch.Tensor] = None,
1325
- labels: Optional[torch.Tensor] = None,
1326
- use_cache: Optional[bool] = None,
1327
- output_attentions: Optional[bool] = None,
1328
- output_hidden_states: Optional[bool] = None,
1329
- return_dict: Optional[bool] = None,
1330
- **deprecated_arguments,
1331
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1332
- r"""
1333
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1334
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1335
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1336
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1337
- """
1338
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1339
-
1340
- model_outputs = self.model(
1341
- input_ids,
1342
- past_key_values=past_key_values,
1343
- attention_mask=attention_mask,
1344
- inputs_embeds=inputs_embeds,
1345
- use_cache=use_cache,
1346
- output_attentions=output_attentions,
1347
- output_hidden_states=output_hidden_states,
1348
- return_dict=return_dict,
1349
- )
1350
-
1351
- hidden_states = model_outputs[0]
1352
- hidden_states = self.dropout(hidden_states)
1353
- logits = self.classifier(hidden_states)
1354
-
1355
- loss = None
1356
- if labels is not None:
1357
- # move labels to correct device to enable model parallelism
1358
- labels = labels.to(logits.device)
1359
- batch_size, seq_length = labels.shape
1360
- loss_fct = CrossEntropyLoss()
1361
- loss = loss_fct(
1362
- logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1363
- )
1364
-
1365
- if not return_dict:
1366
- output = (logits,) + model_outputs[2:]
1367
- return ((loss,) + output) if loss is not None else output
1368
-
1369
- return TokenClassifierOutput(
1370
- loss=loss,
1371
- logits=logits,
1372
- hidden_states=model_outputs.hidden_states,
1373
- attentions=model_outputs.attentions,
1374
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/phi3/__init__.py DELETED
@@ -1,69 +0,0 @@
1
- # Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import TYPE_CHECKING
17
-
18
- from transformers.utils import (
19
- OptionalDependencyNotAvailable,
20
- _LazyModule,
21
- is_sentencepiece_available,
22
- is_tokenizers_available,
23
- is_torch_available,
24
- )
25
-
26
-
27
- _import_structure = {
28
- "configuration_phi3": ["PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP", "Phi3Config"],
29
- }
30
-
31
- try:
32
- if not is_torch_available():
33
- raise OptionalDependencyNotAvailable()
34
- except OptionalDependencyNotAvailable:
35
- pass
36
- else:
37
- _import_structure["modeling_phi3"] = [
38
- "PHI3_PRETRAINED_MODEL_ARCHIVE_LIST",
39
- "Phi3PreTrainedModel",
40
- "Phi3Model",
41
- "Phi3ForCausalLM",
42
- "Phi3ForSequenceClassification",
43
- "Phi3ForTokenClassification",
44
- ]
45
-
46
-
47
- if TYPE_CHECKING:
48
- from .configuration_phi3 import PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP, Phi3Config
49
-
50
- try:
51
- if not is_torch_available():
52
- raise OptionalDependencyNotAvailable()
53
- except OptionalDependencyNotAvailable:
54
- pass
55
- else:
56
- from .modeling_phi3 import (
57
- PHI3_PRETRAINED_MODEL_ARCHIVE_LIST,
58
- Phi3ForCausalLM,
59
- Phi3ForSequenceClassification,
60
- Phi3ForTokenClassification,
61
- Phi3Model,
62
- Phi3PreTrainedModel,
63
- )
64
-
65
-
66
- else:
67
- import sys
68
-
69
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/phi3/configuration_phi3.py DELETED
@@ -1,213 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """ Phi-3 model configuration"""
17
-
18
-
19
- from transformers.configuration_utils import PretrainedConfig
20
- from transformers.utils import logging
21
-
22
-
23
- logger = logging.get_logger(__name__)
24
-
25
- PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26
- "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json",
27
- "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json",
28
- }
29
-
30
-
31
- class Phi3Config(PretrainedConfig):
32
- r"""
33
- This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
34
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
- defaults will yield a similar configuration to that of the
36
- [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
37
-
38
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
- documentation from [`PretrainedConfig`] for more information.
40
-
41
- Args:
42
- vocab_size (`int`, *optional*, defaults to 32064):
43
- Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
44
- `inputs_ids` passed when calling [`Phi3Model`].
45
- hidden_size (`int`, *optional*, defaults to 3072):
46
- Dimension of the hidden representations.
47
- intermediate_size (`int`, *optional*, defaults to 8192):
48
- Dimension of the MLP representations.
49
- num_hidden_layers (`int`, *optional*, defaults to 32):
50
- Number of hidden layers in the Transformer decoder.
51
- num_attention_heads (`int`, *optional*, defaults to 32):
52
- Number of attention heads for each attention layer in the Transformer decoder.
53
- num_key_value_heads (`int`, *optional*):
54
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
- by meanpooling all the original heads within that group. For more details checkout [this
59
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
- `num_attention_heads`.
61
- resid_pdrop (`float`, *optional*, defaults to 0.0):
62
- Dropout probability for mlp outputs.
63
- embd_pdrop (`int`, *optional*, defaults to 0.0):
64
- The dropout ratio for the embeddings.
65
- attention_dropout (`float`, *optional*, defaults to 0.0):
66
- The dropout ratio after computing the attention scores.
67
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
68
- The non-linear activation function (function or string) in the decoder.
69
- max_position_embeddings (`int`, *optional*, defaults to 4096):
70
- The maximum sequence length that this model might ever be used with.
71
- original_max_position_embeddings (`int`, *optional*, defaults to 4096):
72
- The maximum sequence length that this model was trained with. This is used to determine the size of the
73
- original RoPE embeddings when using long scaling.
74
- initializer_range (`float`, *optional*, defaults to 0.02):
75
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
- rms_norm_eps (`float`, *optional*, defaults to 1e-05):
77
- The epsilon value used for the RMSNorm.
78
- use_cache (`bool`, *optional*, defaults to `True`):
79
- Whether or not the model should return the last key/values attentions (not used by all models). Only
80
- relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
81
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
- Whether to tie weight embeddings
83
- rope_theta (`float`, *optional*, defaults to 10000.0):
84
- The base period of the RoPE embeddings.
85
- rope_scaling (`dict`, *optional*):
86
- The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
87
- contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
88
- the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
89
- divided by the number of attention heads divided by 2.
90
- bos_token_id (`int`, *optional*, defaults to 1):
91
- The id of the "beginning-of-sequence" token.
92
- eos_token_id (`int`, *optional*, defaults to 32000):
93
- The id of the "end-of-sequence" token.
94
- pad_token_id (`int`, *optional*, defaults to 32000):
95
- The id of the padding token.
96
- sliding_window (`int`, *optional*):
97
- Sliding window attention window size. If `None`, no sliding window is applied.
98
-
99
- Example:
100
-
101
- ```python
102
- >>> from transformers import Phi3Model, Phi3Config
103
-
104
- >>> # Initializing a Phi-3 style configuration
105
- >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
106
-
107
- >>> # Initializing a model from the configuration
108
- >>> model = Phi3Model(configuration)
109
-
110
- >>> # Accessing the model configuration
111
- >>> configuration = model.config
112
- ```"""
113
-
114
- model_type = "phi3"
115
- keys_to_ignore_at_inference = ["past_key_values"]
116
-
117
- def __init__(
118
- self,
119
- vocab_size=32064,
120
- hidden_size=3072,
121
- intermediate_size=8192,
122
- num_hidden_layers=32,
123
- num_attention_heads=32,
124
- num_key_value_heads=None,
125
- resid_pdrop=0.0,
126
- embd_pdrop=0.0,
127
- attention_dropout=0.0,
128
- hidden_act="silu",
129
- max_position_embeddings=4096,
130
- original_max_position_embeddings=4096,
131
- initializer_range=0.02,
132
- rms_norm_eps=1e-5,
133
- use_cache=True,
134
- tie_word_embeddings=False,
135
- rope_theta=10000.0,
136
- rope_scaling=None,
137
- bos_token_id=1,
138
- eos_token_id=32000,
139
- pad_token_id=32000,
140
- sliding_window=None,
141
- **kwargs,
142
- ):
143
- self.vocab_size = vocab_size
144
- self.hidden_size = hidden_size
145
- self.intermediate_size = intermediate_size
146
- self.num_hidden_layers = num_hidden_layers
147
- self.num_attention_heads = num_attention_heads
148
-
149
- if num_key_value_heads is None:
150
- num_key_value_heads = num_attention_heads
151
-
152
- self.num_key_value_heads = num_key_value_heads
153
- self.resid_pdrop = resid_pdrop
154
- self.embd_pdrop = embd_pdrop
155
- self.attention_dropout = attention_dropout
156
- self.hidden_act = hidden_act
157
- self.max_position_embeddings = max_position_embeddings
158
- self.original_max_position_embeddings = original_max_position_embeddings
159
- self.initializer_range = initializer_range
160
- self.rms_norm_eps = rms_norm_eps
161
- self.use_cache = use_cache
162
- self.rope_theta = rope_theta
163
- self.rope_scaling = rope_scaling
164
- self._rope_scaling_validation()
165
- self.sliding_window = sliding_window
166
-
167
- super().__init__(
168
- bos_token_id=bos_token_id,
169
- eos_token_id=eos_token_id,
170
- pad_token_id=pad_token_id,
171
- tie_word_embeddings=tie_word_embeddings,
172
- **kwargs,
173
- )
174
-
175
- def _rope_scaling_validation(self):
176
- """
177
- Validate the `rope_scaling` configuration.
178
- """
179
- if self.rope_scaling is None:
180
- return
181
-
182
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
183
- raise ValueError(
184
- "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
185
- f"got {self.rope_scaling}"
186
- )
187
- rope_scaling_type = self.rope_scaling.get("type", None)
188
- rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
189
- rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
190
- if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
191
- raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
192
- if not (
193
- isinstance(rope_scaling_short_factor, list)
194
- and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
195
- ):
196
- raise ValueError(
197
- f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
198
- )
199
- if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
200
- raise ValueError(
201
- f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
202
- )
203
- if not (
204
- isinstance(rope_scaling_long_factor, list)
205
- and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
206
- ):
207
- raise ValueError(
208
- f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
209
- )
210
- if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
211
- raise ValueError(
212
- f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
213
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/phi3/modeling_phi3.py DELETED
@@ -1,1597 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """ PyTorch Phi-3 model."""
17
-
18
- import inspect
19
- import math
20
- import warnings
21
- from typing import List, Optional, Tuple, Union
22
-
23
- import torch
24
- import torch.nn.functional as F
25
- import torch.utils.checkpoint
26
- from torch import nn
27
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
-
29
- from transformers.activations import ACT2FN
30
- from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
- from transformers.modeling_outputs import (
33
- BaseModelOutputWithPast,
34
- CausalLMOutputWithPast,
35
- SequenceClassifierOutputWithPast,
36
- TokenClassifierOutput,
37
- )
38
- from transformers.modeling_utils import PreTrainedModel
39
- from transformers.utils import (
40
- add_code_sample_docstrings,
41
- add_start_docstrings,
42
- add_start_docstrings_to_model_forward,
43
- is_flash_attn_2_available,
44
- is_flash_attn_greater_or_equal_2_10,
45
- logging,
46
- replace_return_docstrings,
47
- )
48
- from .configuration_phi3 import Phi3Config
49
-
50
-
51
- if is_flash_attn_2_available():
52
- from flash_attn import flash_attn_func, flash_attn_varlen_func
53
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
54
-
55
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
56
-
57
- logger = logging.get_logger(__name__)
58
-
59
- _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
60
- _CONFIG_FOR_DOC = "Phi3Config"
61
-
62
- PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
- "microsoft/Phi-3-mini-4k-instruct",
64
- "microsoft/Phi-3-mini-128k-instruct",
65
- # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
66
- ]
67
-
68
-
69
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
70
- class Phi3RMSNorm(nn.Module):
71
- def __init__(self, hidden_size, eps=1e-6):
72
- """
73
- Phi3RMSNorm is equivalent to T5LayerNorm
74
- """
75
- super().__init__()
76
- self.weight = nn.Parameter(torch.ones(hidden_size))
77
- self.variance_epsilon = eps
78
-
79
- def forward(self, hidden_states):
80
- input_dtype = hidden_states.dtype
81
- hidden_states = hidden_states.to(torch.float32)
82
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
83
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
84
- return self.weight * hidden_states.to(input_dtype)
85
-
86
-
87
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
88
- def _get_unpad_data(attention_mask):
89
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
90
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
91
- max_seqlen_in_batch = seqlens_in_batch.max().item()
92
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
93
- return (
94
- indices,
95
- cu_seqlens,
96
- max_seqlen_in_batch,
97
- )
98
-
99
-
100
- # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
101
- class Phi3RotaryEmbedding(nn.Module):
102
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
103
- super().__init__()
104
-
105
- self.dim = dim
106
- self.max_position_embeddings = max_position_embeddings
107
- self.base = base
108
- self.register_buffer("inv_freq", None, persistent=False)
109
-
110
- @torch.no_grad()
111
- def forward(self, x, position_ids, seq_len=None):
112
- # x: [bs, num_attention_heads, seq_len, head_size]
113
- if self.inv_freq is None:
114
- self.inv_freq = 1.0 / (
115
- self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
116
- )
117
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
118
- position_ids_expanded = position_ids[:, None, :].float()
119
- # Force float32 since bfloat16 loses precision on long contexts
120
- # See https://github.com/huggingface/transformers/pull/29285
121
- device_type = x.device.type
122
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
123
- with torch.autocast(device_type=device_type, enabled=False):
124
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
125
- emb = torch.cat((freqs, freqs), dim=-1)
126
- cos = emb.cos()
127
- sin = emb.sin()
128
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
129
-
130
-
131
- class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
132
- def __init__(self, dim, config, device=None):
133
- super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
134
-
135
- self.short_factor = config.rope_scaling["short_factor"]
136
- self.long_factor = config.rope_scaling["long_factor"]
137
- self.original_max_position_embeddings = config.original_max_position_embeddings
138
-
139
- @torch.no_grad()
140
- def forward(self, x, position_ids, seq_len=None):
141
- seq_len = torch.max(position_ids) + 1
142
- if seq_len > self.original_max_position_embeddings:
143
- ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
144
- else:
145
- ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
146
-
147
- inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
148
- self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
149
-
150
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
151
- position_ids_expanded = position_ids[:, None, :].float()
152
-
153
- # Force float32 since bfloat16 loses precision on long contexts
154
- # See https://github.com/huggingface/transformers/pull/29285
155
- device_type = x.device.type
156
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
157
- with torch.autocast(device_type=device_type, enabled=False):
158
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
159
- emb = torch.cat((freqs, freqs), dim=-1)
160
-
161
- scale = self.max_position_embeddings / self.original_max_position_embeddings
162
- if scale <= 1.0:
163
- scaling_factor = 1.0
164
- else:
165
- scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
166
-
167
- cos = emb.cos() * scaling_factor
168
- sin = emb.sin() * scaling_factor
169
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
170
-
171
-
172
- class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
173
- def __init__(self, dim, config, device=None):
174
- super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
175
-
176
- self.short_factor = config.rope_scaling["short_factor"]
177
- self.long_factor = config.rope_scaling["long_factor"]
178
- self.original_max_position_embeddings = config.original_max_position_embeddings
179
-
180
- @torch.no_grad()
181
- def forward(self, x, position_ids, seq_len=None):
182
- seq_len = torch.max(position_ids) + 1
183
- if seq_len > self.original_max_position_embeddings:
184
- ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
185
- else:
186
- ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
187
-
188
- inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
189
- self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
190
-
191
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
192
- position_ids_expanded = position_ids[:, None, :].float()
193
-
194
- # Force float32 since bfloat16 loses precision on long contexts
195
- # See https://github.com/huggingface/transformers/pull/29285
196
- device_type = x.device.type
197
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
198
- with torch.autocast(device_type=device_type, enabled=False):
199
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
200
- emb = torch.cat((freqs, freqs), dim=-1)
201
-
202
- scale = self.max_position_embeddings / self.original_max_position_embeddings
203
- if scale <= 1.0:
204
- scaling_factor = 1.0
205
- else:
206
- scaling_factor = 0.1 * math.log(scale) + 1.0
207
-
208
- cos = emb.cos() * scaling_factor
209
- sin = emb.sin() * scaling_factor
210
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
211
-
212
-
213
- # Copied from transformers.models.llama.modeling_llama.rotate_half
214
- def rotate_half(x):
215
- """Rotates half the hidden dims of the input."""
216
- x1 = x[..., : x.shape[-1] // 2]
217
- x2 = x[..., x.shape[-1] // 2 :]
218
- return torch.cat((-x2, x1), dim=-1)
219
-
220
-
221
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
222
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
223
- """Applies Rotary Position Embedding to the query and key tensors.
224
-
225
- Args:
226
- q (`torch.Tensor`): The query tensor.
227
- k (`torch.Tensor`): The key tensor.
228
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
229
- sin (`torch.Tensor`): The sine part of the rotary embedding.
230
- position_ids (`torch.Tensor`, *optional*):
231
- Deprecated and unused.
232
- unsqueeze_dim (`int`, *optional*, defaults to 1):
233
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
234
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
235
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
236
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
237
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
238
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
239
- Returns:
240
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
241
- """
242
- cos = cos.unsqueeze(unsqueeze_dim)
243
- sin = sin.unsqueeze(unsqueeze_dim)
244
- q_embed = (q * cos) + (rotate_half(q) * sin)
245
- k_embed = (k * cos) + (rotate_half(k) * sin)
246
- return q_embed, k_embed
247
-
248
-
249
- class Phi3MLP(nn.Module):
250
- def __init__(self, config):
251
- super().__init__()
252
-
253
- self.config = config
254
- self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
255
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
256
-
257
- self.activation_fn = ACT2FN[config.hidden_act]
258
-
259
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
260
- up_states = self.gate_up_proj(hidden_states)
261
-
262
- gate, up_states = up_states.chunk(2, dim=-1)
263
- up_states = up_states * self.activation_fn(gate)
264
-
265
- return self.down_proj(up_states)
266
-
267
-
268
- # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
269
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
270
- """
271
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
272
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
273
- """
274
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
275
- if n_rep == 1:
276
- return hidden_states
277
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
278
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
279
-
280
-
281
- class Phi3Attention(nn.Module):
282
- """Multi-headed attention from 'Attention Is All You Need' paper"""
283
-
284
- def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
285
- super().__init__()
286
- self.config = config
287
- self.layer_idx = layer_idx
288
- if layer_idx is None:
289
- logger.warning_once(
290
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
291
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
292
- "when creating this class."
293
- )
294
-
295
- self.attention_dropout = config.attention_dropout
296
- self.hidden_size = config.hidden_size
297
- self.num_heads = config.num_attention_heads
298
- self.head_dim = self.hidden_size // self.num_heads
299
- self.num_key_value_heads = config.num_key_value_heads
300
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
301
- self.max_position_embeddings = config.max_position_embeddings
302
- self.original_max_position_embeddings = config.original_max_position_embeddings
303
- self.rope_theta = config.rope_theta
304
- self.rope_scaling = config.rope_scaling
305
- self.is_causal = True
306
-
307
- if (self.head_dim * self.num_heads) != self.hidden_size:
308
- raise ValueError(
309
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
310
- f" and `num_heads`: {self.num_heads})."
311
- )
312
-
313
- op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
314
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
315
- self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
316
- self._init_rope()
317
-
318
- def _init_rope(self):
319
- if self.rope_scaling is None:
320
- self.rotary_emb = Phi3RotaryEmbedding(
321
- self.head_dim,
322
- max_position_embeddings=self.max_position_embeddings,
323
- base=self.rope_theta,
324
- )
325
- else:
326
- scaling_type = self.config.rope_scaling["type"]
327
- if scaling_type == "su":
328
- self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
329
- elif scaling_type == "yarn":
330
- self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
331
- else:
332
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
333
-
334
- def forward(
335
- self,
336
- hidden_states: torch.Tensor,
337
- attention_mask: Optional[torch.Tensor] = None,
338
- position_ids: Optional[torch.LongTensor] = None,
339
- past_key_value: Optional[Cache] = None,
340
- output_attentions: bool = False,
341
- use_cache: bool = False,
342
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
343
- logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
344
-
345
- bsz, q_len, _ = hidden_states.size()
346
-
347
- qkv = self.qkv_proj(hidden_states)
348
- query_pos = self.num_heads * self.head_dim
349
- query_states = qkv[..., :query_pos]
350
- key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
351
- value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
352
-
353
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
354
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
355
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
356
-
357
- kv_seq_len = key_states.shape[-2]
358
- if past_key_value is not None:
359
- if self.layer_idx is None:
360
- raise ValueError(
361
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
362
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
363
- "with a layer index."
364
- )
365
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
366
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
367
-
368
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
369
-
370
- if past_key_value is not None:
371
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
372
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
373
-
374
- # repeat k/v heads if n_kv_heads < n_heads
375
- key_states = repeat_kv(key_states, self.num_key_value_groups)
376
- value_states = repeat_kv(value_states, self.num_key_value_groups)
377
-
378
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
379
-
380
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
381
- raise ValueError(
382
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
383
- f" {attn_weights.size()}"
384
- )
385
-
386
- if attention_mask is not None:
387
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
388
- raise ValueError(
389
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
390
- )
391
- attn_weights = attn_weights + attention_mask
392
-
393
- # upcast attention to fp32
394
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
395
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
396
-
397
- attn_output = torch.matmul(attn_weights, value_states)
398
-
399
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
400
- raise ValueError(
401
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
402
- f" {attn_output.size()}"
403
- )
404
-
405
- attn_output = attn_output.transpose(1, 2).contiguous()
406
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
407
-
408
- attn_output = self.o_proj(attn_output)
409
-
410
- if not output_attentions:
411
- attn_weights = None
412
-
413
- return attn_output, attn_weights, past_key_value
414
-
415
-
416
- class Phi3FlashAttention2(Phi3Attention):
417
- """
418
- Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
419
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
420
- flash attention and deal with padding tokens in case the input contains any of them.
421
- """
422
-
423
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
424
- def __init__(self, *args, **kwargs):
425
- super().__init__(*args, **kwargs)
426
-
427
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
428
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
429
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
430
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
431
-
432
- def forward(
433
- self,
434
- hidden_states: torch.Tensor,
435
- attention_mask: Optional[torch.LongTensor] = None,
436
- position_ids: Optional[torch.LongTensor] = None,
437
- past_key_value: Optional[Cache] = None,
438
- output_attentions: bool = False,
439
- use_cache: bool = False,
440
- **kwargs,
441
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
442
- # Phi3FlashAttention2 attention does not support output_attentions
443
-
444
- if not _flash_supports_window_size:
445
- logger.warning_once(
446
- "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
447
- )
448
- raise ValueError("The current flash attention version does not support sliding window attention.")
449
-
450
- output_attentions = False
451
-
452
- if "padding_mask" in kwargs:
453
- warnings.warn(
454
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
455
- )
456
-
457
- # overwrite attention_mask with padding_mask
458
- attention_mask = kwargs.pop("padding_mask")
459
-
460
- bsz, q_len, _ = hidden_states.size()
461
-
462
- qkv = self.qkv_proj(hidden_states)
463
- query_pos = self.num_heads * self.head_dim
464
- query_states = qkv[..., :query_pos]
465
- key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
466
- value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
467
-
468
- # Flash attention requires the input to have the shape
469
- # batch_size x seq_length x head_dim x hidden_dim
470
- # therefore we just need to keep the original shape
471
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
472
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
473
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
474
-
475
- kv_seq_len = key_states.shape[-2]
476
- if past_key_value is not None:
477
- if self.layer_idx is None:
478
- raise ValueError(
479
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
480
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
481
- "with a layer index."
482
- )
483
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
484
-
485
- # Because the input can be padded, the absolute sequence length depends on the max position id.
486
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
487
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
488
-
489
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
490
-
491
- use_sliding_windows = (
492
- _flash_supports_window_size
493
- and getattr(self.config, "sliding_window", None) is not None
494
- and kv_seq_len > self.config.sliding_window
495
- )
496
-
497
- if past_key_value is not None:
498
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
499
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
500
- if (
501
- getattr(self.config, "sliding_window", None) is not None
502
- and kv_seq_len > self.config.sliding_window
503
- and cache_has_contents
504
- ):
505
- slicing_tokens = 1 - self.config.sliding_window
506
-
507
- past_key = past_key_value[self.layer_idx][0]
508
- past_value = past_key_value[self.layer_idx][1]
509
-
510
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
511
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
512
-
513
- if past_key.shape[-2] != self.config.sliding_window - 1:
514
- raise ValueError(
515
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
516
- f" {past_key.shape}"
517
- )
518
-
519
- if attention_mask is not None:
520
- attention_mask = attention_mask[:, slicing_tokens:]
521
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
522
-
523
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
524
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
525
-
526
- # repeat k/v heads if n_kv_heads < n_heads
527
- key_states = repeat_kv(key_states, self.num_key_value_groups)
528
- value_states = repeat_kv(value_states, self.num_key_value_groups)
529
-
530
- attn_dropout = self.attention_dropout if self.training else 0.0
531
-
532
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
533
- # therefore the input hidden states gets silently casted in float32. Hence, we need
534
- # cast them back in the correct dtype just to be sure everything works as expected.
535
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
536
- # in fp32.
537
-
538
- if query_states.dtype == torch.float32:
539
- if torch.is_autocast_enabled():
540
- target_dtype = torch.get_autocast_gpu_dtype()
541
- # Handle the case where the model is quantized
542
- elif hasattr(self.config, "_pre_quantization_dtype"):
543
- target_dtype = self.config._pre_quantization_dtype
544
- else:
545
- target_dtype = self.qkv_proj.weight.dtype
546
-
547
- logger.warning_once(
548
- f"The input hidden states seems to be silently casted in float32, this might be related to"
549
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
550
- f" {target_dtype}."
551
- )
552
-
553
- query_states = query_states.to(target_dtype)
554
- key_states = key_states.to(target_dtype)
555
- value_states = value_states.to(target_dtype)
556
-
557
- # Reashape to the expected shape for Flash Attention
558
- query_states = query_states.transpose(1, 2)
559
- key_states = key_states.transpose(1, 2)
560
- value_states = value_states.transpose(1, 2)
561
-
562
- attn_output = self._flash_attention_forward(
563
- query_states,
564
- key_states,
565
- value_states,
566
- attention_mask,
567
- q_len,
568
- dropout=attn_dropout,
569
- use_sliding_windows=use_sliding_windows,
570
- )
571
-
572
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
573
- attn_output = self.o_proj(attn_output)
574
-
575
- if not output_attentions:
576
- attn_weights = None
577
-
578
- return attn_output, attn_weights, past_key_value
579
-
580
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
581
- def _flash_attention_forward(
582
- self,
583
- query_states,
584
- key_states,
585
- value_states,
586
- attention_mask,
587
- query_length,
588
- dropout=0.0,
589
- softmax_scale=None,
590
- use_sliding_windows=False,
591
- ):
592
- """
593
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
594
- first unpad the input, then computes the attention scores and pad the final attention scores.
595
-
596
- Args:
597
- query_states (`torch.Tensor`):
598
- Input query states to be passed to Flash Attention API
599
- key_states (`torch.Tensor`):
600
- Input key states to be passed to Flash Attention API
601
- value_states (`torch.Tensor`):
602
- Input value states to be passed to Flash Attention API
603
- attention_mask (`torch.Tensor`):
604
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
605
- position of padding tokens and 1 for the position of non-padding tokens.
606
- dropout (`float`):
607
- Attention dropout
608
- softmax_scale (`float`, *optional*):
609
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
610
- use_sliding_windows (`bool`, *optional*):
611
- Whether to activate sliding window attention.
612
- """
613
- if not self._flash_attn_uses_top_left_mask:
614
- causal = self.is_causal
615
- else:
616
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
617
- causal = self.is_causal and query_length != 1
618
-
619
- # Contains at least one padding token in the sequence
620
- if attention_mask is not None:
621
- batch_size = query_states.shape[0]
622
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
623
- query_states, key_states, value_states, attention_mask, query_length
624
- )
625
-
626
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
627
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
628
-
629
- if not use_sliding_windows:
630
- attn_output_unpad = flash_attn_varlen_func(
631
- query_states,
632
- key_states,
633
- value_states,
634
- cu_seqlens_q=cu_seqlens_q,
635
- cu_seqlens_k=cu_seqlens_k,
636
- max_seqlen_q=max_seqlen_in_batch_q,
637
- max_seqlen_k=max_seqlen_in_batch_k,
638
- dropout_p=dropout,
639
- softmax_scale=softmax_scale,
640
- causal=causal,
641
- )
642
- else:
643
- attn_output_unpad = flash_attn_varlen_func(
644
- query_states,
645
- key_states,
646
- value_states,
647
- cu_seqlens_q=cu_seqlens_q,
648
- cu_seqlens_k=cu_seqlens_k,
649
- max_seqlen_q=max_seqlen_in_batch_q,
650
- max_seqlen_k=max_seqlen_in_batch_k,
651
- dropout_p=dropout,
652
- softmax_scale=softmax_scale,
653
- causal=causal,
654
- window_size=(self.config.sliding_window, self.config.sliding_window),
655
- )
656
-
657
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
658
- else:
659
- if not use_sliding_windows:
660
- attn_output = flash_attn_func(
661
- query_states,
662
- key_states,
663
- value_states,
664
- dropout,
665
- softmax_scale=softmax_scale,
666
- causal=causal,
667
- )
668
- else:
669
- attn_output = flash_attn_func(
670
- query_states,
671
- key_states,
672
- value_states,
673
- dropout,
674
- softmax_scale=softmax_scale,
675
- causal=causal,
676
- window_size=(self.config.sliding_window, self.config.sliding_window),
677
- )
678
-
679
- return attn_output
680
-
681
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
682
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
683
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
684
-
685
- # On the first iteration we need to properly re-create the padding mask
686
- # by slicing it on the proper place
687
- if kv_seq_len != attention_mask.shape[-1]:
688
- attention_mask_num_tokens = attention_mask.shape[-1]
689
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
690
-
691
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
692
-
693
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
694
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
695
-
696
- if query_length == kv_seq_len:
697
- query_layer = index_first_axis(
698
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
699
- )
700
- cu_seqlens_q = cu_seqlens_k
701
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
702
- indices_q = indices_k
703
- elif query_length == 1:
704
- max_seqlen_in_batch_q = 1
705
- cu_seqlens_q = torch.arange(
706
- batch_size + 1, dtype=torch.int32, device=query_layer.device
707
- ) # There is a memcpy here, that is very bad.
708
- indices_q = cu_seqlens_q[:-1]
709
- query_layer = query_layer.squeeze(1)
710
- else:
711
- # The -q_len: slice assumes left padding.
712
- attention_mask = attention_mask[:, -query_length:]
713
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
714
-
715
- return (
716
- query_layer,
717
- key_layer,
718
- value_layer,
719
- indices_q,
720
- (cu_seqlens_q, cu_seqlens_k),
721
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
722
- )
723
-
724
-
725
- # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
726
- # TODO @Arthur no longer copied from LLama after static cache
727
- class Phi3SdpaAttention(Phi3Attention):
728
- """
729
- Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
730
- `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
731
- SDPA API.
732
- """
733
-
734
- # Adapted from Phi3Attention.forward
735
- def forward(
736
- self,
737
- hidden_states: torch.Tensor,
738
- attention_mask: Optional[torch.Tensor] = None,
739
- position_ids: Optional[torch.LongTensor] = None,
740
- past_key_value: Optional[Cache] = None,
741
- output_attentions: bool = False,
742
- use_cache: bool = False,
743
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
744
- if output_attentions:
745
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
746
- logger.warning_once(
747
- "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
748
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
749
- )
750
- return super().forward(
751
- hidden_states=hidden_states,
752
- attention_mask=attention_mask,
753
- position_ids=position_ids,
754
- past_key_value=past_key_value,
755
- output_attentions=output_attentions,
756
- use_cache=use_cache,
757
- )
758
-
759
- bsz, q_len, _ = hidden_states.size()
760
-
761
- qkv = self.qkv_proj(hidden_states)
762
- query_pos = self.num_heads * self.head_dim
763
- query_states = qkv[..., :query_pos]
764
- key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
765
- value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
766
-
767
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
768
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
769
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
770
-
771
- kv_seq_len = key_states.shape[-2]
772
- if past_key_value is not None:
773
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
774
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
775
-
776
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
777
-
778
- if past_key_value is not None:
779
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
780
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
781
-
782
- key_states = repeat_kv(key_states, self.num_key_value_groups)
783
- value_states = repeat_kv(value_states, self.num_key_value_groups)
784
-
785
- if attention_mask is not None:
786
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
787
- raise ValueError(
788
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
789
- )
790
-
791
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
792
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
793
- if query_states.device.type == "cuda" and attention_mask is not None:
794
- query_states = query_states.contiguous()
795
- key_states = key_states.contiguous()
796
- value_states = value_states.contiguous()
797
-
798
- attn_output = torch.nn.functional.scaled_dot_product_attention(
799
- query_states,
800
- key_states,
801
- value_states,
802
- attn_mask=attention_mask,
803
- dropout_p=self.attention_dropout if self.training else 0.0,
804
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
805
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
806
- )
807
-
808
- attn_output = attn_output.transpose(1, 2).contiguous()
809
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
810
-
811
- attn_output = self.o_proj(attn_output)
812
-
813
- return attn_output, None, past_key_value
814
-
815
-
816
- PHI3_ATTENTION_CLASSES = {
817
- "eager": Phi3Attention,
818
- "flash_attention_2": Phi3FlashAttention2,
819
- "sdpa": Phi3SdpaAttention,
820
- }
821
-
822
-
823
- class Phi3DecoderLayer(nn.Module):
824
- def __init__(self, config: Phi3Config, layer_idx: int):
825
- super().__init__()
826
-
827
- self.config = config
828
- self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
829
-
830
- self.mlp = Phi3MLP(config)
831
- self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
832
-
833
- self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
834
- self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
835
- self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
836
-
837
- def forward(
838
- self,
839
- hidden_states: torch.Tensor,
840
- attention_mask: Optional[torch.Tensor] = None,
841
- position_ids: Optional[torch.LongTensor] = None,
842
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
843
- output_attentions: Optional[bool] = False,
844
- use_cache: Optional[bool] = False,
845
- **kwargs,
846
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
847
- if "padding_mask" in kwargs:
848
- warnings.warn(
849
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
850
- )
851
- """
852
- Args:
853
- hidden_states (`torch.FloatTensor`):
854
- input to the layer of shape `(batch, seq_len, embed_dim)`
855
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
856
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
857
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
858
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
859
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
860
- output_attentions (`bool`, *optional*):
861
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
862
- returned tensors for more detail.
863
- use_cache (`bool`, *optional*):
864
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
865
- (see `past_key_values`).
866
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
867
- """
868
-
869
- residual = hidden_states
870
-
871
- hidden_states = self.input_layernorm(hidden_states)
872
-
873
- # Self Attention
874
- attn_outputs, self_attn_weights, present_key_value = self.self_attn(
875
- hidden_states=hidden_states,
876
- attention_mask=attention_mask,
877
- position_ids=position_ids,
878
- past_key_value=past_key_value,
879
- output_attentions=output_attentions,
880
- use_cache=use_cache,
881
- )
882
-
883
- hidden_states = residual + self.resid_attn_dropout(attn_outputs)
884
-
885
- residual = hidden_states
886
- hidden_states = self.post_attention_layernorm(hidden_states)
887
- hidden_states = self.mlp(hidden_states)
888
- hidden_states = residual + self.resid_mlp_dropout(hidden_states)
889
-
890
- outputs = (hidden_states,)
891
-
892
- if output_attentions:
893
- outputs += (self_attn_weights,)
894
-
895
- if use_cache:
896
- outputs += (present_key_value,)
897
-
898
- return outputs
899
-
900
-
901
- PHI3_START_DOCSTRING = r"""
902
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
903
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
904
- etc.)
905
-
906
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
907
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
908
- and behavior.
909
-
910
- Parameters:
911
- config ([`Phi3Config`]):
912
- Model configuration class with all the parameters of the model. Initializing with a config file does not
913
- load the weights associated with the model, only the configuration. Check out the
914
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
915
- """
916
-
917
-
918
- @add_start_docstrings(
919
- "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
920
- PHI3_START_DOCSTRING,
921
- )
922
- class Phi3PreTrainedModel(PreTrainedModel):
923
- config_class = Phi3Config
924
- base_model_prefix = "model"
925
- supports_gradient_checkpointing = True
926
- _no_split_modules = ["Phi3DecoderLayer"]
927
- _skip_keys_device_placement = "past_key_values"
928
- _supports_flash_attn_2 = True
929
- _supports_sdpa = False
930
- _supports_cache_class = True
931
-
932
- _version = "0.0.5"
933
-
934
- def _init_weights(self, module):
935
- std = self.config.initializer_range
936
- if isinstance(module, nn.Linear):
937
- module.weight.data.normal_(mean=0.0, std=std)
938
- if module.bias is not None:
939
- module.bias.data.zero_()
940
- elif isinstance(module, nn.Embedding):
941
- module.weight.data.normal_(mean=0.0, std=std)
942
- if module.padding_idx is not None:
943
- module.weight.data[module.padding_idx].zero_()
944
-
945
-
946
- PHI3_INPUTS_DOCSTRING = r"""
947
- Args:
948
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
949
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
950
- it.
951
-
952
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
953
- [`PreTrainedTokenizer.__call__`] for details.
954
-
955
- [What are input IDs?](../glossary#input-ids)
956
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
957
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
958
-
959
- - 1 for tokens that are **not masked**,
960
- - 0 for tokens that are **masked**.
961
-
962
- [What are attention masks?](../glossary#attention-mask)
963
-
964
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
965
- [`PreTrainedTokenizer.__call__`] for details.
966
-
967
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
968
- `past_key_values`).
969
-
970
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
971
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
972
- information on the default strategy.
973
-
974
- - 1 indicates the head is **not masked**,
975
- - 0 indicates the head is **masked**.
976
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
977
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
978
- config.n_positions - 1]`.
979
-
980
- [What are position IDs?](../glossary#position-ids)
981
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
982
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
983
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
984
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
985
-
986
- Two formats are allowed:
987
- - a [`~cache_utils.Cache`] instance;
988
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
989
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
990
- cache format.
991
-
992
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
993
- legacy cache format will be returned.
994
-
995
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
996
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
997
- of shape `(batch_size, sequence_length)`.
998
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
999
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1000
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1001
- model's internal embedding lookup matrix.
1002
- use_cache (`bool`, *optional*):
1003
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1004
- `past_key_values`).
1005
- output_attentions (`bool`, *optional*):
1006
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1007
- tensors for more detail.
1008
- output_hidden_states (`bool`, *optional*):
1009
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1010
- more detail.
1011
- return_dict (`bool`, *optional*):
1012
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1013
- """
1014
-
1015
-
1016
- @add_start_docstrings(
1017
- "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
1018
- PHI3_START_DOCSTRING,
1019
- )
1020
- class Phi3Model(Phi3PreTrainedModel):
1021
- """
1022
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1023
-
1024
- Args:
1025
- config: Phi3Config
1026
- """
1027
-
1028
- def __init__(self, config: Phi3Config):
1029
- super().__init__(config)
1030
- self.padding_idx = config.pad_token_id
1031
- self.vocab_size = config.vocab_size
1032
-
1033
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1034
- self.embed_dropout = nn.Dropout(config.embd_pdrop)
1035
- self.layers = nn.ModuleList(
1036
- [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1037
- )
1038
- self._attn_implementation = config._attn_implementation
1039
- self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1040
-
1041
- self.gradient_checkpointing = False
1042
- # Initialize weights and apply final processing
1043
- self.post_init()
1044
-
1045
- def get_input_embeddings(self):
1046
- return self.embed_tokens
1047
-
1048
- def set_input_embeddings(self, value):
1049
- self.embed_tokens = value
1050
-
1051
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1052
- def forward(
1053
- self,
1054
- input_ids: torch.LongTensor = None,
1055
- attention_mask: Optional[torch.Tensor] = None,
1056
- position_ids: Optional[torch.LongTensor] = None,
1057
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1058
- inputs_embeds: Optional[torch.FloatTensor] = None,
1059
- use_cache: Optional[bool] = None,
1060
- output_attentions: Optional[bool] = None,
1061
- output_hidden_states: Optional[bool] = None,
1062
- return_dict: Optional[bool] = None,
1063
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1064
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1065
- output_hidden_states = (
1066
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1067
- )
1068
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1069
-
1070
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
-
1072
- # retrieve input_ids and inputs_embeds
1073
- if input_ids is not None and inputs_embeds is not None:
1074
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1075
- elif input_ids is not None:
1076
- batch_size, seq_length = input_ids.shape[:2]
1077
- elif inputs_embeds is not None:
1078
- batch_size, seq_length = inputs_embeds.shape[:2]
1079
- else:
1080
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1081
-
1082
- past_key_values_length = 0
1083
-
1084
- if self.gradient_checkpointing and self.training:
1085
- if use_cache:
1086
- logger.warning_once(
1087
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1088
- )
1089
- use_cache = False
1090
-
1091
- if use_cache:
1092
- use_legacy_cache = not isinstance(past_key_values, Cache)
1093
- if use_legacy_cache:
1094
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1095
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1096
-
1097
- if position_ids is None:
1098
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1099
- position_ids = torch.arange(
1100
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1101
- )
1102
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1103
- else:
1104
- position_ids = position_ids.view(-1, seq_length).long()
1105
-
1106
- if inputs_embeds is None:
1107
- inputs_embeds = self.embed_tokens(input_ids)
1108
-
1109
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1110
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1111
- if is_padding_right:
1112
- raise ValueError(
1113
- "You are attempting to perform batched generation with padding_side='right'"
1114
- " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
1115
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1116
- )
1117
-
1118
- if self._attn_implementation == "flash_attention_2":
1119
- # 2d mask is passed through the layers
1120
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1121
- else:
1122
- # 4d mask is passed through the layers
1123
- attention_mask = _prepare_4d_causal_attention_mask(
1124
- attention_mask,
1125
- (batch_size, seq_length),
1126
- inputs_embeds,
1127
- past_key_values_length,
1128
- sliding_window=self.config.sliding_window,
1129
- )
1130
-
1131
- hidden_states = inputs_embeds
1132
-
1133
- # decoder layers
1134
- all_hidden_states = () if output_hidden_states else None
1135
- all_self_attns = () if output_attentions else None
1136
- next_decoder_cache = None
1137
-
1138
- for decoder_layer in self.layers:
1139
- if output_hidden_states:
1140
- all_hidden_states += (hidden_states,)
1141
-
1142
- if self.gradient_checkpointing and self.training:
1143
- layer_outputs = self._gradient_checkpointing_func(
1144
- decoder_layer.__call__,
1145
- hidden_states,
1146
- attention_mask,
1147
- position_ids,
1148
- past_key_values,
1149
- output_attentions,
1150
- use_cache,
1151
- )
1152
- else:
1153
- layer_outputs = decoder_layer(
1154
- hidden_states,
1155
- attention_mask=attention_mask,
1156
- position_ids=position_ids,
1157
- past_key_value=past_key_values,
1158
- output_attentions=output_attentions,
1159
- use_cache=use_cache,
1160
- )
1161
-
1162
- hidden_states = layer_outputs[0]
1163
-
1164
- if use_cache:
1165
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1166
-
1167
- if output_attentions:
1168
- all_self_attns += (layer_outputs[1],)
1169
-
1170
- hidden_states = self.norm(hidden_states)
1171
-
1172
- # add hidden states from the last decoder layer
1173
- if output_hidden_states:
1174
- all_hidden_states += (hidden_states,)
1175
-
1176
- next_cache = None
1177
- if use_cache:
1178
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1179
- if not return_dict:
1180
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1181
- return BaseModelOutputWithPast(
1182
- last_hidden_state=hidden_states,
1183
- past_key_values=next_cache,
1184
- hidden_states=all_hidden_states,
1185
- attentions=all_self_attns,
1186
- )
1187
-
1188
-
1189
- class Phi3ForCausalLM(Phi3PreTrainedModel):
1190
- _tied_weights_keys = ["lm_head.weight"]
1191
-
1192
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1193
- def __init__(self, config):
1194
- super().__init__(config)
1195
- self.model = Phi3Model(config)
1196
- self.vocab_size = config.vocab_size
1197
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1198
-
1199
- # Initialize weights and apply final processing
1200
- self.post_init()
1201
-
1202
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1203
- def get_input_embeddings(self):
1204
- return self.model.embed_tokens
1205
-
1206
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1207
- def set_input_embeddings(self, value):
1208
- self.model.embed_tokens = value
1209
-
1210
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1211
- def get_output_embeddings(self):
1212
- return self.lm_head
1213
-
1214
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1215
- def set_output_embeddings(self, new_embeddings):
1216
- self.lm_head = new_embeddings
1217
-
1218
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1219
- def set_decoder(self, decoder):
1220
- self.model = decoder
1221
-
1222
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1223
- def get_decoder(self):
1224
- return self.model
1225
-
1226
- # Ignore copy
1227
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1228
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1229
- def forward(
1230
- self,
1231
- input_ids: torch.LongTensor = None,
1232
- attention_mask: Optional[torch.Tensor] = None,
1233
- position_ids: Optional[torch.LongTensor] = None,
1234
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1235
- inputs_embeds: Optional[torch.FloatTensor] = None,
1236
- labels: Optional[torch.LongTensor] = None,
1237
- use_cache: Optional[bool] = None,
1238
- output_attentions: Optional[bool] = None,
1239
- output_hidden_states: Optional[bool] = None,
1240
- return_dict: Optional[bool] = None,
1241
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1242
- r"""
1243
- Args:
1244
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1245
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1246
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1247
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1248
-
1249
- Returns:
1250
-
1251
- Example:
1252
-
1253
- ```python
1254
- >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1255
-
1256
- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1257
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1258
-
1259
- >>> prompt = "This is an example script ."
1260
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1261
-
1262
- >>> # Generate
1263
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1264
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1265
- 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1266
- ```"""
1267
-
1268
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1269
- output_hidden_states = (
1270
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1271
- )
1272
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1273
-
1274
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1275
- outputs = self.model(
1276
- input_ids=input_ids,
1277
- attention_mask=attention_mask,
1278
- position_ids=position_ids,
1279
- past_key_values=past_key_values,
1280
- inputs_embeds=inputs_embeds,
1281
- use_cache=use_cache,
1282
- output_attentions=output_attentions,
1283
- output_hidden_states=output_hidden_states,
1284
- return_dict=return_dict,
1285
- )
1286
-
1287
- hidden_states = outputs[0]
1288
- logits = self.lm_head(hidden_states)
1289
- logits = logits.float()
1290
-
1291
- loss = None
1292
- if labels is not None:
1293
- # Shift so that tokens < n predict n
1294
- shift_logits = logits[..., :-1, :].contiguous()
1295
- shift_labels = labels[..., 1:].contiguous()
1296
- # Flatten the tokens
1297
- loss_fct = CrossEntropyLoss()
1298
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1299
- shift_labels = shift_labels.view(-1)
1300
- # Enable model parallelism
1301
- shift_labels = shift_labels.to(shift_logits.device)
1302
- loss = loss_fct(shift_logits, shift_labels)
1303
-
1304
- if not return_dict:
1305
- output = (logits,) + outputs[1:]
1306
- return (loss,) + output if loss is not None else output
1307
-
1308
- return CausalLMOutputWithPast(
1309
- loss=loss,
1310
- logits=logits,
1311
- past_key_values=outputs.past_key_values,
1312
- hidden_states=outputs.hidden_states,
1313
- attentions=outputs.attentions,
1314
- )
1315
-
1316
- # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1317
- def prepare_inputs_for_generation(
1318
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1319
- ):
1320
- if past_key_values is not None:
1321
- if isinstance(past_key_values, Cache):
1322
- cache_length = past_key_values.get_seq_length()
1323
- past_length = past_key_values.seen_tokens
1324
- max_cache_length = past_key_values.get_max_length()
1325
- else:
1326
- cache_length = past_length = past_key_values[0][0].shape[2]
1327
- max_cache_length = None
1328
-
1329
- # Keep only the unprocessed tokens:
1330
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1331
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1332
- # input)
1333
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1334
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1335
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1336
- # input_ids based on the past_length.
1337
- elif past_length < input_ids.shape[1]:
1338
- input_ids = input_ids[:, past_length:]
1339
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1340
- else:
1341
- remove_prefix_length = input_ids.shape[1] - 1
1342
- input_ids = input_ids[:, remove_prefix_length:]
1343
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1344
- if (
1345
- max_cache_length is not None
1346
- and attention_mask is not None
1347
- and cache_length + input_ids.shape[1] > max_cache_length
1348
- ):
1349
- attention_mask = attention_mask[:, -max_cache_length:]
1350
-
1351
- position_ids = kwargs.get("position_ids", None)
1352
- if attention_mask is not None and position_ids is None:
1353
- # create position_ids on the fly for batch generation
1354
- position_ids = attention_mask.long().cumsum(-1) - 1
1355
- position_ids.masked_fill_(attention_mask == 0, 1)
1356
- if past_key_values:
1357
- position_ids = position_ids[:, -input_ids.shape[1] :]
1358
-
1359
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1360
- if inputs_embeds is not None and past_key_values is None:
1361
- model_inputs = {"inputs_embeds": inputs_embeds}
1362
- else:
1363
- model_inputs = {"input_ids": input_ids}
1364
-
1365
- model_inputs.update(
1366
- {
1367
- "position_ids": position_ids,
1368
- "past_key_values": past_key_values,
1369
- "use_cache": kwargs.get("use_cache"),
1370
- "attention_mask": attention_mask,
1371
- }
1372
- )
1373
- return model_inputs
1374
-
1375
- @staticmethod
1376
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1377
- def _reorder_cache(past_key_values, beam_idx):
1378
- reordered_past = ()
1379
- for layer_past in past_key_values:
1380
- reordered_past += (
1381
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1382
- )
1383
- return reordered_past
1384
-
1385
-
1386
- @add_start_docstrings(
1387
- """
1388
- The [`Phi3Model`] with a sequence classification head on top (linear layer).
1389
-
1390
- [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1391
- (e.g. GPT-2) do.
1392
-
1393
- Since it does classification on the last token, it requires to know the position of the last token. If a
1394
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1395
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1396
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1397
- each row of the batch).
1398
- """,
1399
- PHI3_START_DOCSTRING,
1400
- )
1401
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1402
- class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1403
- def __init__(self, config):
1404
- super().__init__(config)
1405
- self.num_labels = config.num_labels
1406
- self.model = Phi3Model(config)
1407
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1408
-
1409
- # Initialize weights and apply final processing
1410
- self.post_init()
1411
-
1412
- def get_input_embeddings(self):
1413
- return self.model.embed_tokens
1414
-
1415
- def set_input_embeddings(self, value):
1416
- self.model.embed_tokens = value
1417
-
1418
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1419
- def forward(
1420
- self,
1421
- input_ids: torch.LongTensor = None,
1422
- attention_mask: Optional[torch.Tensor] = None,
1423
- position_ids: Optional[torch.LongTensor] = None,
1424
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1425
- inputs_embeds: Optional[torch.FloatTensor] = None,
1426
- labels: Optional[torch.LongTensor] = None,
1427
- use_cache: Optional[bool] = None,
1428
- output_attentions: Optional[bool] = None,
1429
- output_hidden_states: Optional[bool] = None,
1430
- return_dict: Optional[bool] = None,
1431
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1432
- r"""
1433
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1434
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1435
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1436
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1437
- """
1438
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1439
-
1440
- model_outputs = self.model(
1441
- input_ids,
1442
- attention_mask=attention_mask,
1443
- position_ids=position_ids,
1444
- past_key_values=past_key_values,
1445
- inputs_embeds=inputs_embeds,
1446
- use_cache=use_cache,
1447
- output_attentions=output_attentions,
1448
- output_hidden_states=output_hidden_states,
1449
- return_dict=return_dict,
1450
- )
1451
- hidden_states = model_outputs[0]
1452
- logits = self.score(hidden_states)
1453
-
1454
- if input_ids is not None:
1455
- batch_size = input_ids.shape[0]
1456
- else:
1457
- batch_size = inputs_embeds.shape[0]
1458
-
1459
- if self.config.pad_token_id is None and batch_size != 1:
1460
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1461
- if self.config.pad_token_id is None:
1462
- sequence_lengths = -1
1463
- else:
1464
- if input_ids is not None:
1465
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1466
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1467
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1468
- sequence_lengths = sequence_lengths.to(logits.device)
1469
- else:
1470
- sequence_lengths = -1
1471
-
1472
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1473
-
1474
- loss = None
1475
- if labels is not None:
1476
- labels = labels.to(logits.device)
1477
- if self.config.problem_type is None:
1478
- if self.num_labels == 1:
1479
- self.config.problem_type = "regression"
1480
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1481
- self.config.problem_type = "single_label_classification"
1482
- else:
1483
- self.config.problem_type = "multi_label_classification"
1484
-
1485
- if self.config.problem_type == "regression":
1486
- loss_fct = MSELoss()
1487
- if self.num_labels == 1:
1488
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1489
- else:
1490
- loss = loss_fct(pooled_logits, labels)
1491
- elif self.config.problem_type == "single_label_classification":
1492
- loss_fct = CrossEntropyLoss()
1493
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1494
- elif self.config.problem_type == "multi_label_classification":
1495
- loss_fct = BCEWithLogitsLoss()
1496
- loss = loss_fct(pooled_logits, labels)
1497
- if not return_dict:
1498
- output = (pooled_logits,) + model_outputs[1:]
1499
- return ((loss,) + output) if loss is not None else output
1500
-
1501
- return SequenceClassifierOutputWithPast(
1502
- loss=loss,
1503
- logits=pooled_logits,
1504
- past_key_values=model_outputs.past_key_values,
1505
- hidden_states=model_outputs.hidden_states,
1506
- attentions=model_outputs.attentions,
1507
- )
1508
-
1509
-
1510
- @add_start_docstrings(
1511
- """
1512
- [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1513
- Named-Entity-Recognition (NER) tasks.
1514
- """,
1515
- PHI3_START_DOCSTRING,
1516
- )
1517
- # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1518
- class Phi3ForTokenClassification(Phi3PreTrainedModel):
1519
- def __init__(self, config: Phi3Config):
1520
- super().__init__(config)
1521
- self.num_labels = config.num_labels
1522
-
1523
- self.model = Phi3Model(config)
1524
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1525
- classifier_dropout = config.classifier_dropout
1526
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1527
- classifier_dropout = config.hidden_dropout
1528
- else:
1529
- classifier_dropout = 0.1
1530
- self.dropout = nn.Dropout(classifier_dropout)
1531
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1532
-
1533
- # Initialize weights and apply final processing
1534
- self.post_init()
1535
-
1536
- @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1537
- @add_code_sample_docstrings(
1538
- checkpoint=_CHECKPOINT_FOR_DOC,
1539
- output_type=TokenClassifierOutput,
1540
- config_class=_CONFIG_FOR_DOC,
1541
- )
1542
- def forward(
1543
- self,
1544
- input_ids: Optional[torch.LongTensor] = None,
1545
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1546
- attention_mask: Optional[torch.Tensor] = None,
1547
- inputs_embeds: Optional[torch.Tensor] = None,
1548
- labels: Optional[torch.Tensor] = None,
1549
- use_cache: Optional[bool] = None,
1550
- output_attentions: Optional[bool] = None,
1551
- output_hidden_states: Optional[bool] = None,
1552
- return_dict: Optional[bool] = None,
1553
- **deprecated_arguments,
1554
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1555
- r"""
1556
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1557
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1558
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1559
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1560
- """
1561
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1562
-
1563
- model_outputs = self.model(
1564
- input_ids,
1565
- past_key_values=past_key_values,
1566
- attention_mask=attention_mask,
1567
- inputs_embeds=inputs_embeds,
1568
- use_cache=use_cache,
1569
- output_attentions=output_attentions,
1570
- output_hidden_states=output_hidden_states,
1571
- return_dict=return_dict,
1572
- )
1573
-
1574
- hidden_states = model_outputs[0]
1575
- hidden_states = self.dropout(hidden_states)
1576
- logits = self.classifier(hidden_states)
1577
-
1578
- loss = None
1579
- if labels is not None:
1580
- # move labels to correct device to enable model parallelism
1581
- labels = labels.to(logits.device)
1582
- batch_size, seq_length = labels.shape
1583
- loss_fct = CrossEntropyLoss()
1584
- loss = loss_fct(
1585
- logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1586
- )
1587
-
1588
- if not return_dict:
1589
- output = (logits,) + model_outputs[2:]
1590
- return ((loss,) + output) if loss is not None else output
1591
-
1592
- return TokenClassifierOutput(
1593
- loss=loss,
1594
- logits=logits,
1595
- hidden_states=model_outputs.hidden_states,
1596
- attentions=model_outputs.attentions,
1597
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/qwen2/__init__.py DELETED
@@ -1,80 +0,0 @@
1
- # Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import TYPE_CHECKING
15
-
16
- from transformers.utils import (
17
- OptionalDependencyNotAvailable,
18
- _LazyModule,
19
- is_tokenizers_available,
20
- is_torch_available,
21
- )
22
-
23
-
24
- _import_structure = {
25
- "configuration_qwen2": ["QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Qwen2Config"],
26
- "tokenization_qwen2": ["Qwen2Tokenizer"],
27
- }
28
-
29
- try:
30
- if not is_tokenizers_available():
31
- raise OptionalDependencyNotAvailable()
32
- except OptionalDependencyNotAvailable:
33
- pass
34
- else:
35
- _import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
36
-
37
- try:
38
- if not is_torch_available():
39
- raise OptionalDependencyNotAvailable()
40
- except OptionalDependencyNotAvailable:
41
- pass
42
- else:
43
- _import_structure["modeling_qwen2"] = [
44
- "Qwen2ForCausalLM",
45
- "Qwen2Model",
46
- "Qwen2PreTrainedModel",
47
- "Qwen2ForSequenceClassification",
48
- ]
49
-
50
-
51
- if TYPE_CHECKING:
52
- from .configuration_qwen2 import QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP, Qwen2Config
53
- from .tokenization_qwen2 import Qwen2Tokenizer
54
-
55
- try:
56
- if not is_tokenizers_available():
57
- raise OptionalDependencyNotAvailable()
58
- except OptionalDependencyNotAvailable:
59
- pass
60
- else:
61
- from .tokenization_qwen2_fast import Qwen2TokenizerFast
62
-
63
- try:
64
- if not is_torch_available():
65
- raise OptionalDependencyNotAvailable()
66
- except OptionalDependencyNotAvailable:
67
- pass
68
- else:
69
- from .modeling_qwen2 import (
70
- Qwen2ForCausalLM,
71
- Qwen2ForSequenceClassification,
72
- Qwen2Model,
73
- Qwen2PreTrainedModel,
74
- )
75
-
76
-
77
- else:
78
- import sys
79
-
80
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/qwen2/configuration_qwen2.py DELETED
@@ -1,144 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Qwen2 model configuration"""
16
-
17
- from transformers.configuration_utils import PretrainedConfig
18
- from transformers.utils import logging
19
-
20
-
21
- logger = logging.get_logger(__name__)
22
-
23
- QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
- "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
25
- }
26
-
27
-
28
- class Qwen2Config(PretrainedConfig):
29
- r"""
30
- This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
31
- Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
- with the defaults will yield a similar configuration to that of
33
- Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
34
-
35
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
- documentation from [`PretrainedConfig`] for more information.
37
-
38
-
39
- Args:
40
- vocab_size (`int`, *optional*, defaults to 151936):
41
- Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
42
- `inputs_ids` passed when calling [`Qwen2Model`]
43
- hidden_size (`int`, *optional*, defaults to 4096):
44
- Dimension of the hidden representations.
45
- intermediate_size (`int`, *optional*, defaults to 22016):
46
- Dimension of the MLP representations.
47
- num_hidden_layers (`int`, *optional*, defaults to 32):
48
- Number of hidden layers in the Transformer encoder.
49
- num_attention_heads (`int`, *optional*, defaults to 32):
50
- Number of attention heads for each attention layer in the Transformer encoder.
51
- num_key_value_heads (`int`, *optional*, defaults to 32):
52
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
- by meanpooling all the original heads within that group. For more details checkout [this
57
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
58
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
- The non-linear activation function (function or string) in the decoder.
60
- max_position_embeddings (`int`, *optional*, defaults to 32768):
61
- The maximum sequence length that this model might ever be used with.
62
- initializer_range (`float`, *optional*, defaults to 0.02):
63
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
- The epsilon used by the rms normalization layers.
66
- use_cache (`bool`, *optional*, defaults to `True`):
67
- Whether or not the model should return the last key/values attentions (not used by all models). Only
68
- relevant if `config.is_decoder=True`.
69
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
70
- Whether the model's input and output word embeddings should be tied.
71
- rope_theta (`float`, *optional*, defaults to 10000.0):
72
- The base period of the RoPE embeddings.
73
- use_sliding_window (`bool`, *optional*, defaults to `False`):
74
- Whether to use sliding window attention.
75
- sliding_window (`int`, *optional*, defaults to 4096):
76
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
77
- max_window_layers (`int`, *optional*, defaults to 28):
78
- The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
79
- attention_dropout (`float`, *optional*, defaults to 0.0):
80
- The dropout ratio for the attention probabilities.
81
-
82
- ```python
83
- >>> from transformers import Qwen2Model, Qwen2Config
84
-
85
- >>> # Initializing a Qwen2 style configuration
86
- >>> configuration = Qwen2Config()
87
-
88
- >>> # Initializing a model from the Qwen2-7B style configuration
89
- >>> model = Qwen2Model(configuration)
90
-
91
- >>> # Accessing the model configuration
92
- >>> configuration = model.config
93
- ```"""
94
-
95
- model_type = "qwen2"
96
- keys_to_ignore_at_inference = ["past_key_values"]
97
-
98
- def __init__(
99
- self,
100
- vocab_size=151936,
101
- hidden_size=4096,
102
- intermediate_size=22016,
103
- num_hidden_layers=32,
104
- num_attention_heads=32,
105
- num_key_value_heads=32,
106
- hidden_act="silu",
107
- max_position_embeddings=32768,
108
- initializer_range=0.02,
109
- rms_norm_eps=1e-6,
110
- use_cache=True,
111
- tie_word_embeddings=False,
112
- rope_theta=10000.0,
113
- use_sliding_window=False,
114
- sliding_window=4096,
115
- max_window_layers=28,
116
- attention_dropout=0.0,
117
- **kwargs,
118
- ):
119
- self.vocab_size = vocab_size
120
- self.max_position_embeddings = max_position_embeddings
121
- self.hidden_size = hidden_size
122
- self.intermediate_size = intermediate_size
123
- self.num_hidden_layers = num_hidden_layers
124
- self.num_attention_heads = num_attention_heads
125
- self.use_sliding_window = use_sliding_window
126
- self.sliding_window = sliding_window
127
- self.max_window_layers = max_window_layers
128
-
129
- # for backward compatibility
130
- if num_key_value_heads is None:
131
- num_key_value_heads = num_attention_heads
132
-
133
- self.num_key_value_heads = num_key_value_heads
134
- self.hidden_act = hidden_act
135
- self.initializer_range = initializer_range
136
- self.rms_norm_eps = rms_norm_eps
137
- self.use_cache = use_cache
138
- self.rope_theta = rope_theta
139
- self.attention_dropout = attention_dropout
140
-
141
- super().__init__(
142
- tie_word_embeddings=tie_word_embeddings,
143
- **kwargs,
144
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/qwen2/modeling_qwen2.py DELETED
@@ -1,1403 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ PyTorch Qwen2 model."""
21
- import inspect
22
- import math
23
- import warnings
24
- from typing import List, Optional, Tuple, Union
25
-
26
- import torch
27
- import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
- from torch import nn
30
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
-
32
- from transformers.activations import ACT2FN
33
- from transformers.cache_utils import Cache, DynamicCache
34
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
- from transformers.modeling_utils import PreTrainedModel
37
- from transformers.utils import (
38
- add_start_docstrings,
39
- add_start_docstrings_to_model_forward,
40
- is_flash_attn_2_available,
41
- is_flash_attn_greater_or_equal_2_10,
42
- logging,
43
- replace_return_docstrings,
44
- )
45
- from .configuration_qwen2 import Qwen2Config
46
-
47
-
48
- if is_flash_attn_2_available():
49
- from flash_attn import flash_attn_func, flash_attn_varlen_func
50
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
-
52
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
53
-
54
-
55
- logger = logging.get_logger(__name__)
56
-
57
-
58
- _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
- _CONFIG_FOR_DOC = "Qwen2Config"
60
-
61
- QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
- "Qwen/Qwen2-7B-beta",
63
- # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
64
- ]
65
-
66
-
67
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
- def _get_unpad_data(attention_mask):
69
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
- max_seqlen_in_batch = seqlens_in_batch.max().item()
72
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
73
- return (
74
- indices,
75
- cu_seqlens,
76
- max_seqlen_in_batch,
77
- )
78
-
79
-
80
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
81
- class Qwen2RMSNorm(nn.Module):
82
- def __init__(self, hidden_size, eps=1e-6):
83
- """
84
- Qwen2RMSNorm is equivalent to T5LayerNorm
85
- """
86
- super().__init__()
87
- self.weight = nn.Parameter(torch.ones(hidden_size))
88
- self.variance_epsilon = eps
89
-
90
- def forward(self, hidden_states):
91
- input_dtype = hidden_states.dtype
92
- hidden_states = hidden_states.to(torch.float32)
93
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
94
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
95
- return self.weight * hidden_states.to(input_dtype)
96
-
97
-
98
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
99
- class Qwen2RotaryEmbedding(nn.Module):
100
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
101
- super().__init__()
102
-
103
- self.dim = dim
104
- self.max_position_embeddings = max_position_embeddings
105
- self.base = base
106
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
107
- self.register_buffer("inv_freq", inv_freq, persistent=False)
108
-
109
- # Build here to make `torch.jit.trace` work.
110
- self._set_cos_sin_cache(
111
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
112
- )
113
-
114
- def _set_cos_sin_cache(self, seq_len, device, dtype):
115
- self.max_seq_len_cached = seq_len
116
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
117
-
118
- freqs = torch.outer(t, self.inv_freq)
119
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
120
- emb = torch.cat((freqs, freqs), dim=-1)
121
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
122
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
123
-
124
- def forward(self, x, seq_len=None):
125
- # x: [bs, num_attention_heads, seq_len, head_size]
126
- if seq_len > self.max_seq_len_cached:
127
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
128
-
129
- return (
130
- self.cos_cached[:seq_len].to(dtype=x.dtype),
131
- self.sin_cached[:seq_len].to(dtype=x.dtype),
132
- )
133
-
134
-
135
- # Copied from transformers.models.llama.modeling_llama.rotate_half
136
- def rotate_half(x):
137
- """Rotates half the hidden dims of the input."""
138
- x1 = x[..., : x.shape[-1] // 2]
139
- x2 = x[..., x.shape[-1] // 2 :]
140
- return torch.cat((-x2, x1), dim=-1)
141
-
142
-
143
- # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
144
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
145
- """Applies Rotary Position Embedding to the query and key tensors.
146
-
147
- Args:
148
- q (`torch.Tensor`): The query tensor.
149
- k (`torch.Tensor`): The key tensor.
150
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
151
- sin (`torch.Tensor`): The sine part of the rotary embedding.
152
- position_ids (`torch.Tensor`):
153
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
154
- used to pass offsetted position ids when working with a KV-cache.
155
- unsqueeze_dim (`int`, *optional*, defaults to 1):
156
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
157
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
158
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
159
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
160
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
161
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
162
- Returns:
163
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
164
- """
165
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
166
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
167
- q_embed = (q * cos) + (rotate_half(q) * sin)
168
- k_embed = (k * cos) + (rotate_half(k) * sin)
169
- return q_embed, k_embed
170
-
171
-
172
- # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
173
- class Qwen2MLP(nn.Module):
174
- def __init__(self, config):
175
- super().__init__()
176
- self.config = config
177
- self.hidden_size = config.hidden_size
178
- self.intermediate_size = config.intermediate_size
179
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
180
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
181
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
182
- self.act_fn = ACT2FN[config.hidden_act]
183
-
184
- def forward(self, x):
185
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
186
-
187
-
188
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
189
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
190
- """
191
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
192
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
193
- """
194
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
195
- if n_rep == 1:
196
- return hidden_states
197
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
198
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
199
-
200
-
201
- class Qwen2Attention(nn.Module):
202
- """
203
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
204
- and "Generating Long Sequences with Sparse Transformers".
205
- """
206
-
207
- def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
208
- super().__init__()
209
- self.config = config
210
- self.layer_idx = layer_idx
211
- if layer_idx is None:
212
- logger.warning_once(
213
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
214
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
215
- "when creating this class."
216
- )
217
-
218
- self.hidden_size = config.hidden_size
219
- self.num_heads = config.num_attention_heads
220
- self.head_dim = self.hidden_size // self.num_heads
221
- self.num_key_value_heads = config.num_key_value_heads
222
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
223
- self.max_position_embeddings = config.max_position_embeddings
224
- self.rope_theta = config.rope_theta
225
- self.is_causal = True
226
- self.attention_dropout = config.attention_dropout
227
-
228
- if (self.head_dim * self.num_heads) != self.hidden_size:
229
- raise ValueError(
230
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
231
- f" and `num_heads`: {self.num_heads})."
232
- )
233
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
234
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
235
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
236
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
237
-
238
- self.rotary_emb = Qwen2RotaryEmbedding(
239
- self.head_dim,
240
- max_position_embeddings=self.max_position_embeddings,
241
- base=self.rope_theta,
242
- )
243
-
244
- def forward(
245
- self,
246
- hidden_states: torch.Tensor,
247
- attention_mask: Optional[torch.Tensor] = None,
248
- position_ids: Optional[torch.LongTensor] = None,
249
- past_key_value: Optional[Cache] = None,
250
- output_attentions: bool = False,
251
- use_cache: bool = False,
252
- **kwargs,
253
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
254
- if "padding_mask" in kwargs:
255
- warnings.warn(
256
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
257
- )
258
- bsz, q_len, _ = hidden_states.size()
259
-
260
- query_states = self.q_proj(hidden_states)
261
- key_states = self.k_proj(hidden_states)
262
- value_states = self.v_proj(hidden_states)
263
-
264
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
265
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
266
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
267
-
268
- kv_seq_len = key_states.shape[-2]
269
- if past_key_value is not None:
270
- if self.layer_idx is None:
271
- raise ValueError(
272
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
273
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
274
- "with a layer index."
275
- )
276
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
277
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
278
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
279
-
280
- if past_key_value is not None:
281
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
282
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
283
-
284
- # repeat k/v heads if n_kv_heads < n_heads
285
- key_states = repeat_kv(key_states, self.num_key_value_groups)
286
- value_states = repeat_kv(value_states, self.num_key_value_groups)
287
-
288
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
289
-
290
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
291
- raise ValueError(
292
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
293
- f" {attn_weights.size()}"
294
- )
295
-
296
- if attention_mask is not None:
297
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
298
- raise ValueError(
299
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
300
- )
301
-
302
- attn_weights = attn_weights + attention_mask
303
-
304
- # upcast attention to fp32
305
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
306
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
307
- attn_output = torch.matmul(attn_weights, value_states)
308
-
309
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
310
- raise ValueError(
311
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
312
- f" {attn_output.size()}"
313
- )
314
-
315
- attn_output = attn_output.transpose(1, 2).contiguous()
316
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
317
-
318
- attn_output = self.o_proj(attn_output)
319
-
320
- if not output_attentions:
321
- attn_weights = None
322
-
323
- return attn_output, attn_weights, past_key_value
324
-
325
-
326
- class Qwen2FlashAttention2(Qwen2Attention):
327
- """
328
- Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
329
- as the weights of the module stays untouched. The only required change would be on the forward pass
330
- where it needs to correctly call the public API of flash attention and deal with padding tokens
331
- in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
332
- config.max_window_layers layers.
333
- """
334
-
335
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
336
- def __init__(self, *args, **kwargs):
337
- super().__init__(*args, **kwargs)
338
-
339
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
340
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
341
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
342
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
343
-
344
- def forward(
345
- self,
346
- hidden_states: torch.Tensor,
347
- attention_mask: Optional[torch.Tensor] = None,
348
- position_ids: Optional[torch.LongTensor] = None,
349
- past_key_value: Optional[Cache] = None,
350
- output_attentions: bool = False,
351
- use_cache: bool = False,
352
- **kwargs,
353
- ):
354
- if "padding_mask" in kwargs:
355
- warnings.warn(
356
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
357
- )
358
-
359
- # overwrite attention_mask with padding_mask
360
- attention_mask = kwargs.pop("padding_mask")
361
- bsz, q_len, _ = hidden_states.size()
362
-
363
- query_states = self.q_proj(hidden_states)
364
- key_states = self.k_proj(hidden_states)
365
- value_states = self.v_proj(hidden_states)
366
-
367
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
368
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
369
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
370
-
371
- kv_seq_len = key_states.shape[-2]
372
- if past_key_value is not None:
373
- if self.layer_idx is None:
374
- raise ValueError(
375
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
376
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
377
- "with a layer index."
378
- )
379
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
380
-
381
- # Because the input can be padded, the absolute sequence length depends on the max position id.
382
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
383
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
384
-
385
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
386
-
387
- use_sliding_windows = (
388
- _flash_supports_window_size
389
- and getattr(self.config, "sliding_window", None) is not None
390
- and kv_seq_len > self.config.sliding_window
391
- and self.config.use_sliding_window
392
- )
393
-
394
- if not _flash_supports_window_size:
395
- logger.warning_once(
396
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
397
- " make sure to upgrade flash-attn library."
398
- )
399
-
400
- if past_key_value is not None:
401
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
402
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
403
- if (
404
- getattr(self.config, "sliding_window", None) is not None
405
- and kv_seq_len > self.config.sliding_window
406
- and cache_has_contents
407
- ):
408
- slicing_tokens = 1 - self.config.sliding_window
409
-
410
- past_key = past_key_value[self.layer_idx][0]
411
- past_value = past_key_value[self.layer_idx][1]
412
-
413
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
414
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
415
-
416
- if past_key.shape[-2] != self.config.sliding_window - 1:
417
- raise ValueError(
418
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
419
- f" {past_key.shape}"
420
- )
421
-
422
- if attention_mask is not None:
423
- attention_mask = attention_mask[:, slicing_tokens:]
424
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
425
-
426
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
427
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
428
-
429
- # repeat k/v heads if n_kv_heads < n_heads
430
- key_states = repeat_kv(key_states, self.num_key_value_groups)
431
- value_states = repeat_kv(value_states, self.num_key_value_groups)
432
- dropout_rate = 0.0 if not self.training else self.attention_dropout
433
-
434
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
435
- # therefore the input hidden states gets silently casted in float32. Hence, we need
436
- # cast them back in float16 just to be sure everything works as expected.
437
- input_dtype = query_states.dtype
438
- if input_dtype == torch.float32:
439
- if torch.is_autocast_enabled():
440
- target_dtype = torch.get_autocast_gpu_dtype()
441
- # Handle the case where the model is quantized
442
- elif hasattr(self.config, "_pre_quantization_dtype"):
443
- target_dtype = self.config._pre_quantization_dtype
444
- else:
445
- target_dtype = self.q_proj.weight.dtype
446
-
447
- logger.warning_once(
448
- f"The input hidden states seems to be silently casted in float32, this might be related to"
449
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
450
- f" {target_dtype}."
451
- )
452
-
453
- query_states = query_states.to(target_dtype)
454
- key_states = key_states.to(target_dtype)
455
- value_states = value_states.to(target_dtype)
456
-
457
- # Reashape to the expected shape for Flash Attention
458
- query_states = query_states.transpose(1, 2)
459
- key_states = key_states.transpose(1, 2)
460
- value_states = value_states.transpose(1, 2)
461
-
462
- attn_output = self._flash_attention_forward(
463
- query_states,
464
- key_states,
465
- value_states,
466
- attention_mask,
467
- q_len,
468
- dropout=dropout_rate,
469
- use_sliding_windows=use_sliding_windows,
470
- )
471
-
472
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
473
- attn_output = self.o_proj(attn_output)
474
-
475
- if not output_attentions:
476
- attn_weights = None
477
-
478
- return attn_output, attn_weights, past_key_value
479
-
480
- def _flash_attention_forward(
481
- self,
482
- query_states,
483
- key_states,
484
- value_states,
485
- attention_mask,
486
- query_length,
487
- dropout=0.0,
488
- softmax_scale=None,
489
- use_sliding_windows=False,
490
- ):
491
- """
492
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
493
- first unpad the input, then computes the attention scores and pad the final attention scores.
494
-
495
- Args:
496
- query_states (`torch.Tensor`):
497
- Input query states to be passed to Flash Attention API
498
- key_states (`torch.Tensor`):
499
- Input key states to be passed to Flash Attention API
500
- value_states (`torch.Tensor`):
501
- Input value states to be passed to Flash Attention API
502
- attention_mask (`torch.Tensor`):
503
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
504
- position of padding tokens and 1 for the position of non-padding tokens.
505
- dropout (`float`):
506
- Attention dropout
507
- softmax_scale (`float`, *optional*):
508
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
509
- use_sliding_windows (`bool`, *optional*):
510
- Whether to activate sliding window attention.
511
- """
512
- if not self._flash_attn_uses_top_left_mask:
513
- causal = self.is_causal
514
- else:
515
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
516
- causal = self.is_causal and query_length != 1
517
-
518
- # Decide whether to use SWA or not by layer index.
519
- if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
520
- use_sliding_windows = False
521
-
522
- # Contains at least one padding token in the sequence
523
- if attention_mask is not None:
524
- batch_size = query_states.shape[0]
525
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
526
- query_states, key_states, value_states, attention_mask, query_length
527
- )
528
-
529
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
530
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
531
-
532
- if not use_sliding_windows:
533
- attn_output_unpad = flash_attn_varlen_func(
534
- query_states,
535
- key_states,
536
- value_states,
537
- cu_seqlens_q=cu_seqlens_q,
538
- cu_seqlens_k=cu_seqlens_k,
539
- max_seqlen_q=max_seqlen_in_batch_q,
540
- max_seqlen_k=max_seqlen_in_batch_k,
541
- dropout_p=dropout,
542
- softmax_scale=softmax_scale,
543
- causal=causal,
544
- )
545
- else:
546
- attn_output_unpad = flash_attn_varlen_func(
547
- query_states,
548
- key_states,
549
- value_states,
550
- cu_seqlens_q=cu_seqlens_q,
551
- cu_seqlens_k=cu_seqlens_k,
552
- max_seqlen_q=max_seqlen_in_batch_q,
553
- max_seqlen_k=max_seqlen_in_batch_k,
554
- dropout_p=dropout,
555
- softmax_scale=softmax_scale,
556
- causal=causal,
557
- window_size=(self.config.sliding_window, self.config.sliding_window),
558
- )
559
-
560
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
561
- else:
562
- if not use_sliding_windows:
563
- attn_output = flash_attn_func(
564
- query_states,
565
- key_states,
566
- value_states,
567
- dropout,
568
- softmax_scale=softmax_scale,
569
- causal=causal,
570
- )
571
- else:
572
- attn_output = flash_attn_func(
573
- query_states,
574
- key_states,
575
- value_states,
576
- dropout,
577
- softmax_scale=softmax_scale,
578
- causal=causal,
579
- window_size=(self.config.sliding_window, self.config.sliding_window),
580
- )
581
-
582
- return attn_output
583
-
584
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
585
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
586
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
587
-
588
- # On the first iteration we need to properly re-create the padding mask
589
- # by slicing it on the proper place
590
- if kv_seq_len != attention_mask.shape[-1]:
591
- attention_mask_num_tokens = attention_mask.shape[-1]
592
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
593
-
594
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
595
-
596
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
597
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
598
-
599
- if query_length == kv_seq_len:
600
- query_layer = index_first_axis(
601
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
602
- )
603
- cu_seqlens_q = cu_seqlens_k
604
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
605
- indices_q = indices_k
606
- elif query_length == 1:
607
- max_seqlen_in_batch_q = 1
608
- cu_seqlens_q = torch.arange(
609
- batch_size + 1, dtype=torch.int32, device=query_layer.device
610
- ) # There is a memcpy here, that is very bad.
611
- indices_q = cu_seqlens_q[:-1]
612
- query_layer = query_layer.squeeze(1)
613
- else:
614
- # The -q_len: slice assumes left padding.
615
- attention_mask = attention_mask[:, -query_length:]
616
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
617
-
618
- return (
619
- query_layer,
620
- key_layer,
621
- value_layer,
622
- indices_q,
623
- (cu_seqlens_q, cu_seqlens_k),
624
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
625
- )
626
-
627
-
628
- # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
629
- class Qwen2SdpaAttention(Qwen2Attention):
630
- """
631
- Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
632
- `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
633
- SDPA API.
634
- """
635
-
636
- # Adapted from Qwen2Attention.forward
637
- def forward(
638
- self,
639
- hidden_states: torch.Tensor,
640
- attention_mask: Optional[torch.Tensor] = None,
641
- position_ids: Optional[torch.LongTensor] = None,
642
- past_key_value: Optional[Cache] = None,
643
- output_attentions: bool = False,
644
- use_cache: bool = False,
645
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
646
- if output_attentions:
647
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
648
- logger.warning_once(
649
- "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
650
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
651
- )
652
- return super().forward(
653
- hidden_states=hidden_states,
654
- attention_mask=attention_mask,
655
- position_ids=position_ids,
656
- past_key_value=past_key_value,
657
- output_attentions=output_attentions,
658
- use_cache=use_cache,
659
- )
660
-
661
- bsz, q_len, _ = hidden_states.size()
662
-
663
- query_states = self.q_proj(hidden_states)
664
- key_states = self.k_proj(hidden_states)
665
- value_states = self.v_proj(hidden_states)
666
-
667
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
668
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
669
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
670
-
671
- kv_seq_len = key_states.shape[-2]
672
- if past_key_value is not None:
673
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
674
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
675
-
676
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
677
-
678
- if past_key_value is not None:
679
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
680
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
681
-
682
- key_states = repeat_kv(key_states, self.num_key_value_groups)
683
- value_states = repeat_kv(value_states, self.num_key_value_groups)
684
-
685
- if attention_mask is not None:
686
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
687
- raise ValueError(
688
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
689
- )
690
-
691
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
692
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
693
- if query_states.device.type == "cuda" and attention_mask is not None:
694
- query_states = query_states.contiguous()
695
- key_states = key_states.contiguous()
696
- value_states = value_states.contiguous()
697
-
698
- attn_output = torch.nn.functional.scaled_dot_product_attention(
699
- query_states,
700
- key_states,
701
- value_states,
702
- attn_mask=attention_mask,
703
- dropout_p=self.attention_dropout if self.training else 0.0,
704
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
705
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
706
- )
707
-
708
- attn_output = attn_output.transpose(1, 2).contiguous()
709
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
710
-
711
- attn_output = self.o_proj(attn_output)
712
-
713
- return attn_output, None, past_key_value
714
-
715
-
716
- QWEN2_ATTENTION_CLASSES = {
717
- "eager": Qwen2Attention,
718
- "flash_attention_2": Qwen2FlashAttention2,
719
- "sdpa": Qwen2SdpaAttention,
720
- }
721
-
722
-
723
- class Qwen2DecoderLayer(nn.Module):
724
- def __init__(self, config: Qwen2Config, layer_idx: int):
725
- super().__init__()
726
- self.hidden_size = config.hidden_size
727
-
728
- if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
729
- logger.warning_once(
730
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
731
- "unexpected results may be encountered."
732
- )
733
- self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
734
-
735
- self.mlp = Qwen2MLP(config)
736
- self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
737
- self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
738
-
739
- def forward(
740
- self,
741
- hidden_states: torch.Tensor,
742
- attention_mask: Optional[torch.Tensor] = None,
743
- position_ids: Optional[torch.LongTensor] = None,
744
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
745
- output_attentions: Optional[bool] = False,
746
- use_cache: Optional[bool] = False,
747
- **kwargs,
748
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
749
- if "padding_mask" in kwargs:
750
- warnings.warn(
751
- "Passing `padding_mask` is deprecated and will be removed in v4.37. "
752
- "Please make sure use `attention_mask` instead.`"
753
- )
754
- """
755
- Args:
756
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
757
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
758
- `(batch, sequence_length)` where padding elements are indicated by 0.
759
- output_attentions (`bool`, *optional*):
760
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
761
- returned tensors for more detail.
762
- use_cache (`bool`, *optional*):
763
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
764
- (see `past_key_values`).
765
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
766
- """
767
-
768
- residual = hidden_states
769
-
770
- hidden_states = self.input_layernorm(hidden_states)
771
-
772
- # Self Attention
773
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
774
- hidden_states=hidden_states,
775
- attention_mask=attention_mask,
776
- position_ids=position_ids,
777
- past_key_value=past_key_value,
778
- output_attentions=output_attentions,
779
- use_cache=use_cache,
780
- )
781
- hidden_states = residual + hidden_states
782
-
783
- # Fully Connected
784
- residual = hidden_states
785
- hidden_states = self.post_attention_layernorm(hidden_states)
786
- hidden_states = self.mlp(hidden_states)
787
- hidden_states = residual + hidden_states
788
-
789
- outputs = (hidden_states,)
790
-
791
- if output_attentions:
792
- outputs += (self_attn_weights,)
793
-
794
- if use_cache:
795
- outputs += (present_key_value,)
796
-
797
- return outputs
798
-
799
-
800
- QWEN2_START_DOCSTRING = r"""
801
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
802
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
803
- etc.)
804
-
805
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
806
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
807
- and behavior.
808
-
809
- Parameters:
810
- config ([`Qwen2Config`]):
811
- Model configuration class with all the parameters of the model. Initializing with a config file does not
812
- load the weights associated with the model, only the configuration. Check out the
813
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
814
- """
815
-
816
-
817
- @add_start_docstrings(
818
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
819
- QWEN2_START_DOCSTRING,
820
- )
821
- class Qwen2PreTrainedModel(PreTrainedModel):
822
- config_class = Qwen2Config
823
- base_model_prefix = "model"
824
- supports_gradient_checkpointing = True
825
- _no_split_modules = ["Qwen2DecoderLayer"]
826
- _skip_keys_device_placement = "past_key_values"
827
- _supports_flash_attn_2 = True
828
- _supports_sdpa = True
829
- _supports_cache_class = True
830
-
831
- def _init_weights(self, module):
832
- std = self.config.initializer_range
833
- if isinstance(module, nn.Linear):
834
- module.weight.data.normal_(mean=0.0, std=std)
835
- if module.bias is not None:
836
- module.bias.data.zero_()
837
- elif isinstance(module, nn.Embedding):
838
- module.weight.data.normal_(mean=0.0, std=std)
839
- if module.padding_idx is not None:
840
- module.weight.data[module.padding_idx].zero_()
841
-
842
-
843
- QWEN2_INPUTS_DOCSTRING = r"""
844
- Args:
845
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
846
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
847
- it.
848
-
849
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
850
- [`PreTrainedTokenizer.__call__`] for details.
851
-
852
- [What are input IDs?](../glossary#input-ids)
853
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
854
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
855
-
856
- - 1 for tokens that are **not masked**,
857
- - 0 for tokens that are **masked**.
858
-
859
- [What are attention masks?](../glossary#attention-mask)
860
-
861
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
862
- [`PreTrainedTokenizer.__call__`] for details.
863
-
864
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
865
- `past_key_values`).
866
-
867
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
868
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
869
- information on the default strategy.
870
-
871
- - 1 indicates the head is **not masked**,
872
- - 0 indicates the head is **masked**.
873
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
874
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
875
- config.n_positions - 1]`.
876
-
877
- [What are position IDs?](../glossary#position-ids)
878
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
879
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
880
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
881
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
882
-
883
- Two formats are allowed:
884
- - a [`~cache_utils.Cache`] instance;
885
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
886
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
887
- cache format.
888
-
889
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
890
- legacy cache format will be returned.
891
-
892
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
893
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
894
- of shape `(batch_size, sequence_length)`.
895
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
896
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
897
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
898
- model's internal embedding lookup matrix.
899
- use_cache (`bool`, *optional*):
900
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
901
- `past_key_values`).
902
- output_attentions (`bool`, *optional*):
903
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
904
- tensors for more detail.
905
- output_hidden_states (`bool`, *optional*):
906
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
907
- more detail.
908
- return_dict (`bool`, *optional*):
909
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
910
- """
911
-
912
-
913
- @add_start_docstrings(
914
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
915
- QWEN2_START_DOCSTRING,
916
- )
917
- class Qwen2Model(Qwen2PreTrainedModel):
918
- """
919
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
920
-
921
- Args:
922
- config: Qwen2Config
923
- """
924
-
925
- def __init__(self, config: Qwen2Config):
926
- super().__init__(config)
927
- self.padding_idx = config.pad_token_id
928
- self.vocab_size = config.vocab_size
929
-
930
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
931
- self.layers = nn.ModuleList(
932
- [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
933
- )
934
- self._attn_implementation = config._attn_implementation
935
- self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
936
-
937
- self.gradient_checkpointing = False
938
- # Initialize weights and apply final processing
939
- self.post_init()
940
-
941
- def get_input_embeddings(self):
942
- return self.embed_tokens
943
-
944
- def set_input_embeddings(self, value):
945
- self.embed_tokens = value
946
-
947
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
948
- def forward(
949
- self,
950
- input_ids: torch.LongTensor = None,
951
- attention_mask: Optional[torch.Tensor] = None,
952
- position_ids: Optional[torch.LongTensor] = None,
953
- past_key_values: Optional[List[torch.FloatTensor]] = None,
954
- inputs_embeds: Optional[torch.FloatTensor] = None,
955
- use_cache: Optional[bool] = None,
956
- output_attentions: Optional[bool] = None,
957
- output_hidden_states: Optional[bool] = None,
958
- return_dict: Optional[bool] = None,
959
- ) -> Union[Tuple, BaseModelOutputWithPast]:
960
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
961
- output_hidden_states = (
962
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
963
- )
964
- use_cache = use_cache if use_cache is not None else self.config.use_cache
965
-
966
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
967
-
968
- # retrieve input_ids and inputs_embeds
969
- if input_ids is not None and inputs_embeds is not None:
970
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
971
- elif input_ids is not None:
972
- batch_size, seq_length = input_ids.shape
973
- elif inputs_embeds is not None:
974
- batch_size, seq_length, _ = inputs_embeds.shape
975
- else:
976
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
977
-
978
- if self.gradient_checkpointing and self.training:
979
- if use_cache:
980
- logger.warning_once(
981
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
982
- )
983
- use_cache = False
984
-
985
- past_key_values_length = 0
986
-
987
- if use_cache:
988
- use_legacy_cache = not isinstance(past_key_values, Cache)
989
- if use_legacy_cache:
990
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
991
- past_key_values_length = past_key_values.get_usable_length(seq_length)
992
-
993
- if position_ids is None:
994
- device = input_ids.device if input_ids is not None else inputs_embeds.device
995
- position_ids = torch.arange(
996
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
997
- )
998
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
999
- else:
1000
- position_ids = position_ids.view(-1, seq_length).long()
1001
-
1002
- if inputs_embeds is None:
1003
- inputs_embeds = self.embed_tokens(input_ids)
1004
-
1005
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1006
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1007
- if is_padding_right:
1008
- raise ValueError(
1009
- "You are attempting to perform batched generation with padding_side='right'"
1010
- " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1011
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1012
- )
1013
-
1014
- if self._attn_implementation == "flash_attention_2":
1015
- # 2d mask is passed through the layers
1016
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1017
- elif self._attn_implementation == "sdpa" and not output_attentions:
1018
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1019
- # the manual implementation that requires a 4D causal mask in all cases.
1020
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1021
- attention_mask,
1022
- (batch_size, seq_length),
1023
- inputs_embeds,
1024
- past_key_values_length,
1025
- )
1026
- else:
1027
- # 4d mask is passed through the layers
1028
- attention_mask = _prepare_4d_causal_attention_mask(
1029
- attention_mask,
1030
- (batch_size, seq_length),
1031
- inputs_embeds,
1032
- past_key_values_length,
1033
- sliding_window=self.config.sliding_window,
1034
- )
1035
-
1036
- hidden_states = inputs_embeds
1037
-
1038
- # decoder layers
1039
- all_hidden_states = () if output_hidden_states else None
1040
- all_self_attns = () if output_attentions else None
1041
- next_decoder_cache = None
1042
-
1043
- for decoder_layer in self.layers:
1044
- if output_hidden_states:
1045
- all_hidden_states += (hidden_states,)
1046
-
1047
- if self.gradient_checkpointing and self.training:
1048
- layer_outputs = self._gradient_checkpointing_func(
1049
- decoder_layer.__call__,
1050
- hidden_states,
1051
- attention_mask,
1052
- position_ids,
1053
- past_key_values,
1054
- output_attentions,
1055
- use_cache,
1056
- )
1057
- else:
1058
- layer_outputs = decoder_layer(
1059
- hidden_states,
1060
- attention_mask=attention_mask,
1061
- position_ids=position_ids,
1062
- past_key_value=past_key_values,
1063
- output_attentions=output_attentions,
1064
- use_cache=use_cache,
1065
- )
1066
-
1067
- hidden_states = layer_outputs[0]
1068
-
1069
- if use_cache:
1070
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1071
-
1072
- if output_attentions:
1073
- all_self_attns += (layer_outputs[1],)
1074
-
1075
- hidden_states = self.norm(hidden_states)
1076
-
1077
- # add hidden states from the last decoder layer
1078
- if output_hidden_states:
1079
- all_hidden_states += (hidden_states,)
1080
-
1081
- next_cache = None
1082
- if use_cache:
1083
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1084
-
1085
- if not return_dict:
1086
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1087
- return BaseModelOutputWithPast(
1088
- last_hidden_state=hidden_states,
1089
- past_key_values=next_cache,
1090
- hidden_states=all_hidden_states,
1091
- attentions=all_self_attns,
1092
- )
1093
-
1094
-
1095
- class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1096
- _tied_weights_keys = ["lm_head.weight"]
1097
-
1098
- def __init__(self, config):
1099
- super().__init__(config)
1100
- self.model = Qwen2Model(config)
1101
- self.vocab_size = config.vocab_size
1102
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1103
-
1104
- # Initialize weights and apply final processing
1105
- self.post_init()
1106
-
1107
- def get_input_embeddings(self):
1108
- return self.model.embed_tokens
1109
-
1110
- def set_input_embeddings(self, value):
1111
- self.model.embed_tokens = value
1112
-
1113
- def get_output_embeddings(self):
1114
- return self.lm_head
1115
-
1116
- def set_output_embeddings(self, new_embeddings):
1117
- self.lm_head = new_embeddings
1118
-
1119
- def set_decoder(self, decoder):
1120
- self.model = decoder
1121
-
1122
- def get_decoder(self):
1123
- return self.model
1124
-
1125
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1126
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1127
- def forward(
1128
- self,
1129
- input_ids: torch.LongTensor = None,
1130
- attention_mask: Optional[torch.Tensor] = None,
1131
- position_ids: Optional[torch.LongTensor] = None,
1132
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1133
- inputs_embeds: Optional[torch.FloatTensor] = None,
1134
- labels: Optional[torch.LongTensor] = None,
1135
- use_cache: Optional[bool] = None,
1136
- output_attentions: Optional[bool] = None,
1137
- output_hidden_states: Optional[bool] = None,
1138
- return_dict: Optional[bool] = None,
1139
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1140
- r"""
1141
- Args:
1142
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1143
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1144
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1145
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1146
-
1147
- Returns:
1148
-
1149
- Example:
1150
-
1151
- ```python
1152
- >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1153
-
1154
- >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1155
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1156
-
1157
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1158
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1159
-
1160
- >>> # Generate
1161
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1162
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1163
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1164
- ```"""
1165
-
1166
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1167
- output_hidden_states = (
1168
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1169
- )
1170
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1171
-
1172
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1173
- outputs = self.model(
1174
- input_ids=input_ids,
1175
- attention_mask=attention_mask,
1176
- position_ids=position_ids,
1177
- past_key_values=past_key_values,
1178
- inputs_embeds=inputs_embeds,
1179
- use_cache=use_cache,
1180
- output_attentions=output_attentions,
1181
- output_hidden_states=output_hidden_states,
1182
- return_dict=return_dict,
1183
- )
1184
-
1185
- hidden_states = outputs[0]
1186
- logits = self.lm_head(hidden_states)
1187
- logits = logits.float()
1188
-
1189
- loss = None
1190
- if labels is not None:
1191
- # Shift so that tokens < n predict n
1192
- shift_logits = logits[..., :-1, :].contiguous()
1193
- shift_labels = labels[..., 1:].contiguous()
1194
- # Flatten the tokens
1195
- loss_fct = CrossEntropyLoss()
1196
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1197
- shift_labels = shift_labels.view(-1)
1198
- # Enable model parallelism
1199
- shift_labels = shift_labels.to(shift_logits.device)
1200
- loss = loss_fct(shift_logits, shift_labels)
1201
-
1202
- if not return_dict:
1203
- output = (logits,) + outputs[1:]
1204
- return (loss,) + output if loss is not None else output
1205
-
1206
- return CausalLMOutputWithPast(
1207
- loss=loss,
1208
- logits=logits,
1209
- past_key_values=outputs.past_key_values,
1210
- hidden_states=outputs.hidden_states,
1211
- attentions=outputs.attentions,
1212
- )
1213
-
1214
- def prepare_inputs_for_generation(
1215
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1216
- ):
1217
- # Omit tokens covered by past_key_values
1218
- if past_key_values is not None:
1219
- if isinstance(past_key_values, Cache):
1220
- cache_length = past_key_values.get_seq_length()
1221
- past_length = past_key_values.seen_tokens
1222
- max_cache_length = past_key_values.get_max_length()
1223
- else:
1224
- cache_length = past_length = past_key_values[0][0].shape[2]
1225
- max_cache_length = None
1226
-
1227
- # Keep only the unprocessed tokens:
1228
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1229
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1230
- # input)
1231
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1232
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1233
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1234
- # input_ids based on the past_length.
1235
- elif past_length < input_ids.shape[1]:
1236
- input_ids = input_ids[:, past_length:]
1237
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1238
- else:
1239
- remove_prefix_length = input_ids.shape[1] - 1
1240
- input_ids = input_ids[:, remove_prefix_length:]
1241
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1242
- if (
1243
- max_cache_length is not None
1244
- and attention_mask is not None
1245
- and cache_length + input_ids.shape[1] > max_cache_length
1246
- ):
1247
- attention_mask = attention_mask[:, -max_cache_length:]
1248
-
1249
- position_ids = kwargs.get("position_ids", None)
1250
- if attention_mask is not None and position_ids is None:
1251
- # create position_ids on the fly for batch generation
1252
- position_ids = attention_mask.long().cumsum(-1) - 1
1253
- position_ids.masked_fill_(attention_mask == 0, 1)
1254
- if past_key_values:
1255
- position_ids = position_ids[:, -input_ids.shape[1] :]
1256
-
1257
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1258
- if inputs_embeds is not None and past_key_values is None:
1259
- model_inputs = {"inputs_embeds": inputs_embeds}
1260
- else:
1261
- model_inputs = {"input_ids": input_ids}
1262
-
1263
- model_inputs.update(
1264
- {
1265
- "position_ids": position_ids,
1266
- "past_key_values": past_key_values,
1267
- "use_cache": kwargs.get("use_cache"),
1268
- "attention_mask": attention_mask,
1269
- }
1270
- )
1271
- return model_inputs
1272
-
1273
- @staticmethod
1274
- def _reorder_cache(past_key_values, beam_idx):
1275
- reordered_past = ()
1276
- for layer_past in past_key_values:
1277
- reordered_past += (
1278
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1279
- )
1280
- return reordered_past
1281
-
1282
-
1283
- @add_start_docstrings(
1284
- """
1285
- The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1286
-
1287
- [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1288
- (e.g. GPT-2) do.
1289
-
1290
- Since it does classification on the last token, it requires to know the position of the last token. If a
1291
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1292
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1293
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1294
- each row of the batch).
1295
- """,
1296
- QWEN2_START_DOCSTRING,
1297
- )
1298
- class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1299
- def __init__(self, config):
1300
- super().__init__(config)
1301
- self.num_labels = config.num_labels
1302
- self.model = Qwen2Model(config)
1303
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1304
-
1305
- # Initialize weights and apply final processing
1306
- self.post_init()
1307
-
1308
- def get_input_embeddings(self):
1309
- return self.model.embed_tokens
1310
-
1311
- def set_input_embeddings(self, value):
1312
- self.model.embed_tokens = value
1313
-
1314
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1315
- def forward(
1316
- self,
1317
- input_ids: torch.LongTensor = None,
1318
- attention_mask: Optional[torch.Tensor] = None,
1319
- position_ids: Optional[torch.LongTensor] = None,
1320
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1321
- inputs_embeds: Optional[torch.FloatTensor] = None,
1322
- labels: Optional[torch.LongTensor] = None,
1323
- use_cache: Optional[bool] = None,
1324
- output_attentions: Optional[bool] = None,
1325
- output_hidden_states: Optional[bool] = None,
1326
- return_dict: Optional[bool] = None,
1327
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1328
- r"""
1329
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1330
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1331
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1332
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1333
- """
1334
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1335
-
1336
- transformer_outputs = self.model(
1337
- input_ids,
1338
- attention_mask=attention_mask,
1339
- position_ids=position_ids,
1340
- past_key_values=past_key_values,
1341
- inputs_embeds=inputs_embeds,
1342
- use_cache=use_cache,
1343
- output_attentions=output_attentions,
1344
- output_hidden_states=output_hidden_states,
1345
- return_dict=return_dict,
1346
- )
1347
- hidden_states = transformer_outputs[0]
1348
- logits = self.score(hidden_states)
1349
-
1350
- if input_ids is not None:
1351
- batch_size = input_ids.shape[0]
1352
- else:
1353
- batch_size = inputs_embeds.shape[0]
1354
-
1355
- if self.config.pad_token_id is None and batch_size != 1:
1356
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1357
- if self.config.pad_token_id is None:
1358
- sequence_lengths = -1
1359
- else:
1360
- if input_ids is not None:
1361
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1362
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1363
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1364
- sequence_lengths = sequence_lengths.to(logits.device)
1365
- else:
1366
- sequence_lengths = -1
1367
-
1368
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1369
-
1370
- loss = None
1371
- if labels is not None:
1372
- labels = labels.to(logits.device)
1373
- if self.config.problem_type is None:
1374
- if self.num_labels == 1:
1375
- self.config.problem_type = "regression"
1376
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1377
- self.config.problem_type = "single_label_classification"
1378
- else:
1379
- self.config.problem_type = "multi_label_classification"
1380
-
1381
- if self.config.problem_type == "regression":
1382
- loss_fct = MSELoss()
1383
- if self.num_labels == 1:
1384
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1385
- else:
1386
- loss = loss_fct(pooled_logits, labels)
1387
- elif self.config.problem_type == "single_label_classification":
1388
- loss_fct = CrossEntropyLoss()
1389
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1390
- elif self.config.problem_type == "multi_label_classification":
1391
- loss_fct = BCEWithLogitsLoss()
1392
- loss = loss_fct(pooled_logits, labels)
1393
- if not return_dict:
1394
- output = (pooled_logits,) + transformer_outputs[1:]
1395
- return ((loss,) + output) if loss is not None else output
1396
-
1397
- return SequenceClassifierOutputWithPast(
1398
- loss=loss,
1399
- logits=pooled_logits,
1400
- past_key_values=transformer_outputs.past_key_values,
1401
- hidden_states=transformer_outputs.hidden_states,
1402
- attentions=transformer_outputs.attentions,
1403
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/qwen2/tokenization_qwen2.py DELETED
@@ -1,345 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Tokenization classes for Qwen2."""
16
-
17
- import json
18
- import os
19
- import unicodedata
20
- from functools import lru_cache
21
- from typing import Optional, Tuple
22
-
23
- import regex as re
24
-
25
- from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
26
- from transformers.utils import logging
27
-
28
-
29
- logger = logging.get_logger(__name__)
30
-
31
- VOCAB_FILES_NAMES = {
32
- "vocab_file": "vocab.json",
33
- "merges_file": "merges.txt",
34
- }
35
-
36
- PRETRAINED_VOCAB_FILES_MAP = {
37
- "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
38
- "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
39
- }
40
-
41
- MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
-
43
- PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
44
-
45
-
46
- @lru_cache()
47
- # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
48
- def bytes_to_unicode():
49
- """
50
- Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
51
- characters the bpe code barfs on.
52
-
53
- The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
54
- if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
55
- decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
56
- tables between utf-8 bytes and unicode strings.
57
- """
58
- bs = (
59
- list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
60
- )
61
- cs = bs[:]
62
- n = 0
63
- for b in range(2**8):
64
- if b not in bs:
65
- bs.append(b)
66
- cs.append(2**8 + n)
67
- n += 1
68
- cs = [chr(n) for n in cs]
69
- return dict(zip(bs, cs))
70
-
71
-
72
- # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
73
- def get_pairs(word):
74
- """
75
- Return set of symbol pairs in a word.
76
-
77
- Word is represented as tuple of symbols (symbols being variable-length strings).
78
- """
79
- pairs = set()
80
- prev_char = word[0]
81
- for char in word[1:]:
82
- pairs.add((prev_char, char))
83
- prev_char = char
84
- return pairs
85
-
86
-
87
- class Qwen2Tokenizer(PreTrainedTokenizer):
88
- """
89
- Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
90
-
91
- Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
92
- be encoded differently whether it is at the beginning of the sentence (without space) or not:
93
-
94
- ```python
95
- >>> from transformers import Qwen2Tokenizer
96
-
97
- >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
98
- >>> tokenizer("Hello world")["input_ids"]
99
- [9707, 1879]
100
-
101
- >>> tokenizer(" Hello world")["input_ids"]
102
- [21927, 1879]
103
- ```
104
- This is expected.
105
-
106
- You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
107
-
108
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
109
- this superclass for more information regarding those methods.
110
-
111
- Args:
112
- vocab_file (`str`):
113
- Path to the vocabulary file.
114
- merges_file (`str`):
115
- Path to the merges file.
116
- errors (`str`, *optional*, defaults to `"replace"`):
117
- Paradigm to follow when decoding bytes to UTF-8. See
118
- [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
119
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
120
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
121
- token instead.
122
- bos_token (`str`, *optional*):
123
- The beginning of sequence token. Not applicable for this tokenizer.
124
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
- The end of sequence token.
126
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
127
- The token used for padding, for example when batching sequences of different lengths.
128
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
129
- Whether or not the model should cleanup the spaces that were added when splitting the input text during the
130
- tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
131
- split_special_tokens (`bool`, *optional*, defaults to `False`):
132
- Whether or not the special tokens should be split during the tokenization process. The default behavior is
133
- to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
134
- ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
135
- '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
136
- """
137
-
138
- vocab_files_names = VOCAB_FILES_NAMES
139
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
140
- max_model_input_sizes = MAX_MODEL_INPUT_SIZES
141
- model_input_names = ["input_ids", "attention_mask"]
142
-
143
- def __init__(
144
- self,
145
- vocab_file,
146
- merges_file,
147
- errors="replace",
148
- unk_token="<|endoftext|>",
149
- bos_token=None,
150
- eos_token="<|endoftext|>",
151
- pad_token="<|endoftext|>",
152
- clean_up_tokenization_spaces=False,
153
- split_special_tokens=False,
154
- **kwargs,
155
- ):
156
- # Qwen vocab does not contain control tokens; added tokens need to be special
157
- bos_token = (
158
- AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
159
- if isinstance(bos_token, str)
160
- else bos_token
161
- )
162
- eos_token = (
163
- AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
164
- if isinstance(eos_token, str)
165
- else eos_token
166
- )
167
- unk_token = (
168
- AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
169
- if isinstance(unk_token, str)
170
- else unk_token
171
- )
172
- pad_token = (
173
- AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
174
- if isinstance(pad_token, str)
175
- else pad_token
176
- )
177
-
178
- with open(vocab_file, encoding="utf-8") as vocab_handle:
179
- self.encoder = json.load(vocab_handle)
180
- self.decoder = {v: k for k, v in self.encoder.items()}
181
- self.errors = errors # how to handle errors in decoding
182
- self.byte_encoder = bytes_to_unicode()
183
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
184
- bpe_merges = []
185
- with open(merges_file, encoding="utf-8") as merges_handle:
186
- for line in merges_handle:
187
- line = line.strip()
188
- if not line or line.startswith("#"):
189
- continue
190
- bpe_merges.append(tuple(line.split()))
191
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
192
- # NOTE: the cache can grow without bound and will get really large for long running processes
193
- # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
194
- # not a memory leak but appears as one.
195
- # GPT2Tokenizer has the same problem, so let's be consistent.
196
- self.cache = {}
197
-
198
- self.pat = re.compile(PRETOKENIZE_REGEX)
199
-
200
- if kwargs.get("add_prefix_space", False):
201
- logger.warning_once(
202
- f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
203
- )
204
-
205
- super().__init__(
206
- errors=errors,
207
- bos_token=bos_token,
208
- eos_token=eos_token,
209
- pad_token=pad_token,
210
- unk_token=unk_token,
211
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
212
- split_special_tokens=split_special_tokens,
213
- **kwargs,
214
- )
215
-
216
- @property
217
- def vocab_size(self) -> int:
218
- return len(self.encoder)
219
-
220
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
221
- def get_vocab(self):
222
- return dict(self.encoder, **self.added_tokens_encoder)
223
-
224
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
225
- def bpe(self, token):
226
- if token in self.cache:
227
- return self.cache[token]
228
- word = tuple(token)
229
- pairs = get_pairs(word)
230
-
231
- if not pairs:
232
- return token
233
-
234
- while True:
235
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
236
- if bigram not in self.bpe_ranks:
237
- break
238
- first, second = bigram
239
- new_word = []
240
- i = 0
241
- while i < len(word):
242
- try:
243
- j = word.index(first, i)
244
- except ValueError:
245
- new_word.extend(word[i:])
246
- break
247
- else:
248
- new_word.extend(word[i:j])
249
- i = j
250
-
251
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
252
- new_word.append(first + second)
253
- i += 2
254
- else:
255
- new_word.append(word[i])
256
- i += 1
257
- new_word = tuple(new_word)
258
- word = new_word
259
- if len(word) == 1:
260
- break
261
- else:
262
- pairs = get_pairs(word)
263
- word = " ".join(word)
264
- self.cache[token] = word
265
- return word
266
-
267
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
268
- def _tokenize(self, text):
269
- """Tokenize a string."""
270
- bpe_tokens = []
271
- for token in re.findall(self.pat, text):
272
- token = "".join(
273
- self.byte_encoder[b] for b in token.encode("utf-8")
274
- ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
275
- bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
276
- return bpe_tokens
277
-
278
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
279
- def _convert_token_to_id(self, token):
280
- """Converts a token (str) in an id using the vocab."""
281
- return self.encoder.get(token, self.encoder.get(self.unk_token))
282
-
283
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
284
- def _convert_id_to_token(self, index):
285
- """Converts an index (integer) in a token (str) using the vocab."""
286
- return self.decoder.get(index)
287
-
288
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
289
- def convert_tokens_to_string(self, tokens):
290
- """Converts a sequence of tokens (string) in a single string."""
291
- text = "".join(tokens)
292
- text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
293
- return text
294
-
295
- def decode(
296
- self,
297
- token_ids,
298
- skip_special_tokens: bool = False,
299
- clean_up_tokenization_spaces: Optional[bool] = False,
300
- spaces_between_special_tokens: bool = False,
301
- **kwargs,
302
- ) -> str:
303
- # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
304
- # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
305
- return super().decode(
306
- token_ids,
307
- skip_special_tokens=skip_special_tokens,
308
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
309
- spaces_between_special_tokens=spaces_between_special_tokens,
310
- **kwargs,
311
- )
312
-
313
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
314
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
- if not os.path.isdir(save_directory):
316
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
317
- return
318
- vocab_file = os.path.join(
319
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
320
- )
321
- merge_file = os.path.join(
322
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
323
- )
324
-
325
- with open(vocab_file, "w", encoding="utf-8") as f:
326
- f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
327
-
328
- index = 0
329
- with open(merge_file, "w", encoding="utf-8") as writer:
330
- writer.write("#version: 0.2\n")
331
- for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
332
- if index != token_index:
333
- logger.warning(
334
- f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
335
- " Please check that the tokenizer is not corrupted!"
336
- )
337
- index = token_index
338
- writer.write(" ".join(bpe_tokens) + "\n")
339
- index += 1
340
-
341
- return vocab_file, merge_file
342
-
343
- def prepare_for_tokenization(self, text, **kwargs):
344
- text = unicodedata.normalize("NFC", text)
345
- return (text, kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/qwen2/tokenization_qwen2_fast.py DELETED
@@ -1,143 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Tokenization classes for Qwen2."""
16
-
17
- from typing import Optional, Tuple
18
-
19
- from transformers.tokenization_utils import AddedToken
20
- from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
- from transformers.utils import logging
22
- from .tokenization_qwen2 import Qwen2Tokenizer
23
-
24
-
25
- logger = logging.get_logger(__name__)
26
-
27
- VOCAB_FILES_NAMES = {
28
- "vocab_file": "vocab.json",
29
- "merges_file": "merges.txt",
30
- "tokenizer_file": "tokenizer.json",
31
- }
32
-
33
- PRETRAINED_VOCAB_FILES_MAP = {
34
- "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
35
- "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
36
- "tokenizer_file": {
37
- "qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/tokenizer.json"
38
- },
39
- }
40
-
41
- MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
-
43
-
44
- class Qwen2TokenizerFast(PreTrainedTokenizerFast):
45
- """
46
- Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
47
- Byte-Pair-Encoding.
48
-
49
- Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
50
- be encoded differently whether it is at the beginning of the sentence (without space) or not:
51
-
52
- ```python
53
- >>> from transformers import Qwen2TokenizerFast
54
-
55
- >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
56
- >>> tokenizer("Hello world")["input_ids"]
57
- [9707, 1879]
58
-
59
- >>> tokenizer(" Hello world")["input_ids"]
60
- [21927, 1879]
61
- ```
62
- This is expected.
63
-
64
- This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
65
- refer to this superclass for more information regarding those methods.
66
-
67
- Args:
68
- vocab_file (`str`, *optional*):
69
- Path to the vocabulary file.
70
- merges_file (`str`, *optional*):
71
- Path to the merges file.
72
- tokenizer_file (`str`, *optional*):
73
- Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
74
- contains everything needed to load the tokenizer.
75
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
77
- token instead. Not applicable to this tokenizer.
78
- bos_token (`str`, *optional*):
79
- The beginning of sequence token. Not applicable for this tokenizer.
80
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
81
- The end of sequence token.
82
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
83
- The token used for padding, for example when batching sequences of different lengths.
84
- """
85
-
86
- vocab_files_names = VOCAB_FILES_NAMES
87
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
88
- max_model_input_sizes = MAX_MODEL_INPUT_SIZES
89
- model_input_names = ["input_ids", "attention_mask"]
90
- slow_tokenizer_class = Qwen2Tokenizer
91
-
92
- def __init__(
93
- self,
94
- vocab_file=None,
95
- merges_file=None,
96
- tokenizer_file=None,
97
- unk_token="<|endoftext|>",
98
- bos_token=None,
99
- eos_token="<|endoftext|>",
100
- pad_token="<|endoftext|>",
101
- **kwargs,
102
- ):
103
- # We need to at least pass vocab_file and merges_file to base class
104
- # in case a slow tokenizer needs to be initialized; other can be
105
- # configured through files.
106
- # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
107
-
108
- bos_token = (
109
- AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
110
- if isinstance(bos_token, str)
111
- else bos_token
112
- )
113
- eos_token = (
114
- AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
115
- if isinstance(eos_token, str)
116
- else eos_token
117
- )
118
- unk_token = (
119
- AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
120
- if isinstance(unk_token, str)
121
- else unk_token
122
- )
123
- pad_token = (
124
- AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
125
- if isinstance(pad_token, str)
126
- else pad_token
127
- )
128
-
129
- super().__init__(
130
- vocab_file,
131
- merges_file,
132
- tokenizer_file=tokenizer_file,
133
- unk_token=unk_token,
134
- bos_token=bos_token,
135
- eos_token=eos_token,
136
- pad_token=pad_token,
137
- **kwargs,
138
- )
139
-
140
- # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
141
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
142
- files = self._tokenizer.model.save(save_directory, name=filename_prefix)
143
- return tuple(files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/stable_lm/configuration_stablelm_epoch.py DELETED
@@ -1,113 +0,0 @@
1
- # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """ StableLM Epoch model configuration"""
15
- from transformers import PretrainedConfig
16
- from transformers.utils import logging
17
-
18
-
19
- logger = logging.get_logger(__name__)
20
-
21
-
22
- class StableLMEpochConfig(PretrainedConfig):
23
- r"""
24
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
- documentation from [`PretrainedConfig`] for more information.
26
-
27
- Args:
28
- vocab_size (`int`, *optional*, defaults to 50_304):
29
- Vocabulary size of the StableLM model. Defines the number of different tokens that
30
- can be represented by the `inputs_ids` passed when calling [`StableLMEpochModel`].
31
- intermediate_size (`int`, *optional*, defaults to 6912):
32
- Dimension of the MLP representations.
33
- hidden_size (`int`, *optional*, defaults to 2560):
34
- Dimension of the decoder layers and the pooler layer.
35
- num_hidden_layers (`int`, *optional*, defaults to 32):
36
- Number of hidden layers in the Transformer decoder.
37
- num_attention_heads (`int`, *optional*, defaults to 32):
38
- Number of attention heads for each attention layer in the Transformer encoder.
39
- num_key_value_heads (`int`, *optional*):
40
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
41
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
42
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
43
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
44
- by meanpooling all the original heads within that group. For more details checkout [this
45
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
46
- `num_attention_heads`.
47
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
48
- The non-linear activation function (function or string).
49
- rope_pct (`float`, *optional*, defaults to 1.0):
50
- Percentage of hidden dimensions to allocate to rotary embeddings.
51
- rope_theta (`float`, *optional*, defaults to 10000.0):
52
- The base period of the RoPE embeddings.
53
- max_position_embeddings (`int`, *optional*, defaults to 2048):
54
- The maximum sequence length that this model might ever be used with.
55
- Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
56
- initializer_range (`float`, *optional*, defaults to 1e-5):
57
- The standard deviation of the truncated_normal_initializer for initializing
58
- all weight matrices.
59
- norm_eps (`float`, *optional*, defaults to 1e-8):
60
- The epsilon used by the normalization layers.
61
- use_cache (`bool`, *optional*, defaults to `True`):
62
- Whether or not the model should return the last key/values attentions
63
- (not used by all models). Only relevant if `config.is_decoder=True`.
64
- use_qkv_bias (`bool`, *optional*, defaults to `True`):
65
- Whether or not the model should use bias for qkv layers.
66
- tie_word_embeddings(`bool`, *optional*, defaults to `False`):
67
- Whether to tie weight embeddings
68
- """
69
- model_type = "stablelm_epoch"
70
- keys_to_ignore_at_inference = ["past_key_values"]
71
-
72
- def __init__(
73
- self,
74
- vocab_size=50_304,
75
- intermediate_size=6912,
76
- hidden_size=2560,
77
- num_hidden_layers=32,
78
- num_attention_heads=32,
79
- num_key_value_heads=32,
80
- hidden_act="silu",
81
- rope_pct=0.25,
82
- rope_theta=10_000,
83
- max_position_embeddings=4096,
84
- initializer_range=0.02,
85
- norm_eps=1.0e-5,
86
- use_cache=True,
87
- use_qkv_bias=True,
88
- bos_token_id=0,
89
- eos_token_id=2,
90
- tie_word_embeddings=False,
91
- **kwargs,
92
- ):
93
- self.vocab_size = vocab_size
94
- self.max_position_embeddings = max_position_embeddings
95
- self.intermediate_size = intermediate_size
96
- self.hidden_size = hidden_size
97
- self.num_hidden_layers = num_hidden_layers
98
- self.num_attention_heads = num_attention_heads
99
- self.num_key_value_heads = num_key_value_heads
100
- self.hidden_act = hidden_act
101
- self.rope_pct = rope_pct
102
- self.rope_theta = rope_theta
103
- self.initializer_range = initializer_range
104
- self.norm_eps = norm_eps
105
- self.use_cache = use_cache
106
- self.use_qkv_bias = use_qkv_bias
107
- self.tie_word_embeddings = tie_word_embeddings
108
- super().__init__(
109
- bos_token_id=bos_token_id,
110
- eos_token_id=eos_token_id,
111
- tie_word_embeddings=tie_word_embeddings,
112
- **kwargs,
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/language_model/stable_lm/modeling_stablelm_epoch.py DELETED
@@ -1,917 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # This code is based off the following work:
17
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
19
- """ PyTorch StableLM Epoch model. """
20
- from typing import Optional, Tuple, Union
21
- import math
22
- import warnings
23
-
24
- import torch
25
- import torch.nn.functional as F
26
- import torch.utils.checkpoint
27
- from torch import nn
28
- from torch.nn import CrossEntropyLoss
29
-
30
- from transformers.cache_utils import Cache
31
- from transformers.modeling_outputs import (
32
- BaseModelOutputWithPast,
33
- CausalLMOutputWithPast,
34
- )
35
- from transformers.modeling_utils import PreTrainedModel
36
- from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
37
-
38
- from .configuration_stablelm_epoch import StableLMEpochConfig
39
-
40
- try:
41
- from flash_attn import flash_attn_func, flash_attn_varlen_func
42
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
43
- except:
44
- flash_attn_func, flash_attn_varlen_func = None, None
45
- index_first_axis, pad_input, unpad_input = None, None, None
46
-
47
-
48
- logger = logging.get_logger(__name__)
49
-
50
-
51
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
52
- def _get_unpad_data(attention_mask):
53
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
54
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
55
- max_seqlen_in_batch = seqlens_in_batch.max().item()
56
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
57
- return (
58
- indices,
59
- cu_seqlens,
60
- max_seqlen_in_batch,
61
- )
62
-
63
-
64
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
- def _make_causal_mask(
66
- input_ids_shape: torch.Size,
67
- dtype: torch.dtype,
68
- device: torch.device,
69
- past_key_values_length: int = 0,
70
- ):
71
- """Make causal mask used for bi-directional self-attention."""
72
- batch_size, tgt_len = input_ids_shape
73
- mask = torch.full((tgt_len, tgt_len), torch.finfo(torch.float16).min, device=device)
74
- mask_cond = torch.arange(mask.size(-1), device=device)
75
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
76
- mask = mask.to(dtype)
77
- if past_key_values_length > 0:
78
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
79
- return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length)
80
-
81
-
82
- # Copied from transformers.models.bart.modeling_bart._expand_mask
83
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
84
- """Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, tgt_seq_len, src_seq_len]`."""
85
- batch_size, src_len = mask.size()
86
- tgt_len = tgt_len if tgt_len is not None else src_len
87
-
88
- expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
89
- inverted_mask = 1.0 - expanded_mask
90
-
91
- return inverted_mask.masked_fill(
92
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
93
- )
94
-
95
-
96
- class RotaryEmbedding(nn.Module):
97
- def __init__(
98
- self,
99
- dim: int,
100
- max_position_embeddings: int,
101
- base: int = 10_000,
102
- device: Optional[torch.device] = None,
103
- ):
104
- super().__init__()
105
-
106
- self.dim = dim
107
- self.max_position_embeddings = max_position_embeddings
108
- self.base = base
109
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
110
- self.register_buffer("inv_freq", inv_freq, persistent=False)
111
-
112
- # Build here to make `torch.jit.trace` work.
113
- self._set_cos_sin_cache(
114
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype(),
115
- )
116
-
117
- def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
118
- self.max_seq_len_cached = seq_len
119
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
120
-
121
- # Don't do einsum, it converts fp32 to fp16 under AMP
122
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
123
- freqs = torch.outer(t, self.inv_freq)
124
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
125
- emb = torch.cat((freqs, freqs), dim=-1)
126
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
127
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
128
-
129
- def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
130
- # x: [batch_size, num_heads, seq_len, head_size]
131
- if seq_len > self.max_seq_len_cached:
132
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype())
133
- return (
134
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
135
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
136
- )
137
-
138
-
139
- def rotate_half(x: torch.Tensor):
140
- """Rotates half the hidden dims of the input."""
141
- x1, x2 = torch.chunk(x, 2, dim=-1)
142
- return torch.cat((-x2, x1), dim=-1)
143
-
144
-
145
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
146
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
147
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
148
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
149
- cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
150
- sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
151
- q_embed = (q * cos) + (rotate_half(q) * sin)
152
- k_embed = (k * cos) + (rotate_half(k) * sin)
153
- return q_embed, k_embed
154
-
155
-
156
- class MLP(nn.Module):
157
- def __init__(self, config: StableLMEpochConfig):
158
- super().__init__()
159
- self.config = config
160
- self.hidden_size = config.hidden_size
161
- self.intermediate_size = config.intermediate_size
162
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
163
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
164
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
165
- self.act_fn = nn.SiLU()
166
-
167
- def forward(self, x: torch.Tensor) -> torch.Tensor:
168
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
169
-
170
-
171
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
172
- """
173
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
174
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
175
- """
176
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
177
- if n_rep == 1:
178
- return hidden_states
179
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
180
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
181
-
182
-
183
- class Attention(nn.Module):
184
- def __init__(self, config: StableLMEpochConfig):
185
- super().__init__()
186
- self.config = config
187
- self.hidden_size = config.hidden_size
188
- self.num_heads = config.num_attention_heads
189
- self.head_dim = self.hidden_size // self.num_heads
190
- self.num_key_value_heads = config.num_key_value_heads
191
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
- self.max_position_embeddings = config.max_position_embeddings
193
- self.is_causal = True
194
-
195
- if (self.head_dim * self.num_heads) != self.hidden_size:
196
- raise ValueError(
197
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
198
- f" and `num_heads`: {self.num_heads})."
199
- )
200
-
201
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
202
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
203
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
204
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
205
-
206
- self._init_rope()
207
-
208
- def _init_rope(self):
209
- self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
210
- self.rotary_emb = RotaryEmbedding(
211
- self.rotary_ndims,
212
- max_position_embeddings=self.config.max_position_embeddings,
213
- base=self.config.rope_theta,
214
- )
215
-
216
- def forward(
217
- self,
218
- hidden_states: torch.FloatTensor,
219
- attention_mask: torch.FloatTensor,
220
- position_ids: torch.LongTensor,
221
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
222
- output_attentions: Optional[bool] = False,
223
- use_cache: Optional[bool] = False,
224
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
225
- bsz, q_len, _ = hidden_states.size()
226
-
227
- query_states = self.q_proj(hidden_states)
228
- key_states = self.k_proj(hidden_states)
229
- value_states = self.v_proj(hidden_states)
230
-
231
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
232
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
233
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
234
-
235
- query_rot = query_states[..., : self.rotary_ndims]
236
- query_pass = query_states[..., self.rotary_ndims :]
237
- key_rot = key_states[..., : self.rotary_ndims]
238
- key_pass = key_states[..., self.rotary_ndims :]
239
-
240
- kv_seq_len = key_states.shape[-2]
241
- if past_key_value is not None:
242
- kv_seq_len += past_key_value[0].shape[-2]
243
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
244
- query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
245
-
246
- # [batch_size, num_heads, seq_len, head_dim]
247
- query_states = torch.cat((query_states, query_pass), dim=-1)
248
- key_states = torch.cat((key_states, key_pass), dim=-1)
249
-
250
- if past_key_value is not None:
251
- # Reuse k, v, self_attention
252
- key_states = torch.cat((past_key_value[0], key_states), dim=2)
253
- value_states = torch.cat((past_key_value[1], value_states), dim=2)
254
-
255
- past_key_value = (key_states, value_states) if use_cache else None
256
-
257
- # Repeat k/v heads if n_kv_heads < n_heads
258
- key_states = repeat_kv(key_states, self.num_key_value_groups)
259
- value_states = repeat_kv(value_states, self.num_key_value_groups)
260
-
261
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
262
-
263
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
264
- raise ValueError(
265
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
266
- f" {attn_weights.size()}"
267
- )
268
-
269
- if attention_mask is not None:
270
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
271
- raise ValueError(
272
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
273
- )
274
- attn_weights = attn_weights + attention_mask
275
-
276
- # Upcast attention to fp32
277
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
278
- attn_output = torch.matmul(attn_weights, value_states)
279
-
280
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
281
- raise ValueError(
282
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
283
- f" {attn_output.size()}"
284
- )
285
-
286
- # Merge heads
287
- attn_output = attn_output.transpose(1, 2).contiguous()
288
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
289
-
290
- # Final linear projection
291
- attn_output = self.o_proj(attn_output)
292
-
293
- if not output_attentions:
294
- attn_weights = None
295
-
296
- return attn_output, attn_weights, past_key_value
297
-
298
-
299
- class FlashAttention2(Attention):
300
- """
301
- Reference: https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/models/llama/modeling_llama.py#L456
302
- """
303
-
304
- def __init__(self, *args, **kwargs):
305
- super().__init__(*args, **kwargs)
306
-
307
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
308
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
309
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
310
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
311
-
312
- def forward(
313
- self,
314
- hidden_states: torch.Tensor,
315
- attention_mask: Optional[torch.LongTensor] = None,
316
- position_ids: Optional[torch.LongTensor] = None,
317
- past_key_value: Optional[Cache] = None,
318
- output_attentions: bool = False,
319
- use_cache: bool = False,
320
- **kwargs,
321
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
322
- # FlashAttention2 attention does not support output_attentions
323
- if "padding_mask" in kwargs:
324
- warnings.warn(
325
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
326
- )
327
-
328
- # overwrite attention_mask with padding_mask
329
- attention_mask = kwargs.pop("padding_mask")
330
-
331
- output_attentions = False
332
-
333
- bsz, q_len, _ = hidden_states.size()
334
-
335
- query_states = self.q_proj(hidden_states)
336
- key_states = self.k_proj(hidden_states)
337
- value_states = self.v_proj(hidden_states)
338
-
339
- # Flash attention requires the input to have the shape
340
- # batch_size x seq_length x head_dim x hidden_dim
341
- # therefore we just need to keep the original shape
342
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
343
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
344
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
345
-
346
- query_rot = query_states[..., : self.rotary_ndims]
347
- query_pass = query_states[..., self.rotary_ndims :]
348
- key_rot = key_states[..., : self.rotary_ndims]
349
- key_pass = key_states[..., self.rotary_ndims :]
350
-
351
- kv_seq_len = key_states.shape[-2]
352
- if past_key_value is not None:
353
- kv_seq_len += past_key_value[0].shape[-2]
354
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
355
- query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
356
-
357
- # [batch_size, num_heads, seq_len, head_dim]
358
- query_states = torch.cat((query_states, query_pass), dim=-1)
359
- key_states = torch.cat((key_states, key_pass), dim=-1)
360
-
361
- if past_key_value is not None:
362
- # Reuse k, v, self_attention
363
- key_states = torch.cat((past_key_value[0], key_states), dim=2)
364
- value_states = torch.cat((past_key_value[1], value_states), dim=2)
365
-
366
- past_key_value = (key_states, value_states) if use_cache else None
367
-
368
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
369
- # to be able to avoid many of these transpose/reshape/view.
370
- query_states = query_states.transpose(1, 2)
371
- key_states = key_states.transpose(1, 2)
372
- value_states = value_states.transpose(1, 2)
373
-
374
- dropout_rate = self.attention_dropout if self.training else 0.0
375
-
376
- attn_output = self._flash_attention_forward(
377
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
378
- )
379
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
380
- attn_output = self.o_proj(attn_output)
381
-
382
- if not output_attentions:
383
- attn_weights = None
384
-
385
- return attn_output, attn_weights, past_key_value
386
-
387
- def _flash_attention_forward(
388
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
389
- ):
390
- """
391
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
392
- first unpad the input, then computes the attention scores and pad the final attention scores.
393
-
394
- Args:
395
- query_states (`torch.Tensor`):
396
- Input query states to be passed to Flash Attention API
397
- key_states (`torch.Tensor`):
398
- Input key states to be passed to Flash Attention API
399
- value_states (`torch.Tensor`):
400
- Input value states to be passed to Flash Attention API
401
- attention_mask (`torch.Tensor`):
402
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
403
- position of padding tokens and 1 for the position of non-padding tokens.
404
- dropout (`int`, *optional*):
405
- Attention dropout
406
- softmax_scale (`float`, *optional*):
407
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
408
- """
409
- if not self._flash_attn_uses_top_left_mask:
410
- causal = self.is_causal
411
- else:
412
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FlashAttention2 __init__.
413
- causal = self.is_causal and query_length != 1
414
-
415
- # Contains at least one padding token in the sequence
416
- if attention_mask is not None:
417
- batch_size = query_states.shape[0]
418
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
419
- query_states, key_states, value_states, attention_mask, query_length
420
- )
421
-
422
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
423
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
424
-
425
- attn_output_unpad = flash_attn_varlen_func(
426
- query_states,
427
- key_states,
428
- value_states,
429
- cu_seqlens_q=cu_seqlens_q,
430
- cu_seqlens_k=cu_seqlens_k,
431
- max_seqlen_q=max_seqlen_in_batch_q,
432
- max_seqlen_k=max_seqlen_in_batch_k,
433
- dropout_p=dropout,
434
- softmax_scale=softmax_scale,
435
- causal=causal,
436
- )
437
-
438
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
439
- else:
440
- attn_output = flash_attn_func(
441
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
442
- )
443
-
444
- return attn_output
445
-
446
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
447
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
448
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
449
-
450
- key_layer = index_first_axis(
451
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
452
- )
453
- value_layer = index_first_axis(
454
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
455
- )
456
- if query_length == kv_seq_len:
457
- query_layer = index_first_axis(
458
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
459
- )
460
- cu_seqlens_q = cu_seqlens_k
461
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
462
- indices_q = indices_k
463
- elif query_length == 1:
464
- max_seqlen_in_batch_q = 1
465
- cu_seqlens_q = torch.arange(
466
- batch_size + 1, dtype=torch.int32, device=query_layer.device
467
- ) # There is a memcpy here, that is very bad.
468
- indices_q = cu_seqlens_q[:-1]
469
- query_layer = query_layer.squeeze(1)
470
- else:
471
- # The -q_len: slice assumes left padding.
472
- attention_mask = attention_mask[:, -query_length:]
473
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
474
-
475
- return (
476
- query_layer,
477
- key_layer,
478
- value_layer,
479
- indices_q,
480
- (cu_seqlens_q, cu_seqlens_k),
481
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
482
- )
483
-
484
-
485
- ATTENTION_CLASSES = {
486
- "eager": Attention,
487
- "flash_attention_2": FlashAttention2,
488
- }
489
-
490
-
491
- class DecoderLayer(nn.Module):
492
- def __init__(self, config: StableLMEpochConfig):
493
- super().__init__()
494
- self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config)
495
- self.mlp = MLP(config)
496
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
497
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
498
-
499
- def forward(
500
- self,
501
- hidden_states: Optional[torch.FloatTensor],
502
- attention_mask: Optional[torch.FloatTensor] = None,
503
- position_ids: Optional[torch.LongTensor] = None,
504
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
505
- output_attentions: Optional[bool] = False,
506
- use_cache: Optional[bool] = False,
507
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
508
- residual = hidden_states
509
-
510
- hidden_states = self.input_layernorm(hidden_states)
511
-
512
- # Self Attention
513
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
514
- hidden_states=hidden_states,
515
- attention_mask=attention_mask,
516
- position_ids=position_ids,
517
- past_key_value=past_key_value,
518
- output_attentions=output_attentions,
519
- use_cache=use_cache,
520
- )
521
- hidden_states = residual + hidden_states
522
-
523
- # Fully Connected
524
- residual = hidden_states
525
- hidden_states = self.post_attention_layernorm(hidden_states)
526
- hidden_states = self.mlp(hidden_states)
527
- hidden_states = residual + hidden_states
528
-
529
- outputs = (hidden_states,)
530
-
531
- if output_attentions:
532
- outputs += (self_attn_weights,)
533
-
534
- if use_cache:
535
- outputs += (present_key_value,)
536
-
537
- return outputs
538
-
539
-
540
- class StableLMEpochPreTrainedModel(PreTrainedModel):
541
- """An abstract class to handle weights initialization and a simple interface
542
- for downloading and loading pretrained models.
543
- """
544
-
545
- config_class = StableLMEpochConfig
546
- base_model_prefix = "transformer"
547
- supports_gradient_checkpointing = True
548
- _no_split_modules = ["DecoderLayer"]
549
- _skip_keys_device_placement = "past_key_values"
550
- _supports_flash_attn_2 = True
551
-
552
- def _init_weights(self, module: nn.Module):
553
- """Initialize the weights"""
554
- if isinstance(module, nn.Linear):
555
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
556
- if module.bias is not None:
557
- module.bias.data.zero_()
558
- elif isinstance(module, nn.Embedding):
559
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
560
- if module.padding_idx is not None:
561
- module.weight.data[module.padding_idx].zero_()
562
- elif isinstance(module, nn.LayerNorm):
563
- module.bias.data.zero_()
564
- module.weight.data.fill_(1.0)
565
-
566
- def _set_gradient_checkpointing(self, module: nn.Module, value=False):
567
- if isinstance(module, StableLMEpochModel):
568
- module.gradient_checkpointing = value
569
-
570
-
571
- class StableLMEpochModel(StableLMEpochPreTrainedModel):
572
- def __init__(self, config: StableLMEpochConfig):
573
- super().__init__(config)
574
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
575
- self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
576
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
577
-
578
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
579
- self.gradient_checkpointing = False
580
- # Initialize weights and apply final processing
581
- self.post_init()
582
-
583
- def get_input_embeddings(self):
584
- return self.embed_tokens
585
-
586
- def set_input_embeddings(self, value: nn.Module):
587
- self.embed_tokens = value
588
-
589
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
590
- def _prepare_decoder_attention_mask(
591
- self,
592
- attention_mask: torch.Tensor,
593
- input_shape: torch.Size,
594
- inputs_embeds: torch.Tensor,
595
- past_key_values_length: int,
596
- ):
597
- # Create causal mask
598
- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
599
- combined_attention_mask = None
600
- if input_shape[-1] > 1:
601
- combined_attention_mask = _make_causal_mask(
602
- input_shape,
603
- inputs_embeds.dtype,
604
- device=inputs_embeds.device,
605
- past_key_values_length=past_key_values_length,
606
- )
607
-
608
- if attention_mask is not None:
609
- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
610
- expanded_attn_mask = _expand_mask(
611
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
612
- ).to(inputs_embeds.device)
613
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
614
-
615
- return combined_attention_mask
616
-
617
- def forward(
618
- self,
619
- input_ids: Optional[torch.LongTensor] = None,
620
- attention_mask: Optional[torch.FloatTensor] = None,
621
- position_ids: Optional[torch.LongTensor] = None,
622
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
623
- inputs_embeds: Optional[torch.FloatTensor] = None,
624
- use_cache: Optional[bool] = None,
625
- output_attentions: Optional[bool] = None,
626
- output_hidden_states: Optional[bool] = None,
627
- return_dict: Optional[bool] = None,
628
- ) -> Union[Tuple, BaseModelOutputWithPast]:
629
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
630
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
631
- use_cache = use_cache if use_cache is not None else self.config.use_cache
632
-
633
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
634
-
635
- # Retrieve input_ids and inputs_embeds
636
- if input_ids is not None and inputs_embeds is not None:
637
- raise ValueError(
638
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
639
- )
640
- elif input_ids is not None:
641
- batch_size, seq_length = input_ids.shape
642
- elif inputs_embeds is not None:
643
- batch_size, seq_length, _ = inputs_embeds.shape
644
- else:
645
- raise ValueError(
646
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
647
- )
648
-
649
- seq_length_with_past = seq_length
650
- past_key_values_length = 0
651
-
652
- if position_ids is None:
653
- device = input_ids.device if input_ids is not None else inputs_embeds.device
654
- position_ids = torch.arange(
655
- past_key_values_length,
656
- seq_length + past_key_values_length,
657
- dtype=torch.long,
658
- device=device,
659
- )
660
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
661
- else:
662
- position_ids = position_ids.view(-1, seq_length).long()
663
-
664
- if inputs_embeds is None:
665
- inputs_embeds = self.embed_tokens(input_ids)
666
- # Embed positions
667
- if self._use_flash_attention_2:
668
- # 2d mask is passed through the layers
669
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
670
- else:
671
- if attention_mask is None:
672
- attention_mask = torch.ones(
673
- (batch_size, seq_length_with_past),
674
- dtype=torch.bool,
675
- device=inputs_embeds.device,
676
- )
677
- attention_mask = self._prepare_decoder_attention_mask(
678
- attention_mask,
679
- (batch_size, seq_length),
680
- inputs_embeds,
681
- past_key_values_length,
682
- )
683
-
684
- hidden_states = inputs_embeds
685
-
686
- if self.gradient_checkpointing and self.training:
687
- if use_cache:
688
- logger.warning(
689
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
690
- )
691
- use_cache = False
692
-
693
- # Decoder layers
694
- all_hidden_states = () if output_hidden_states else None
695
- all_self_attns = () if output_attentions else None
696
- next_decoder_cache = () if use_cache else None
697
-
698
- for idx, decoder_layer in enumerate(self.layers):
699
- if output_hidden_states:
700
- all_hidden_states += (hidden_states,)
701
-
702
- past_key_value = (
703
- past_key_values[idx] if past_key_values is not None else None
704
- )
705
-
706
- if self.gradient_checkpointing and self.training:
707
-
708
- def create_custom_forward(module):
709
- def custom_forward(*inputs):
710
- # None for past_key_value
711
- return module(*inputs, past_key_value, output_attentions)
712
-
713
- return custom_forward
714
-
715
- layer_outputs = torch.utils.checkpoint.checkpoint(
716
- create_custom_forward(decoder_layer),
717
- hidden_states,
718
- attention_mask,
719
- position_ids,
720
- )
721
- else:
722
- layer_outputs = decoder_layer(
723
- hidden_states,
724
- attention_mask=attention_mask,
725
- position_ids=position_ids,
726
- past_key_value=past_key_value,
727
- output_attentions=output_attentions,
728
- use_cache=use_cache,
729
- )
730
-
731
- hidden_states = layer_outputs[0]
732
-
733
- if use_cache:
734
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
735
-
736
- if output_attentions:
737
- all_self_attns += (layer_outputs[1],)
738
-
739
- hidden_states = self.norm(hidden_states)
740
-
741
- # Add hidden states from the last decoder layer
742
- if output_hidden_states:
743
- all_hidden_states += (hidden_states,)
744
-
745
- next_cache = next_decoder_cache if use_cache else None
746
- if not return_dict:
747
- return tuple(
748
- v
749
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
750
- if v is not None
751
- )
752
- return BaseModelOutputWithPast(
753
- last_hidden_state=hidden_states,
754
- past_key_values=next_cache,
755
- hidden_states=all_hidden_states,
756
- attentions=all_self_attns,
757
- )
758
-
759
-
760
- class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
761
- _tied_weights_keys = ["lm_head.weight"]
762
-
763
- def __init__(self, config: StableLMEpochConfig):
764
- super().__init__(config)
765
-
766
- self.model = StableLMEpochModel(config)
767
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
768
-
769
- # Initialize weights and apply final processing
770
- self.post_init()
771
-
772
- def get_input_embeddings(self):
773
- return self.model.embed_tokens
774
-
775
- def set_input_embeddings(self, value):
776
- self.model.embed_tokens = value
777
-
778
- def get_output_embeddings(self):
779
- return self.lm_head
780
-
781
- def set_output_embeddings(self, new_embeddings: nn.Module):
782
- self.lm_head = new_embeddings
783
-
784
- def get_decoder(self):
785
- return self.model
786
-
787
- def set_decoder(self, decoder):
788
- self.model = decoder
789
-
790
- def forward(
791
- self,
792
- input_ids: Optional[torch.LongTensor] = None,
793
- attention_mask: Optional[torch.FloatTensor] = None,
794
- position_ids: Optional[torch.LongTensor] = None,
795
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
796
- inputs_embeds: Optional[torch.FloatTensor] = None,
797
- labels: Optional[torch.LongTensor] = None,
798
- use_cache: Optional[bool] = None,
799
- output_attentions: Optional[bool] = None,
800
- output_hidden_states: Optional[bool] = None,
801
- return_dict: Optional[bool] = None,
802
- ) -> Union[Tuple, CausalLMOutputWithPast]:
803
- output_attentions = (
804
- output_attentions
805
- if output_attentions is not None
806
- else self.config.output_attentions
807
- )
808
- output_hidden_states = (
809
- output_hidden_states
810
- if output_hidden_states is not None
811
- else self.config.output_hidden_states
812
- )
813
- return_dict = (
814
- return_dict if return_dict is not None else self.config.use_return_dict
815
- )
816
-
817
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
818
- outputs = self.model(
819
- input_ids,
820
- attention_mask=attention_mask,
821
- position_ids=position_ids,
822
- past_key_values=past_key_values,
823
- inputs_embeds=inputs_embeds,
824
- use_cache=use_cache,
825
- output_attentions=output_attentions,
826
- output_hidden_states=output_hidden_states,
827
- return_dict=return_dict,
828
- )
829
-
830
- hidden_states = outputs[0]
831
- logits = self.lm_head(hidden_states).float()
832
-
833
- loss = None
834
- if labels is not None:
835
- # Shift so that tokens < n predict n
836
- shift_logits = logits[..., :-1, :].contiguous()
837
- shift_labels = labels[..., 1:].contiguous()
838
- # Flatten the tokens
839
- loss_fct = CrossEntropyLoss()
840
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
841
- shift_labels = shift_labels.view(-1)
842
- # Enable model parallelism
843
- shift_labels = shift_labels.to(shift_logits.device)
844
- loss = loss_fct(shift_logits, shift_labels)
845
-
846
- if not return_dict:
847
- output = (logits,) + outputs[1:]
848
- return (loss,) + output if loss is not None else output
849
-
850
- return CausalLMOutputWithPast(
851
- loss=loss,
852
- logits=logits,
853
- past_key_values=outputs.past_key_values,
854
- hidden_states=outputs.hidden_states,
855
- attentions=outputs.attentions,
856
- )
857
-
858
- def prepare_inputs_for_generation(
859
- self,
860
- input_ids,
861
- past_key_values: Optional[torch.Tensor] = None,
862
- attention_mask: Optional[torch.Tensor] = None,
863
- inputs_embeds: Optional[torch.Tensor] = None,
864
- **kwargs,
865
- ):
866
- # Trim decoder_input_ids if past is used
867
- if past_key_values is not None:
868
- past_length = past_key_values[0][0].shape[2]
869
-
870
- # Some generation methods already pass only the last input ID
871
- if input_ids.shape[1] > past_length:
872
- remove_prefix_length = past_length
873
- else:
874
- # Default to old behavior: keep only final ID
875
- remove_prefix_length = input_ids.shape[1] - 1
876
-
877
- input_ids = input_ids[:, remove_prefix_length:]
878
-
879
- position_ids = kwargs.get("position_ids", None)
880
- if attention_mask is not None and position_ids is None:
881
- # Create position_ids on the fly for batch generation
882
- position_ids = attention_mask.long().cumsum(-1) - 1
883
- position_ids.masked_fill_(attention_mask == 0, 1)
884
- if past_key_values:
885
- position_ids = position_ids[:, -1].unsqueeze(-1)
886
-
887
- # If `inputs_embeds` are passed, we only want to use them in the 1st generation step
888
- if inputs_embeds is not None and past_key_values is None:
889
- model_inputs = {"inputs_embeds": inputs_embeds}
890
- else:
891
- model_inputs = {"input_ids": input_ids}
892
-
893
- model_inputs.update(
894
- {
895
- "attention_mask": attention_mask,
896
- "past_key_values": past_key_values,
897
- "use_cache": kwargs.get("use_cache"),
898
- "position_ids": position_ids,
899
- }
900
- )
901
- return model_inputs
902
-
903
- @staticmethod
904
- def _reorder_cache(past_key_values, beam_idx):
905
- reordered_past = ()
906
- for layer_past in past_key_values:
907
- reordered_past += (
908
- tuple(
909
- past_state.index_select(0, beam_idx.to(past_state.device))
910
- for past_state in layer_past
911
- ),
912
- )
913
- return reordered_past
914
-
915
-
916
- StableLMEpochConfig.register_for_auto_class()
917
- StableLMEpochForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/multimodal_encoder/builder.py DELETED
@@ -1,29 +0,0 @@
1
- import os
2
- from .eva_clip.eva_clip_encoder import EvaClipVisionTower
3
- from .siglip.siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2
4
- from .clip.clip_encoder import CLIPVisionTower
5
-
6
-
7
- def build_vision_tower(vision_tower_cfg, **kwargs):
8
- vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
9
- use_s2 = getattr(vision_tower_cfg, 'use_s2', False)
10
-
11
- if 'sig' in vision_tower.lower():
12
- if use_s2:
13
- return SiglipVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
14
- else:
15
- return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16
- elif 'eva' in vision_tower.lower():
17
- if use_s2:
18
- raise ValueError(f'Currently not supporting S2 for EVA-CLIP')
19
- else:
20
- return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
21
-
22
- elif 'clip' in vision_tower.lower():
23
- if use_s2:
24
- raise ValueError(f'Currently not supporting S2 for CLIP')
25
- else:
26
- return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
27
-
28
- else:
29
- raise ValueError(f'Unknown vision tower: {vision_tower}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/multimodal_encoder/clip/clip_encoder.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
-
6
-
7
- class CLIPVisionTower(nn.Module):
8
- def __init__(self, vision_tower, args, delay_load=False):
9
- super().__init__()
10
-
11
- self.is_loaded = False
12
-
13
- self.vision_tower_name = vision_tower
14
- self.select_layer = -2
15
-
16
- if not delay_load:
17
- self.load_model()
18
- else:
19
- self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
20
-
21
- def load_model(self):
22
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
23
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
24
- self.vision_tower.requires_grad_(False)
25
-
26
- self.is_loaded = True
27
-
28
- def feature_select(self, image_forward_outs):
29
- image_features = image_forward_outs.hidden_states[self.select_layer]
30
-
31
- image_features = image_features[:, 1:]
32
-
33
- return image_features
34
-
35
- @torch.no_grad()
36
- def forward(self, images):
37
- if type(images) is list:
38
- image_features = []
39
- for image in images:
40
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
41
- output_hidden_states=True)
42
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
43
- image_features.append(image_feature)
44
- else:
45
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
46
- output_hidden_states=True)
47
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
48
-
49
- return image_features
50
-
51
- @property
52
- def dummy_feature(self):
53
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
54
-
55
- @property
56
- def dtype(self):
57
- return self.vision_tower.dtype
58
-
59
- @property
60
- def device(self):
61
- return self.vision_tower.device
62
-
63
- @property
64
- def config(self):
65
- if self.is_loaded:
66
- return self.vision_tower.config
67
- else:
68
- return self.cfg_only
69
-
70
- @property
71
- def hidden_size(self):
72
- return self.config.hidden_size
73
-
74
- @property
75
- def num_patches(self):
76
- return (self.config.image_size // self.config.patch_size) ** 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/multimodal_encoder/eva_clip/eva_clip_encoder.py DELETED
@@ -1,63 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from .eva_clip_processors import EvaClipImageTrainProcessor
5
- from .eva_vit import Eva2LargePlusEncoder
6
-
7
-
8
- class EvaClipVisionTower(nn.Module):
9
- def __init__(self, vision_tower, args, delay_load=False):
10
- super().__init__()
11
-
12
- self.is_loaded = False
13
-
14
- self.vision_tower_path = vision_tower
15
- self.config = VisionTowerConfig()
16
-
17
- if not delay_load:
18
- self.load_model()
19
- else:
20
- self.cfg_only = self.config
21
-
22
- def load_model(self):
23
- self.image_processor = EvaClipImageTrainProcessor(self.config.image_size)
24
- self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path)
25
- self.vision_tower.requires_grad_(False)
26
-
27
- self.is_loaded = True
28
-
29
- @torch.no_grad()
30
- def forward(self, images):
31
- if type(images) is list:
32
- image_features = []
33
- for image in images:
34
- image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(
35
- image.dtype)
36
- image_features.append(image_feature)
37
- else:
38
- image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
39
-
40
- return image_features
41
-
42
- @property
43
- def dtype(self):
44
- return self.vision_tower.dtype
45
-
46
- @property
47
- def device(self):
48
- return self.vision_tower.device
49
-
50
- @property
51
- def hidden_size(self):
52
- return self.config.hidden_size
53
-
54
- @property
55
- def num_patches(self):
56
- return (self.config.image_size // self.config.patch_size) ** 2
57
-
58
-
59
- class VisionTowerConfig():
60
- def __init__(self):
61
- self.image_size = 336
62
- self.patch_size = 14
63
- self.hidden_size = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/multimodal_encoder/eva_clip/eva_clip_processors.py DELETED
@@ -1,68 +0,0 @@
1
- '''
2
- # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
- '''
4
-
5
- from torchvision import transforms
6
- from torchvision.transforms.functional import InterpolationMode
7
- from transformers.image_processing_utils import BatchFeature
8
- from PIL import Image
9
- from transformers.image_transforms import convert_to_rgb
10
-
11
-
12
- class BaseProcessor:
13
- def __init__(self):
14
- self.transform = lambda x: x
15
- return
16
-
17
- def __call__(self, item):
18
- return self.transform(item)
19
-
20
-
21
- class EvaClipImageBaseProcessor(BaseProcessor):
22
- def __init__(self, mean=None, std=None):
23
- self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
24
- self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
25
-
26
- self.normalize = transforms.Normalize(self.mean, self.std)
27
-
28
- @property
29
- def image_mean(self):
30
- return self.mean
31
-
32
-
33
- class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
34
- def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
35
- super().__init__(mean=mean, std=std)
36
-
37
- self.transform = transforms.Compose(
38
- [
39
- convert_to_rgb,
40
- transforms.Resize(
41
- image_size,
42
- interpolation=InterpolationMode.BICUBIC,
43
- ),
44
- transforms.CenterCrop(image_size),
45
- transforms.ToTensor(),
46
- self.normalize,
47
- ]
48
- )
49
-
50
- self.image_size = image_size
51
-
52
- def preprocess(self, images, return_tensors):
53
- if isinstance(images, Image.Image):
54
- images = [images]
55
- else:
56
- assert isinstance(images, list)
57
-
58
- transformed_images = [self.transform(image).numpy() for image in images]
59
- data = {"pixel_values": transformed_images}
60
-
61
- return BatchFeature(data=data, tensor_type=return_tensors)
62
-
63
- def __call__(self, item):
64
- return self.transform(item)
65
-
66
- @property
67
- def crop_size(self):
68
- return {'height': self.image_size, 'width': self.image_size}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/multimodal_encoder/eva_clip/eva_vit.py DELETED
@@ -1,851 +0,0 @@
1
- '''
2
- # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
- '''
4
-
5
- from math import pi
6
- import torch
7
- from torch import nn
8
- from einops import rearrange, repeat
9
- import logging
10
-
11
-
12
- def broadcat(tensors, dim=-1):
13
- num_tensors = len(tensors)
14
- shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
15
- assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
16
- shape_len = list(shape_lens)[0]
17
- dim = (dim + shape_len) if dim < 0 else dim
18
- dims = list(zip(*map(lambda t: list(t.shape), tensors)))
19
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
20
- assert all(
21
- [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
22
- max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
23
- expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
24
- expanded_dims.insert(dim, (dim, dims[dim]))
25
- expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
26
- tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
27
- return torch.cat(tensors, dim=dim)
28
-
29
-
30
- def rotate_half(x):
31
- x = rearrange(x, '... (d r) -> ... d r', r=2)
32
- x1, x2 = x.unbind(dim=-1)
33
- x = torch.stack((-x2, x1), dim=-1)
34
- return rearrange(x, '... d r -> ... (d r)')
35
-
36
-
37
- class VisionRotaryEmbeddingFast(nn.Module):
38
- def __init__(
39
- self,
40
- dim,
41
- pt_seq_len,
42
- ft_seq_len=None,
43
- custom_freqs=None,
44
- freqs_for='lang',
45
- theta=10000,
46
- max_freq=10,
47
- num_freqs=1,
48
- patch_dropout=0.
49
- ):
50
- super().__init__()
51
- if custom_freqs:
52
- freqs = custom_freqs
53
- elif freqs_for == 'lang':
54
- freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
55
- elif freqs_for == 'pixel':
56
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
57
- elif freqs_for == 'constant':
58
- freqs = torch.ones(num_freqs).float()
59
- else:
60
- raise ValueError(f'unknown modality {freqs_for}')
61
-
62
- if ft_seq_len is None: ft_seq_len = pt_seq_len
63
- t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
64
-
65
- freqs = torch.einsum('..., f -> ... f', t, freqs)
66
- freqs = repeat(freqs, '... n -> ... (n r)', r=2)
67
- freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
68
-
69
- freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
70
- freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
71
-
72
- self.patch_dropout = patch_dropout
73
-
74
- self.register_buffer("freqs_cos", freqs_cos)
75
- self.register_buffer("freqs_sin", freqs_sin)
76
-
77
- logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
78
-
79
- def forward(self, t, patch_indices_keep=None):
80
- if patch_indices_keep is not None:
81
- batch = t.size()[0]
82
- batch_indices = torch.arange(batch)
83
- batch_indices = batch_indices[..., None]
84
-
85
- freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
86
- freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
87
-
88
- freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
89
- freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
90
- freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
91
- freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
92
-
93
- return t * freqs_cos + rotate_half(t) * freqs_sin
94
-
95
- return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
96
-
97
-
98
- class LayerNorm(nn.LayerNorm):
99
- """Subclass torch's LayerNorm (with cast back to input dtype)."""
100
-
101
- def forward(self, x: torch.Tensor):
102
- orig_type = x.dtype
103
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
104
- return x.to(orig_type)
105
-
106
-
107
- class PatchDropout(nn.Module):
108
- """
109
- https://arxiv.org/abs/2212.00794
110
- """
111
-
112
- def __init__(self, prob, exclude_first_token=True):
113
- super().__init__()
114
- assert 0 <= prob < 1.
115
- self.prob = prob
116
- self.exclude_first_token = exclude_first_token # exclude CLS token
117
- logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
118
-
119
- def forward(self, x):
120
- if not self.training or self.prob == 0.:
121
- return x
122
-
123
- if self.exclude_first_token:
124
- cls_tokens, x = x[:, :1], x[:, 1:]
125
- else:
126
- cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
127
-
128
- batch = x.size()[0]
129
- num_tokens = x.size()[1]
130
-
131
- batch_indices = torch.arange(batch)
132
- batch_indices = batch_indices[..., None]
133
-
134
- keep_prob = 1 - self.prob
135
- num_patches_keep = max(1, int(num_tokens * keep_prob))
136
-
137
- rand = torch.randn(batch, num_tokens)
138
- patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
139
-
140
- x = x[batch_indices, patch_indices_keep]
141
-
142
- if self.exclude_first_token:
143
- x = torch.cat((cls_tokens, x), dim=1)
144
-
145
- if self.training and os.getenv('RoPE') == '1':
146
- return x, patch_indices_keep
147
-
148
- return x
149
-
150
-
151
- # --------------------------------------------------------
152
- # Adapted from https://github.com/microsoft/unilm/tree/master/beit
153
- # --------------------------------------------------------
154
- import math
155
- import os
156
- from functools import partial
157
- import torch.nn as nn
158
- import torch.nn.functional as F
159
-
160
- try:
161
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
162
- except:
163
- from timm.layers import drop_path, to_2tuple, trunc_normal_
164
-
165
- if os.getenv('ENV_TYPE') == 'deepspeed':
166
- try:
167
- from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
168
- except:
169
- from torch.utils.checkpoint import checkpoint
170
- else:
171
- from torch.utils.checkpoint import checkpoint
172
-
173
- import xformers.ops as xops
174
-
175
-
176
- class DropPath(nn.Module):
177
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
178
- """
179
-
180
- def __init__(self, drop_prob=None):
181
- super(DropPath, self).__init__()
182
- self.drop_prob = drop_prob
183
-
184
- def forward(self, x):
185
- return drop_path(x, self.drop_prob, self.training)
186
-
187
- def extra_repr(self) -> str:
188
- return 'p={}'.format(self.drop_prob)
189
-
190
-
191
- class Mlp(nn.Module):
192
- def __init__(
193
- self,
194
- in_features,
195
- hidden_features=None,
196
- out_features=None,
197
- act_layer=nn.GELU,
198
- norm_layer=nn.LayerNorm,
199
- drop=0.,
200
- subln=False,
201
-
202
- ):
203
- super().__init__()
204
- out_features = out_features or in_features
205
- hidden_features = hidden_features or in_features
206
- self.fc1 = nn.Linear(in_features, hidden_features)
207
- self.act = act_layer()
208
-
209
- self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
210
-
211
- self.fc2 = nn.Linear(hidden_features, out_features)
212
- self.drop = nn.Dropout(drop)
213
-
214
- def forward(self, x):
215
- x = self.fc1(x)
216
- x = self.act(x)
217
- # x = self.drop(x)
218
- # commit this for the orignal BERT implement
219
- x = self.ffn_ln(x)
220
-
221
- x = self.fc2(x)
222
- x = self.drop(x)
223
- return x
224
-
225
-
226
- class SwiGLU(nn.Module):
227
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
228
- norm_layer=nn.LayerNorm, subln=False):
229
- super().__init__()
230
- out_features = out_features or in_features
231
- hidden_features = hidden_features or in_features
232
-
233
- self.w1 = nn.Linear(in_features, hidden_features)
234
- self.w2 = nn.Linear(in_features, hidden_features)
235
-
236
- self.act = act_layer()
237
- self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
238
- self.w3 = nn.Linear(hidden_features, out_features)
239
-
240
- self.drop = nn.Dropout(drop)
241
-
242
- def forward(self, x):
243
- x1 = self.w1(x)
244
- x2 = self.w2(x)
245
- hidden = self.act(x1) * x2
246
- x = self.ffn_ln(hidden)
247
- x = self.w3(x)
248
- x = self.drop(x)
249
- return x
250
-
251
-
252
- class Attention(nn.Module):
253
- def __init__(
254
- self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
255
- proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False,
256
- norm_layer=nn.LayerNorm):
257
- super().__init__()
258
- self.num_heads = num_heads
259
- head_dim = dim // num_heads
260
- if attn_head_dim is not None:
261
- head_dim = attn_head_dim
262
- all_head_dim = head_dim * self.num_heads
263
- self.scale = qk_scale or head_dim ** -0.5
264
-
265
- self.subln = subln
266
- if self.subln:
267
- self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
268
- self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
269
- self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
270
- else:
271
- self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
272
-
273
- if qkv_bias:
274
- self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
275
- self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
276
- else:
277
- self.q_bias = None
278
- self.v_bias = None
279
-
280
- if window_size:
281
- self.window_size = window_size
282
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
283
- self.relative_position_bias_table = nn.Parameter(
284
- torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
285
- # cls to token & token 2 cls & cls to cls
286
-
287
- # get pair-wise relative position index for each token inside the window
288
- coords_h = torch.arange(window_size[0])
289
- coords_w = torch.arange(window_size[1])
290
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
291
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
292
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
293
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
294
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
295
- relative_coords[:, :, 1] += window_size[1] - 1
296
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
297
- relative_position_index = \
298
- torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
299
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
300
- relative_position_index[0, 0:] = self.num_relative_distance - 3
301
- relative_position_index[0:, 0] = self.num_relative_distance - 2
302
- relative_position_index[0, 0] = self.num_relative_distance - 1
303
-
304
- self.register_buffer("relative_position_index", relative_position_index)
305
- else:
306
- self.window_size = None
307
- self.relative_position_bias_table = None
308
- self.relative_position_index = None
309
-
310
- self.attn_drop = nn.Dropout(attn_drop)
311
- self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
312
- # self.proj = nn.Linear(all_head_dim, all_head_dim)
313
- self.proj = nn.Linear(all_head_dim, dim)
314
- self.proj_drop = nn.Dropout(proj_drop)
315
- self.xattn = xattn
316
- self.xattn_drop = attn_drop
317
-
318
- self.rope = rope
319
-
320
- def forward(self, x, rel_pos_bias=None, attn_mask=None):
321
- B, N, C = x.shape
322
- if self.subln:
323
- q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
324
- k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
325
- v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
326
-
327
- q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
328
- k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
329
- v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
330
- else:
331
-
332
- qkv_bias = None
333
- if self.q_bias is not None:
334
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
335
-
336
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
337
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
338
- q, k, v = qkv[0], qkv[1], qkv[2]
339
-
340
- if self.rope:
341
- # slightly fast impl
342
- q_t = q[:, :, 1:, :]
343
- ro_q_t = self.rope(q_t)
344
- q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
345
-
346
- k_t = k[:, :, 1:, :]
347
- ro_k_t = self.rope(k_t)
348
- k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
349
-
350
- if self.xattn:
351
- q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
352
- k = k.permute(0, 2, 1, 3)
353
- v = v.permute(0, 2, 1, 3)
354
-
355
- x = xops.memory_efficient_attention(
356
- q, k, v,
357
- p=self.xattn_drop,
358
- scale=self.scale,
359
- )
360
- x = x.reshape(B, N, -1)
361
- x = self.inner_attn_ln(x)
362
- x = self.proj(x)
363
- x = self.proj_drop(x)
364
- else:
365
- q = q * self.scale
366
- attn = (q @ k.transpose(-2, -1))
367
-
368
- if self.relative_position_bias_table is not None:
369
- relative_position_bias = \
370
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
371
- self.window_size[0] * self.window_size[1] + 1,
372
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
373
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
374
- attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
375
-
376
- if rel_pos_bias is not None:
377
- attn = attn + rel_pos_bias.type_as(attn)
378
-
379
- if attn_mask is not None:
380
- attn_mask = attn_mask.bool()
381
- attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
382
-
383
- attn = attn.softmax(dim=-1)
384
- attn = self.attn_drop(attn)
385
-
386
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
387
- x = self.inner_attn_ln(x)
388
- x = self.proj(x)
389
- x = self.proj_drop(x)
390
- return x
391
-
392
-
393
- class Block(nn.Module):
394
-
395
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
396
- drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
397
- window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
398
- subln=False, naiveswiglu=False):
399
- super().__init__()
400
- self.norm1 = norm_layer(dim)
401
- self.attn = Attention(
402
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
403
- attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
404
- xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
405
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
406
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
407
- self.norm2 = norm_layer(dim)
408
- mlp_hidden_dim = int(dim * mlp_ratio)
409
-
410
- if naiveswiglu:
411
- self.mlp = SwiGLU(
412
- in_features=dim,
413
- hidden_features=mlp_hidden_dim,
414
- subln=subln,
415
- norm_layer=norm_layer,
416
- )
417
- else:
418
- self.mlp = Mlp(
419
- in_features=dim,
420
- hidden_features=mlp_hidden_dim,
421
- act_layer=act_layer,
422
- subln=subln,
423
- drop=drop
424
- )
425
-
426
- if init_values is not None and init_values > 0:
427
- self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
428
- self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
429
- else:
430
- self.gamma_1, self.gamma_2 = None, None
431
-
432
- self.postnorm = postnorm
433
-
434
- def forward(self, x, rel_pos_bias=None, attn_mask=None):
435
- if self.gamma_1 is None:
436
- if self.postnorm:
437
- x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
438
- x = x + self.drop_path(self.norm2(self.mlp(x)))
439
- else:
440
- x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
441
- x = x + self.drop_path(self.mlp(self.norm2(x)))
442
- else:
443
- if self.postnorm:
444
- x = x + self.drop_path(
445
- self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
446
- x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
447
- else:
448
- x = x + self.drop_path(
449
- self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
450
- x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
451
- return x
452
-
453
-
454
- class PatchEmbed(nn.Module):
455
- """ Image to Patch Embedding
456
- """
457
-
458
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
459
- super().__init__()
460
- img_size = to_2tuple(img_size)
461
- patch_size = to_2tuple(patch_size)
462
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
463
- self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
464
- self.img_size = img_size
465
- self.patch_size = patch_size
466
- self.num_patches = num_patches
467
-
468
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
469
-
470
- def forward(self, x, **kwargs):
471
- B, C, H, W = x.shape
472
- # FIXME look at relaxing size constraints
473
- assert H == self.img_size[0] and W == self.img_size[1], \
474
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
475
- x = self.proj(x).flatten(2).transpose(1, 2)
476
- return x
477
-
478
-
479
- class RelativePositionBias(nn.Module):
480
-
481
- def __init__(self, window_size, num_heads):
482
- super().__init__()
483
- self.window_size = window_size
484
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
485
- self.relative_position_bias_table = nn.Parameter(
486
- torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
487
- # cls to token & token 2 cls & cls to cls
488
-
489
- # get pair-wise relative position index for each token inside the window
490
- coords_h = torch.arange(window_size[0])
491
- coords_w = torch.arange(window_size[1])
492
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
493
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
494
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
495
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
496
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
497
- relative_coords[:, :, 1] += window_size[1] - 1
498
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
499
- relative_position_index = \
500
- torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
501
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
502
- relative_position_index[0, 0:] = self.num_relative_distance - 3
503
- relative_position_index[0:, 0] = self.num_relative_distance - 2
504
- relative_position_index[0, 0] = self.num_relative_distance - 1
505
-
506
- self.register_buffer("relative_position_index", relative_position_index)
507
-
508
- def forward(self):
509
- relative_position_bias = \
510
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
511
- self.window_size[0] * self.window_size[1] + 1,
512
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
513
- return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
514
-
515
-
516
- class EVAVisionTransformer(nn.Module):
517
- """ Vision Transformer with support for patch or hybrid CNN input stage
518
- """
519
-
520
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
521
- num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
522
- drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
523
- use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
524
- use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
525
- pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
526
- super().__init__()
527
- self.image_size = img_size
528
- self.num_classes = num_classes
529
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
530
-
531
- self.patch_embed = PatchEmbed(
532
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
533
- num_patches = self.patch_embed.num_patches
534
-
535
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
536
- # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
537
- if use_abs_pos_emb:
538
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
539
- else:
540
- self.pos_embed = None
541
- self.pos_drop = nn.Dropout(p=drop_rate)
542
-
543
- if use_shared_rel_pos_bias:
544
- self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
545
- else:
546
- self.rel_pos_bias = None
547
-
548
- if rope:
549
- half_head_dim = embed_dim // num_heads // 2
550
- hw_seq_len = img_size // patch_size
551
- self.rope = VisionRotaryEmbeddingFast(
552
- dim=half_head_dim,
553
- pt_seq_len=pt_hw_seq_len,
554
- ft_seq_len=hw_seq_len if intp_freq else None,
555
- # patch_dropout=patch_dropout
556
- )
557
- else:
558
- self.rope = None
559
-
560
- self.naiveswiglu = naiveswiglu
561
-
562
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
563
- self.use_rel_pos_bias = use_rel_pos_bias
564
- self.blocks = nn.ModuleList([
565
- Block(
566
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
567
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
568
- init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
569
- xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
570
- for i in range(depth)])
571
- self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
572
- self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
573
- self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
574
-
575
- if self.pos_embed is not None:
576
- trunc_normal_(self.pos_embed, std=.02)
577
-
578
- trunc_normal_(self.cls_token, std=.02)
579
- # trunc_normal_(self.mask_token, std=.02)
580
-
581
- self.apply(self._init_weights)
582
- self.fix_init_weight()
583
-
584
- if isinstance(self.head, nn.Linear):
585
- trunc_normal_(self.head.weight, std=.02)
586
- self.head.weight.data.mul_(init_scale)
587
- self.head.bias.data.mul_(init_scale)
588
-
589
- # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
590
- self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
591
-
592
- self.grad_checkpointing = grad_checkpointing
593
-
594
- def fix_init_weight(self):
595
- def rescale(param, layer_id):
596
- param.div_(math.sqrt(2.0 * layer_id))
597
-
598
- for layer_id, layer in enumerate(self.blocks):
599
- rescale(layer.attn.proj.weight.data, layer_id + 1)
600
- if self.naiveswiglu:
601
- rescale(layer.mlp.w3.weight.data, layer_id + 1)
602
- else:
603
- rescale(layer.mlp.fc2.weight.data, layer_id + 1)
604
-
605
- def get_cast_dtype(self) -> torch.dtype:
606
- return self.blocks[0].mlp.fc2.weight.dtype
607
-
608
- def _init_weights(self, m):
609
- if isinstance(m, nn.Linear):
610
- trunc_normal_(m.weight, std=.02)
611
- if m.bias is not None:
612
- nn.init.constant_(m.bias, 0)
613
- elif isinstance(m, nn.LayerNorm):
614
- nn.init.constant_(m.bias, 0)
615
- nn.init.constant_(m.weight, 1.0)
616
-
617
- def get_num_layers(self):
618
- return len(self.blocks)
619
-
620
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
621
- assert unlocked_groups == 0, 'partial locking not currently supported for this model'
622
- for param in self.parameters():
623
- param.requires_grad = False
624
-
625
- @torch.jit.ignore
626
- def set_grad_checkpointing(self, enable=True):
627
- self.grad_checkpointing = enable
628
-
629
- @torch.jit.ignore
630
- def no_weight_decay(self):
631
- return {'pos_embed', 'cls_token'}
632
-
633
- def get_classifier(self):
634
- return self.head
635
-
636
- def reset_classifier(self, num_classes, global_pool=''):
637
- self.num_classes = num_classes
638
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
639
-
640
- def forward_features(self, x, return_all_features=False):
641
-
642
- x = self.patch_embed(x)
643
- batch_size, seq_len, _ = x.size()
644
-
645
- cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
646
- x = torch.cat((cls_tokens, x), dim=1)
647
- if self.pos_embed is not None:
648
- x = x + self.pos_embed
649
- x = self.pos_drop(x)
650
-
651
- # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
652
- if os.getenv('RoPE') == '1':
653
- if self.training and not isinstance(self.patch_dropout, nn.Identity):
654
- x, patch_indices_keep = self.patch_dropout(x)
655
- self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
656
- else:
657
- self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
658
- x = self.patch_dropout(x)
659
- else:
660
- x = self.patch_dropout(x)
661
-
662
- rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
663
- for i, blk in enumerate(self.blocks):
664
- if i == len(self.blocks) - 1:
665
- continue
666
- if self.grad_checkpointing:
667
- x = checkpoint(blk, x, (rel_pos_bias,))
668
- else:
669
- x = blk(x, rel_pos_bias=rel_pos_bias)
670
-
671
- if not return_all_features:
672
- x = self.norm(x)
673
- if self.fc_norm is not None:
674
- return self.fc_norm(x.mean(1))
675
- else:
676
- return x[:, 0]
677
- return x
678
-
679
- def forward(self, x, return_all_features=False):
680
- if return_all_features:
681
- return self.forward_features(x, return_all_features)
682
- x = self.forward_features(x)
683
- x = self.head(x)
684
- return x
685
-
686
-
687
- def load_state_dict(checkpoint_path: str, map_location: str = 'cpu', model_key: str = 'model|module|state_dict',
688
- is_openai: bool = False, skip_list: list = []):
689
- if is_openai:
690
- model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
691
- state_dict = model.state_dict()
692
- for key in ["input_resolution", "context_length", "vocab_size"]:
693
- state_dict.pop(key, None)
694
- else:
695
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
696
- for mk in model_key.split('|'):
697
- if isinstance(checkpoint, dict) and mk in checkpoint:
698
- state_dict = checkpoint[mk]
699
- break
700
- else:
701
- state_dict = checkpoint
702
- if next(iter(state_dict.items()))[0].startswith('module'):
703
- state_dict = {k[7:]: v for k, v in state_dict.items()}
704
-
705
- for k in skip_list:
706
- if k in list(state_dict.keys()):
707
- logging.info(f"Removing key {k} from pretrained checkpoint")
708
- del state_dict[k]
709
-
710
- if os.getenv('RoPE') == '1':
711
- for k in list(state_dict.keys()):
712
- if 'freqs_cos' in k or 'freqs_sin' in k:
713
- del state_dict[k]
714
- return state_dict
715
-
716
-
717
- def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = 'cpu', is_openai: bool = False,
718
- skip_list: list = []):
719
- state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
720
-
721
- for k in list(state_dict.keys()):
722
- if not k.startswith('visual.'):
723
- del state_dict[k]
724
- for k in list(state_dict.keys()):
725
- if k.startswith('visual.'):
726
- new_k = k[7:]
727
- state_dict[new_k] = state_dict[k]
728
- del state_dict[k]
729
- return state_dict
730
-
731
-
732
- from dataclasses import dataclass
733
- from typing import Optional, Tuple, Union
734
-
735
- try:
736
- from apex.normalization import FusedLayerNorm
737
- except:
738
- FusedLayerNorm = LayerNorm
739
- print(
740
- "Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .")
741
-
742
-
743
- @dataclass
744
- class CLIPVisionCfg:
745
- layers: Union[Tuple[int, int, int, int], int] = 12
746
- width: int = 768
747
- head_width: int = 64
748
- mlp_ratio: float = 4.0
749
- patch_size: int = 16
750
- image_size: Union[Tuple[int, int], int] = 224
751
- ls_init_value: Optional[float] = None # layer scale initial value
752
- patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
753
- global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
754
- drop_path_rate: Optional[float] = None # drop path rate
755
- timm_model_name: str = None # a valid model name overrides layers, width, patch_size
756
- timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
757
- timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
758
- timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
759
- timm_proj_bias: bool = False # enable bias final projection
760
- eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
761
- qkv_bias: bool = True
762
- fusedLN: bool = False
763
- xattn: bool = False
764
- postnorm: bool = False
765
- rope: bool = False
766
- pt_hw_seq_len: int = 16 # 224/14
767
- intp_freq: bool = False
768
- naiveswiglu: bool = False
769
- subln: bool = False
770
-
771
-
772
- def _build_vision_tower(
773
- vision_tower_path: str,
774
- embed_dim: int,
775
- vision_cfg: CLIPVisionCfg
776
- ):
777
- if isinstance(vision_cfg, dict):
778
- vision_cfg = CLIPVisionCfg(**vision_cfg)
779
-
780
- if vision_cfg.eva_model_name:
781
- vision_heads = vision_cfg.width // vision_cfg.head_width
782
- norm_layer = LayerNorm
783
-
784
- visual = EVAVisionTransformer(
785
- img_size=vision_cfg.image_size,
786
- patch_size=vision_cfg.patch_size,
787
- num_classes=embed_dim,
788
- use_mean_pooling=vision_cfg.global_average_pool, # False
789
- init_values=vision_cfg.ls_init_value,
790
- patch_dropout=vision_cfg.patch_dropout,
791
- embed_dim=vision_cfg.width,
792
- depth=vision_cfg.layers,
793
- num_heads=vision_heads,
794
- mlp_ratio=vision_cfg.mlp_ratio,
795
- qkv_bias=vision_cfg.qkv_bias,
796
- drop_path_rate=vision_cfg.drop_path_rate,
797
- norm_layer=partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
798
- xattn=vision_cfg.xattn,
799
- rope=vision_cfg.rope,
800
- postnorm=vision_cfg.postnorm,
801
- pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
802
- intp_freq=vision_cfg.intp_freq,
803
- naiveswiglu=vision_cfg.naiveswiglu,
804
- subln=vision_cfg.subln
805
- )
806
-
807
- state_dict = load_clip_visual_state_dict(vision_tower_path)
808
- incompatible_keys = visual.load_state_dict(state_dict, strict=False)
809
- print('EVA-CLIP incompatible_keys:', incompatible_keys)
810
-
811
- return visual
812
-
813
-
814
- class Eva2LargePlusEncoder(nn.Module):
815
- def __init__(self, vision_tower_path):
816
- super(Eva2LargePlusEncoder, self).__init__()
817
- self.config = {
818
- "embed_dim": 768,
819
- "vision_cfg": {
820
- "image_size": 336,
821
- "layers": 24,
822
- "width": 1024,
823
- "drop_path_rate": 0,
824
- "head_width": 64,
825
- "mlp_ratio": 2.6667,
826
- "patch_size": 14,
827
- "eva_model_name": "eva-clip-l-14-336",
828
- "xattn": True,
829
- "fusedLN": True,
830
- "rope": True,
831
- "pt_hw_seq_len": 16,
832
- "intp_freq": True,
833
- "naiveswiglu": True,
834
- "subln": True
835
- }
836
- }
837
-
838
- self.config['vision_tower_path'] = vision_tower_path
839
- self.model = _build_vision_tower(**self.config)
840
-
841
- def forward(self, image, **kwargs):
842
- encode = self.model(image, return_all_features=True)[:, 1:, :]
843
- return encode
844
-
845
- @property
846
- def dtype(self):
847
- return list(self.parameters())[-1].dtype
848
-
849
- @property
850
- def device(self):
851
- return list(self.parameters())[-1].device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bunny/model/multimodal_encoder/siglip/siglip_encoder.py DELETED
@@ -1,129 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from transformers import SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig
5
- from bunny.util.s2wrapper import forward as multiscale_forward
6
-
7
-
8
- class SiglipVisionTower(nn.Module):
9
- def __init__(self, vision_tower, args, delay_load=False):
10
- super().__init__()
11
-
12
- self.is_loaded = False
13
-
14
- self.vision_tower_name = vision_tower
15
- self.select_layer = -2
16
-
17
- if not delay_load:
18
- self.load_model()
19
- else:
20
- self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
21
-
22
- def load_model(self):
23
- self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
24
- self.image_processor.crop_size = self.image_processor.size
25
- self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
26
- self.vision_tower.requires_grad_(False)
27
-
28
- self.is_loaded = True
29
-
30
- def feature_select(self, image_forward_outs):
31
- image_features = image_forward_outs.hidden_states[self.select_layer]
32
-
33
- return image_features
34
-
35
- @torch.no_grad()
36
- def forward(self, images):
37
- if type(images) is list:
38
- image_features = []
39
- for image in images:
40
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
41
- output_hidden_states=True)
42
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
43
- image_features.append(image_feature)
44
- else:
45
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
46
- output_hidden_states=True)
47
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
48
-
49
- return image_features
50
-
51
- @property
52
- def dummy_feature(self):
53
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
54
-
55
- @property
56
- def dtype(self):
57
- return self.vision_tower.dtype
58
-
59
- @property
60
- def device(self):
61
- return self.vision_tower.device
62
-
63
- @property
64
- def config(self):
65
- if self.is_loaded:
66
- return self.vision_tower.config
67
- else:
68
- return self.cfg_only
69
-
70
- @property
71
- def hidden_size(self):
72
- return self.config.hidden_size
73
-
74
- @property
75
- def num_patches(self):
76
- return (self.config.image_size // self.config.patch_size) ** 2
77
-
78
-
79
- class SiglipVisionTowerS2(SiglipVisionTower):
80
- def __init__(self, vision_tower, args, delay_load=False):
81
- self.s2_scales = getattr(args, 's2_scales', '384,768,1152')
82
- self.s2_scales = list(map(int, self.s2_scales.split(',')))
83
- self.s2_scales.sort()
84
- self.s2_split_size = self.s2_scales[0]
85
- self.s2_image_size = self.s2_scales[-1]
86
-
87
- super().__init__(vision_tower, args, delay_load)
88
-
89
- self.multiscale_forward = multiscale_forward
90
-
91
- if not delay_load:
92
- self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
93
- self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
94
-
95
- def load_model(self):
96
- self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
97
- self.image_processor.crop_size = self.image_processor.size
98
- self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
99
- self.vision_tower.requires_grad_(False)
100
-
101
- self.image_processor.size['height'] = self.image_processor.size['width'] = self.s2_image_size
102
- self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
103
-
104
- self.is_loaded = True
105
-
106
- @torch.no_grad()
107
- def forward_feature(self, images):
108
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
109
- output_hidden_states=True)
110
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
111
- return image_features
112
-
113
- @torch.no_grad()
114
- def forward(self, images):
115
- if type(images) is list:
116
- image_features = []
117
- for image in images:
118
- image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0),
119
- img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
120
- image_features.append(image_feature)
121
- else:
122
- image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales,
123
- max_split_size=self.s2_split_size)
124
-
125
- return image_features
126
-
127
- @property
128
- def hidden_size(self):
129
- return self.config.hidden_size * len(self.s2_scales)