Commit
·
de4ade4
1
Parent(s):
82c3d93
Upload 303 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Perceptrix/__init__.py +2 -0
- Perceptrix/chat.py +125 -0
- Perceptrix/create_data/interface.py +152 -0
- Perceptrix/create_data/static/style.css +154 -0
- Perceptrix/create_data/templates/index.html +80 -0
- Perceptrix/engine.py +213 -0
- Perceptrix/finetune/Dockerfile +13 -0
- Perceptrix/finetune/Makefile +23 -0
- Perceptrix/finetune/README.md +265 -0
- Perceptrix/finetune/build/lib/inference/__init__.py +4 -0
- Perceptrix/finetune/build/lib/inference/convert_composer_mpt_to_ft.py +232 -0
- Perceptrix/finetune/build/lib/inference/convert_composer_to_hf.py +290 -0
- Perceptrix/finetune/build/lib/inference/convert_hf_mpt_to_ft.py +154 -0
- Perceptrix/finetune/build/lib/inference/convert_hf_to_onnx.py +229 -0
- Perceptrix/finetune/build/lib/inference/hf_chat.py +389 -0
- Perceptrix/finetune/build/lib/inference/hf_generate.py +372 -0
- Perceptrix/finetune/build/lib/inference/run_mpt_with_ft.py +480 -0
- Perceptrix/finetune/build/lib/llmfoundry/__init__.py +71 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/__init__.py +31 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/eval_gauntlet_callback.py +177 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/fdiff_callback.py +67 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/generate_callback.py +30 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/hf_checkpointer.py +167 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/model_gauntlet_callback.py +21 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/monolithic_ckpt_callback.py +115 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/resumption_callbacks.py +89 -0
- Perceptrix/finetune/build/lib/llmfoundry/callbacks/scheduled_gc_callback.py +75 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/__init__.py +21 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/data.py +117 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/denoising.py +937 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/__init__.py +7 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/collator.py +343 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/dataloader.py +516 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/tasks.py +433 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/packing.py +423 -0
- Perceptrix/finetune/build/lib/llmfoundry/data/text_data.py +367 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/__init__.py +18 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/hf/__init__.py +18 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_causal_lm.py +227 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_fsdp.py +257 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_prefix_lm.py +150 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_t5.py +134 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/hf/model_wrapper.py +108 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/__init__.py +13 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/interface.py +110 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +243 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/layers/__init__.py +32 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/layers/attention.py +768 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/layers/blocks.py +117 -0
- Perceptrix/finetune/build/lib/llmfoundry/models/layers/custom_embedding.py +14 -0
Perceptrix/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# from Perceptrix.engine import robotix, identify_objects_from_text, search_keyword
|
2 |
+
# from Perceptrix.chat import perceptrix
|
Perceptrix/chat.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig
|
2 |
+
from Perceptrix.streamer import TextStreamer
|
3 |
+
from utils import setup_device
|
4 |
+
import torch
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
|
8 |
+
model_name = os.environ.get('CHAT_MODEL')
|
9 |
+
|
10 |
+
model_path = "models/CRYSTAL-chat" if model_name == None else model_name
|
11 |
+
config = AutoConfig.from_pretrained(
|
12 |
+
model_path, trust_remote_code=True)
|
13 |
+
|
14 |
+
device = setup_device()
|
15 |
+
device = "mps"
|
16 |
+
|
17 |
+
bnb_config = BitsAndBytesConfig(
|
18 |
+
load_in_4bit=True,
|
19 |
+
bnb_4bit_use_double_quant=True,
|
20 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
21 |
+
)
|
22 |
+
|
23 |
+
model = AutoModelForCausalLM.from_pretrained(
|
24 |
+
model_path,
|
25 |
+
config=config,
|
26 |
+
low_cpu_mem_usage=True,
|
27 |
+
trust_remote_code=True,
|
28 |
+
device_map="auto",
|
29 |
+
torch_dtype=torch.float16,
|
30 |
+
# quantization_config=bnb_config,
|
31 |
+
offload_folder="offloads",
|
32 |
+
)
|
33 |
+
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
35 |
+
model_path,
|
36 |
+
trust_remote_code=True,
|
37 |
+
)
|
38 |
+
|
39 |
+
if tokenizer.pad_token_id is None:
|
40 |
+
tokenizer.pad_token = tokenizer.eos_token
|
41 |
+
|
42 |
+
tokenizer.padding_side = "left"
|
43 |
+
tokenizer = tokenizer
|
44 |
+
model.eval()
|
45 |
+
|
46 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True,
|
47 |
+
skip_special_tokens=True, save_file="reply.txt")
|
48 |
+
|
49 |
+
def evaluate(
|
50 |
+
prompt='',
|
51 |
+
temperature=0.4,
|
52 |
+
top_p=0.65,
|
53 |
+
top_k=35,
|
54 |
+
repetition_penalty=1.1,
|
55 |
+
max_new_tokens=512,
|
56 |
+
**kwargs,
|
57 |
+
):
|
58 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
59 |
+
input_ids = inputs["input_ids"].to(device)
|
60 |
+
generation_config = GenerationConfig(
|
61 |
+
temperature=temperature,
|
62 |
+
top_p=top_p,
|
63 |
+
top_k=top_k,
|
64 |
+
repetition_penalty=repetition_penalty,
|
65 |
+
**kwargs,
|
66 |
+
)
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
generation_output = model.generate(
|
70 |
+
input_ids=input_ids,
|
71 |
+
generation_config=generation_config,
|
72 |
+
return_dict_in_generate=True,
|
73 |
+
output_scores=True,
|
74 |
+
max_new_tokens=max_new_tokens,
|
75 |
+
eos_token_id=tokenizer.eos_token_id,
|
76 |
+
pad_token_id=tokenizer.pad_token_id,
|
77 |
+
streamer=streamer,
|
78 |
+
)
|
79 |
+
s = generation_output.sequences[0]
|
80 |
+
output = tokenizer.decode(s, skip_special_tokens=True)
|
81 |
+
yield output.split("### Response:")[-1].strip()
|
82 |
+
|
83 |
+
|
84 |
+
def predict(
|
85 |
+
inputs,
|
86 |
+
temperature=0.4,
|
87 |
+
top_p=0.65,
|
88 |
+
top_k=35,
|
89 |
+
repetition_penalty=1.1,
|
90 |
+
max_new_tokens=512,
|
91 |
+
):
|
92 |
+
now_prompt = inputs
|
93 |
+
|
94 |
+
response = evaluate(
|
95 |
+
now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True
|
96 |
+
)
|
97 |
+
|
98 |
+
for i in response:
|
99 |
+
print(i)
|
100 |
+
response = i
|
101 |
+
|
102 |
+
return response
|
103 |
+
|
104 |
+
|
105 |
+
instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user."
|
106 |
+
|
107 |
+
def perceptrix(prompt):
|
108 |
+
prompt = instructions+"\n"+prompt
|
109 |
+
response = predict(
|
110 |
+
inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512
|
111 |
+
)
|
112 |
+
spl_tokens = ["<|im_start|>", "<|im_end|>"]
|
113 |
+
clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
|
114 |
+
return response[len(clean_prompt):]
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
history = ""
|
119 |
+
while True:
|
120 |
+
user_input = input("User: ")
|
121 |
+
start = time.time()
|
122 |
+
user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n"
|
123 |
+
result = perceptrix(history+user_input)
|
124 |
+
history += user_input + result + "<|im_end|>\n"
|
125 |
+
print("Answer completed in ~", round(time.time()-start), "seconds")
|
Perceptrix/create_data/interface.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
|
5 |
+
app = Flask(__name__)
|
6 |
+
|
7 |
+
data_file_path = "Perceptrix/finetune/finetune-data/crystal-finetune.json"
|
8 |
+
|
9 |
+
with open(data_file_path, 'r') as data_file:
|
10 |
+
data = json.loads(data_file.read())
|
11 |
+
file_data = data_file.read()
|
12 |
+
|
13 |
+
|
14 |
+
room_descriptions = ["The surroundings include a bed, a chair, and a dog. The bed is made up with a white blanket, and there are two people sitting on it, likely a man and a woman. The chair is positioned next to the bed, and the dog is sitting on the bed as well. The room appears to be a bedroom, and the atmosphere seems to be cozy and comfortable.",
|
15 |
+
"In the image, there is a living room with a couch, a chair, a coffee table, and a window. The room is well-decorated and filled with furniture, including a bed, a desk, and a dining table. The living room is situated next to the bedroom, and there is a window in the living room. The overall atmosphere of the room is cozy and inviting.",
|
16 |
+
"The surroundings include a living room with a couch, a chair, and a coffee table. There is also a television in the room.",
|
17 |
+
"In the image, there is a group of people sitting in chairs, likely in a classroom or a meeting room. They are engaged in a discussion or a presentation, with some of them looking at a screen.",
|
18 |
+
"The surroundings include a living room with a couch, a chair, and a window. The room is well-lit, and there are several potted plants in the space.",
|
19 |
+
"In the image, a woman is lying on a bed surrounded by a variety of stuffed animals. There are at least ten stuffed animals of different colors and sizes, including teddy bears, dolls, and other toys. The scene appears to be cozy and comfortable, with the woman resting peacefully in her bed.",
|
20 |
+
"The surroundings include a living room with a yellow couch, a table, and a potted plant. There are three people sitting on the couch, and a laptop is placed on the table.",
|
21 |
+
"The surroundings include a living room with a couch, a coffee table, and a TV. The room is filled with people, and they are sitting on the couch, engaging in a conversation.",
|
22 |
+
"The surroundings include a living room with a couch, a chair, and a window. There is also a woman standing in the room, possibly near the window. The room appears to be clean and well-maintained.",
|
23 |
+
"The surroundings include a living room with a bed, a couch, a chair, and a TV. There are also various items scattered around, such as a book, a bottle, and a cup. The room appears to be messy and disorganized, with some items like the book and bottle being placed on the floor.",
|
24 |
+
"In the image, there is a bedroom with a bed, a chair, and a window. A woman is sitting on the bed, and a dog is nearby. The woman is wearing a white shirt and appears to be engaged in a conversation with the AI assistant. The room appears to be clean and well-organized.",
|
25 |
+
"I am in a living room, sitting on a couch, and using a laptop.",
|
26 |
+
"The surroundings include a couch, a chair, a window, and a potted plant. There is also a person sitting on the couch, and a baby is laying on the person's lap.",
|
27 |
+
"In the image, there is a group of people standing in a living room, with a bed and a couch visible in the background. The room appears to be clean and well-organized.",
|
28 |
+
"In the image, there is a living room with a white couch, a chair, and a window. The room is well-lit and appears to be clean and organized.",
|
29 |
+
"The surroundings consist of a large, empty room with a hardwood floor. There is a man sitting on the floor, possibly in a corner or a cubicle, and he is holding a remote control.",
|
30 |
+
"The surroundings include a living room with a woman standing in front of a door, which is open. The room appears to be dimly lit, creating a somewhat dark atmosphere.",
|
31 |
+
"The surroundings include a living room with a couch, a chair, and a TV. There are three people sitting on the couch, and a baby is present. The living room appears to be a comfortable and cozy space for the family to spend time together.",
|
32 |
+
"In the image, there is a person sitting on a bed in a bedroom. The bed is surrounded by a colorful blanket, and there is a laptop on the bed. The room appears to be a small bedroom, and the bed is positioned near a window.",
|
33 |
+
"The surroundings include a bedroom with a bed, a nightstand, and a window. The bed is neatly made and has a white and gray color scheme. There are also potted plants in the room, adding a touch of greenery and a sense of freshness.",
|
34 |
+
"In the image, there is a large bedroom with a bed, a nightstand, and a window. The room is clean and well-organized, with a white color scheme and a minimalist design. The bed is neatly made, and there are pillows on it. The room also has a chair and a potted plant, adding a touch of warmth and natural elements to the space. The window provides natural light, and the room appears to be well-lit and inviting.",
|
35 |
+
"In the image, there is a living room with a couch, a chair, and a coffee table. The room is well-decorated and features a dark color scheme, with a black couch and a black coffee table. There is also a potted plant in the room, adding a touch of greenery to the space. The living room is well-lit, and there are several books scattered around the room, suggesting that the occupants enjoy reading.",
|
36 |
+
"In the image, there is a living room with a large window, a couch, and a chair. The room is filled with furniture, including a coffee table, a dining table, and a potted plant. The living room has a modern and clean design, with a white color scheme. The large window allows for natural light to enter the room, creating a bright and inviting atmosphere.",
|
37 |
+
"In the image, there is a bedroom with a large bed, a nightstand, and a window. The room is well-lit and clean, creating a comfortable and inviting atmosphere.",
|
38 |
+
"The surroundings include a living room with a couch, a coffee table, and a television. The room is filled with various items, such as books, a vase, and a potted plant. The living room is well-lit, and there is a window in the room. Additionally, there is a dining table and chairs, which suggests that the living room and dining area are combined.",
|
39 |
+
"In the image, I am surrounded by a living room with a piano, a couch, and a chair. The living room has a modern design, and the furniture is arranged in a way that creates a comfortable and inviting atmosphere.",
|
40 |
+
"The surroundings include a living room with a couch, a coffee table, and a vase. The living room is well-decorated and has a clean and organized appearance.",
|
41 |
+
"The surroundings include a large bedroom with a white color scheme, a bed with a white comforter, and a window. There is also a ceiling fan, which is a white fan, and a chair in the room. The room appears to be clean and well-maintained.",
|
42 |
+
"In the image, there is a green couch, a green chair, and a green ottoman in a living room. The room is filled with books, suggesting that it is a cozy and well-read space.",
|
43 |
+
"The surroundings include a large bedroom with a large bed, a chair, and a desk. The room also has a window, a lamp, and a mirror. The bed is neatly made, and there are pillows on it. The room is well-lit, with a lamp providing illumination.",
|
44 |
+
"In the image, there is a living room with a couch, a chair, and a table. The room has a modern design, featuring a large window and a chandelier. The living room is filled with furniture, including a couch, a chair, and a table. The room also has a potted plant, which adds a touch of greenery and a sense of freshness to the space. The living room is well-lit, with the large window providing ample natural light, and the chandelier adding a touch of elegance and sophistication.",
|
45 |
+
"I am in a living room with a fireplace, a couch, a chair, a dining table, and a potted plant. The room is filled with furniture and decorations, creating a cozy and inviting atmosphere.",
|
46 |
+
"In the image, there is a living room with a couch, two chairs, and a coffee table. The living room is well-lit and has a view of the city, which adds to the ambiance of the space. The room also features a potted plant and a vase, adding a touch of greenery and decoration to the area.",
|
47 |
+
"In the image, there is a neatly made bed in a bedroom, with a white comforter and a red blanket. The bed is situated next to a window, which allows natural light to enter the room. The room also has a nightstand with a lamp, providing additional lighting. The overall atmosphere of the room is clean and inviting.",
|
48 |
+
"In the image, I am surrounded by a large, clean living room with white walls, a fireplace, and a comfortable couch. There are also several chairs and a dining table in the room. The space is well-lit, and the furniture is arranged to create a cozy and inviting atmosphere.",
|
49 |
+
"The surroundings in the image include a living room with a couch, chairs, and a coffee table. The living room is filled with furniture, and there are multiple lamps and potted plants scattered throughout the space. The room also has a window, which allows natural light to enter the room.",
|
50 |
+
"The surroundings include a living room with a couch, a coffee table, and a lamp. The room also has a large window, which allows for natural light to enter. There are several chairs and a dining table in the room, suggesting that it is a multi-purpose space for relaxation and dining. The living room is well-lit and furnished with comfortable seating options, creating a welcoming atmosphere for people to gather and socialize.",
|
51 |
+
"In the image, I am surrounded by a living room filled with furniture, including a couch, chairs, and a coffee table. The living room is well-decorated, and there are several books and a vase present. The room also features a rug, which adds to the overall aesthetic and comfort of the space.",
|
52 |
+
"The surroundings include a messy room with a bed, a desk, and a chair. The room is filled with clothes, shoes, and other items, creating a cluttered and disorganized space.",
|
53 |
+
"The surroundings in the image include a cluttered room with a desk, a bed, and various items scattered around. The room appears to be messy and disorganized, with clothes and other belongings scattered on the floor.",
|
54 |
+
"The surroundings include a group of people sitting on a bed, with a laptop and a cell phone visible. The room appears to be a bedroom, and the individuals are engaged in a conversation.",
|
55 |
+
"I am in a living room, surrounded by several people sitting on a couch. They are all engaged in various activities, such as watching TV, using their cell phones, and possibly playing video games. The room is filled with furniture, including a couch, chairs, and a TV. The atmosphere appears to be casual and relaxed, with the people enjoying their time together in the living room.",
|
56 |
+
"The surroundings include a group of people sitting on a couch, with a wooden table in the background. The room appears to be a living room, and there is a window nearby.",
|
57 |
+
"In the image, there are several people sitting on a couch, using their cell phones. The couch is located in a living room, and the people are engaged in various activities on their devices.",
|
58 |
+
"The surroundings include a living room with a fireplace, where a group of people is sitting on couches and chairs. There are multiple books scattered around the room, suggesting that the individuals might be engaged in reading or studying. The room also has a dining table and a potted plant, which adds to the cozy atmosphere of the space.",
|
59 |
+
"The surroundings include a living room with a couch, a dining table, and a pizza on the table. The people in the room are sitting and enjoying their meal together.",
|
60 |
+
"In the image, there are four people sitting on a bed, with two of them facing the camera. They are all wearing blue shirts and are engaged in a conversation. The scene takes place in a bedroom, which is a comfortable and familiar setting for the group.",
|
61 |
+
"The surroundings include a group of people sitting on a couch, with some of them holding pizza boxes. The room appears to be a living room, and there is a book nearby.",
|
62 |
+
"In the image, there are four people sitting on a couch in a living room. They are all engaged in using their cell phones, with one of them holding a book. The room has a wooden floor and a table, and there are chairs nearby. The atmosphere appears to be casual and relaxed, with the group of friends enjoying their time together while using their devices.",
|
63 |
+
"I am in a living room, which is filled with furniture such as a couch, a chair, and a table. The room is well-lit and appears to be a comfortable space for relaxation and socializing.",
|
64 |
+
"The surroundings include a living room with a couch, a table, and a chair. There are people sitting on the couch and a man standing in the room.",
|
65 |
+
"The surroundings include a group of people sitting on a couch, likely in a living room or a similar space. They are engaged in a conversation or enjoying each other's company.",
|
66 |
+
"The surroundings include a living room with a couch, a chair, and a table. There are several people sitting on the couch and chairs, engaging in conversation and enjoying each other's company. The room appears to be well-lit, creating a comfortable atmosphere for socializing.",
|
67 |
+
"In the image, there is a group of people sitting around a dining table in a room. The table is covered with various items, including cups, bowls, and a vase. The room appears to be a living room or a dining area, with a couch and chairs nearby. The scene suggests a casual and comfortable setting where people are gathered for a meal or a social event.",
|
68 |
+
"The surroundings include a living room with a couch, a dining table, and a TV. The room is filled with people, some of whom are sitting on the couch, while others are standing around the table. There are also several bottles and cups, which might be used for drinking. The atmosphere appears to be relaxed and social, with people enjoying each other's company and engaging in conversations.",
|
69 |
+
"The surroundings include a living room with a couch, a coffee table, and a window. There are three people sitting on the couch, engaging in a conversation.",
|
70 |
+
"The surroundings include a living room with a couch, a table, and a couple of chairs. There are also several bottles and cups on the table, suggesting that the room is set up for a casual gathering or a social event.",
|
71 |
+
"The surroundings include a living room with a couch, chairs, and a coffee table. There are also books scattered around the room, suggesting that the group of people might be engaged in a discussion or reading. The room appears to be cozy and comfortable, with a relaxed atmosphere.",
|
72 |
+
"The surroundings include a living room with a couch, a television, and a group of people sitting together."]
|
73 |
+
|
74 |
+
cities = ["Wake Forest, NC", "Rocky Mount, NC", "San Francisco, CA", "New York City, NY", "Trenton, NJ", "Philadelphia, PA",
|
75 |
+
"Vikas Puri, New Delhi", "Jiugong, Beijing", "Les Halles, Paris", "Diemen, Amsterdam", "Al Shamkhah, Abu Dhabi",
|
76 |
+
"Cairo", "Idore, Madhya Pradesh", "Bangalore, Karnataka", "Toronoto, Ontario", "Brixton, London", "Charlotte, NC",
|
77 |
+
"Los Angeles, CA", "Las Vegas, NV", "Cupertino, CA", "Silicon Valley, CA", "Sham Shui Po, Hong Kong", "Danilovsky District, Moscow",
|
78 |
+
"Rochester, NY", "Manhattan, NY"]
|
79 |
+
weather_names = ["Sunny", "Rainy", "Windy", "Cloudy",
|
80 |
+
"Mostly Cloudy", "Partly Cloudy", "Light Rain", "Sleet"]
|
81 |
+
days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
|
82 |
+
months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun",
|
83 |
+
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
|
84 |
+
|
85 |
+
dummy_time = f"{random.choice(days)} {random.randint(1, 30)} {random.choice(months)} {random.randint(2020, 2024)} 0{random.randint(1, 9)}:{random.randint(0, 59)} {random.choice(['AM', 'PM'])}"
|
86 |
+
dummy_weather = f"{random.choice(cities)} is {random.choice(weather_names)} with {random.randint(60, 95)}°F and Precipitation: {random.randint(0, 100)}%, Humidity: {random.randint(0, 100)}%, Wind: {random.randint(1, 15)} mph"
|
87 |
+
dummy_current_events = random.choice(room_descriptions)
|
88 |
+
|
89 |
+
|
90 |
+
@app.route("/")
|
91 |
+
def home():
|
92 |
+
vqa = file_data.count("> VQA")
|
93 |
+
robot = file_data.count("> VQA")
|
94 |
+
internet = file_data.count("> VQA")
|
95 |
+
cli = file_data.count("> VQA")
|
96 |
+
note = file_data.count("> VQA")
|
97 |
+
home_automation = file_data.count("> VQA")
|
98 |
+
header = f"""VQA: {vqa}
|
99 |
+
ROBOT: {robot}
|
100 |
+
INTERNET: {internet}
|
101 |
+
CLI: {cli}
|
102 |
+
NOTE: {note}
|
103 |
+
HOME AUTOMATION: {home_automation}
|
104 |
+
|
105 |
+
"""
|
106 |
+
|
107 |
+
return render_template("index.html", full_data=header + json.dumps(data, indent=4), data_entries=len(data), dummy_time=dummy_time, dummy_weather=dummy_weather, dummy_current_events=dummy_current_events[:-1])
|
108 |
+
|
109 |
+
|
110 |
+
@app.route('/record', methods=['GET', 'POST'])
|
111 |
+
def record():
|
112 |
+
if request.method == "POST":
|
113 |
+
with open(data_file_path, 'r') as data:
|
114 |
+
data = json.loads(data.read())
|
115 |
+
dummy_time = f"{random.choice(days)} {random.randint(1, 30)} {random.choice(months)} {random.randint(2020, 2024)} 0{random.randint(1, 9)}:{random.randint(0, 59)} {random.choice(['AM', 'PM'])}"
|
116 |
+
dummy_weather = f"{random.choice(cities)} is {random.choice(weather_names)} with {random.randint(60, 95)}°F and Precipitation: {random.randint(0, 100)}%, Humidity: {random.randint(0, 100)}%, Wind: {random.randint(1, 15)} mph"
|
117 |
+
dummy_current_events = random.choice(room_descriptions)
|
118 |
+
|
119 |
+
entry = request.form["current-data-preview"]
|
120 |
+
input_field = request.form["input"]
|
121 |
+
entry = {
|
122 |
+
"prompt": entry.split(entry.split(input_field)[-1])[0],
|
123 |
+
"response": entry.split(input_field)[-1][2:],
|
124 |
+
}
|
125 |
+
data.append(entry)
|
126 |
+
with open(data_file_path, 'w+') as file:
|
127 |
+
file.write(str(json.dumps(data, indent=4)).replace("\r\n", "\n"))
|
128 |
+
|
129 |
+
with open(data_file_path, "r") as data_file:
|
130 |
+
file_data = data_file.read()
|
131 |
+
|
132 |
+
vqa = file_data.count("> VQA")
|
133 |
+
robot = file_data.count("> Robot")
|
134 |
+
internet = file_data.count("> Internet")
|
135 |
+
cli = file_data.count("> CLI")
|
136 |
+
note = file_data.count("> Note")
|
137 |
+
home_automation = file_data.count("> Home Automation")
|
138 |
+
|
139 |
+
header = f"""VQA: {vqa}
|
140 |
+
ROBOT: {robot}
|
141 |
+
INTERNET: {internet}
|
142 |
+
CLI: {cli}
|
143 |
+
NOTE: {note}
|
144 |
+
HOME AUTOMATION: {home_automation}
|
145 |
+
|
146 |
+
"""
|
147 |
+
|
148 |
+
return render_template("index.html", full_data=header + json.dumps(data, indent=4), data_entries=len(data), dummy_time=dummy_time, dummy_weather=dummy_weather, dummy_current_events=dummy_current_events[:-1])
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
app.run(host="0.0.0.0")
|
Perceptrix/create_data/static/style.css
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@import url('https://fonts.cdnfonts.com/css/bitsumishi');
|
2 |
+
|
3 |
+
h1 {
|
4 |
+
font-family: "Bitsumishi", sans-serif;
|
5 |
+
font-size: 50px;
|
6 |
+
color: rgb(255, 255, 255);
|
7 |
+
margin: 10px;
|
8 |
+
letter-spacing: 3px;
|
9 |
+
}
|
10 |
+
|
11 |
+
body {
|
12 |
+
text-align: center;
|
13 |
+
background-color: rgb(0, 0, 0);
|
14 |
+
display: flex;
|
15 |
+
flex-direction: column;
|
16 |
+
justify-content: center;
|
17 |
+
margin: auto;
|
18 |
+
margin-top: 20px;
|
19 |
+
overflow: auto;
|
20 |
+
font-family: "Bitsumishi", sans-serif;
|
21 |
+
}
|
22 |
+
|
23 |
+
.entries {
|
24 |
+
display: flex;
|
25 |
+
flex-direction: column;
|
26 |
+
align-items: center;
|
27 |
+
justify-content: space-evenly;
|
28 |
+
height: 90vh;
|
29 |
+
width: fit-content;
|
30 |
+
}
|
31 |
+
|
32 |
+
.entry {
|
33 |
+
width: 20vw;
|
34 |
+
height: 25vh;
|
35 |
+
border: 2px solid rgb(32, 161, 236);
|
36 |
+
font-family: "Bitsumishi", sans-serif;
|
37 |
+
background-color: transparent;
|
38 |
+
border-top-left-radius: 30px;
|
39 |
+
border-bottom-right-radius: 30px;
|
40 |
+
padding: 25px;
|
41 |
+
resize: none;
|
42 |
+
outline: none;
|
43 |
+
color: white;
|
44 |
+
font-size: medium;
|
45 |
+
letter-spacing: 1px;
|
46 |
+
}
|
47 |
+
|
48 |
+
.main {
|
49 |
+
display: flex;
|
50 |
+
width: 90vw;
|
51 |
+
justify-content: space-between;
|
52 |
+
align-items: center;
|
53 |
+
margin: auto;
|
54 |
+
}
|
55 |
+
|
56 |
+
#submit {
|
57 |
+
background-color: transparent;
|
58 |
+
font-family: "Bitsumishi", sans-serif;
|
59 |
+
height: fit-content;
|
60 |
+
padding: 10px;
|
61 |
+
border: 2px solid rgb(32, 161, 236);
|
62 |
+
color: rgb(32, 161, 236);
|
63 |
+
font-size: medium;
|
64 |
+
cursor: pointer;
|
65 |
+
border-top-left-radius: 10px;
|
66 |
+
border-bottom-right-radius: 10px;
|
67 |
+
box-shadow: 0px 0px 100px 5px rgb(0, 0, 0) inset;
|
68 |
+
transition: 0.5s;
|
69 |
+
margin-top: 10px;
|
70 |
+
width: 50%;
|
71 |
+
}
|
72 |
+
|
73 |
+
#submit:hover {
|
74 |
+
box-shadow: 0px 0px 20px 5px rgb(33, 77, 255);
|
75 |
+
transition: 0.5s;
|
76 |
+
}
|
77 |
+
|
78 |
+
.other-inputs {
|
79 |
+
padding: 10px;
|
80 |
+
outline: none;
|
81 |
+
border: 2px solid rgb(32, 161, 236);
|
82 |
+
border-radius: 10px;
|
83 |
+
font-size: medium;
|
84 |
+
background-color: transparent;
|
85 |
+
font-family: "Bitsumishi", sans-serif;
|
86 |
+
width: 100%;
|
87 |
+
color: white;
|
88 |
+
}
|
89 |
+
|
90 |
+
.control {
|
91 |
+
display: flex;
|
92 |
+
flex-direction: column;
|
93 |
+
justify-content: space-evenly;
|
94 |
+
text-align: left;
|
95 |
+
margin-left: 100px;
|
96 |
+
}
|
97 |
+
|
98 |
+
.data-preview {
|
99 |
+
margin-top: 25px;
|
100 |
+
}
|
101 |
+
|
102 |
+
.data{
|
103 |
+
background-color:rgb(22, 22, 22);
|
104 |
+
padding: 25px;
|
105 |
+
width: 20vw;
|
106 |
+
height: 15vw;
|
107 |
+
color: rgb(255, 255, 255);
|
108 |
+
border-radius: 10px;
|
109 |
+
overflow: scroll;
|
110 |
+
}
|
111 |
+
|
112 |
+
.other-fields{
|
113 |
+
display: flex;
|
114 |
+
flex-direction: column;
|
115 |
+
width: 20%;
|
116 |
+
justify-content: space-around;
|
117 |
+
height: 32vh;
|
118 |
+
}
|
119 |
+
|
120 |
+
.all-inputs{
|
121 |
+
display: flex;
|
122 |
+
width: 63vw;
|
123 |
+
justify-content: space-between;
|
124 |
+
font-weight: 100;
|
125 |
+
}
|
126 |
+
|
127 |
+
.center-input{
|
128 |
+
padding: 10px;
|
129 |
+
outline: none;
|
130 |
+
border: 2px solid rgb(32, 161, 236);
|
131 |
+
border-radius: 10px;
|
132 |
+
font-size: medium;
|
133 |
+
background-color: transparent;
|
134 |
+
font-family: "Bitsumishi", sans-serif;
|
135 |
+
width: 45%;
|
136 |
+
color: white;
|
137 |
+
}
|
138 |
+
|
139 |
+
.center-inputs{
|
140 |
+
display: flex;
|
141 |
+
width: 63vw;
|
142 |
+
justify-content: space-between;
|
143 |
+
}
|
144 |
+
|
145 |
+
#output{
|
146 |
+
width: 60vw
|
147 |
+
}
|
148 |
+
|
149 |
+
#current-data-preview{
|
150 |
+
outline: none;
|
151 |
+
border: none;
|
152 |
+
resize: none;
|
153 |
+
font-size: 14px;
|
154 |
+
}
|
Perceptrix/create_data/templates/index.html
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta charset="UTF-8">
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
<title>Fine-tune Data CRYSTAL</title>
|
8 |
+
<link rel="stylesheet" href="{{url_for('static',filename='style.css')}}">
|
9 |
+
</head>
|
10 |
+
|
11 |
+
<body>
|
12 |
+
<h1>FINETUNE CRYSTAL</h1>
|
13 |
+
<form action="/record" method="post">
|
14 |
+
<div class="main">
|
15 |
+
<div class="entries">
|
16 |
+
<div class="all-inputs">
|
17 |
+
<textarea class="entry" name="notes" id="notes" placeholder="Enter Additional Notes Here:" oninput="updatePreview()">Notes- </textarea>
|
18 |
+
<textarea class="entry" name="input" id="input" placeholder="Enter Input Here:" oninput="updatePreview()"></textarea>
|
19 |
+
<div class="other-fields">
|
20 |
+
<input type="text" class="other-inputs" name="user" id="user" placeholder="User Name" oninput="updatePreview()">
|
21 |
+
<input type="text" class="other-inputs" name="time" id="time" placeholder="Time" oninput="updatePreview()">
|
22 |
+
<input type="text" class="other-inputs" name="weather" id="weather" placeholder="Weather" oninput="updatePreview()">
|
23 |
+
<input type="text" class="other-inputs" name="action" id="action" placeholder="Action" oninput="updatePreview()">
|
24 |
+
</div>
|
25 |
+
</div>
|
26 |
+
<div class="center-inputs">
|
27 |
+
<input type="text" class="center-input" name="current-events" id="current-events" placeholder="Current Events" oninput="updatePreview()">
|
28 |
+
<input type="text" class="center-input" name="speak" id="speak" placeholder="Speak" oninput="updatePreview()">
|
29 |
+
</div>
|
30 |
+
<textarea class="entry" name="output" id="output" placeholder="Enter Output Here:" oninput="updatePreview()"></textarea>
|
31 |
+
</div>
|
32 |
+
<div class="control">
|
33 |
+
<input id="submit" type="submit" value="ADD TO DATABASE">
|
34 |
+
<div class="data-preview">
|
35 |
+
<label style="color: white; font-size: x-large;" for="preview">Data Preview</label>
|
36 |
+
<textarea id="current-data-preview" name="current-data-preview" class="data"></textarea>
|
37 |
+
<p style="color: rgb(143, 143, 143); font-size: medium; margin: 5px;"> Data Entries: {{data_entries}}</p>
|
38 |
+
<pre class="data">{{full_data}}</pre>
|
39 |
+
</div>
|
40 |
+
</div>
|
41 |
+
</div>
|
42 |
+
</form>
|
43 |
+
<script>
|
44 |
+
var user = document.querySelector('#user');
|
45 |
+
var time = document.querySelector('#time');
|
46 |
+
var weather = document.querySelector('#weather');
|
47 |
+
var action = document.querySelector('#action');
|
48 |
+
var current_events = document.querySelector('#current-events');
|
49 |
+
var input = document.querySelector('#input');
|
50 |
+
var output = document.querySelector('#output');
|
51 |
+
var speak = document.querySelector('#speak');
|
52 |
+
var notes = document.querySelector('#notes');
|
53 |
+
var preview = document.querySelector('#current-data-preview');
|
54 |
+
|
55 |
+
time.value = "{{dummy_time}}";
|
56 |
+
weather.value = "{{dummy_weather}}";
|
57 |
+
current_events.value = "{{dummy_current_events}}";
|
58 |
+
|
59 |
+
function updatePreview() {
|
60 |
+
format = "Time- {time}\nWeather- {weather}\nSurroundings- {current_events}\n{notes}\n{user}: {input}\nCRYSTAL:<###CRYSTAL-INTERNAL###> Speak\n{speak}\n<###CRYSTAL-INTERNAL###> {action}\n{output}"
|
61 |
+
var formattedText = format.replace('{time}', time.value)
|
62 |
+
.replace('{weather}', weather.value)
|
63 |
+
.replace('{current_events}', current_events.value)
|
64 |
+
.replace('{user}', user.value)
|
65 |
+
.replace('{input}', input.value)
|
66 |
+
.replace('{action}', action.value)
|
67 |
+
.replace('{speak}', speak.value)
|
68 |
+
.replace('{notes}', notes.value)
|
69 |
+
.replace('{output}', output.value);
|
70 |
+
|
71 |
+
preview.textContent = formattedText;
|
72 |
+
}
|
73 |
+
|
74 |
+
// Initial preview update
|
75 |
+
updatePreview();
|
76 |
+
|
77 |
+
</script>
|
78 |
+
</body>
|
79 |
+
|
80 |
+
</html>
|
Perceptrix/engine.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, BitsAndBytesConfig
|
2 |
+
from Perceptrix.streamer import TextStreamer
|
3 |
+
from utils import setup_device
|
4 |
+
import torch
|
5 |
+
import tqdm
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
|
9 |
+
model_name = os.environ.get('LLM_MODEL')
|
10 |
+
|
11 |
+
model_id = "models/CRYSTAL-instruct" if model_name == None else model_name
|
12 |
+
|
13 |
+
device = setup_device()
|
14 |
+
|
15 |
+
bnb_config = BitsAndBytesConfig(
|
16 |
+
load_in_4bit=True,
|
17 |
+
bnb_4bit_use_double_quant=True,
|
18 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
19 |
+
)
|
20 |
+
|
21 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
22 |
+
model_id,
|
23 |
+
use_fast=True)
|
24 |
+
|
25 |
+
model = LlamaForCausalLM.from_pretrained(
|
26 |
+
model_id,
|
27 |
+
load_in_8bit=False,
|
28 |
+
device_map="auto",
|
29 |
+
torch_dtype=torch.float16,
|
30 |
+
low_cpu_mem_usage=True,
|
31 |
+
offload_folder="offload",
|
32 |
+
quantization_config=bnb_config,
|
33 |
+
)
|
34 |
+
|
35 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True,
|
36 |
+
skip_special_tokens=True, save_file="reply.txt")
|
37 |
+
|
38 |
+
PROMPT = '''### Instruction:
|
39 |
+
{}
|
40 |
+
### Input:
|
41 |
+
{}
|
42 |
+
### Response:'''
|
43 |
+
|
44 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0
|
45 |
+
model.config.bos_token_id = 1
|
46 |
+
model.config.eos_token_id = 2
|
47 |
+
|
48 |
+
model.eval()
|
49 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
50 |
+
model = torch.compile(model)
|
51 |
+
|
52 |
+
|
53 |
+
def evaluate(
|
54 |
+
prompt='',
|
55 |
+
temperature=0.4,
|
56 |
+
top_p=0.65,
|
57 |
+
top_k=35,
|
58 |
+
repetition_penalty=1.1,
|
59 |
+
max_new_tokens=512,
|
60 |
+
**kwargs,
|
61 |
+
):
|
62 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
63 |
+
input_ids = inputs["input_ids"].to(device)
|
64 |
+
generation_config = GenerationConfig(
|
65 |
+
temperature=temperature,
|
66 |
+
top_p=top_p,
|
67 |
+
top_k=top_k,
|
68 |
+
repetition_penalty=repetition_penalty,
|
69 |
+
**kwargs,
|
70 |
+
)
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
generation_output = model.generate(
|
74 |
+
input_ids=input_ids,
|
75 |
+
generation_config=generation_config,
|
76 |
+
return_dict_in_generate=True,
|
77 |
+
output_scores=True,
|
78 |
+
max_new_tokens=max_new_tokens,
|
79 |
+
streamer=streamer,
|
80 |
+
)
|
81 |
+
s = generation_output.sequences[0]
|
82 |
+
output = tokenizer.decode(s)
|
83 |
+
yield output.split("### Response:")[-1].strip()
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
def run_instruction(
|
88 |
+
instruction,
|
89 |
+
inputs,
|
90 |
+
temperature=0.4,
|
91 |
+
top_p=0.65,
|
92 |
+
top_k=35,
|
93 |
+
repetition_penalty=1.1,
|
94 |
+
max_new_tokens=512,
|
95 |
+
stop_tokens=None,
|
96 |
+
):
|
97 |
+
now_prompt = PROMPT.format(instruction+'\n', inputs)
|
98 |
+
|
99 |
+
response = evaluate(
|
100 |
+
now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, stop_tokens=stop_tokens, do_sample=True
|
101 |
+
)
|
102 |
+
|
103 |
+
for i in response:
|
104 |
+
yield i
|
105 |
+
|
106 |
+
def search_keyword(prompt):
|
107 |
+
instructions = """Prompt:Time: Fri, 23 August 2023 2:30PM\nWeather: 73F\nHow many friends have I told you about?
|
108 |
+
Search Keyword:Friends
|
109 |
+
Prompt:Time: Thu, 27 September 2023 3:41PM\nWeather: 62F\nWhat was our very first conversation
|
110 |
+
Chat Index:0
|
111 |
+
Prompt:Time: Tue, 21 September 2023 2:30PM\nWeather: 67F\nWhat was the last thing I said to you
|
112 |
+
Chat Index:-1
|
113 |
+
Prompt:Time: Sun, 31 October 2023 7:33AM\nWeather: 59F\nWhat was the last thing I said to you before that
|
114 |
+
Chat Index:-2
|
115 |
+
Prompt:Time: Sat, 30 October 2023 8:21PM\nWeather: 65F\nDid I ever tell you about my math class?
|
116 |
+
Search Keyword:math
|
117 |
+
Prompt:Time: Mon, 13 November 2023 4:52PM\nWeather: 55F\nWhat was my 7th grade English teacher's name?
|
118 |
+
Search Keyword:English
|
119 |
+
Prompt:Time: Wed, 15 May 2023 6:19PM\nWeather: 80F\nWhere did I say my wallet was?
|
120 |
+
Search Keyword:Wallet
|
121 |
+
Prompt:Time: Fri, 24 June 2023 1:52PM\nWeather: 92F\nWhat did Alex tell you?
|
122 |
+
Search Keyword:Alex
|
123 |
+
Prompt:Time: Sat, 19 July 2023 2:44PM\nWeather: 91F\nWhat was my first conversation today
|
124 |
+
Search Keyword:24 June"""
|
125 |
+
answer = ''.join(run_instruction(
|
126 |
+
instructions,
|
127 |
+
"Prompt:"+prompt+"\n",
|
128 |
+
temperature=0.5,
|
129 |
+
top_p=0.5,
|
130 |
+
top_k=200,
|
131 |
+
repetition_penalty=1.1,
|
132 |
+
max_new_tokens=256,
|
133 |
+
))
|
134 |
+
return answer
|
135 |
+
|
136 |
+
|
137 |
+
def identify_objects_from_text(prompt):
|
138 |
+
instructions = """Input:The object that flies in the air from this picture is a toy helicopter
|
139 |
+
Output:Toy helicopter
|
140 |
+
Input:For the robot to be able to achieve the task, the robot needs to look for a white shirt
|
141 |
+
Output:White shirt
|
142 |
+
Input:To complete the task, the robot needs to remove the fruits from the wooden basket.
|
143 |
+
Output:fruits, wooden basket
|
144 |
+
Input:To clean up your desk, you need to gather and organize the various items scattered around it. This includes the laptop, cell phone, scissors, pens, and other objects. By putting these items back in their designated spaces or containers, you can create a more organized and clutter-free workspace.
|
145 |
+
Output:Laptop, cell phone, scissors, pens, containers
|
146 |
+
Input:The tree with a colorful sky background is the one to be looking for.
|
147 |
+
Output:Tree"""
|
148 |
+
answer = ''.join(run_instruction(
|
149 |
+
instructions,
|
150 |
+
prompt,
|
151 |
+
temperature=0.5,
|
152 |
+
top_p=0.5,
|
153 |
+
top_k=200,
|
154 |
+
repetition_penalty=1.1,
|
155 |
+
max_new_tokens=256,
|
156 |
+
))
|
157 |
+
return answer
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
def robotix(prompt, stop=None):
|
162 |
+
instructions = """#Get me some water
|
163 |
+
objects = [['water: 57%', (781, 592)]]
|
164 |
+
robot.target((781, 592))
|
165 |
+
object_distance = distance()
|
166 |
+
if object_distance > 10:
|
167 |
+
robot.go("forward", object_distance, track="water")
|
168 |
+
robot.grab()
|
169 |
+
if object_distance > 10:
|
170 |
+
robot.go("back", object_distance)
|
171 |
+
robot.release("here")
|
172 |
+
#Stand by the table
|
173 |
+
objects = [['table: 81%', (1489, 1173)], ['table: 75%', (1971, 1293)]]
|
174 |
+
robot.target((1489, 1173))
|
175 |
+
if distance() > 10:
|
176 |
+
robot.go(forward, distance())
|
177 |
+
#Put the apples in the basket
|
178 |
+
objects = [['basket: 77%', (89, 112)], ['apples: 72%', (222, 182)]]
|
179 |
+
robot.target((281, 189))
|
180 |
+
if distance() > 10:
|
181 |
+
robot.go("forward", distance(), track="apples")
|
182 |
+
robot.grab()
|
183 |
+
robot.target(robot.find("basket"))
|
184 |
+
robot.release(distance())
|
185 |
+
#Go to the sofa
|
186 |
+
objects=[['sofa: 81%', (1060, 931)]]
|
187 |
+
robot.target((1060, 931))
|
188 |
+
if distance() > 10:
|
189 |
+
robot.go("forward", distance())
|
190 |
+
#Go to that person over there and then come back
|
191 |
+
objects=[['person: 85%', (331, 354)]]
|
192 |
+
robot.target((331, 354))
|
193 |
+
object_distance = distance()
|
194 |
+
if object_distance > 10:
|
195 |
+
robot.go("forward", object_distance)
|
196 |
+
robot.go("backward", object_distance)
|
197 |
+
"""
|
198 |
+
|
199 |
+
answer = ''.join(run_instruction(
|
200 |
+
instructions,
|
201 |
+
prompt,
|
202 |
+
temperature=0.2,
|
203 |
+
top_p=0.5,
|
204 |
+
top_k=300,
|
205 |
+
repetition_penalty=1.1,
|
206 |
+
max_new_tokens=256,
|
207 |
+
stop_tokens=stop,
|
208 |
+
))
|
209 |
+
return answer
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
print(robotix("#Get me a glass of water\nobjects = [['water: 65%', (695, 234)]]"))
|
Perceptrix/finetune/Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
ARG BASE_IMAGE
|
5 |
+
FROM $BASE_IMAGE
|
6 |
+
|
7 |
+
ARG DEP_GROUPS
|
8 |
+
|
9 |
+
# Install and uninstall foundry to cache foundry requirements
|
10 |
+
RUN git clone -b main https://github.com/mosaicml/llm-foundry.git
|
11 |
+
RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}"
|
12 |
+
RUN pip uninstall -y llm-foundry
|
13 |
+
RUN rm -rf llm-foundry
|
Perceptrix/finetune/Makefile
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# several pytest settings
|
2 |
+
WORLD_SIZE ?= 1 # world size for launcher tests
|
3 |
+
MASTER_PORT ?= 26000 # port for distributed tests
|
4 |
+
PYTHON ?= python3 # Python command
|
5 |
+
PYTEST ?= pytest # Pytest command
|
6 |
+
PYRIGHT ?= pyright # Pyright command. Pyright must be installed seperately -- e.g. `node install -g pyright`
|
7 |
+
EXTRA_ARGS ?= # extra arguments for pytest
|
8 |
+
EXTRA_LAUNCHER_ARGS ?= # extra arguments for the composer cli launcher
|
9 |
+
|
10 |
+
test:
|
11 |
+
LOCAL_WORLD_SIZE=1 $(PYTHON) -m $(PYTEST) $(EXTRA_ARGS)
|
12 |
+
|
13 |
+
test-gpu:
|
14 |
+
LOCAL_WORLD_SIZE=1 $(PYTHON) -m $(PYTEST) -m gpu $(EXTRA_ARGS)
|
15 |
+
|
16 |
+
# runs tests with the launcher
|
17 |
+
test-dist:
|
18 |
+
$(PYTHON) -m composer.cli.launcher -n $(WORLD_SIZE) --master_port $(MASTER_PORT) $(EXTRA_LAUNCHER_ARGS) -m $(PYTEST) $(EXTRA_ARGS)
|
19 |
+
|
20 |
+
test-dist-gpu:
|
21 |
+
$(PYTHON) -m composer.cli.launcher -n $(WORLD_SIZE) --master_port $(MASTER_PORT) $(EXTRA_LAUNCHER_ARGS) -m $(PYTEST) -m gpu $(EXTRA_ARGS)
|
22 |
+
|
23 |
+
.PHONY: test test-gpu test-dist test-dist-gpu
|
Perceptrix/finetune/README.md
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN -->
|
2 |
+
<p align="center">
|
3 |
+
<a href="https://github.com/mosaicml/llm-foundry">
|
4 |
+
<picture>
|
5 |
+
<img alt="LLM Foundry" src="./assets/llm-foundry.png" width="95%">
|
6 |
+
</picture>
|
7 |
+
</a>
|
8 |
+
</p>
|
9 |
+
<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END -->
|
10 |
+
|
11 |
+
<p align="center">
|
12 |
+
<a href="https://pypi.org/project/llm-foundry/">
|
13 |
+
<img alt="PyPi Version" src="https://img.shields.io/pypi/pyversions/llm-foundry">
|
14 |
+
</a>
|
15 |
+
<a href="https://pypi.org/project/llm-foundry/">
|
16 |
+
<img alt="PyPi Package Version" src="https://img.shields.io/pypi/v/llm-foundry">
|
17 |
+
</a>
|
18 |
+
<a href="https://mosaicml.me/slack">
|
19 |
+
<img alt="Chat @ Slack" src="https://img.shields.io/badge/slack-chat-2eb67d.svg?logo=slack">
|
20 |
+
</a>
|
21 |
+
<a href="https://github.com/mosaicml/llm-foundry/blob/main/LICENSE">
|
22 |
+
<img alt="License" src="https://img.shields.io/badge/License-Apache%202.0-green.svg">
|
23 |
+
</a>
|
24 |
+
</p>
|
25 |
+
<br />
|
26 |
+
|
27 |
+
# LLM Foundry
|
28 |
+
|
29 |
+
This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase is designed to enable rapid experimentation with the latest techniques.
|
30 |
+
|
31 |
+
You'll find in this repo:
|
32 |
+
* `llmfoundry/` - source code for models, datasets, callbacks, utilities, etc.
|
33 |
+
* `scripts/` - scripts to run LLM workloads
|
34 |
+
* `data_prep/` - convert text data from original sources to StreamingDataset format
|
35 |
+
* `train/` - train or finetune HuggingFace and MPT models from 125M - 70B parameters
|
36 |
+
* `train/benchmarking` - profile training throughput and MFU
|
37 |
+
* `inference/` - convert models to HuggingFace or ONNX format, and generate responses
|
38 |
+
* `inference/benchmarking` - profile inference latency and throughput
|
39 |
+
* `eval/` - evaluate LLMs on academic (or custom) in-context-learning tasks
|
40 |
+
* `mcli/` - launch any of these workloads using [MCLI](https://docs.mosaicml.com/projects/mcli/en/latest/) and the [MosaicML platform](https://www.mosaicml.com/platform)
|
41 |
+
* `TUTORIAL.md` - a deeper dive into the repo, example workflows, and FAQs
|
42 |
+
|
43 |
+
# MPT
|
44 |
+
|
45 |
+
Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:
|
46 |
+
|
47 |
+
|
48 |
+
| Model | Context Length | Download | Demo | Commercial use? |
|
49 |
+
| ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- |
|
50 |
+
| MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
|
51 |
+
| MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
|
52 |
+
| MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
|
53 |
+
| MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
|
54 |
+
| MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
|
55 |
+
| MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
|
56 |
+
| MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |
|
57 |
+
|
58 |
+
To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts.
|
59 |
+
|
60 |
+
# MPT Community
|
61 |
+
|
62 |
+
We've been overwhelmed by all the amazing work the community has put into MPT! Here we provide a few links to some of them:
|
63 |
+
* [ReplitLM](https://github.com/replit/replitLM): `replit-code-v1-3b` is a 2.7B Causal Language Model focused on Code Completion. The model has been trained on a subset of the Stack Dedup v1.2 dataset covering 20 languages such as Java, Python, and C++
|
64 |
+
* [LLaVa-MPT](https://github.com/haotian-liu/LLaVA#LLaVA-MPT-7b): Visual instruction tuning to get MPT multimodal capabilities
|
65 |
+
* [ggml](https://github.com/ggerganov/ggml/tree/master): Optimized MPT version for efficient inference on consumer hardware
|
66 |
+
* [GPT4All](https://gpt4all.io/index.html): locally running chat system, now with MPT support!
|
67 |
+
* [Q8MPT-Chat](https://huggingface.co/spaces/Intel/Q8-Chat): 8-bit optimized MPT for CPU by our friends at Intel
|
68 |
+
|
69 |
+
Tutorial videos from the community:
|
70 |
+
* [Using MPT-7B with Langchain](https://www.youtube.com/watch?v=DXpk9K7DgMo&t=3s) by [@jamesbriggs](https://www.youtube.com/@jamesbriggs)
|
71 |
+
* [MPT-7B StoryWriter Intro](https://www.youtube.com/watch?v=O9Y_ZdsuKWQ) by [AItrepreneur](https://www.youtube.com/@Aitrepreneur)
|
72 |
+
* [Fine-tuning MPT-7B on a single GPU](https://www.youtube.com/watch?v=KSlWkrByc0o&t=9s) by [@AIology2022](https://www.youtube.com/@AIology2022)
|
73 |
+
* [How to Fine-tune MPT-7B-Instruct on Google Colab](https://youtu.be/3de0Utr9XnI) by [@VRSEN](https://www.youtube.com/@vrsen)
|
74 |
+
|
75 |
+
Something missing? Contribute with a PR!
|
76 |
+
|
77 |
+
# Latest News
|
78 |
+
* [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b)
|
79 |
+
* [Blog: Introducing MPT-7B](https://www.mosaicml.com/blog/mpt-7b)
|
80 |
+
* [Blog: Benchmarking LLMs on H100](https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1)
|
81 |
+
* [Blog: Blazingly Fast LLM Evaluation](https://www.mosaicml.com/blog/llm-evaluation-for-icl)
|
82 |
+
* [Blog: GPT3 Quality for $500k](https://www.mosaicml.com/blog/gpt-3-quality-for-500k)
|
83 |
+
* [Blog: Billion parameter GPT training made easy](https://www.mosaicml.com/blog/billion-parameter-gpt-training-made-easy)
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
# Hardware and Software Requirements
|
88 |
+
This codebase has been tested with PyTorch 1.13.1 and PyTorch 2.0.1 on systems with NVIDIA A100s and H100s.
|
89 |
+
This codebase may also work on systems with other devices, such as consumer NVIDIA cards and AMD cards, but we are not actively testing these systems.
|
90 |
+
If you have success/failure using LLM Foundry on other systems, please let us know in a Github issue and we will update the support matrix!
|
91 |
+
|
92 |
+
| Device | Torch Version | Cuda Version | Status |
|
93 |
+
| -------------- | ------------- | ------------ | ---------------------------- |
|
94 |
+
| A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
|
95 |
+
| A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
|
96 |
+
| A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
|
97 |
+
| H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
|
98 |
+
| H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
|
99 |
+
| H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
|
100 |
+
| A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
|
101 |
+
| A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
|
102 |
+
| MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |
|
103 |
+
|
104 |
+
## MosaicML Docker Images
|
105 |
+
We highly recommend using our prebuilt Docker images. You can find them here: https://hub.docker.com/orgs/mosaicml/repositories.
|
106 |
+
|
107 |
+
The `mosaicml/pytorch` images are pinned to specific PyTorch and CUDA versions, and are stable and rarely updated.
|
108 |
+
|
109 |
+
The `mosaicml/llm-foundry` images are built with new tags upon every commit to the `main` branch.
|
110 |
+
You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117-f678575` or take the latest one using `mosaicml/llm-foundry:1.13.1_cu117-latest`.
|
111 |
+
|
112 |
+
**Please Note:** The `mosaicml/llm-foundry` images do not come with the `llm-foundry` package preinstalled, just the dependencies. You will still need to `pip install llm-foundry` either from PyPi or from source.
|
113 |
+
|
114 |
+
| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
|
115 |
+
| ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- |
|
116 |
+
| `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 (Infiniband) | No |
|
117 |
+
| `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 (Infiniband) | No |
|
118 |
+
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No |
|
119 |
+
| `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 (Infiniband) | Yes |
|
120 |
+
| `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 (Infiniband) | Yes |
|
121 |
+
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1) |
|
122 |
+
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2) |
|
123 |
+
| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1) |
|
124 |
+
| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2) |
|
125 |
+
|
126 |
+
|
127 |
+
# Installation
|
128 |
+
|
129 |
+
This assumes you already have PyTorch and CMake installed.
|
130 |
+
|
131 |
+
To get started, clone the repo and set up your environment. Instructions to do so differ slightly depending on whether you're using Docker.
|
132 |
+
### With Docker (recommended)
|
133 |
+
|
134 |
+
We *strongly* recommend working with LLM Foundry inside a Docker container (see our recommended Docker image above). If you are doing so, follow these steps to clone the repo and install the requirements.
|
135 |
+
|
136 |
+
<!--pytest.mark.skip-->
|
137 |
+
```bash
|
138 |
+
git clone https://github.com/mosaicml/llm-foundry.git
|
139 |
+
cd llm-foundry
|
140 |
+
pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU
|
141 |
+
```
|
142 |
+
|
143 |
+
### Without Docker (not recommended)
|
144 |
+
|
145 |
+
If you choose not to use Docker, you should create and use a virtual environment.
|
146 |
+
|
147 |
+
<!--pytest.mark.skip-->
|
148 |
+
```bash
|
149 |
+
git clone https://github.com/mosaicml/llm-foundry.git
|
150 |
+
cd llm-foundry
|
151 |
+
|
152 |
+
# Creating and activate a virtual environment
|
153 |
+
python3 -m venv llmfoundry-venv
|
154 |
+
source llmfoundry-venv/bin/activate
|
155 |
+
|
156 |
+
pip install cmake packaging torch # setup.py requires these be installed
|
157 |
+
|
158 |
+
pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU
|
159 |
+
```
|
160 |
+
|
161 |
+
### TransformerEngine and amp_fp8 support
|
162 |
+
NVIDIA H100 GPUs have FP8 support; this additionally requires the following installations:
|
163 |
+
<!--pytest.mark.skip-->
|
164 |
+
```bash
|
165 |
+
pip install flash-attn==1.0.7 --no-build-isolation
|
166 |
+
pip install git+https://github.com/NVIDIA/[email protected]
|
167 |
+
```
|
168 |
+
|
169 |
+
See [here](https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#TransformerEngine-and-amp_fp8-support) for more details on enabling TransformerEngine layers and amp_fp8.
|
170 |
+
|
171 |
+
### AMD (BETA support)
|
172 |
+
|
173 |
+
In [our testing of AMD GPUs](https://www.mosaicml.com/blog/amd-mi250), the env setup includes:
|
174 |
+
|
175 |
+
<!--pytest.mark.skip-->
|
176 |
+
```bash
|
177 |
+
git clone https://github.com/mosaicml/llm-foundry.git
|
178 |
+
cd llm-foundry
|
179 |
+
|
180 |
+
# Creating and activate a virtual environment
|
181 |
+
python3 -m venv llmfoundry-venv-amd
|
182 |
+
source llmfoundry-venv-amd/bin/activate
|
183 |
+
|
184 |
+
# installs
|
185 |
+
pip install cmake packaging torch
|
186 |
+
pip install -e . # This installs some things that are not needed but they don't hurt
|
187 |
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2
|
188 |
+
```
|
189 |
+
**Lastly**, install the ROCm enabled flash attention (instructions [here](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support)).
|
190 |
+
|
191 |
+
Notes:
|
192 |
+
1. `attn_impl: triton` does not work.
|
193 |
+
1. We don't yet have a docker img where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.
|
194 |
+
|
195 |
+
# Quickstart
|
196 |
+
|
197 |
+
> **Note**
|
198 |
+
> Make sure to go through the installation steps above before trying the quickstart!
|
199 |
+
|
200 |
+
Here is an end-to-end workflow for preparing a subset of the C4 dataset, training an MPT-125M model for 10 batches,
|
201 |
+
converting the model to HuggingFace format, evaluating the model on the Winograd challenge, and generating responses to prompts.
|
202 |
+
|
203 |
+
**(Remember this is a quickstart just to demonstrate the tools -- To get good quality, the LLM must be trained for longer than 10 batches 😄)**
|
204 |
+
|
205 |
+
<!--pytest.mark.skip-->
|
206 |
+
```bash
|
207 |
+
cd scripts
|
208 |
+
|
209 |
+
# Convert C4 dataset to StreamingDataset format
|
210 |
+
python data_prep/convert_dataset_hf.py \
|
211 |
+
--dataset c4 --data_subset en \
|
212 |
+
--out_root my-copy-c4 --splits train_small val_small \
|
213 |
+
--concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>'
|
214 |
+
|
215 |
+
# Train an MPT-125m model for 10 batches
|
216 |
+
composer train/train.py \
|
217 |
+
train/yamls/pretrain/mpt-125m.yaml \
|
218 |
+
data_local=my-copy-c4 \
|
219 |
+
train_loader.dataset.split=train_small \
|
220 |
+
eval_loader.dataset.split=val_small \
|
221 |
+
max_duration=10ba \
|
222 |
+
eval_interval=0 \
|
223 |
+
save_folder=mpt-125m
|
224 |
+
|
225 |
+
# Convert the model to HuggingFace format
|
226 |
+
python inference/convert_composer_to_hf.py \
|
227 |
+
--composer_path mpt-125m/ep0-ba10-rank0.pt \
|
228 |
+
--hf_output_path mpt-125m-hf \
|
229 |
+
--output_precision bf16 \
|
230 |
+
# --hf_repo_for_upload user-org/repo-name
|
231 |
+
|
232 |
+
# Evaluate the model on a subset of tasks
|
233 |
+
composer eval/eval.py \
|
234 |
+
eval/yamls/hf_eval.yaml \
|
235 |
+
icl_tasks=eval/yamls/copa.yaml \
|
236 |
+
model_name_or_path=mpt-125m-hf
|
237 |
+
|
238 |
+
# Generate responses to prompts
|
239 |
+
python inference/hf_generate.py \
|
240 |
+
--name_or_path mpt-125m-hf \
|
241 |
+
--max_new_tokens 256 \
|
242 |
+
--prompts \
|
243 |
+
"The answer to life, the universe, and happiness is" \
|
244 |
+
"Here's a quick recipe for baking chocolate chip cookies: Start by"
|
245 |
+
```
|
246 |
+
|
247 |
+
Note: the `composer` command used above to train the model refers to [Composer](https://github.com/mosaicml/composer) library's distributed launcher.
|
248 |
+
|
249 |
+
If you have a write-enabled [HuggingFace auth token](https://huggingface.co/docs/hub/security-tokens), you can optionally upload your model to the Hub! Just export your token like this:
|
250 |
+
|
251 |
+
```bash
|
252 |
+
export HUGGING_FACE_HUB_TOKEN=your-auth-token
|
253 |
+
```
|
254 |
+
|
255 |
+
and uncomment the line containing `--hf_repo_for_upload ...` in the above call to `inference/convert_composer_to_hf.py`.
|
256 |
+
|
257 |
+
# Learn more about LLM Foundry!
|
258 |
+
|
259 |
+
Check out [TUTORIAL.md](https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md) to keep learning about working with LLM Foundry. The tutorial highlights example workflows, points you to other resources throughout the repo, and answers frequently asked questions!
|
260 |
+
|
261 |
+
# Contact Us
|
262 |
+
|
263 |
+
If you run into any problems with the code, please file Github issues directly to this repo.
|
264 |
+
|
265 |
+
If you want to train LLMs on the MosaicML platform, reach out to us at [[email protected]](mailto:[email protected])!
|
Perceptrix/finetune/build/lib/inference/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
__all__ = []
|
Perceptrix/finetune/build/lib/inference/convert_composer_mpt_to_ft.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# Note: This script is specifically for converting MPT Composer checkpoints to FasterTransformer format.
|
5 |
+
|
6 |
+
import configparser
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
from argparse import ArgumentParser, Namespace
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, Dict, Optional, Union
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from composer.utils import get_file, safe_torch_load
|
15 |
+
from transformers import PreTrainedTokenizer
|
16 |
+
|
17 |
+
from llmfoundry.utils import (convert_and_save_ft_weights,
|
18 |
+
get_hf_tokenizer_from_composer_state_dict)
|
19 |
+
|
20 |
+
|
21 |
+
def save_ft_config(composer_config: Dict[str, Any],
|
22 |
+
tokenizer: PreTrainedTokenizer,
|
23 |
+
save_dir: str,
|
24 |
+
infer_gpu_num: int = 1,
|
25 |
+
weight_data_type: str = 'fp32',
|
26 |
+
force: bool = False):
|
27 |
+
|
28 |
+
config = configparser.ConfigParser()
|
29 |
+
config['gpt'] = {}
|
30 |
+
try:
|
31 |
+
config['gpt']['model_name'] = 'mpt'
|
32 |
+
config['gpt']['head_num'] = str(composer_config['n_heads'])
|
33 |
+
n_embd = composer_config['d_model']
|
34 |
+
config['gpt']['size_per_head'] = str(n_embd //
|
35 |
+
composer_config['n_heads'])
|
36 |
+
config['gpt']['inter_size'] = str(n_embd * composer_config['mlp_ratio'])
|
37 |
+
config['gpt']['max_pos_seq_len'] = str(composer_config['max_seq_len'])
|
38 |
+
config['gpt']['num_layer'] = str(composer_config['n_layers'])
|
39 |
+
config['gpt']['vocab_size'] = str(composer_config['vocab_size'])
|
40 |
+
config['gpt']['start_id'] = str(tokenizer.bos_token_id)
|
41 |
+
config['gpt']['end_id'] = str(tokenizer.eos_token_id)
|
42 |
+
config['gpt']['weight_data_type'] = weight_data_type
|
43 |
+
config['gpt']['tensor_para_size'] = str(infer_gpu_num)
|
44 |
+
# nn.LayerNorm default eps is 1e-5
|
45 |
+
config['gpt']['layernorm_eps'] = str(1e-5)
|
46 |
+
if composer_config['alibi']:
|
47 |
+
config['gpt']['has_positional_encoding'] = str(False)
|
48 |
+
config['gpt']['use_attention_linear_bias'] = str(True)
|
49 |
+
if composer_config['attn_clip_qkv'] and not force:
|
50 |
+
raise RuntimeError(
|
51 |
+
'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
|
52 |
+
)
|
53 |
+
if composer_config['attn_qk_ln'] and not force:
|
54 |
+
raise RuntimeError(
|
55 |
+
'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
|
56 |
+
)
|
57 |
+
|
58 |
+
with open(os.path.join(save_dir, 'config.ini'), 'w') as configfile:
|
59 |
+
config.write(configfile)
|
60 |
+
return config
|
61 |
+
except:
|
62 |
+
print(f'Failed to save the config in config.ini.')
|
63 |
+
raise
|
64 |
+
|
65 |
+
|
66 |
+
def write_ft_checkpoint_from_composer_checkpoint(
|
67 |
+
checkpoint_path: Union[Path, str],
|
68 |
+
infer_gpu_num: int,
|
69 |
+
save_dir: str,
|
70 |
+
output_precision: str = 'fp32',
|
71 |
+
local_checkpoint_save_location: Optional[Union[Path,
|
72 |
+
str]] = None) -> None:
|
73 |
+
"""Convert a Composer checkpoint to a FasterTransformer checkpoint folder.
|
74 |
+
|
75 |
+
.. note:: This function may not work properly if you used surgery algorithms when you trained your model. In that case you may need to
|
76 |
+
edit the parameter conversion methods to properly convert your custom model.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend
|
80 |
+
supported by Composer.
|
81 |
+
infer_gpu_num (int): The number of gpus you are planning to use for inference.
|
82 |
+
save_dir (str): Path of the directory to save the checkpoint in FT format.
|
83 |
+
output_precision (str, optional): The precision of the output weights saved to the FasterTransformer model. Can be either ``fp32`` or ``fp16``.
|
84 |
+
local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally.
|
85 |
+
If the input ``checkpoint_path`` is already a local path, this will be a symlink.
|
86 |
+
Defaults to None, which will use a temporary file.
|
87 |
+
"""
|
88 |
+
dtype = {
|
89 |
+
'fp32': torch.float32,
|
90 |
+
'fp16': torch.float16,
|
91 |
+
}[output_precision]
|
92 |
+
|
93 |
+
# default local path to a tempfile if path is not provided
|
94 |
+
if local_checkpoint_save_location is None:
|
95 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
96 |
+
local_checkpoint_save_location = Path(
|
97 |
+
tmp_dir.name) / 'local-composer-checkpoint.pt'
|
98 |
+
|
99 |
+
# download the checkpoint file
|
100 |
+
print(
|
101 |
+
f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}'
|
102 |
+
)
|
103 |
+
get_file(str(checkpoint_path), str(local_checkpoint_save_location))
|
104 |
+
|
105 |
+
# Load the Composer checkpoint. Use it to get the
|
106 |
+
# Composer state dict and weights
|
107 |
+
print('Loading checkpoint into CPU RAM...')
|
108 |
+
composer_state_dict = safe_torch_load(local_checkpoint_save_location)
|
109 |
+
|
110 |
+
# Extract Composer config from state dict
|
111 |
+
if 'state' not in composer_state_dict:
|
112 |
+
raise RuntimeError(
|
113 |
+
f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?'
|
114 |
+
)
|
115 |
+
if 'integrations' not in composer_state_dict[
|
116 |
+
'state'] or 'huggingface' not in composer_state_dict['state'][
|
117 |
+
'integrations']:
|
118 |
+
raise RuntimeError(
|
119 |
+
'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!'
|
120 |
+
)
|
121 |
+
composer_config = composer_state_dict['state']['integrations'][
|
122 |
+
'huggingface']['model']['config']['content']
|
123 |
+
|
124 |
+
# Extract the HF tokenizer
|
125 |
+
print('#' * 30)
|
126 |
+
print('Extracting HF Tokenizer...')
|
127 |
+
hf_tokenizer = get_hf_tokenizer_from_composer_state_dict(
|
128 |
+
composer_state_dict)
|
129 |
+
if hf_tokenizer is None:
|
130 |
+
print('Warning! No HF Tokenizer found!')
|
131 |
+
|
132 |
+
# Extract the model weights
|
133 |
+
weights_state_dict = composer_state_dict['state']['model']
|
134 |
+
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
|
135 |
+
weights_state_dict, prefix='model.')
|
136 |
+
|
137 |
+
# Converting weights to desired dtype
|
138 |
+
for k, v in weights_state_dict.items():
|
139 |
+
if isinstance(v, torch.Tensor):
|
140 |
+
weights_state_dict[k] = v.to(dtype=dtype)
|
141 |
+
|
142 |
+
# Convert the weights using the config and tokenizer to FasterTransformer format
|
143 |
+
print('#' * 30)
|
144 |
+
print('Saving FasterTransformer config...')
|
145 |
+
save_ft_config(composer_config,
|
146 |
+
tokenizer=hf_tokenizer,
|
147 |
+
save_dir=save_dir,
|
148 |
+
weight_data_type=output_precision)
|
149 |
+
print('#' * 30)
|
150 |
+
print('Converting weights to FasterTransformer format...')
|
151 |
+
convert_and_save_ft_weights(named_params=weights_state_dict,
|
152 |
+
config=composer_config,
|
153 |
+
infer_gpu_num=infer_gpu_num,
|
154 |
+
weight_data_type=output_precision,
|
155 |
+
save_dir=save_dir)
|
156 |
+
|
157 |
+
print('#' * 30)
|
158 |
+
print(
|
159 |
+
f'FasterTransformer checkpoint folder successfully created at {save_dir}.'
|
160 |
+
)
|
161 |
+
|
162 |
+
print('Done.')
|
163 |
+
print('#' * 30)
|
164 |
+
|
165 |
+
|
166 |
+
def parse_args() -> Namespace:
|
167 |
+
"""Parse commandline arguments."""
|
168 |
+
parser = ArgumentParser(
|
169 |
+
description=
|
170 |
+
'Convert an MPT Composer checkpoint into a standard FasterTransformer checkpoint folder.'
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
'--composer_path',
|
174 |
+
'-i',
|
175 |
+
type=str,
|
176 |
+
help='Composer checkpoint path. Can be a local file path or cloud URI',
|
177 |
+
required=True)
|
178 |
+
parser.add_argument(
|
179 |
+
'--local_checkpoint_save_location',
|
180 |
+
type=str,
|
181 |
+
help='If specified, where to save the checkpoint file to locally. \
|
182 |
+
If the input ``checkpoint_path`` is already a local path, this will be a symlink. \
|
183 |
+
Defaults to None, which will use a temporary file.',
|
184 |
+
default=None)
|
185 |
+
parser.add_argument(
|
186 |
+
'--ft_save_dir',
|
187 |
+
'-o',
|
188 |
+
type=str,
|
189 |
+
help='Directory to save FasterTransformer converted checkpoint in',
|
190 |
+
required=True)
|
191 |
+
parser.add_argument('--infer_gpu_num',
|
192 |
+
'-i_g',
|
193 |
+
type=int,
|
194 |
+
help='How many gpus for inference?',
|
195 |
+
required=True)
|
196 |
+
parser.add_argument(
|
197 |
+
'--force',
|
198 |
+
action='store_true',
|
199 |
+
help=
|
200 |
+
'Force conversion to FT even if some features may not work as expected in FT'
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
'--output_precision',
|
204 |
+
type=str,
|
205 |
+
help=
|
206 |
+
'Data type of weights in the FasterTransformer output model. Input checkpoint weights will be converted to this dtype.',
|
207 |
+
choices=['fp32', 'fp16'],
|
208 |
+
default='fp32')
|
209 |
+
|
210 |
+
return parser.parse_args()
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == '__main__':
|
214 |
+
args = parse_args()
|
215 |
+
print('\n=============== Argument ===============')
|
216 |
+
for key in vars(args):
|
217 |
+
print('{}: {}'.format(key, vars(args)[key]))
|
218 |
+
print('========================================')
|
219 |
+
|
220 |
+
save_dir = os.path.join(args.ft_save_dir, f'{args.infer_gpu_num}-gpu')
|
221 |
+
|
222 |
+
if os.path.exists(save_dir) == False:
|
223 |
+
os.makedirs(save_dir)
|
224 |
+
else:
|
225 |
+
raise RuntimeError(f'Output path {save_dir} already exists!')
|
226 |
+
|
227 |
+
write_ft_checkpoint_from_composer_checkpoint(
|
228 |
+
checkpoint_path=args.composer_path,
|
229 |
+
infer_gpu_num=args.infer_gpu_num,
|
230 |
+
save_dir=save_dir,
|
231 |
+
output_precision=args.output_precision,
|
232 |
+
local_checkpoint_save_location=args.local_checkpoint_save_location)
|
Perceptrix/finetune/build/lib/inference/convert_composer_to_hf.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import os
|
5 |
+
import tempfile
|
6 |
+
from argparse import ArgumentParser, Namespace
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import transformers
|
12 |
+
from composer.models.huggingface import get_hf_config_from_composer_state_dict
|
13 |
+
from composer.utils import (get_file, maybe_create_object_store_from_uri,
|
14 |
+
parse_uri, safe_torch_load)
|
15 |
+
from transformers import PretrainedConfig, PreTrainedTokenizerBase
|
16 |
+
|
17 |
+
from llmfoundry import MPTConfig, MPTForCausalLM
|
18 |
+
from llmfoundry.utils import get_hf_tokenizer_from_composer_state_dict
|
19 |
+
from llmfoundry.utils.huggingface_hub_utils import \
|
20 |
+
edit_files_for_hf_compatibility
|
21 |
+
|
22 |
+
|
23 |
+
def write_huggingface_pretrained_from_composer_checkpoint(
|
24 |
+
checkpoint_path: Union[Path, str],
|
25 |
+
output_path: Union[Path, str],
|
26 |
+
output_precision: str = 'fp32',
|
27 |
+
local_checkpoint_save_location: Optional[Union[Path, str]] = None
|
28 |
+
) -> Tuple[PretrainedConfig, Optional[PreTrainedTokenizerBase]]:
|
29 |
+
"""Convert a Composer checkpoint to a pretrained HF checkpoint folder.
|
30 |
+
|
31 |
+
Write a ``config.json`` and ``pytorch_model.bin``, like
|
32 |
+
:meth:`transformers.PreTrainedModel.from_pretrained` expects, from a
|
33 |
+
composer checkpoint.
|
34 |
+
|
35 |
+
.. note:: This function will not work properly if you used surgery algorithms when you trained your model. In that case you will want to
|
36 |
+
load the model weights using the Composer :class:`~composer.Trainer` with the ``load_path`` argument.
|
37 |
+
.. testsetup::
|
38 |
+
import torch
|
39 |
+
dataset = RandomTextClassificationDataset(size=16, use_keys=True)
|
40 |
+
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
|
41 |
+
eval_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
|
42 |
+
import transformers
|
43 |
+
from composer.models import HuggingFaceModel
|
44 |
+
from composer.trainer import Trainer
|
45 |
+
hf_model = transformers.AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-tiny', num_labels=2)
|
46 |
+
hf_tokenizer = transformers.AutoTokenizer.from_pretrained('prajjwal1/bert-tiny')
|
47 |
+
composer_model = HuggingFaceModel(hf_model, tokenizer=hf_tokenizer, metrics=[], use_logits=True)
|
48 |
+
trainer = Trainer(model=composer_model,
|
49 |
+
train_dataloader=train_dataloader,
|
50 |
+
save_filename='composer-hf-checkpoint.pt',
|
51 |
+
max_duration='1ep',
|
52 |
+
save_folder='./')
|
53 |
+
trainer.fit()
|
54 |
+
trainer.close()
|
55 |
+
|
56 |
+
Example:
|
57 |
+
.. testcode::
|
58 |
+
from composer.models import write_huggingface_pretrained_from_composer_checkpoint
|
59 |
+
write_huggingface_pretrained_from_composer_checkpoint('composer-hf-checkpoint.pt', './hf-save-pretrained-output')
|
60 |
+
loaded_model = transformers.AutoModelForSequenceClassification.from_pretrained('./hf-save-pretrained-output')
|
61 |
+
|
62 |
+
Args:
|
63 |
+
checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend
|
64 |
+
supported by :meth:`composer.utils.maybe_create_object_store_from_uri`.
|
65 |
+
output_path (Union[Path, str]): Path to the folder to write the output to.
|
66 |
+
output_precision (str, optional): The precision of the output weights saved to `pytorch_model.bin`. Can be one of ``fp32``, ``fp16``, or ``bf16``.
|
67 |
+
local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally.
|
68 |
+
If the input ``checkpoint_path`` is already a local path, this will be a symlink.
|
69 |
+
Defaults to None, which will use a temporary file.
|
70 |
+
"""
|
71 |
+
dtype = {
|
72 |
+
'fp32': torch.float32,
|
73 |
+
'fp16': torch.float16,
|
74 |
+
'bf16': torch.bfloat16,
|
75 |
+
}[output_precision]
|
76 |
+
|
77 |
+
# default local path to a tempfile if path is not provided
|
78 |
+
if local_checkpoint_save_location is None:
|
79 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
80 |
+
local_checkpoint_save_location = Path(
|
81 |
+
tmp_dir.name) / 'local-composer-checkpoint.pt'
|
82 |
+
|
83 |
+
# create folder
|
84 |
+
os.makedirs(output_path)
|
85 |
+
|
86 |
+
# download the checkpoint file
|
87 |
+
print(
|
88 |
+
f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}'
|
89 |
+
)
|
90 |
+
get_file(str(checkpoint_path), str(local_checkpoint_save_location))
|
91 |
+
|
92 |
+
# Load the Composer checkpoint state dict
|
93 |
+
print('Loading checkpoint into CPU RAM...')
|
94 |
+
composer_state_dict = safe_torch_load(local_checkpoint_save_location)
|
95 |
+
|
96 |
+
if 'state' not in composer_state_dict:
|
97 |
+
raise RuntimeError(
|
98 |
+
f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?'
|
99 |
+
)
|
100 |
+
|
101 |
+
# Build and save HF Config
|
102 |
+
print('#' * 30)
|
103 |
+
print('Saving HF Model Config...')
|
104 |
+
hf_config = get_hf_config_from_composer_state_dict(composer_state_dict)
|
105 |
+
hf_config.torch_dtype = dtype
|
106 |
+
hf_config.save_pretrained(output_path)
|
107 |
+
print(hf_config)
|
108 |
+
|
109 |
+
# Extract and save the HF tokenizer
|
110 |
+
print('#' * 30)
|
111 |
+
print('Saving HF Tokenizer...')
|
112 |
+
hf_tokenizer = get_hf_tokenizer_from_composer_state_dict(
|
113 |
+
composer_state_dict)
|
114 |
+
if hf_tokenizer is not None:
|
115 |
+
hf_tokenizer.save_pretrained(output_path)
|
116 |
+
print(hf_tokenizer)
|
117 |
+
else:
|
118 |
+
print('Warning! No HF Tokenizer found!')
|
119 |
+
|
120 |
+
# Extract the HF model weights
|
121 |
+
print('#' * 30)
|
122 |
+
print('Saving HF Model Weights...')
|
123 |
+
weights_state_dict = composer_state_dict
|
124 |
+
if 'state' in weights_state_dict:
|
125 |
+
weights_state_dict = weights_state_dict['state']['model']
|
126 |
+
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
|
127 |
+
weights_state_dict, prefix='model.')
|
128 |
+
|
129 |
+
# Convert weights to desired dtype
|
130 |
+
for k, v in weights_state_dict.items():
|
131 |
+
if isinstance(v, torch.Tensor):
|
132 |
+
weights_state_dict[k] = v.to(dtype=dtype)
|
133 |
+
|
134 |
+
# Save weights
|
135 |
+
torch.save(weights_state_dict, Path(output_path) / 'pytorch_model.bin')
|
136 |
+
|
137 |
+
print('#' * 30)
|
138 |
+
print(f'HF checkpoint folder successfully created at {output_path}.')
|
139 |
+
|
140 |
+
return hf_config, hf_tokenizer
|
141 |
+
|
142 |
+
|
143 |
+
def parse_args() -> Namespace:
|
144 |
+
"""Parse commandline arguments."""
|
145 |
+
parser = ArgumentParser(
|
146 |
+
description=
|
147 |
+
'Convert a HuggingFace causal LM in a Composer checkpoint into a standard HuggingFace checkpoint folder, and optionally upload to the hub.'
|
148 |
+
)
|
149 |
+
parser.add_argument('--composer_path', type=str, required=True)
|
150 |
+
parser.add_argument('--hf_output_path', type=str, required=True)
|
151 |
+
parser.add_argument('--local_checkpoint_save_location',
|
152 |
+
type=str,
|
153 |
+
default=None)
|
154 |
+
parser.add_argument('--output_precision',
|
155 |
+
type=str,
|
156 |
+
choices=['fp32', 'fp16', 'bf16'],
|
157 |
+
default='fp32')
|
158 |
+
parser.add_argument('--hf_repo_for_upload', type=str, default=None)
|
159 |
+
parser.add_argument('--test_uploaded_model', action='store_true')
|
160 |
+
|
161 |
+
return parser.parse_args()
|
162 |
+
|
163 |
+
|
164 |
+
def convert_composer_to_hf(args: Namespace) -> None:
|
165 |
+
print()
|
166 |
+
print('#' * 30)
|
167 |
+
print('Converting Composer checkpoint to HuggingFace checkpoint format...')
|
168 |
+
|
169 |
+
# Register MPT auto classes so that this script works with MPT
|
170 |
+
# This script will not work without modification for other custom models,
|
171 |
+
# but will work for other HuggingFace causal LMs
|
172 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
173 |
+
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
|
174 |
+
MPTConfig.register_for_auto_class()
|
175 |
+
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
|
176 |
+
|
177 |
+
_, _, local_folder_path = parse_uri(args.hf_output_path)
|
178 |
+
|
179 |
+
config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint(
|
180 |
+
checkpoint_path=args.composer_path,
|
181 |
+
output_path=local_folder_path,
|
182 |
+
output_precision=args.output_precision,
|
183 |
+
local_checkpoint_save_location=args.local_checkpoint_save_location)
|
184 |
+
|
185 |
+
dtype = {
|
186 |
+
'fp32': torch.float32,
|
187 |
+
'fp16': torch.float16,
|
188 |
+
'bf16': torch.bfloat16,
|
189 |
+
}[args.output_precision]
|
190 |
+
|
191 |
+
print(f'Loading model from {local_folder_path}')
|
192 |
+
if config.model_type == 'mpt':
|
193 |
+
config.attn_config['attn_impl'] = 'torch'
|
194 |
+
config.init_device = 'cpu'
|
195 |
+
|
196 |
+
if config.model_type == 'mpt':
|
197 |
+
loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path,
|
198 |
+
config=config,
|
199 |
+
torch_dtype=dtype)
|
200 |
+
else:
|
201 |
+
loaded_hf_model = transformers.AutoModelForCausalLM.from_pretrained(
|
202 |
+
local_folder_path, config=config, torch_dtype=dtype)
|
203 |
+
|
204 |
+
delattr(loaded_hf_model.config, '_name_or_path')
|
205 |
+
|
206 |
+
loaded_hf_model.save_pretrained(local_folder_path)
|
207 |
+
|
208 |
+
print(f'Loading tokenizer from {local_folder_path}')
|
209 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(local_folder_path)
|
210 |
+
tokenizer.save_pretrained(local_folder_path)
|
211 |
+
|
212 |
+
# Only need to edit files for MPT because it has custom code
|
213 |
+
if config.model_type == 'mpt':
|
214 |
+
print('Editing files for HF compatibility...')
|
215 |
+
edit_files_for_hf_compatibility(local_folder_path)
|
216 |
+
|
217 |
+
object_store = maybe_create_object_store_from_uri(str(args.hf_output_path))
|
218 |
+
|
219 |
+
if object_store is not None:
|
220 |
+
print(
|
221 |
+
f'Uploading HF checkpoint folder from {local_folder_path} -> {args.hf_output_path}'
|
222 |
+
)
|
223 |
+
for file in os.listdir(local_folder_path):
|
224 |
+
remote_file = os.path.join(local_folder_path, file)
|
225 |
+
local_file = os.path.join(local_folder_path, file)
|
226 |
+
object_store.upload_object(remote_file, local_file)
|
227 |
+
|
228 |
+
if args.hf_repo_for_upload is not None:
|
229 |
+
from huggingface_hub import HfApi
|
230 |
+
api = HfApi()
|
231 |
+
|
232 |
+
print(
|
233 |
+
f'Uploading {args.hf_output_path} to HuggingFace Hub at {args.hf_repo_for_upload}'
|
234 |
+
)
|
235 |
+
api.create_repo(repo_id=args.hf_repo_for_upload,
|
236 |
+
use_auth_token=True,
|
237 |
+
repo_type='model',
|
238 |
+
private=True,
|
239 |
+
exist_ok=True)
|
240 |
+
print('Repo created.')
|
241 |
+
|
242 |
+
# ignore the full checkpoint file if we now have sharded checkpoint files
|
243 |
+
ignore_patterns = []
|
244 |
+
if any(
|
245 |
+
f.startswith('pytorch_model-00001')
|
246 |
+
for f in os.listdir(args.hf_output_path)):
|
247 |
+
ignore_patterns.append('pytorch_model.bin')
|
248 |
+
|
249 |
+
api.upload_folder(folder_path=args.hf_output_path,
|
250 |
+
repo_id=args.hf_repo_for_upload,
|
251 |
+
use_auth_token=True,
|
252 |
+
repo_type='model',
|
253 |
+
ignore_patterns=ignore_patterns)
|
254 |
+
print('Folder uploaded.')
|
255 |
+
|
256 |
+
if args.test_uploaded_model:
|
257 |
+
print('Testing uploaded model...')
|
258 |
+
hub_model = transformers.AutoModelForCausalLM.from_pretrained(
|
259 |
+
args.hf_repo_for_upload,
|
260 |
+
trust_remote_code=True,
|
261 |
+
use_auth_token=True,
|
262 |
+
torch_dtype=dtype)
|
263 |
+
hub_tokenizer = transformers.AutoTokenizer.from_pretrained(
|
264 |
+
args.hf_repo_for_upload,
|
265 |
+
trust_remote_code=True,
|
266 |
+
use_auth_token=True)
|
267 |
+
|
268 |
+
assert sum(p.numel() for p in hub_model.parameters()) == sum(
|
269 |
+
p.numel() for p in loaded_hf_model.parameters())
|
270 |
+
assert all(
|
271 |
+
str(type(module1)).split('.')[-2:] == str(type(module2)).split(
|
272 |
+
'.')[-2:] for module1, module2 in zip(
|
273 |
+
hub_model.modules(), loaded_hf_model.modules()))
|
274 |
+
|
275 |
+
assert next(
|
276 |
+
hub_model.parameters()
|
277 |
+
).dtype == dtype, f'Expected model dtype to be {dtype}, but got {next(hub_model.parameters()).dtype}'
|
278 |
+
print(
|
279 |
+
hub_tokenizer.batch_decode(
|
280 |
+
hub_model.generate(hub_tokenizer(
|
281 |
+
'MosaicML is', return_tensors='pt').input_ids,
|
282 |
+
max_new_tokens=10)))
|
283 |
+
|
284 |
+
print(
|
285 |
+
'Composer checkpoint successfully converted to HuggingFace checkpoint format.'
|
286 |
+
)
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == '__main__':
|
290 |
+
convert_composer_to_hf(parse_args())
|
Perceptrix/finetune/build/lib/inference/convert_hf_mpt_to_ft.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
"""Convert MPT model checkpoint to FT format.
|
19 |
+
|
20 |
+
It's a modified version of
|
21 |
+
https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/utils/huggingface_gpt_convert.py
|
22 |
+
"""
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
import configparser
|
26 |
+
import os
|
27 |
+
|
28 |
+
import transformers
|
29 |
+
|
30 |
+
from llmfoundry.utils import convert_and_save_ft_weights
|
31 |
+
|
32 |
+
|
33 |
+
def convert_mpt_to_ft(model_name_or_path: str,
|
34 |
+
output_dir: str,
|
35 |
+
infer_gpu_num: int = 1,
|
36 |
+
weight_data_type: str = 'fp32',
|
37 |
+
force: bool = False) -> None:
|
38 |
+
"""Convert an MPT checkpoint to a FasterTransformer compatible format.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
model_name_or_path (str): The HF hub name of the model (e.g., mosaicml/mpt-7b) or the path of a directory
|
42 |
+
containing an MPT checkpoint in a local dir.
|
43 |
+
output_dir (str): Path of the directory to save the checkpoint in FT format. The directory must not already exist.
|
44 |
+
infer_gpu_num (int): The number of gpus you are planning to use for inference.
|
45 |
+
weight_data_type (str): Data type of the weights in the input checkpoint.
|
46 |
+
force (bool): force conversion even with unsupported features in FT.
|
47 |
+
"""
|
48 |
+
save_dir = os.path.join(output_dir, f'{infer_gpu_num}-gpu')
|
49 |
+
|
50 |
+
if (os.path.exists(save_dir) == False):
|
51 |
+
os.makedirs(save_dir)
|
52 |
+
else:
|
53 |
+
raise RuntimeError(f'Output path {save_dir} already exists!')
|
54 |
+
|
55 |
+
# do conversion on cpu
|
56 |
+
torch_device = 'cpu'
|
57 |
+
|
58 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
59 |
+
model_name_or_path, trust_remote_code=True).to(torch_device)
|
60 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
61 |
+
model_name_or_path, trust_remote_code=True)
|
62 |
+
|
63 |
+
hf_config = vars(model.config)
|
64 |
+
|
65 |
+
config = configparser.ConfigParser()
|
66 |
+
config['gpt'] = {}
|
67 |
+
try:
|
68 |
+
config['gpt']['model_name'] = 'mpt' if hf_config[
|
69 |
+
'_name_or_path'] == '' else hf_config['_name_or_path']
|
70 |
+
config['gpt']['head_num'] = str(hf_config['n_heads'])
|
71 |
+
n_embd = hf_config['d_model']
|
72 |
+
config['gpt']['size_per_head'] = str(n_embd // hf_config['n_heads'])
|
73 |
+
config['gpt']['inter_size'] = str(n_embd * hf_config['expansion_ratio'])
|
74 |
+
config['gpt']['max_pos_seq_len'] = str(hf_config['max_seq_len'])
|
75 |
+
config['gpt']['num_layer'] = str(hf_config['n_layers'])
|
76 |
+
config['gpt']['vocab_size'] = str(hf_config['vocab_size'])
|
77 |
+
config['gpt']['start_id'] = str(
|
78 |
+
hf_config['bos_token_id']
|
79 |
+
) if hf_config['bos_token_id'] != None else str(tokenizer.bos_token_id)
|
80 |
+
config['gpt']['end_id'] = str(
|
81 |
+
hf_config['eos_token_id']
|
82 |
+
) if hf_config['eos_token_id'] != None else str(tokenizer.eos_token_id)
|
83 |
+
config['gpt']['weight_data_type'] = weight_data_type
|
84 |
+
config['gpt']['tensor_para_size'] = str(infer_gpu_num)
|
85 |
+
# nn.LayerNorm default eps is 1e-5
|
86 |
+
config['gpt']['layernorm_eps'] = str(1e-5)
|
87 |
+
if hf_config['attn_config']['alibi']:
|
88 |
+
config['gpt']['has_positional_encoding'] = str(False)
|
89 |
+
config['gpt']['use_attention_linear_bias'] = str(True)
|
90 |
+
if hf_config['attn_config']['clip_qkv'] and not force:
|
91 |
+
raise RuntimeError(
|
92 |
+
'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
|
93 |
+
)
|
94 |
+
if hf_config['attn_config']['qk_ln'] and not force:
|
95 |
+
raise RuntimeError(
|
96 |
+
'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
|
97 |
+
)
|
98 |
+
|
99 |
+
with open(os.path.join(save_dir, 'config.ini'), 'w') as configfile:
|
100 |
+
config.write(configfile)
|
101 |
+
except:
|
102 |
+
print(f'Failed to save the config in config.ini.')
|
103 |
+
raise
|
104 |
+
|
105 |
+
named_params_dict = {
|
106 |
+
name: param for name, param in model.named_parameters()
|
107 |
+
}
|
108 |
+
convert_and_save_ft_weights(named_params=named_params_dict,
|
109 |
+
config=hf_config,
|
110 |
+
infer_gpu_num=infer_gpu_num,
|
111 |
+
weight_data_type=weight_data_type,
|
112 |
+
save_dir=save_dir)
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
parser = argparse.ArgumentParser(
|
117 |
+
formatter_class=argparse.RawTextHelpFormatter)
|
118 |
+
parser.add_argument('--save_dir',
|
119 |
+
'-o',
|
120 |
+
type=str,
|
121 |
+
help='Directory to save converted checkpoint in',
|
122 |
+
required=True)
|
123 |
+
parser.add_argument(
|
124 |
+
'--name_or_dir',
|
125 |
+
'-i',
|
126 |
+
type=str,
|
127 |
+
help=
|
128 |
+
'HF hub Model name (e.g., mosaicml/mpt-7b) or local dir path to load checkpoint from',
|
129 |
+
required=True)
|
130 |
+
parser.add_argument('--infer_gpu_num',
|
131 |
+
'-i_g',
|
132 |
+
type=int,
|
133 |
+
help='How many gpus for inference?',
|
134 |
+
required=True)
|
135 |
+
parser.add_argument(
|
136 |
+
'--force',
|
137 |
+
action='store_true',
|
138 |
+
help=
|
139 |
+
'Force conversion to FT even if some features may not work as expected in FT'
|
140 |
+
)
|
141 |
+
parser.add_argument('--weight_data_type',
|
142 |
+
type=str,
|
143 |
+
help='Data type of weights in the input checkpoint',
|
144 |
+
default='fp32',
|
145 |
+
choices=['fp32', 'fp16'])
|
146 |
+
|
147 |
+
args = parser.parse_args()
|
148 |
+
print('\n=============== Argument ===============')
|
149 |
+
for key in vars(args):
|
150 |
+
print('{}: {}'.format(key, vars(args)[key]))
|
151 |
+
print('========================================')
|
152 |
+
|
153 |
+
convert_mpt_to_ft(args.name_or_dir, args.save_dir, args.infer_gpu_num,
|
154 |
+
args.weight_data_type, args.force)
|
Perceptrix/finetune/build/lib/inference/convert_hf_to_onnx.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Basic HuggingFace -> ONNX export script.
|
5 |
+
|
6 |
+
This scripts show a basic HuggingFace -> ONNX export workflow. This works for a MPT model
|
7 |
+
that has been saved using `MPT.save_pretrained`. For more details and examples
|
8 |
+
of exporting and working with HuggingFace models with ONNX, see https://huggingface.co/docs/transformers/serialization#export-to-onnx.
|
9 |
+
|
10 |
+
Example usage:
|
11 |
+
|
12 |
+
1) Local export
|
13 |
+
|
14 |
+
python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder
|
15 |
+
|
16 |
+
2) Remote export
|
17 |
+
|
18 |
+
python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder s3://bucket/remote/folder
|
19 |
+
|
20 |
+
3) Verify the exported model
|
21 |
+
|
22 |
+
python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder --verify_export
|
23 |
+
|
24 |
+
4) Change the batch size or max sequence length
|
25 |
+
|
26 |
+
python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder --export_batch_size 1 --max_seq_len 32000
|
27 |
+
"""
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
import os
|
31 |
+
from argparse import ArgumentTypeError
|
32 |
+
from pathlib import Path
|
33 |
+
from typing import Any, Dict, Optional, Union
|
34 |
+
|
35 |
+
import torch
|
36 |
+
from composer.utils import (maybe_create_object_store_from_uri, parse_uri,
|
37 |
+
reproducibility)
|
38 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
39 |
+
|
40 |
+
|
41 |
+
def str2bool(v: Union[str, bool]):
|
42 |
+
if isinstance(v, bool):
|
43 |
+
return v
|
44 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
45 |
+
return True
|
46 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
47 |
+
return False
|
48 |
+
else:
|
49 |
+
raise ArgumentTypeError('Boolean value expected.')
|
50 |
+
|
51 |
+
|
52 |
+
def str_or_bool(v: Union[str, bool]):
|
53 |
+
if isinstance(v, bool):
|
54 |
+
return v
|
55 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
56 |
+
return True
|
57 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
58 |
+
return False
|
59 |
+
else:
|
60 |
+
return v
|
61 |
+
|
62 |
+
|
63 |
+
def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int):
|
64 |
+
# generate input batch of random data
|
65 |
+
batch = {
|
66 |
+
'input_ids':
|
67 |
+
torch.randint(
|
68 |
+
low=0,
|
69 |
+
high=vocab_size,
|
70 |
+
size=(batch_size, max_seq_len),
|
71 |
+
dtype=torch.int64,
|
72 |
+
),
|
73 |
+
'attention_mask':
|
74 |
+
torch.ones(size=(batch_size, max_seq_len), dtype=torch.bool)
|
75 |
+
}
|
76 |
+
return batch
|
77 |
+
|
78 |
+
|
79 |
+
def export_to_onnx(
|
80 |
+
pretrained_model_name_or_path: str,
|
81 |
+
output_folder: str,
|
82 |
+
export_batch_size: int,
|
83 |
+
max_seq_len: Optional[int],
|
84 |
+
verify_export: bool,
|
85 |
+
from_pretrained_kwargs: Dict[str, Any],
|
86 |
+
):
|
87 |
+
reproducibility.seed_all(42)
|
88 |
+
save_object_store = maybe_create_object_store_from_uri(output_folder)
|
89 |
+
_, _, parsed_save_path = parse_uri(output_folder)
|
90 |
+
|
91 |
+
print('Loading HF config/model/tokenizer...')
|
92 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
|
93 |
+
**from_pretrained_kwargs)
|
94 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
|
95 |
+
**from_pretrained_kwargs)
|
96 |
+
|
97 |
+
# specifically for MPT, switch to the torch version of attention for ONNX export
|
98 |
+
if hasattr(config, 'attn_config'):
|
99 |
+
config.attn_config['attn_impl'] = 'torch'
|
100 |
+
|
101 |
+
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,
|
102 |
+
config=config,
|
103 |
+
**from_pretrained_kwargs)
|
104 |
+
model.eval()
|
105 |
+
|
106 |
+
if max_seq_len is None and not hasattr(model.config, 'max_seq_len'):
|
107 |
+
raise ValueError(
|
108 |
+
'max_seq_len must be specified in either the model config or as an argument to this function.'
|
109 |
+
)
|
110 |
+
elif max_seq_len is None:
|
111 |
+
max_seq_len = model.config.max_seq_len
|
112 |
+
|
113 |
+
assert isinstance(max_seq_len, int) # pyright
|
114 |
+
|
115 |
+
print('Creating random batch...')
|
116 |
+
sample_input = gen_random_batch(
|
117 |
+
export_batch_size,
|
118 |
+
len(tokenizer),
|
119 |
+
max_seq_len,
|
120 |
+
)
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
model(**sample_input)
|
124 |
+
|
125 |
+
output_file = Path(parsed_save_path) / 'model.onnx'
|
126 |
+
os.makedirs(parsed_save_path, exist_ok=True)
|
127 |
+
print('Exporting the model with ONNX...')
|
128 |
+
torch.onnx.export(
|
129 |
+
model,
|
130 |
+
(sample_input,),
|
131 |
+
str(output_file),
|
132 |
+
input_names=['input_ids', 'attention_mask'],
|
133 |
+
output_names=['output'],
|
134 |
+
opset_version=16,
|
135 |
+
)
|
136 |
+
|
137 |
+
if verify_export:
|
138 |
+
with torch.no_grad():
|
139 |
+
orig_out = model(**sample_input)
|
140 |
+
|
141 |
+
import onnx
|
142 |
+
import onnx.checker
|
143 |
+
import onnxruntime as ort
|
144 |
+
|
145 |
+
_ = onnx.load(str(output_file))
|
146 |
+
|
147 |
+
onnx.checker.check_model(str(output_file))
|
148 |
+
|
149 |
+
ort_session = ort.InferenceSession(str(output_file))
|
150 |
+
|
151 |
+
for key, value in sample_input.items():
|
152 |
+
sample_input[key] = value.cpu().numpy()
|
153 |
+
|
154 |
+
loaded_model_out = ort_session.run(None, sample_input)
|
155 |
+
|
156 |
+
torch.testing.assert_close(
|
157 |
+
orig_out.logits.detach().numpy(),
|
158 |
+
loaded_model_out[0],
|
159 |
+
rtol=1e-2,
|
160 |
+
atol=1e-2,
|
161 |
+
msg=f'output mismatch between the orig and onnx exported model',
|
162 |
+
)
|
163 |
+
print('exported model ouptut matches with unexported model!!')
|
164 |
+
|
165 |
+
if save_object_store is not None:
|
166 |
+
print('Uploading files to object storage...')
|
167 |
+
for filename in os.listdir(parsed_save_path):
|
168 |
+
full_path = str(Path(parsed_save_path) / filename)
|
169 |
+
save_object_store.upload_object(full_path, full_path)
|
170 |
+
|
171 |
+
|
172 |
+
def parse_args():
|
173 |
+
parser = argparse.ArgumentParser(description='Convert HF model to ONNX',)
|
174 |
+
parser.add_argument(
|
175 |
+
'--pretrained_model_name_or_path',
|
176 |
+
type=str,
|
177 |
+
required=True,
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
'--output_folder',
|
181 |
+
type=str,
|
182 |
+
required=True,
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
'--export_batch_size',
|
186 |
+
type=int,
|
187 |
+
default=8,
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
'--max_seq_len',
|
191 |
+
type=int,
|
192 |
+
default=None,
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
'--verify_export',
|
196 |
+
action='store_true',
|
197 |
+
)
|
198 |
+
parser.add_argument('--trust_remote_code',
|
199 |
+
type=str2bool,
|
200 |
+
nargs='?',
|
201 |
+
const=True,
|
202 |
+
default=True)
|
203 |
+
parser.add_argument('--use_auth_token',
|
204 |
+
type=str_or_bool,
|
205 |
+
nargs='?',
|
206 |
+
const=True,
|
207 |
+
default=None)
|
208 |
+
parser.add_argument('--revision', type=str, default=None)
|
209 |
+
return parser.parse_args()
|
210 |
+
|
211 |
+
|
212 |
+
def main(args: argparse.Namespace):
|
213 |
+
from_pretrained_kwargs = {
|
214 |
+
'use_auth_token': args.use_auth_token,
|
215 |
+
'trust_remote_code': args.trust_remote_code,
|
216 |
+
'revision': args.revision,
|
217 |
+
}
|
218 |
+
|
219 |
+
export_to_onnx(
|
220 |
+
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
|
221 |
+
output_folder=args.output_folder,
|
222 |
+
export_batch_size=args.export_batch_size,
|
223 |
+
max_seq_len=args.max_seq_len,
|
224 |
+
verify_export=args.verify_export,
|
225 |
+
from_pretrained_kwargs=from_pretrained_kwargs)
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == '__main__':
|
229 |
+
main(parse_args())
|
Perceptrix/finetune/build/lib/inference/hf_chat.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import time
|
5 |
+
import warnings
|
6 |
+
from argparse import ArgumentParser, ArgumentTypeError, Namespace
|
7 |
+
from contextlib import nullcontext
|
8 |
+
from typing import Any, Dict, List, Optional, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
12 |
+
PreTrainedModel, PreTrainedTokenizerBase,
|
13 |
+
StoppingCriteria, StoppingCriteriaList, TextStreamer)
|
14 |
+
|
15 |
+
|
16 |
+
class ChatFormatter:
|
17 |
+
"""A class for formatting the chat history.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
system: The system prompt. If None, a default ChatML-formatted prompt is used.
|
21 |
+
user: The user prompt. If None, a default ChatML value is used.
|
22 |
+
assistant: The assistant prompt. If None, a default ChatML value is used.
|
23 |
+
|
24 |
+
Attributes:
|
25 |
+
system: The system prompt.
|
26 |
+
user: The user prompt.
|
27 |
+
assistant: The assistant prompt.
|
28 |
+
response_prefix: The response prefix (anything before {} in the assistant format string)
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, system: str, user: str, assistant: str) -> None:
|
32 |
+
self.system = system if system else '<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n'
|
33 |
+
self.user = user if user else '<|im_start|>user\n{}<|im_end|>\n'
|
34 |
+
self.assistant = assistant if assistant else '<|im_start|>assistant\n{}<|im_end|>\n'
|
35 |
+
self.response_prefix = self.assistant.split('{}')[0]
|
36 |
+
|
37 |
+
|
38 |
+
class Conversation:
|
39 |
+
"""A class for interacting with a chat-tuned LLM.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
model: The model to use for inference.
|
43 |
+
tokenizer: The tokenizer to use for inference.
|
44 |
+
chat_format: The chat format to use for the conversation.
|
45 |
+
generate_kwargs: The keyword arguments to pass to `model.generate`.
|
46 |
+
stop_tokens: The tokens to stop generation on.
|
47 |
+
|
48 |
+
Attributes:
|
49 |
+
model: The model to use for inference.
|
50 |
+
tokenizer: The tokenizer to use for inference.
|
51 |
+
chat_format: The chat format to use for the conversation.
|
52 |
+
streamer: The streamer to use for inference.
|
53 |
+
generate_kwargs: The keyword arguments to pass to `model.generate`.
|
54 |
+
history: The conversation history.
|
55 |
+
cli_instructions: The instructions to display to the user.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
model: PreTrainedModel,
|
60 |
+
tokenizer: PreTrainedTokenizerBase,
|
61 |
+
chat_format: ChatFormatter,
|
62 |
+
generate_kwargs: Dict[str, Any],
|
63 |
+
stop_tokens: Optional[List[str]] = None) -> None:
|
64 |
+
if stop_tokens is None:
|
65 |
+
stop_tokens = ['<|endoftext|>', '<|im_end|>']
|
66 |
+
self.model = model
|
67 |
+
self.tokenizer = tokenizer
|
68 |
+
self.chat_format = chat_format
|
69 |
+
|
70 |
+
stop_token_ids = self.tokenizer.convert_tokens_to_ids(stop_tokens)
|
71 |
+
if len(stop_token_ids) != len(stop_tokens):
|
72 |
+
warnings.warn(
|
73 |
+
f'Not all stop tokens were found in the tokenizer vocabulary: {stop_tokens}\n'
|
74 |
+
+ 'Generation may stop or continue unexpectedly.')
|
75 |
+
|
76 |
+
class StopOnTokens(StoppingCriteria):
|
77 |
+
|
78 |
+
def __call__(self, input_ids: torch.LongTensor,
|
79 |
+
scores: torch.FloatTensor, **kwargs: Any) -> bool:
|
80 |
+
del kwargs # unused
|
81 |
+
for stop_id in stop_token_ids:
|
82 |
+
if input_ids[0][-1] == stop_id:
|
83 |
+
return True
|
84 |
+
return False
|
85 |
+
|
86 |
+
self.streamer = TextStreamer(tokenizer,
|
87 |
+
skip_prompt=True,
|
88 |
+
skip_special_tokens=True)
|
89 |
+
self.generate_kwargs = {
|
90 |
+
**generate_kwargs,
|
91 |
+
'stopping_criteria':
|
92 |
+
StoppingCriteriaList([StopOnTokens()]),
|
93 |
+
'streamer':
|
94 |
+
self.streamer,
|
95 |
+
}
|
96 |
+
self.history = []
|
97 |
+
self.cli_instructions = (
|
98 |
+
'Enter your message below.\n- Hit return twice to send input to the model\n'
|
99 |
+
+
|
100 |
+
"- Type 'clear' to restart the conversation\n- Type 'history' to see the conversation\n"
|
101 |
+
+
|
102 |
+
"- Type 'quit' to end\n- Type 'system' to change the system prompt\n"
|
103 |
+
)
|
104 |
+
|
105 |
+
def _history_as_formatted_str(self) -> str:
|
106 |
+
text = self.chat_format.system + ''.join([
|
107 |
+
'\n'.join([
|
108 |
+
self.chat_format.user.format(item[0]),
|
109 |
+
self.chat_format.assistant.format(item[1]),
|
110 |
+
]) for item in self.history[:-1]
|
111 |
+
])
|
112 |
+
text += self.chat_format.user.format(self.history[-1][0])
|
113 |
+
text += self.chat_format.response_prefix
|
114 |
+
return text
|
115 |
+
|
116 |
+
def turn(self, user_inp: str) -> None:
|
117 |
+
self.history.append([user_inp, ''])
|
118 |
+
conversation = self._history_as_formatted_str()
|
119 |
+
input_ids = self.tokenizer(conversation, return_tensors='pt').input_ids
|
120 |
+
input_ids = input_ids.to(self.model.device)
|
121 |
+
# also stream to stdout
|
122 |
+
maybe_synchronize()
|
123 |
+
start = time.time()
|
124 |
+
print('Assistant:')
|
125 |
+
gkwargs = {**self.generate_kwargs, 'input_ids': input_ids}
|
126 |
+
# this will stream to stdout, but we need to keep track of the output_ids for saving history
|
127 |
+
output_ids = self.model.generate(**gkwargs)
|
128 |
+
maybe_synchronize()
|
129 |
+
end = time.time()
|
130 |
+
print(f'took {end - start:.2f} seconds')
|
131 |
+
new_tokens = output_ids[0, len(input_ids[0]):]
|
132 |
+
assistant_response = self.tokenizer.decode(new_tokens,
|
133 |
+
skip_special_tokens=True)
|
134 |
+
self.history[-1][-1] = assistant_response
|
135 |
+
|
136 |
+
def __call__(self) -> None:
|
137 |
+
print(self.cli_instructions)
|
138 |
+
while True:
|
139 |
+
print('User:')
|
140 |
+
user_inp_lines = []
|
141 |
+
while True:
|
142 |
+
line = input()
|
143 |
+
if line.strip() == '':
|
144 |
+
break
|
145 |
+
user_inp_lines.append(line)
|
146 |
+
user_inp = '\n'.join(user_inp_lines)
|
147 |
+
if user_inp.lower() == 'quit':
|
148 |
+
break
|
149 |
+
elif user_inp.lower() == 'clear':
|
150 |
+
self.history = []
|
151 |
+
continue
|
152 |
+
elif user_inp == 'history':
|
153 |
+
print(f'history: {self.history}')
|
154 |
+
continue
|
155 |
+
elif user_inp == 'history_fmt':
|
156 |
+
print(f'history: {self._history_as_formatted_str()}')
|
157 |
+
continue
|
158 |
+
elif user_inp == 'system':
|
159 |
+
print('Enter a new system prompt:')
|
160 |
+
new_system = input()
|
161 |
+
sys = f'<|im_start|>system\n{new_system.strip()}.<|im_end|>\n'
|
162 |
+
self.chat_format.system = sys
|
163 |
+
continue
|
164 |
+
self.turn(user_inp)
|
165 |
+
|
166 |
+
|
167 |
+
def get_dtype(dtype: str):
|
168 |
+
if dtype == 'fp32':
|
169 |
+
return torch.float32
|
170 |
+
elif dtype == 'fp16':
|
171 |
+
return torch.float16
|
172 |
+
elif dtype == 'bf16':
|
173 |
+
return torch.bfloat16
|
174 |
+
else:
|
175 |
+
raise NotImplementedError(
|
176 |
+
f'dtype {dtype} is not supported. ' +
|
177 |
+
'We only support fp32, fp16, and bf16 currently')
|
178 |
+
|
179 |
+
|
180 |
+
def str2bool(v: Union[str, bool]):
|
181 |
+
if isinstance(v, bool):
|
182 |
+
return v
|
183 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
184 |
+
return True
|
185 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
186 |
+
return False
|
187 |
+
else:
|
188 |
+
raise ArgumentTypeError('Boolean value expected.')
|
189 |
+
|
190 |
+
|
191 |
+
def str_or_bool(v: Union[str, bool]):
|
192 |
+
if isinstance(v, bool):
|
193 |
+
return v
|
194 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
195 |
+
return True
|
196 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
197 |
+
return False
|
198 |
+
else:
|
199 |
+
return v
|
200 |
+
|
201 |
+
|
202 |
+
def parse_args() -> Namespace:
|
203 |
+
"""Parse commandline arguments."""
|
204 |
+
parser = ArgumentParser(
|
205 |
+
description='Load a HF CausalLM Model and use it to generate text.')
|
206 |
+
parser.add_argument('-n', '--name_or_path', type=str, required=True)
|
207 |
+
parser.add_argument('--max_new_tokens', type=int, default=512)
|
208 |
+
parser.add_argument('--max_seq_len', type=int, default=None)
|
209 |
+
parser.add_argument('--temperature', type=float, default=1.0)
|
210 |
+
parser.add_argument('--top_k', type=int, default=50)
|
211 |
+
parser.add_argument('--top_p', type=float, default=1.0)
|
212 |
+
parser.add_argument('--do_sample',
|
213 |
+
type=str2bool,
|
214 |
+
nargs='?',
|
215 |
+
const=True,
|
216 |
+
default=True)
|
217 |
+
parser.add_argument('--use_cache',
|
218 |
+
type=str2bool,
|
219 |
+
nargs='?',
|
220 |
+
const=True,
|
221 |
+
default=True)
|
222 |
+
parser.add_argument('--eos_token_id', type=str, default=None)
|
223 |
+
parser.add_argument('--pad_token_id', type=str, default=None)
|
224 |
+
parser.add_argument('--model_dtype',
|
225 |
+
type=str,
|
226 |
+
choices=['fp32', 'fp16', 'bf16'],
|
227 |
+
default=None)
|
228 |
+
parser.add_argument('--autocast_dtype',
|
229 |
+
type=str,
|
230 |
+
choices=['fp32', 'fp16', 'bf16'],
|
231 |
+
default=None)
|
232 |
+
parser.add_argument('--warmup',
|
233 |
+
type=str2bool,
|
234 |
+
nargs='?',
|
235 |
+
const=True,
|
236 |
+
default=True)
|
237 |
+
parser.add_argument('--trust_remote_code',
|
238 |
+
type=str2bool,
|
239 |
+
nargs='?',
|
240 |
+
const=True,
|
241 |
+
default=True)
|
242 |
+
parser.add_argument('--use_auth_token',
|
243 |
+
type=str_or_bool,
|
244 |
+
nargs='?',
|
245 |
+
const=True,
|
246 |
+
default=None)
|
247 |
+
parser.add_argument('--revision', type=str, default=None)
|
248 |
+
parser.add_argument('--device', type=str, default=None)
|
249 |
+
parser.add_argument('--device_map', type=str, default=None)
|
250 |
+
parser.add_argument('--attn_impl', type=str, default=None)
|
251 |
+
parser.add_argument('--seed', type=int, default=42)
|
252 |
+
parser.add_argument('--system_prompt', type=str, default=None)
|
253 |
+
parser.add_argument('--user_msg_fmt', type=str, default=None)
|
254 |
+
parser.add_argument('--assistant_msg_fmt', type=str, default=None)
|
255 |
+
parser.add_argument(
|
256 |
+
'--stop_tokens',
|
257 |
+
type=str,
|
258 |
+
default='<|endoftext|> <|im_end|>',
|
259 |
+
help='A string of tokens to stop generation on; will be split on spaces.'
|
260 |
+
)
|
261 |
+
return parser.parse_args()
|
262 |
+
|
263 |
+
|
264 |
+
def maybe_synchronize():
|
265 |
+
if torch.cuda.is_available():
|
266 |
+
torch.cuda.synchronize()
|
267 |
+
|
268 |
+
|
269 |
+
def main(args: Namespace) -> None:
|
270 |
+
# Set device or device_map
|
271 |
+
if args.device and args.device_map:
|
272 |
+
raise ValueError('You can only set one of `device` and `device_map`.')
|
273 |
+
if args.device is not None:
|
274 |
+
device = args.device
|
275 |
+
device_map = None
|
276 |
+
else:
|
277 |
+
device = None
|
278 |
+
device_map = args.device_map or 'auto'
|
279 |
+
print(f'Using {device=} and {device_map=}')
|
280 |
+
|
281 |
+
# Set model_dtype
|
282 |
+
if args.model_dtype is not None:
|
283 |
+
model_dtype = get_dtype(args.model_dtype)
|
284 |
+
else:
|
285 |
+
model_dtype = torch.float32
|
286 |
+
print(f'Using {model_dtype=}')
|
287 |
+
|
288 |
+
# Grab config first
|
289 |
+
print(f'Loading HF Config...')
|
290 |
+
from_pretrained_kwargs = {
|
291 |
+
'use_auth_token': args.use_auth_token,
|
292 |
+
'trust_remote_code': args.trust_remote_code,
|
293 |
+
'revision': args.revision,
|
294 |
+
}
|
295 |
+
try:
|
296 |
+
config = AutoConfig.from_pretrained(args.name_or_path,
|
297 |
+
**from_pretrained_kwargs)
|
298 |
+
if args.attn_impl is not None and hasattr(config, 'attn_config'):
|
299 |
+
config.attn_config['attn_impl'] = args.attn_impl
|
300 |
+
if hasattr(config, 'init_device') and device is not None:
|
301 |
+
config.init_device = device
|
302 |
+
if args.max_seq_len is not None and hasattr(config, 'max_seq_len'):
|
303 |
+
config.max_seq_len = args.max_seq_len
|
304 |
+
|
305 |
+
except Exception as e:
|
306 |
+
raise RuntimeError(
|
307 |
+
'If you are having auth problems, try logging in via `huggingface-cli login` '
|
308 |
+
+
|
309 |
+
'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... '
|
310 |
+
+
|
311 |
+
'using your access token from https://huggingface.co/settings/tokens.'
|
312 |
+
) from e
|
313 |
+
|
314 |
+
# Load HF Model
|
315 |
+
print(f'Loading HF model with dtype={model_dtype}...')
|
316 |
+
try:
|
317 |
+
model = AutoModelForCausalLM.from_pretrained(args.name_or_path,
|
318 |
+
config=config,
|
319 |
+
torch_dtype=model_dtype,
|
320 |
+
device_map=device_map,
|
321 |
+
**from_pretrained_kwargs)
|
322 |
+
model.eval()
|
323 |
+
print(f'n_params={sum(p.numel() for p in model.parameters())}')
|
324 |
+
if device is not None:
|
325 |
+
print(f'Placing model on {device=}...')
|
326 |
+
model.to(device)
|
327 |
+
except Exception as e:
|
328 |
+
raise RuntimeError(
|
329 |
+
'Unable to load HF model. ' +
|
330 |
+
'If you are having auth problems, try logging in via `huggingface-cli login` '
|
331 |
+
+
|
332 |
+
'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... '
|
333 |
+
+
|
334 |
+
'using your access token from https://huggingface.co/settings/tokens.'
|
335 |
+
) from e
|
336 |
+
|
337 |
+
print('\nLoading HF tokenizer...')
|
338 |
+
tokenizer = AutoTokenizer.from_pretrained(args.name_or_path,
|
339 |
+
**from_pretrained_kwargs)
|
340 |
+
if tokenizer.pad_token_id is None:
|
341 |
+
warnings.warn(
|
342 |
+
'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.'
|
343 |
+
)
|
344 |
+
tokenizer.pad_token = tokenizer.eos_token
|
345 |
+
tokenizer.padding_side = 'left'
|
346 |
+
|
347 |
+
generate_kwargs = {
|
348 |
+
'max_new_tokens': args.max_new_tokens,
|
349 |
+
'temperature': args.temperature,
|
350 |
+
'top_p': args.top_p,
|
351 |
+
'top_k': args.top_k,
|
352 |
+
'use_cache': args.use_cache,
|
353 |
+
'do_sample': args.do_sample,
|
354 |
+
'eos_token_id': args.eos_token_id or tokenizer.eos_token_id,
|
355 |
+
'pad_token_id': args.pad_token_id or tokenizer.eos_token_id,
|
356 |
+
}
|
357 |
+
# Autocast
|
358 |
+
if args.autocast_dtype is not None:
|
359 |
+
autocast_dtype = get_dtype(args.autocast_dtype)
|
360 |
+
autocast_context = torch.autocast(model.device.type, autocast_dtype)
|
361 |
+
print(f'Using autocast with dtype={autocast_dtype}...')
|
362 |
+
else:
|
363 |
+
autocast_context = nullcontext()
|
364 |
+
print('NOT using autocast...')
|
365 |
+
|
366 |
+
chat_format = ChatFormatter(system=args.system_prompt,
|
367 |
+
user=args.user_msg_fmt,
|
368 |
+
assistant=args.assistant_msg_fmt)
|
369 |
+
|
370 |
+
conversation = Conversation(model=model,
|
371 |
+
tokenizer=tokenizer,
|
372 |
+
chat_format=chat_format,
|
373 |
+
generate_kwargs=generate_kwargs,
|
374 |
+
stop_tokens=args.stop_tokens.split())
|
375 |
+
|
376 |
+
# Warmup
|
377 |
+
if args.warmup:
|
378 |
+
print('Warming up...')
|
379 |
+
with autocast_context:
|
380 |
+
conversation.turn('Write a welcome message to the user.')
|
381 |
+
conversation.history = []
|
382 |
+
|
383 |
+
print('Starting conversation...')
|
384 |
+
with autocast_context:
|
385 |
+
conversation()
|
386 |
+
|
387 |
+
|
388 |
+
if __name__ == '__main__':
|
389 |
+
main(parse_args())
|
Perceptrix/finetune/build/lib/inference/hf_generate.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
import itertools
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
from argparse import ArgumentParser, ArgumentTypeError, Namespace
|
9 |
+
from contextlib import nullcontext
|
10 |
+
from typing import Dict, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
15 |
+
|
16 |
+
|
17 |
+
def get_dtype(dtype: str):
|
18 |
+
if dtype == 'fp32':
|
19 |
+
return torch.float32
|
20 |
+
elif dtype == 'fp16':
|
21 |
+
return torch.float16
|
22 |
+
elif dtype == 'bf16':
|
23 |
+
return torch.bfloat16
|
24 |
+
else:
|
25 |
+
raise NotImplementedError(
|
26 |
+
f'dtype {dtype} is not supported. ' +\
|
27 |
+
f'We only support fp32, fp16, and bf16 currently')
|
28 |
+
|
29 |
+
|
30 |
+
def str2bool(v: Union[str, bool]):
|
31 |
+
if isinstance(v, bool):
|
32 |
+
return v
|
33 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
34 |
+
return True
|
35 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
36 |
+
return False
|
37 |
+
else:
|
38 |
+
raise ArgumentTypeError('Boolean value expected.')
|
39 |
+
|
40 |
+
|
41 |
+
def str_or_bool(v: Union[str, bool]):
|
42 |
+
if isinstance(v, bool):
|
43 |
+
return v
|
44 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
45 |
+
return True
|
46 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
47 |
+
return False
|
48 |
+
else:
|
49 |
+
return v
|
50 |
+
|
51 |
+
|
52 |
+
def parse_args() -> Namespace:
|
53 |
+
"""Parse commandline arguments."""
|
54 |
+
parser = ArgumentParser(
|
55 |
+
description='Load a HF CausalLM Model and use it to generate text.')
|
56 |
+
parser.add_argument('-n', '--name_or_path', type=str, required=True)
|
57 |
+
parser.add_argument(
|
58 |
+
'-p',
|
59 |
+
'--prompts',
|
60 |
+
nargs='+',
|
61 |
+
default=[
|
62 |
+
'My name is',
|
63 |
+
'This is an explanation of deep learning to a five year old. Deep learning is',
|
64 |
+
],
|
65 |
+
help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\
|
66 |
+
'prompt contained in a txt file.'
|
67 |
+
)
|
68 |
+
parser.add_argument('--max_seq_len', type=int, default=None)
|
69 |
+
parser.add_argument('--max_new_tokens', type=int, default=100)
|
70 |
+
parser.add_argument('--max_batch_size', type=int, default=None)
|
71 |
+
#####
|
72 |
+
# Note: Generation config defaults are set to match Hugging Face defaults
|
73 |
+
parser.add_argument('--temperature', type=float, nargs='+', default=[1.0])
|
74 |
+
parser.add_argument('--top_k', type=int, nargs='+', default=[50])
|
75 |
+
parser.add_argument('--top_p', type=float, nargs='+', default=[1.0])
|
76 |
+
parser.add_argument('--repetition_penalty',
|
77 |
+
type=float,
|
78 |
+
nargs='+',
|
79 |
+
default=[1.0])
|
80 |
+
parser.add_argument('--no_repeat_ngram_size',
|
81 |
+
type=int,
|
82 |
+
nargs='+',
|
83 |
+
default=[0])
|
84 |
+
#####
|
85 |
+
parser.add_argument('--seed', type=int, nargs='+', default=[42])
|
86 |
+
parser.add_argument('--do_sample',
|
87 |
+
type=str2bool,
|
88 |
+
nargs='?',
|
89 |
+
const=True,
|
90 |
+
default=True)
|
91 |
+
parser.add_argument('--use_cache',
|
92 |
+
type=str2bool,
|
93 |
+
nargs='?',
|
94 |
+
const=True,
|
95 |
+
default=True)
|
96 |
+
parser.add_argument('--eos_token_id', type=int, default=None)
|
97 |
+
parser.add_argument('--pad_token_id', type=int, default=None)
|
98 |
+
parser.add_argument('--model_dtype',
|
99 |
+
type=str,
|
100 |
+
choices=['fp32', 'fp16', 'bf16'],
|
101 |
+
default=None)
|
102 |
+
parser.add_argument('--autocast_dtype',
|
103 |
+
type=str,
|
104 |
+
choices=['fp32', 'fp16', 'bf16'],
|
105 |
+
default=None)
|
106 |
+
parser.add_argument('--warmup',
|
107 |
+
type=str2bool,
|
108 |
+
nargs='?',
|
109 |
+
const=True,
|
110 |
+
default=True)
|
111 |
+
parser.add_argument('--trust_remote_code',
|
112 |
+
type=str2bool,
|
113 |
+
nargs='?',
|
114 |
+
const=True,
|
115 |
+
default=True)
|
116 |
+
parser.add_argument('--use_auth_token',
|
117 |
+
type=str_or_bool,
|
118 |
+
nargs='?',
|
119 |
+
const=True,
|
120 |
+
default=None)
|
121 |
+
parser.add_argument('--revision', type=str, default=None)
|
122 |
+
parser.add_argument('--device', type=str, default=None)
|
123 |
+
parser.add_argument('--device_map', type=str, default=None)
|
124 |
+
parser.add_argument('--attn_impl', type=str, default=None)
|
125 |
+
return parser.parse_args()
|
126 |
+
|
127 |
+
|
128 |
+
def load_prompt_string_from_file(prompt_path_str: str):
|
129 |
+
if not prompt_path_str.startswith('file::'):
|
130 |
+
raise ValueError('prompt_path_str must start with "file::".')
|
131 |
+
_, prompt_file_path = prompt_path_str.split('file::', maxsplit=1)
|
132 |
+
prompt_file_path = os.path.expanduser(prompt_file_path)
|
133 |
+
if not os.path.isfile(prompt_file_path):
|
134 |
+
raise FileNotFoundError(
|
135 |
+
f'{prompt_file_path=} does not match any existing files.')
|
136 |
+
with open(prompt_file_path, 'r') as f:
|
137 |
+
prompt_string = ''.join(f.readlines())
|
138 |
+
return prompt_string
|
139 |
+
|
140 |
+
|
141 |
+
def maybe_synchronize():
|
142 |
+
if torch.cuda.is_available():
|
143 |
+
torch.cuda.synchronize()
|
144 |
+
|
145 |
+
|
146 |
+
def main(args: Namespace) -> None:
|
147 |
+
# Set device or device_map
|
148 |
+
if args.device and args.device_map:
|
149 |
+
raise ValueError('You can only set one of `device` and `device_map`.')
|
150 |
+
if args.device is not None:
|
151 |
+
device = args.device
|
152 |
+
device_map = None
|
153 |
+
else:
|
154 |
+
device = None
|
155 |
+
device_map = args.device_map or 'auto'
|
156 |
+
print(f'Using {device=} and {device_map=}')
|
157 |
+
|
158 |
+
# Set model_dtype
|
159 |
+
if args.model_dtype is not None:
|
160 |
+
model_dtype = get_dtype(args.model_dtype)
|
161 |
+
else:
|
162 |
+
model_dtype = torch.float32
|
163 |
+
print(f'Using {model_dtype=}')
|
164 |
+
|
165 |
+
# Load prompts
|
166 |
+
prompt_strings = []
|
167 |
+
for prompt in args.prompts:
|
168 |
+
if prompt.startswith('file::'):
|
169 |
+
prompt = load_prompt_string_from_file(prompt)
|
170 |
+
prompt_strings.append(prompt)
|
171 |
+
|
172 |
+
# Grab config first
|
173 |
+
print(f'Loading HF Config...')
|
174 |
+
from_pretrained_kwargs = {
|
175 |
+
'use_auth_token': args.use_auth_token,
|
176 |
+
'trust_remote_code': args.trust_remote_code,
|
177 |
+
'revision': args.revision,
|
178 |
+
}
|
179 |
+
try:
|
180 |
+
config = AutoConfig.from_pretrained(args.name_or_path,
|
181 |
+
**from_pretrained_kwargs)
|
182 |
+
if hasattr(config, 'init_device') and device is not None:
|
183 |
+
config.init_device = device
|
184 |
+
if args.attn_impl is not None and hasattr(config, 'attn_config'):
|
185 |
+
config.attn_config['attn_impl'] = args.attn_impl
|
186 |
+
if args.max_seq_len is not None and hasattr(config, 'max_seq_len'):
|
187 |
+
config.max_seq_len = args.max_seq_len
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
raise RuntimeError(
|
191 |
+
'If you are having auth problems, try logging in via `huggingface-cli login` ' +\
|
192 |
+
'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' +\
|
193 |
+
'using your access token from https://huggingface.co/settings/tokens.'
|
194 |
+
) from e
|
195 |
+
|
196 |
+
# Build tokenizer
|
197 |
+
print('\nLoading HF tokenizer...')
|
198 |
+
tokenizer = AutoTokenizer.from_pretrained(args.name_or_path,
|
199 |
+
**from_pretrained_kwargs)
|
200 |
+
if tokenizer.pad_token_id is None:
|
201 |
+
warnings.warn(
|
202 |
+
'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.'
|
203 |
+
)
|
204 |
+
tokenizer.pad_token = tokenizer.eos_token
|
205 |
+
tokenizer.padding_side = 'left'
|
206 |
+
|
207 |
+
# Load HF Model
|
208 |
+
print(f'Loading HF model with dtype={model_dtype}...')
|
209 |
+
try:
|
210 |
+
model = AutoModelForCausalLM.from_pretrained(args.name_or_path,
|
211 |
+
config=config,
|
212 |
+
torch_dtype=model_dtype,
|
213 |
+
device_map=device_map,
|
214 |
+
**from_pretrained_kwargs)
|
215 |
+
model.eval()
|
216 |
+
print(f'n_params={sum(p.numel() for p in model.parameters())}')
|
217 |
+
if device is not None:
|
218 |
+
print(f'Placing model on {device=}...')
|
219 |
+
model.to(device)
|
220 |
+
except Exception as e:
|
221 |
+
raise RuntimeError(
|
222 |
+
'Unable to load HF model. ' +
|
223 |
+
'If you are having auth problems, try logging in via `huggingface-cli login` '
|
224 |
+
+
|
225 |
+
'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... '
|
226 |
+
+
|
227 |
+
'using your access token from https://huggingface.co/settings/tokens.'
|
228 |
+
) from e
|
229 |
+
|
230 |
+
# Autocast
|
231 |
+
if args.autocast_dtype is not None:
|
232 |
+
autocast_dtype = get_dtype(args.autocast_dtype)
|
233 |
+
autocast_context = torch.autocast(model.device.type, autocast_dtype)
|
234 |
+
print(f'Using autocast with dtype={autocast_dtype}...')
|
235 |
+
else:
|
236 |
+
autocast_context = nullcontext()
|
237 |
+
print('NOT using autocast...')
|
238 |
+
|
239 |
+
done_warmup = False
|
240 |
+
|
241 |
+
for temp, topp, topk, repp, nrnz, seed in itertools.product(
|
242 |
+
args.temperature, args.top_p, args.top_k, args.repetition_penalty,
|
243 |
+
args.no_repeat_ngram_size, args.seed):
|
244 |
+
|
245 |
+
# Seed randomness
|
246 |
+
random.seed(seed)
|
247 |
+
torch.manual_seed(seed)
|
248 |
+
print(f'\nGenerate seed:\n{seed}')
|
249 |
+
|
250 |
+
generate_kwargs = {
|
251 |
+
'max_new_tokens': args.max_new_tokens,
|
252 |
+
'temperature': temp,
|
253 |
+
'top_p': topp,
|
254 |
+
'top_k': topk,
|
255 |
+
'repetition_penalty': repp,
|
256 |
+
'no_repeat_ngram_size': nrnz,
|
257 |
+
'use_cache': args.use_cache,
|
258 |
+
'do_sample': False if temp == 0 else args.do_sample,
|
259 |
+
'eos_token_id': args.eos_token_id or tokenizer.eos_token_id,
|
260 |
+
'pad_token_id': args.pad_token_id or tokenizer.pad_token_id,
|
261 |
+
}
|
262 |
+
print(f'\nGenerate kwargs:\n{generate_kwargs}')
|
263 |
+
|
264 |
+
# Generate function with correct context managers
|
265 |
+
def _generate(encoded_inp: Dict[str, torch.Tensor]):
|
266 |
+
with torch.no_grad():
|
267 |
+
with autocast_context:
|
268 |
+
return model.generate(
|
269 |
+
input_ids=encoded_inp['input_ids'],
|
270 |
+
attention_mask=encoded_inp['attention_mask'],
|
271 |
+
**generate_kwargs,
|
272 |
+
)
|
273 |
+
|
274 |
+
# Split into prompt batches
|
275 |
+
batches = []
|
276 |
+
if args.max_batch_size:
|
277 |
+
bs = args.max_batch_size
|
278 |
+
batches = [
|
279 |
+
prompt_strings[i:i + bs]
|
280 |
+
for i in range(0, len(prompt_strings), bs)
|
281 |
+
]
|
282 |
+
|
283 |
+
else:
|
284 |
+
batches = [prompt_strings]
|
285 |
+
|
286 |
+
for batch in batches:
|
287 |
+
print(f'\nTokenizing prompts...')
|
288 |
+
maybe_synchronize()
|
289 |
+
encode_start = time.time()
|
290 |
+
encoded_inp = tokenizer(batch, return_tensors='pt', padding=True)
|
291 |
+
for key, value in encoded_inp.items():
|
292 |
+
encoded_inp[key] = value.to(model.device)
|
293 |
+
maybe_synchronize()
|
294 |
+
encode_end = time.time()
|
295 |
+
input_tokens = torch.sum(
|
296 |
+
encoded_inp['input_ids'] !=
|
297 |
+
tokenizer.pad_token_id, # type: ignore
|
298 |
+
axis=1).numpy(force=True)
|
299 |
+
|
300 |
+
# Warmup
|
301 |
+
if args.warmup and (not done_warmup):
|
302 |
+
print('Warming up...')
|
303 |
+
_ = _generate(encoded_inp)
|
304 |
+
done_warmup = True
|
305 |
+
|
306 |
+
# Run HF generate
|
307 |
+
print('Generating responses...')
|
308 |
+
maybe_synchronize()
|
309 |
+
gen_start = time.time()
|
310 |
+
encoded_gen = _generate(encoded_inp)
|
311 |
+
maybe_synchronize()
|
312 |
+
gen_end = time.time()
|
313 |
+
|
314 |
+
decode_start = time.time()
|
315 |
+
decoded_gen = tokenizer.batch_decode(encoded_gen,
|
316 |
+
skip_special_tokens=True)
|
317 |
+
maybe_synchronize()
|
318 |
+
decode_end = time.time()
|
319 |
+
gen_tokens = torch.sum(encoded_gen != tokenizer.pad_token_id,
|
320 |
+
axis=1).numpy(force=True) # type: ignore
|
321 |
+
|
322 |
+
# Print generations
|
323 |
+
delimiter = '#' * 100
|
324 |
+
# decode the encoded prompt to handle the case when the tokenizer
|
325 |
+
# trims extra spaces or does other pre-tokenization things
|
326 |
+
effective_prompts = tokenizer.batch_decode(encoded_inp['input_ids'],
|
327 |
+
skip_special_tokens=True)
|
328 |
+
for idx, (effective_prompt, prompt, gen) in enumerate(
|
329 |
+
zip(effective_prompts, batch, decoded_gen)):
|
330 |
+
continuation = gen[len(effective_prompt):]
|
331 |
+
print(delimiter)
|
332 |
+
if len(continuation) > 0:
|
333 |
+
print('\033[92m' + prompt + '\033[0m' + continuation)
|
334 |
+
else:
|
335 |
+
print('Warning. No non-special output tokens generated.')
|
336 |
+
print(
|
337 |
+
'This can happen if the generation only contains padding/eos tokens.'
|
338 |
+
)
|
339 |
+
print('Debug:')
|
340 |
+
full_generation = tokenizer.batch_decode(
|
341 |
+
encoded_gen, skip_special_tokens=False)[idx]
|
342 |
+
print('\033[92m' + 'Prompt:\n' + prompt + '\033[0m')
|
343 |
+
print('Full generation:\n' + full_generation)
|
344 |
+
|
345 |
+
print(delimiter)
|
346 |
+
|
347 |
+
# Print timing info
|
348 |
+
bs = len(batch)
|
349 |
+
# ensure that gen_tokens >= 1 in case model only generated padding tokens
|
350 |
+
gen_tokens = np.maximum(gen_tokens, np.ones_like(gen_tokens))
|
351 |
+
output_tokens = gen_tokens - input_tokens
|
352 |
+
total_input_tokens = input_tokens.sum()
|
353 |
+
total_output_tokens = output_tokens.sum()
|
354 |
+
|
355 |
+
encode_latency = 1000 * (encode_end - encode_start)
|
356 |
+
gen_latency = 1000 * (gen_end - gen_start)
|
357 |
+
decode_latency = 1000 * (decode_end - decode_start)
|
358 |
+
total_latency = encode_latency + gen_latency + decode_latency
|
359 |
+
|
360 |
+
latency_per_output_token = total_latency / total_output_tokens
|
361 |
+
output_tok_per_sec = 1000 / latency_per_output_token
|
362 |
+
print(f'{bs=}, {input_tokens=}, {output_tokens=}')
|
363 |
+
print(f'{total_input_tokens=}, {total_output_tokens=}')
|
364 |
+
print(
|
365 |
+
f'{encode_latency=:.2f}ms, {gen_latency=:.2f}ms, {decode_latency=:.2f}ms, {total_latency=:.2f}ms'
|
366 |
+
)
|
367 |
+
print(f'{latency_per_output_token=:.2f}ms/tok')
|
368 |
+
print(f'{output_tok_per_sec=:.2f}tok/sec')
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == '__main__':
|
372 |
+
main(parse_args())
|
Perceptrix/finetune/build/lib/inference/run_mpt_with_ft.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
+
# Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
"""Run MPT model with FT.
|
20 |
+
|
21 |
+
This script is a modified version of
|
22 |
+
https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/multi_gpu_gpt_example.py
|
23 |
+
"""
|
24 |
+
|
25 |
+
import argparse
|
26 |
+
import configparser
|
27 |
+
import os
|
28 |
+
import sys
|
29 |
+
import timeit
|
30 |
+
|
31 |
+
import torch
|
32 |
+
from torch.nn.utils.rnn import pad_sequence
|
33 |
+
from transformers import AutoTokenizer
|
34 |
+
|
35 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
36 |
+
sys.path.append(os.path.join(dir_path, '../../..'))
|
37 |
+
from examples.pytorch.gpt.utils import comm, gpt_decoder
|
38 |
+
from examples.pytorch.gpt.utils.parallel_gpt import ParallelGPT
|
39 |
+
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def main():
|
43 |
+
parser = argparse.ArgumentParser()
|
44 |
+
parser.add_argument('--layer_num',
|
45 |
+
type=int,
|
46 |
+
default=32,
|
47 |
+
help='number of layers')
|
48 |
+
parser.add_argument('--input_len',
|
49 |
+
type=int,
|
50 |
+
default=128,
|
51 |
+
help='input sequence length to generate.')
|
52 |
+
parser.add_argument('--output_len',
|
53 |
+
type=int,
|
54 |
+
default=64,
|
55 |
+
help='output sequence length to generate.')
|
56 |
+
parser.add_argument('--head_num', type=int, default=32, help='head number')
|
57 |
+
parser.add_argument('--size_per_head',
|
58 |
+
type=int,
|
59 |
+
default=128,
|
60 |
+
help='size per head')
|
61 |
+
parser.add_argument('--vocab_size',
|
62 |
+
type=int,
|
63 |
+
default=50432,
|
64 |
+
help='vocab size')
|
65 |
+
parser.add_argument(
|
66 |
+
'--beam_width',
|
67 |
+
type=int,
|
68 |
+
default=1,
|
69 |
+
help='beam width for beam search. Using sampling when beam width is 1.')
|
70 |
+
parser.add_argument('--top_k',
|
71 |
+
type=int,
|
72 |
+
default=1,
|
73 |
+
help='top k candidate num')
|
74 |
+
parser.add_argument('--top_p',
|
75 |
+
type=float,
|
76 |
+
default=0.95,
|
77 |
+
help='top p probability threshold')
|
78 |
+
parser.add_argument('--temperature',
|
79 |
+
type=float,
|
80 |
+
default=0.8,
|
81 |
+
help='temperature')
|
82 |
+
parser.add_argument('--len_penalty',
|
83 |
+
type=float,
|
84 |
+
default=0.,
|
85 |
+
help='len_penalty')
|
86 |
+
parser.add_argument('--beam_search_diversity_rate',
|
87 |
+
type=float,
|
88 |
+
default=0.,
|
89 |
+
help='beam_search_diversity_rate')
|
90 |
+
parser.add_argument('--tensor_para_size',
|
91 |
+
type=int,
|
92 |
+
default=1,
|
93 |
+
help='tensor parallel size')
|
94 |
+
parser.add_argument('--pipeline_para_size',
|
95 |
+
type=int,
|
96 |
+
default=1,
|
97 |
+
help='pipeline parallel size')
|
98 |
+
parser.add_argument('--ckpt_path',
|
99 |
+
type=str,
|
100 |
+
default='mpt-ft-7b/1-gpu',
|
101 |
+
help='path to the FT checkpoint file.')
|
102 |
+
parser.add_argument(
|
103 |
+
'--tokenizer_name_or_path',
|
104 |
+
type=str,
|
105 |
+
default='EleutherAI/gpt-neox-20b',
|
106 |
+
help=
|
107 |
+
'Name of the tokenizer or the directory where the tokenizer file is located.'
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
'--lib_path',
|
111 |
+
type=str,
|
112 |
+
help=
|
113 |
+
'path to the libth_transformer dynamic lib file(.e.g., build/lib/libth_transformer.so.'
|
114 |
+
)
|
115 |
+
parser.add_argument('--start_id',
|
116 |
+
type=int,
|
117 |
+
default=0,
|
118 |
+
help='start token id.')
|
119 |
+
parser.add_argument('--end_id', type=int, default=0, help='end token id.')
|
120 |
+
parser.add_argument(
|
121 |
+
'--max_batch_size',
|
122 |
+
type=int,
|
123 |
+
default=8,
|
124 |
+
help=
|
125 |
+
'Max batch size. If sample_input_file is given, it is truncated to this max_batch_size, otherwise, this value is used as batch size.'
|
126 |
+
)
|
127 |
+
parser.add_argument('--repetition_penalty',
|
128 |
+
type=float,
|
129 |
+
default=5.,
|
130 |
+
help='repetition penalty')
|
131 |
+
parser.add_argument(
|
132 |
+
'--presence_penalty',
|
133 |
+
type=float,
|
134 |
+
default=0.,
|
135 |
+
help=
|
136 |
+
'presence penalty. Similar to repetition, but additive rather than multiplicative.'
|
137 |
+
)
|
138 |
+
parser.add_argument('--min_length',
|
139 |
+
type=int,
|
140 |
+
default=0,
|
141 |
+
help='A minimum number of tokens to generate')
|
142 |
+
parser.add_argument(
|
143 |
+
'--max_seq_len',
|
144 |
+
type=int,
|
145 |
+
default=2048,
|
146 |
+
help='max sequence length for position embedding table.')
|
147 |
+
parser.add_argument('--inference_data_type',
|
148 |
+
'--data_type',
|
149 |
+
type=str,
|
150 |
+
choices=['fp32', 'fp16', 'bf16'],
|
151 |
+
default='bf16')
|
152 |
+
parser.add_argument('--time',
|
153 |
+
action='store_true',
|
154 |
+
help='whether or not to measure time elapsed.')
|
155 |
+
parser.add_argument(
|
156 |
+
'--sample_input_file',
|
157 |
+
type=str,
|
158 |
+
default=None,
|
159 |
+
help=
|
160 |
+
'path to sample input file. If not set, it runs with no context inputs.'
|
161 |
+
)
|
162 |
+
parser.add_argument('--sample_output_file',
|
163 |
+
type=str,
|
164 |
+
default=None,
|
165 |
+
help='path to sample output file.')
|
166 |
+
parser.add_argument(
|
167 |
+
'--disable_random_seed',
|
168 |
+
dest='random_seed',
|
169 |
+
action='store_false',
|
170 |
+
help='Disable the use of random seed for sentences in a batch.')
|
171 |
+
parser.add_argument('--skip_end_tokens',
|
172 |
+
dest='skip_end_tokens',
|
173 |
+
action='store_false',
|
174 |
+
help='Whether to remove or not end tokens in outputs.')
|
175 |
+
parser.add_argument('--no_detokenize',
|
176 |
+
dest='detokenize',
|
177 |
+
action='store_false',
|
178 |
+
help='Skip detokenizing output token ids.')
|
179 |
+
parser.add_argument(
|
180 |
+
'--int8_mode',
|
181 |
+
type=int,
|
182 |
+
default=0,
|
183 |
+
choices=[0, 1],
|
184 |
+
help='The level of quantization to perform.' +
|
185 |
+
' 0: No quantization. All computation in data_type' +
|
186 |
+
' 1: Quantize weights to int8, all compute occurs in fp16/bf16. Not supported when data_type is fp32'
|
187 |
+
)
|
188 |
+
parser.add_argument(
|
189 |
+
'--weights_data_type',
|
190 |
+
type=str,
|
191 |
+
default='fp32',
|
192 |
+
choices=['fp32', 'fp16'],
|
193 |
+
help='Data type of FT checkpoint weights',
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
'--return_cum_log_probs',
|
197 |
+
type=int,
|
198 |
+
default=0,
|
199 |
+
choices=[0, 1, 2],
|
200 |
+
help='Whether to compute the cumulative log probsbility of sentences.' +
|
201 |
+
' 0: do not return the cumulative log probs' +
|
202 |
+
' 1: return the cumulative log probs of generated sequences' +
|
203 |
+
' 2: return the cumulative log probs of sequences')
|
204 |
+
parser.add_argument('--shared_contexts_ratio',
|
205 |
+
type=float,
|
206 |
+
default=0.0,
|
207 |
+
help='Triggers the shared context optimization when ' +
|
208 |
+
'compact_size <= shared_contexts_ratio * batch_size ' +
|
209 |
+
'A value of 0.0 deactivate the optimization')
|
210 |
+
parser.add_argument(
|
211 |
+
'--use_gpt_decoder_ops',
|
212 |
+
action='store_true',
|
213 |
+
help='Use separate decoder FT operators instead of end-to-end model op.'
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
'--no-alibi',
|
217 |
+
dest='alibi',
|
218 |
+
action='store_false',
|
219 |
+
help='Do not use ALiBi (aka use_attention_linear_bias).')
|
220 |
+
parser.add_argument(
|
221 |
+
'--layernorm_eps',
|
222 |
+
type=float,
|
223 |
+
default=1e-5,
|
224 |
+
help='layernorm eps in PyTorch, by default, is 1e-5 and 1e-6 in FT.')
|
225 |
+
args = parser.parse_args()
|
226 |
+
|
227 |
+
ckpt_config = configparser.ConfigParser()
|
228 |
+
ckpt_config_path = os.path.join(args.ckpt_path, 'config.ini')
|
229 |
+
if os.path.isfile(ckpt_config_path):
|
230 |
+
ckpt_config.read(ckpt_config_path)
|
231 |
+
if 'gpt' in ckpt_config.keys():
|
232 |
+
for args_key, config_key, func in [
|
233 |
+
('layer_num', 'num_layer', ckpt_config.getint),
|
234 |
+
('max_seq_len', 'max_pos_seq_len', ckpt_config.getint),
|
235 |
+
('weights_data_type', 'weight_data_type', ckpt_config.get),
|
236 |
+
('layernorm_eps', 'layernorm_eps', ckpt_config.getfloat),
|
237 |
+
('alibi', 'use_attention_linear_bias', ckpt_config.getboolean),
|
238 |
+
]:
|
239 |
+
if config_key in ckpt_config['gpt'].keys():
|
240 |
+
prev_val = args.__dict__[args_key]
|
241 |
+
args.__dict__[args_key] = func('gpt', config_key)
|
242 |
+
print(
|
243 |
+
'Loading {} from config.ini, previous: {}, current: {}'
|
244 |
+
.format(args_key, prev_val, args.__dict__[args_key]))
|
245 |
+
else:
|
246 |
+
print('Not loading {} from config.ini'.format(args_key))
|
247 |
+
for key in ['head_num', 'size_per_head', 'tensor_para_size']:
|
248 |
+
if key in args.__dict__:
|
249 |
+
prev_val = args.__dict__[key]
|
250 |
+
args.__dict__[key] = ckpt_config.getint('gpt', key)
|
251 |
+
print(
|
252 |
+
'Loading {} from config.ini, previous: {}, current: {}'
|
253 |
+
.format(key, prev_val, args.__dict__[key]))
|
254 |
+
else:
|
255 |
+
print('Not loading {} from config.ini'.format(key))
|
256 |
+
|
257 |
+
layer_num = args.layer_num
|
258 |
+
output_len = args.output_len
|
259 |
+
head_num = args.head_num
|
260 |
+
size_per_head = args.size_per_head
|
261 |
+
vocab_size = args.vocab_size
|
262 |
+
beam_width = args.beam_width
|
263 |
+
top_k = args.top_k
|
264 |
+
top_p = args.top_p
|
265 |
+
temperature = args.temperature
|
266 |
+
len_penalty = args.len_penalty
|
267 |
+
beam_search_diversity_rate = args.beam_search_diversity_rate
|
268 |
+
tensor_para_size = args.tensor_para_size
|
269 |
+
pipeline_para_size = args.pipeline_para_size
|
270 |
+
start_id = args.start_id
|
271 |
+
end_id = args.end_id
|
272 |
+
max_batch_size = args.max_batch_size
|
273 |
+
max_seq_len = args.max_seq_len
|
274 |
+
repetition_penalty = args.repetition_penalty
|
275 |
+
presence_penalty = args.presence_penalty
|
276 |
+
min_length = args.min_length
|
277 |
+
weights_data_type = args.weights_data_type
|
278 |
+
return_cum_log_probs = args.return_cum_log_probs
|
279 |
+
return_output_length = return_cum_log_probs > 0
|
280 |
+
shared_contexts_ratio = args.shared_contexts_ratio
|
281 |
+
layernorm_eps = args.layernorm_eps
|
282 |
+
use_attention_linear_bias = args.alibi
|
283 |
+
has_positional_encoding = not args.alibi
|
284 |
+
|
285 |
+
print('\n=================== Arguments ===================')
|
286 |
+
for k, v in vars(args).items():
|
287 |
+
print(f'{k.ljust(30, ".")}: {v}')
|
288 |
+
print('=================================================\n')
|
289 |
+
|
290 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
|
291 |
+
torch.manual_seed(0)
|
292 |
+
|
293 |
+
comm.initialize_model_parallel(args.tensor_para_size,
|
294 |
+
args.pipeline_para_size)
|
295 |
+
rank = comm.get_rank()
|
296 |
+
device = comm.get_device()
|
297 |
+
|
298 |
+
# Inputs
|
299 |
+
contexts = []
|
300 |
+
if args.sample_input_file:
|
301 |
+
with open(args.sample_input_file, 'r') as f:
|
302 |
+
contexts = f.read().splitlines()
|
303 |
+
batch_size = min(len(contexts), max_batch_size)
|
304 |
+
contexts = contexts[:batch_size]
|
305 |
+
start_ids = [
|
306 |
+
torch.tensor(tokenizer.encode(c), dtype=torch.int32, device=device)
|
307 |
+
for c in contexts
|
308 |
+
]
|
309 |
+
else:
|
310 |
+
batch_size = max_batch_size
|
311 |
+
contexts = ['<|endoftext|>'] * batch_size
|
312 |
+
start_ids = [torch.IntTensor([end_id for _ in range(args.input_len)])
|
313 |
+
] * batch_size
|
314 |
+
|
315 |
+
start_lengths = [len(ids) for ids in start_ids]
|
316 |
+
|
317 |
+
start_ids = pad_sequence(start_ids, batch_first=True, padding_value=end_id)
|
318 |
+
start_lengths = torch.IntTensor(start_lengths)
|
319 |
+
|
320 |
+
# Prepare model.
|
321 |
+
if not args.use_gpt_decoder_ops:
|
322 |
+
gpt = ParallelGPT(head_num,
|
323 |
+
size_per_head,
|
324 |
+
vocab_size,
|
325 |
+
start_id,
|
326 |
+
end_id,
|
327 |
+
layer_num,
|
328 |
+
max_seq_len,
|
329 |
+
tensor_para_size,
|
330 |
+
pipeline_para_size,
|
331 |
+
lib_path=args.lib_path,
|
332 |
+
inference_data_type=args.inference_data_type,
|
333 |
+
int8_mode=args.int8_mode,
|
334 |
+
weights_data_type=weights_data_type,
|
335 |
+
layernorm_eps=layernorm_eps,
|
336 |
+
use_attention_linear_bias=use_attention_linear_bias,
|
337 |
+
has_positional_encoding=has_positional_encoding,
|
338 |
+
shared_contexts_ratio=shared_contexts_ratio)
|
339 |
+
if not gpt.load(ckpt_path=args.ckpt_path):
|
340 |
+
print(
|
341 |
+
'[WARNING] Checkpoint file not found. Model loading is skipped.'
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
gpt = gpt_decoder.Gpt(num_heads=head_num,
|
345 |
+
size_per_head=size_per_head,
|
346 |
+
num_layers=layer_num,
|
347 |
+
vocab_size=vocab_size,
|
348 |
+
start_id=start_id,
|
349 |
+
end_id=end_id,
|
350 |
+
tensor_para_size=tensor_para_size,
|
351 |
+
pipeline_para_size=pipeline_para_size,
|
352 |
+
lib_path=args.lib_path,
|
353 |
+
max_seq_len=max_seq_len,
|
354 |
+
int8_mode=args.int8_mode,
|
355 |
+
weights_data_type=args.weights_data_type)
|
356 |
+
gpt.load(args.ckpt_path, args.inference_data_type)
|
357 |
+
|
358 |
+
if args.random_seed:
|
359 |
+
random_seed_tensor = torch.randint(0,
|
360 |
+
10000,
|
361 |
+
size=[batch_size],
|
362 |
+
dtype=torch.int64)
|
363 |
+
else:
|
364 |
+
random_seed_tensor = torch.zeros([batch_size], dtype=torch.int64)
|
365 |
+
|
366 |
+
repetition_penalty_vec = None if repetition_penalty == 1. else repetition_penalty * torch.ones(
|
367 |
+
batch_size, dtype=torch.float32)
|
368 |
+
presence_penalty_vec = None if presence_penalty == 0. else presence_penalty * torch.ones(
|
369 |
+
batch_size, dtype=torch.float32)
|
370 |
+
|
371 |
+
infer_decode_args = {
|
372 |
+
'beam_width':
|
373 |
+
beam_width,
|
374 |
+
'top_k':
|
375 |
+
top_k * torch.ones(batch_size, dtype=torch.int32),
|
376 |
+
'top_p':
|
377 |
+
top_p * torch.ones(batch_size, dtype=torch.float32),
|
378 |
+
'temperature':
|
379 |
+
temperature * torch.ones(batch_size, dtype=torch.float32),
|
380 |
+
'repetition_penalty':
|
381 |
+
repetition_penalty_vec,
|
382 |
+
'presence_penalty':
|
383 |
+
presence_penalty_vec,
|
384 |
+
'beam_search_diversity_rate':
|
385 |
+
beam_search_diversity_rate *
|
386 |
+
torch.ones(batch_size, dtype=torch.float32),
|
387 |
+
'len_penalty':
|
388 |
+
len_penalty * torch.ones(size=[batch_size], dtype=torch.float32),
|
389 |
+
'bad_words_list':
|
390 |
+
None,
|
391 |
+
'min_length':
|
392 |
+
min_length * torch.ones(size=[batch_size], dtype=torch.int32),
|
393 |
+
'random_seed':
|
394 |
+
random_seed_tensor
|
395 |
+
}
|
396 |
+
|
397 |
+
if not args.use_gpt_decoder_ops:
|
398 |
+
|
399 |
+
def gpt_generate_fn():
|
400 |
+
tokens_batch = gpt(start_ids,
|
401 |
+
start_lengths,
|
402 |
+
output_len,
|
403 |
+
return_output_length=return_output_length,
|
404 |
+
return_cum_log_probs=return_cum_log_probs,
|
405 |
+
**infer_decode_args)
|
406 |
+
return tokens_batch
|
407 |
+
else:
|
408 |
+
|
409 |
+
def gpt_generate_fn():
|
410 |
+
output_dict = gpt.generate(
|
411 |
+
input_token_ids=start_ids,
|
412 |
+
input_lengths=start_lengths,
|
413 |
+
gen_length=output_len,
|
414 |
+
eos_token_id=end_id,
|
415 |
+
return_output_length=return_output_length,
|
416 |
+
return_log_probs=return_cum_log_probs,
|
417 |
+
**infer_decode_args)
|
418 |
+
return output_dict
|
419 |
+
|
420 |
+
# Generate tokens.
|
421 |
+
gen_outputs = gpt_generate_fn()
|
422 |
+
|
423 |
+
if rank == 0:
|
424 |
+
if not args.use_gpt_decoder_ops:
|
425 |
+
if return_cum_log_probs > 0:
|
426 |
+
tokens_batch, _, cum_log_probs = gen_outputs
|
427 |
+
else:
|
428 |
+
tokens_batch, cum_log_probs = gen_outputs, None
|
429 |
+
else:
|
430 |
+
tokens_batch = gen_outputs['output_token_ids']
|
431 |
+
cum_log_probs = gen_outputs[
|
432 |
+
'cum_log_probs'] if return_cum_log_probs > 0 else None
|
433 |
+
if cum_log_probs is not None:
|
434 |
+
print('[INFO] Log probs of sentences:', cum_log_probs)
|
435 |
+
|
436 |
+
outputs = []
|
437 |
+
tokens_batch = tokens_batch.cpu().numpy()
|
438 |
+
for i, (context, tokens) in enumerate(zip(contexts, tokens_batch)):
|
439 |
+
for beam_id in range(beam_width):
|
440 |
+
token = tokens[beam_id][
|
441 |
+
start_lengths[i]:] # exclude context input from the output
|
442 |
+
if args.skip_end_tokens:
|
443 |
+
token = token[token != end_id]
|
444 |
+
output = tokenizer.decode(
|
445 |
+
token) if args.detokenize else ' '.join(
|
446 |
+
str(t) for t in token.tolist())
|
447 |
+
outputs.append(output)
|
448 |
+
print(
|
449 |
+
f'[INFO] batch {i}, beam {beam_id}:\n[Context]\n{context}\n\n[Output]\n{output}\n'
|
450 |
+
)
|
451 |
+
|
452 |
+
if args.sample_output_file:
|
453 |
+
with open(args.sample_output_file, 'w+') as f:
|
454 |
+
outputs = [o.replace('\n', '\\n') for o in outputs]
|
455 |
+
f.writelines('\n'.join(outputs))
|
456 |
+
|
457 |
+
# Measure inference time.
|
458 |
+
if args.time:
|
459 |
+
warmup_iterations = 10
|
460 |
+
for _ in range(warmup_iterations):
|
461 |
+
gpt_generate_fn()
|
462 |
+
torch.cuda.synchronize()
|
463 |
+
measurement_iterations = 10
|
464 |
+
time = timeit.default_timer()
|
465 |
+
for _ in range(measurement_iterations):
|
466 |
+
gpt_generate_fn()
|
467 |
+
torch.cuda.synchronize()
|
468 |
+
time_elapsed = timeit.default_timer() - time
|
469 |
+
if rank == 0:
|
470 |
+
print(f'[INFO] MPT time costs:')
|
471 |
+
print(
|
472 |
+
'model_name, gpu_type, gpu_count, batch_size, input_tokens, output_tokens, latency_ms'
|
473 |
+
)
|
474 |
+
print(
|
475 |
+
f'{ckpt_config.get("gpt", "model_name")}, {torch.cuda.get_device_name().replace(" ", "-")}, {torch.cuda.device_count()}, {batch_size}, {args.input_len}, {args.output_len}, {time_elapsed * 1000 / measurement_iterations:.2f}'
|
476 |
+
)
|
477 |
+
|
478 |
+
|
479 |
+
if __name__ == '__main__':
|
480 |
+
main()
|
Perceptrix/finetune/build/lib/llmfoundry/__init__.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
try:
|
7 |
+
from llmfoundry import optim, utils
|
8 |
+
from llmfoundry.data import (ConcatTokensDataset,
|
9 |
+
MixtureOfDenoisersCollator, NoConcatDataset,
|
10 |
+
Seq2SeqFinetuningCollator,
|
11 |
+
build_finetuning_dataloader,
|
12 |
+
build_text_denoising_dataloader)
|
13 |
+
from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
|
14 |
+
ComposerHFT5)
|
15 |
+
from llmfoundry.models.layers.attention import (
|
16 |
+
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
|
17 |
+
flash_attn_fn, scaled_multihead_dot_product_attention,
|
18 |
+
triton_flash_attn_fn)
|
19 |
+
from llmfoundry.models.layers.blocks import MPTBlock
|
20 |
+
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
|
21 |
+
build_ffn)
|
22 |
+
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
|
23 |
+
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
|
24 |
+
MPTForCausalLM, MPTModel,
|
25 |
+
MPTPreTrainedModel)
|
26 |
+
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
|
27 |
+
|
28 |
+
except ImportError as e:
|
29 |
+
try:
|
30 |
+
is_cuda_available = torch.cuda.is_available()
|
31 |
+
except:
|
32 |
+
is_cuda_available = False
|
33 |
+
|
34 |
+
extras = '.[gpu]' if is_cuda_available else '.'
|
35 |
+
raise ImportError(
|
36 |
+
f'Please make sure to pip install {extras} to get the requirements for the LLM example.'
|
37 |
+
) from e
|
38 |
+
|
39 |
+
__all__ = [
|
40 |
+
'build_text_denoising_dataloader',
|
41 |
+
'build_finetuning_dataloader',
|
42 |
+
'MixtureOfDenoisersCollator',
|
43 |
+
'Seq2SeqFinetuningCollator',
|
44 |
+
'MPTBlock',
|
45 |
+
'FFN_CLASS_REGISTRY',
|
46 |
+
'MPTMLP',
|
47 |
+
'build_ffn',
|
48 |
+
'MPTConfig',
|
49 |
+
'MPTPreTrainedModel',
|
50 |
+
'MPTModel',
|
51 |
+
'MPTForCausalLM',
|
52 |
+
'ComposerMPTCausalLM',
|
53 |
+
'ComposerHFCausalLM',
|
54 |
+
'ComposerHFPrefixLM',
|
55 |
+
'ComposerHFT5',
|
56 |
+
'COMPOSER_MODEL_REGISTRY',
|
57 |
+
'scaled_multihead_dot_product_attention',
|
58 |
+
'flash_attn_fn',
|
59 |
+
'triton_flash_attn_fn',
|
60 |
+
'MultiheadAttention',
|
61 |
+
'NoConcatDataset',
|
62 |
+
'ConcatTokensDataset',
|
63 |
+
'attn_bias_shape',
|
64 |
+
'build_attn_bias',
|
65 |
+
'build_alibi_bias',
|
66 |
+
'optim',
|
67 |
+
'utils',
|
68 |
+
'TiktokenTokenizerWrapper',
|
69 |
+
]
|
70 |
+
|
71 |
+
__version__ = '0.3.0'
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
try:
|
5 |
+
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
|
6 |
+
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
|
7 |
+
from llmfoundry.callbacks.generate_callback import Generate
|
8 |
+
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
|
9 |
+
from llmfoundry.callbacks.model_gauntlet_callback import ModelGauntlet
|
10 |
+
from llmfoundry.callbacks.monolithic_ckpt_callback import \
|
11 |
+
MonolithicCheckpointSaver
|
12 |
+
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
|
13 |
+
LayerFreezing)
|
14 |
+
from llmfoundry.callbacks.scheduled_gc_callback import \
|
15 |
+
ScheduledGarbageCollector
|
16 |
+
except ImportError as e:
|
17 |
+
raise ImportError(
|
18 |
+
'Please make sure to pip install . to get requirements for llm-foundry.'
|
19 |
+
) from e
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
'FDiffMetrics',
|
23 |
+
'Generate',
|
24 |
+
'MonolithicCheckpointSaver',
|
25 |
+
'GlobalLRScaling',
|
26 |
+
'LayerFreezing',
|
27 |
+
'ScheduledGarbageCollector',
|
28 |
+
'EvalGauntlet',
|
29 |
+
'ModelGauntlet',
|
30 |
+
'HuggingFaceCheckpointer',
|
31 |
+
]
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/eval_gauntlet_callback.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Aggregate ICL evals into composite scores."""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from enum import Enum
|
9 |
+
from typing import Dict, Optional
|
10 |
+
|
11 |
+
from composer.core import Callback, State
|
12 |
+
from composer.loggers import Logger
|
13 |
+
|
14 |
+
__all__ = ['EvalGauntlet']
|
15 |
+
|
16 |
+
log = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class Weighting(Enum):
|
20 |
+
EQUAL = 1
|
21 |
+
SAMPLE_SZ = 2
|
22 |
+
LOG_SAMPLE_SZ = 3
|
23 |
+
|
24 |
+
|
25 |
+
class EvalGauntlet(Callback):
|
26 |
+
"""The EvalGauntlet aggregates ICL eval results.
|
27 |
+
|
28 |
+
After `eval_end`, this callback inspects the logger for different ICL metrics and aggregates the scores according to the aggregation
|
29 |
+
specification provided in the constructor.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
logger_keys (list): These are the exact keys that the individual benchmark metrics will be
|
33 |
+
logged under in the logger after eval
|
34 |
+
tasks (dict): This contains the list of categories, as well as the subtasks within them, the
|
35 |
+
random baseline accuracy of each subtask, and the number of fewshot examples
|
36 |
+
used for the task. See `llmfoundry/scripts/eval/yamls/eval_gauntlet.yaml` to see the structure.
|
37 |
+
weighting (Weighting): The weighting scheme used to balance different tasks within each category.
|
38 |
+
Either assign them all equal weight, assign them weight proportional
|
39 |
+
to the dataset size, or assign them weight proportional to the log2 of the dataset size.
|
40 |
+
Options are 'EQUAL', 'SAMPLE_SZ', and 'LOG_SAMPLE_SZ'.
|
41 |
+
subtract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy
|
42 |
+
from the performance on each individual benchmark before aggregating.
|
43 |
+
rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
|
44 |
+
by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
|
45 |
+
benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self,
|
49 |
+
logger_keys: list,
|
50 |
+
categories: dict,
|
51 |
+
weighting: str = 'EQUAL',
|
52 |
+
subtract_random_baseline: bool = True,
|
53 |
+
rescale_accuracy: bool = True,
|
54 |
+
benchmark_sizes: Optional[dict] = None):
|
55 |
+
if isinstance(logger_keys, dict):
|
56 |
+
raise ValueError(
|
57 |
+
'logger_keys now requires a list type as input, not a dict')
|
58 |
+
if weighting != Weighting.EQUAL and benchmark_sizes is None:
|
59 |
+
raise Exception(
|
60 |
+
'When not using equal weighting, you must provide the benchmark sizes.'
|
61 |
+
)
|
62 |
+
|
63 |
+
if rescale_accuracy and not subtract_random_baseline:
|
64 |
+
raise Exception(
|
65 |
+
'Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.'
|
66 |
+
)
|
67 |
+
|
68 |
+
self.categories = categories
|
69 |
+
self.weighting = Weighting[weighting]
|
70 |
+
self.subtract_random_baseline = subtract_random_baseline
|
71 |
+
self.rescale_accuracy = rescale_accuracy
|
72 |
+
self.logger_keys = logger_keys
|
73 |
+
|
74 |
+
for category in self.categories:
|
75 |
+
|
76 |
+
for benchmark in category['benchmarks']:
|
77 |
+
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
|
78 |
+
|
79 |
+
if self.weighting != Weighting.EQUAL:
|
80 |
+
assert benchmark_sizes is not None
|
81 |
+
cumulative_samples = max(
|
82 |
+
sum(count for name, count in benchmark_sizes.items()
|
83 |
+
if name.startswith(bench_name)), 1)
|
84 |
+
else:
|
85 |
+
cumulative_samples = -1 # pyright
|
86 |
+
|
87 |
+
weight = None
|
88 |
+
if self.weighting == Weighting.EQUAL:
|
89 |
+
weight = 1
|
90 |
+
elif self.weighting == Weighting.SAMPLE_SZ:
|
91 |
+
weight = cumulative_samples
|
92 |
+
elif self.weighting == Weighting.LOG_SAMPLE_SZ:
|
93 |
+
weight = max(math.log(cumulative_samples, 2), 1)
|
94 |
+
|
95 |
+
assert weight is not None
|
96 |
+
benchmark['weighting'] = weight
|
97 |
+
|
98 |
+
def compute_averages(self, state: State) -> Dict[str, float]:
|
99 |
+
results = {}
|
100 |
+
|
101 |
+
for key in self.logger_keys:
|
102 |
+
|
103 |
+
# starting at index 1 skips the "metric" part of the key which is superfluous
|
104 |
+
dl_name, metric_name = key.split('/')[1:-1], key.split('/')[-1]
|
105 |
+
if 'Accuracy' not in metric_name:
|
106 |
+
continue
|
107 |
+
|
108 |
+
metric = state.eval_metrics.get('/'.join(dl_name),
|
109 |
+
{}).get(metric_name, None)
|
110 |
+
if metric is None:
|
111 |
+
continue
|
112 |
+
val = metric.compute().item()
|
113 |
+
|
114 |
+
# ending at index 2 allows us to aggregate over dataloaders w/ subcategories
|
115 |
+
key = '/'.join(dl_name[0:2])
|
116 |
+
if key not in results:
|
117 |
+
results[key] = []
|
118 |
+
|
119 |
+
results[key].append(val)
|
120 |
+
|
121 |
+
return {k: sum(v) / len(v) for k, v in results.items()}
|
122 |
+
|
123 |
+
def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
|
124 |
+
new_metrics = self.compute_averages(state)
|
125 |
+
if len(new_metrics) == 0:
|
126 |
+
return {}
|
127 |
+
composite_scores = {}
|
128 |
+
|
129 |
+
for category in self.categories:
|
130 |
+
missing_metrics = []
|
131 |
+
composite_scores[category['name']] = []
|
132 |
+
for benchmark in category['benchmarks']:
|
133 |
+
key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
|
134 |
+
|
135 |
+
if key not in new_metrics:
|
136 |
+
log.warning(
|
137 |
+
f'Could not find results for benchmark: {benchmark}.')
|
138 |
+
missing_metrics.append(key)
|
139 |
+
else:
|
140 |
+
score = new_metrics[key]
|
141 |
+
|
142 |
+
if self.subtract_random_baseline:
|
143 |
+
score -= benchmark['random_baseline']
|
144 |
+
|
145 |
+
if self.rescale_accuracy and self.subtract_random_baseline:
|
146 |
+
score /= 1.0 - benchmark['random_baseline']
|
147 |
+
|
148 |
+
composite_scores[category['name']].append({
|
149 |
+
'name': benchmark['name'],
|
150 |
+
'score': score,
|
151 |
+
'weighting': benchmark['weighting']
|
152 |
+
})
|
153 |
+
|
154 |
+
if len(missing_metrics) > 0:
|
155 |
+
log.warning(
|
156 |
+
f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}"
|
157 |
+
)
|
158 |
+
del composite_scores[category['name']]
|
159 |
+
continue
|
160 |
+
total_weight = sum(
|
161 |
+
k['weighting'] for k in composite_scores[category['name']])
|
162 |
+
composite_scores[category['name']] = sum(
|
163 |
+
k['score'] * (k['weighting'] / total_weight)
|
164 |
+
for k in composite_scores[category['name']])
|
165 |
+
|
166 |
+
composite_scores = {
|
167 |
+
f'icl/metrics/eval_gauntlet/{k}': v
|
168 |
+
for k, v in composite_scores.items()
|
169 |
+
}
|
170 |
+
|
171 |
+
composite_scores['icl/metrics/eval_gauntlet/average'] = sum(
|
172 |
+
composite_scores.values()) / len(composite_scores.values()) if len(
|
173 |
+
composite_scores.values()) > 0 else 0
|
174 |
+
if logger is not None:
|
175 |
+
logger.log_metrics(composite_scores)
|
176 |
+
|
177 |
+
return composite_scores
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/fdiff_callback.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Monitor rate of change of loss."""
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from composer.core import Callback, State
|
9 |
+
from composer.loggers import Logger
|
10 |
+
|
11 |
+
|
12 |
+
class FDiffMetrics(Callback):
|
13 |
+
"""Rate of change of metrics.
|
14 |
+
|
15 |
+
tracks and plots the rate of change of metrics effectively taking the
|
16 |
+
numerical derivative of the metrics
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self,
|
20 |
+
diff_train_metrics: bool = False,
|
21 |
+
diff_eval_metrics: bool = True):
|
22 |
+
self.diff_train_metrics = diff_train_metrics
|
23 |
+
self.diff_eval_metrics = diff_eval_metrics
|
24 |
+
|
25 |
+
self.train_prev_loss = None
|
26 |
+
self.train_prev_metric = {}
|
27 |
+
self.eval_prev_metric = {}
|
28 |
+
|
29 |
+
def batch_end(self, state: State, logger: Logger) -> None:
|
30 |
+
if self.diff_train_metrics:
|
31 |
+
if not isinstance(state.loss, torch.Tensor):
|
32 |
+
raise NotImplementedError('Multiple losses not supported yet')
|
33 |
+
loss = state.loss.item()
|
34 |
+
if self.train_prev_loss:
|
35 |
+
logger.log_metrics(
|
36 |
+
{'loss/train/total_fdiff': loss - self.train_prev_loss})
|
37 |
+
self.train_prev_loss = loss
|
38 |
+
|
39 |
+
for k in self.train_prev_metric.keys():
|
40 |
+
logger.log_metrics({
|
41 |
+
f'metrics/train/{k}_fdiff':
|
42 |
+
state.train_metric_values[k] - self.train_prev_metric[k]
|
43 |
+
})
|
44 |
+
|
45 |
+
for k in state.train_metric_values.keys():
|
46 |
+
value = state.train_metric_values[k]
|
47 |
+
self.train_prev_metric[k] = value
|
48 |
+
|
49 |
+
def eval_end(self, state: State, logger: Logger) -> None:
|
50 |
+
if self.diff_eval_metrics:
|
51 |
+
evaluator = state.dataloader_label
|
52 |
+
assert evaluator is not None, 'dataloader should have been set'
|
53 |
+
|
54 |
+
metrics = list(state.eval_metrics[evaluator].keys())
|
55 |
+
|
56 |
+
for k in metrics:
|
57 |
+
mkey = '/'.join(['metrics', evaluator, k])
|
58 |
+
if mkey in self.eval_prev_metric.keys():
|
59 |
+
logger.log_metrics({
|
60 |
+
f'{mkey}_fdiff':
|
61 |
+
state.eval_metric_values[k] -
|
62 |
+
self.eval_prev_metric[mkey]
|
63 |
+
})
|
64 |
+
|
65 |
+
for k in metrics:
|
66 |
+
mkey = '/'.join(['metrics', evaluator, k])
|
67 |
+
self.eval_prev_metric[mkey] = state.eval_metric_values[k]
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/generate_callback.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Deprecated Generate callback.
|
5 |
+
|
6 |
+
Please use composer.callbacks.Generate instead.
|
7 |
+
"""
|
8 |
+
import warnings
|
9 |
+
from typing import Any, List, Union
|
10 |
+
|
11 |
+
from composer.callbacks import Generate as ComposerGenerate
|
12 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
13 |
+
|
14 |
+
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
15 |
+
|
16 |
+
|
17 |
+
class Generate(ComposerGenerate):
|
18 |
+
|
19 |
+
def __init__(self, prompts: List[str], batch_log_interval: int,
|
20 |
+
**kwargs: Any):
|
21 |
+
|
22 |
+
warnings.warn(
|
23 |
+
('Accessing llmfoundry.callbacks.generate_callback.Generate '
|
24 |
+
'is deprecated and will be removed in a future release. '
|
25 |
+
'Please use composer.callbacks.Generate instead.'),
|
26 |
+
DeprecationWarning,
|
27 |
+
)
|
28 |
+
|
29 |
+
interval = f'{batch_log_interval}ba'
|
30 |
+
super().__init__(prompts=prompts, interval=interval, **kwargs)
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/hf_checkpointer.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import contextlib
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Optional, Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from composer.callbacks.utils import create_interval_scheduler
|
14 |
+
from composer.core import Callback, Event, State, Time
|
15 |
+
from composer.core.state import fsdp_state_dict_type_context
|
16 |
+
from composer.loggers import Logger
|
17 |
+
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
|
18 |
+
from composer.models import HuggingFaceModel
|
19 |
+
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
|
20 |
+
from transformers import PreTrainedTokenizerBase
|
21 |
+
|
22 |
+
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
|
23 |
+
from llmfoundry.utils.huggingface_hub_utils import \
|
24 |
+
edit_files_for_hf_compatibility
|
25 |
+
|
26 |
+
log = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class HuggingFaceCheckpointer(Callback):
|
30 |
+
"""Save a huggingface formatted checkpoint during training.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
|
34 |
+
this would be the same as your save_folder.
|
35 |
+
save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
|
36 |
+
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
|
37 |
+
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
|
38 |
+
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
|
39 |
+
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
|
40 |
+
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
|
41 |
+
overwrite (bool): Whether to overwrite previous checkpoints.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
save_folder: str,
|
47 |
+
save_interval: Union[str, int, Time],
|
48 |
+
huggingface_folder_name: str = 'ba{batch}',
|
49 |
+
precision: str = 'float32',
|
50 |
+
overwrite: bool = False,
|
51 |
+
):
|
52 |
+
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
|
53 |
+
save_folder)
|
54 |
+
self.overwrite = overwrite
|
55 |
+
self.precision = precision
|
56 |
+
self.dtype = {
|
57 |
+
'float32': torch.float32,
|
58 |
+
'float16': torch.float16,
|
59 |
+
'bfloat16': torch.bfloat16,
|
60 |
+
}[precision]
|
61 |
+
self.huggingface_folder_name_fstr = os.path.join(
|
62 |
+
'huggingface', huggingface_folder_name)
|
63 |
+
self.check_interval = create_interval_scheduler(
|
64 |
+
save_interval, include_end_of_training=True)
|
65 |
+
self.upload_to_object_store = (self.backend != '')
|
66 |
+
if self.upload_to_object_store:
|
67 |
+
self.remote_ud = RemoteUploaderDownloader(
|
68 |
+
bucket_uri=f'{self.backend}://{self.bucket_name}',
|
69 |
+
num_concurrent_uploads=4)
|
70 |
+
else:
|
71 |
+
self.remote_ud = None
|
72 |
+
|
73 |
+
self.last_checkpoint_batch: Optional[Time] = None
|
74 |
+
|
75 |
+
def run_event(self, event: Event, state: State, logger: Logger) -> None:
|
76 |
+
# The interval scheduler handles only returning True for the appropriate events
|
77 |
+
if state.get_elapsed_duration() is not None and self.check_interval(
|
78 |
+
state,
|
79 |
+
event) and self.last_checkpoint_batch != state.timestamp.batch:
|
80 |
+
self._save_checkpoint(state, logger)
|
81 |
+
elif event == Event.INIT:
|
82 |
+
if not isinstance(state.model, HuggingFaceModel):
|
83 |
+
raise ValueError(
|
84 |
+
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
|
85 |
+
+ f'Got {type(state.model)} instead.')
|
86 |
+
if self.upload_to_object_store and self.remote_ud is not None:
|
87 |
+
self.remote_ud.init(state, logger)
|
88 |
+
state.callbacks.append(self.remote_ud)
|
89 |
+
|
90 |
+
def _save_checkpoint(self, state: State, logger: Logger):
|
91 |
+
del logger # unused
|
92 |
+
|
93 |
+
self.last_checkpoint_batch = state.timestamp.batch
|
94 |
+
|
95 |
+
log.info('Saving HuggingFace formatted checkpoint')
|
96 |
+
|
97 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
98 |
+
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
|
99 |
+
MPTConfig.register_for_auto_class()
|
100 |
+
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
|
101 |
+
|
102 |
+
assert isinstance(state.model, HuggingFaceModel)
|
103 |
+
|
104 |
+
save_dir = format_name_with_dist_and_time(
|
105 |
+
str(
|
106 |
+
Path(self.save_dir_format_str) /
|
107 |
+
self.huggingface_folder_name_fstr), state.run_name,
|
108 |
+
state.timestamp)
|
109 |
+
dir_context_mgr = tempfile.TemporaryDirectory(
|
110 |
+
) if self.upload_to_object_store else contextlib.nullcontext(
|
111 |
+
enter_result=save_dir)
|
112 |
+
|
113 |
+
with dir_context_mgr as temp_save_dir:
|
114 |
+
assert isinstance(temp_save_dir,
|
115 |
+
str) # pyright doesn't know about enter_result
|
116 |
+
|
117 |
+
with fsdp_state_dict_type_context(state.model.model,
|
118 |
+
state_dict_type='full'):
|
119 |
+
state_dict = state.model.model.state_dict()
|
120 |
+
|
121 |
+
# convert the state dict to the requested precision
|
122 |
+
for k, v in state_dict.items():
|
123 |
+
if isinstance(v, torch.Tensor):
|
124 |
+
state_dict[k] = v.to(dtype=self.dtype)
|
125 |
+
|
126 |
+
if dist.get_global_rank() == 0:
|
127 |
+
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
|
128 |
+
assert hasattr(state.model.model, 'save_pretrained')
|
129 |
+
state.model.model.save_pretrained(temp_save_dir,
|
130 |
+
state_dict=state_dict)
|
131 |
+
|
132 |
+
if state.model.tokenizer is not None:
|
133 |
+
assert isinstance(state.model.tokenizer,
|
134 |
+
PreTrainedTokenizerBase)
|
135 |
+
state.model.tokenizer.save_pretrained(temp_save_dir)
|
136 |
+
|
137 |
+
# Only need to edit files for MPT because it has custom code
|
138 |
+
if state.model.model.config.model_type == 'mpt':
|
139 |
+
edit_files_for_hf_compatibility(temp_save_dir)
|
140 |
+
|
141 |
+
with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
|
142 |
+
edited_config = json.load(f)
|
143 |
+
|
144 |
+
if state.model.model.config.model_type == 'mpt':
|
145 |
+
edited_config['attn_config']['attn_impl'] = 'torch'
|
146 |
+
edited_config['init_device'] = 'cpu'
|
147 |
+
|
148 |
+
edited_config['torch_dtype'] = self.precision
|
149 |
+
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
|
150 |
+
json.dump(edited_config, f, indent=4)
|
151 |
+
|
152 |
+
if self.upload_to_object_store:
|
153 |
+
assert self.remote_ud is not None
|
154 |
+
# TODO change to log after other pr
|
155 |
+
log.info(
|
156 |
+
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
|
157 |
+
)
|
158 |
+
for filename in os.listdir(temp_save_dir):
|
159 |
+
self.remote_ud.upload_file(
|
160 |
+
state=state,
|
161 |
+
remote_file_name=os.path.join(save_dir, filename),
|
162 |
+
file_path=Path(os.path.join(temp_save_dir,
|
163 |
+
filename)),
|
164 |
+
overwrite=self.overwrite,
|
165 |
+
)
|
166 |
+
|
167 |
+
dist.barrier()
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/model_gauntlet_callback.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from composer.core import Callback
|
5 |
+
|
6 |
+
__all__ = ['ModelGauntlet']
|
7 |
+
|
8 |
+
|
9 |
+
class ModelGauntlet(Callback):
|
10 |
+
"""The ModelGauntlet callback has been renamed to EvalGauntlet.
|
11 |
+
|
12 |
+
We've created this dummy class, in order to alert anyone who may have been
|
13 |
+
importing ModelGauntlet.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
*args, # pyright: ignore [reportMissingParameterType]
|
19 |
+
**kwargs): # pyright: ignore [reportMissingParameterType]
|
20 |
+
raise ImportError(
|
21 |
+
'ModelGauntlet class is deprecated, please use EvalGauntlet')
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/monolithic_ckpt_callback.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import contextlib
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from composer.core import Callback, State
|
11 |
+
from composer.core.state import (fsdp_get_optim_state_dict,
|
12 |
+
fsdp_state_dict_type_context)
|
13 |
+
from composer.loggers import Logger
|
14 |
+
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
|
15 |
+
from composer.utils import (dist, format_name_with_dist_and_time, parse_uri,
|
16 |
+
reproducibility)
|
17 |
+
|
18 |
+
|
19 |
+
class MonolithicCheckpointSaver(Callback):
|
20 |
+
"""Save a monolithic checkpoint every N batches.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
save_folder (str): Folder to save checkpoints to (can be a URI)
|
24 |
+
filename (str): Filename to save checkpoints to.
|
25 |
+
batch_interval (int): Number of batches between checkpoints.
|
26 |
+
overwrite (bool): Whether to overwrite previous checkpoints.
|
27 |
+
keep_optimizer(bool): Whether to save the optimizer state in the monolithic checkpoint.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
save_folder: str,
|
32 |
+
batch_interval: int,
|
33 |
+
filename: str = 'ep{epoch}-ba{batch}.pt',
|
34 |
+
overwrite: bool = False,
|
35 |
+
keep_optimizers: bool = False):
|
36 |
+
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
|
37 |
+
save_folder)
|
38 |
+
self.filename_format_str = filename
|
39 |
+
self.batch_interval = batch_interval
|
40 |
+
self.upload_to_object_store = (self.backend != '')
|
41 |
+
self.overwrite = overwrite
|
42 |
+
self.keep_optimizers = keep_optimizers
|
43 |
+
if self.upload_to_object_store:
|
44 |
+
self.remote_ud = RemoteUploaderDownloader(
|
45 |
+
bucket_uri=f'{self.backend}://{self.bucket_name}')
|
46 |
+
else:
|
47 |
+
self.remote_ud = None
|
48 |
+
|
49 |
+
def init(self, state: State, logger: Logger) -> None:
|
50 |
+
if self.upload_to_object_store and self.remote_ud is not None:
|
51 |
+
self.remote_ud.init(state, logger)
|
52 |
+
# updated_logger_destinations = [*logger.destinations, new_remote_ud]
|
53 |
+
# logger.destinations = tuple(updated_logger_destinations)
|
54 |
+
state.callbacks.append(self.remote_ud)
|
55 |
+
|
56 |
+
def batch_checkpoint(self, state: State, logger: Logger) -> None:
|
57 |
+
if state.timestamp.batch.value % self.batch_interval == 0:
|
58 |
+
self._save_checkpoint(state, logger)
|
59 |
+
|
60 |
+
def fit_end(self, state: State, logger: Logger) -> None:
|
61 |
+
if state.timestamp.batch.value % self.batch_interval != 0:
|
62 |
+
self._save_checkpoint(state, logger)
|
63 |
+
|
64 |
+
def _save_checkpoint(self, state: State, logger: Logger) -> None:
|
65 |
+
del logger # unused
|
66 |
+
|
67 |
+
filename = format_name_with_dist_and_time(self.filename_format_str,
|
68 |
+
state.run_name,
|
69 |
+
state.timestamp)
|
70 |
+
save_dir = format_name_with_dist_and_time(self.save_dir_format_str,
|
71 |
+
state.run_name,
|
72 |
+
state.timestamp)
|
73 |
+
dir_context_mgr = tempfile.TemporaryDirectory(
|
74 |
+
) if self.upload_to_object_store else contextlib.nullcontext(
|
75 |
+
enter_result=save_dir)
|
76 |
+
with dir_context_mgr as temp_save_dir:
|
77 |
+
# pyright doesn't know about enter_result
|
78 |
+
assert isinstance(temp_save_dir, str)
|
79 |
+
|
80 |
+
save_path = str(Path(temp_save_dir) / Path(filename))
|
81 |
+
dirname = os.path.dirname(save_path)
|
82 |
+
if dirname:
|
83 |
+
os.makedirs(dirname, exist_ok=True)
|
84 |
+
state_dict = {
|
85 |
+
'state': state.state_dict(),
|
86 |
+
'rng': reproducibility.get_rng_state()
|
87 |
+
}
|
88 |
+
# Remove sharded model and optimizer state dicts
|
89 |
+
state_dict['state'].pop('optimizers')
|
90 |
+
state_dict['state'].pop('model')
|
91 |
+
|
92 |
+
# Add in unsharded model params.
|
93 |
+
with fsdp_state_dict_type_context(state.model,
|
94 |
+
state_dict_type='full'):
|
95 |
+
state_dict['state']['model'] = state.model.state_dict()
|
96 |
+
|
97 |
+
# Add in unsharded optimizer state dict.
|
98 |
+
if self.keep_optimizers:
|
99 |
+
optimizer = state.optimizers[0]
|
100 |
+
state_dict['state']['optimizers'] = {
|
101 |
+
type(optimizer).__qualname__:
|
102 |
+
fsdp_get_optim_state_dict(state.model,
|
103 |
+
optimizer,
|
104 |
+
state_dict_type='full')
|
105 |
+
}
|
106 |
+
if dist.get_global_rank() == 0:
|
107 |
+
torch.save(state_dict, save_path)
|
108 |
+
|
109 |
+
if self.upload_to_object_store and self.remote_ud is not None and dist.get_global_rank(
|
110 |
+
) == 0:
|
111 |
+
remote_file_name = str(Path(save_dir) / Path(filename))
|
112 |
+
self.remote_ud.upload_file(state=state,
|
113 |
+
remote_file_name=remote_file_name,
|
114 |
+
file_path=Path(save_path),
|
115 |
+
overwrite=self.overwrite)
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/resumption_callbacks.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import logging
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from composer.core import Callback, State
|
8 |
+
from composer.loggers import Logger
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'GlobalLRScaling',
|
12 |
+
'LayerFreezing',
|
13 |
+
]
|
14 |
+
|
15 |
+
log = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class GlobalLRScaling(Callback):
|
19 |
+
"""GlobalLRScaling.
|
20 |
+
|
21 |
+
This callback can be applied upon resuming a model checkpoint. Upon
|
22 |
+
fit_start it will multiply the base LR by `lr_scale` and set the WD to be.
|
23 |
+
|
24 |
+
`wd_pct` * `lr`.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
lr_scale (float): Multiplicative factor to scale LR by
|
28 |
+
wd_pct (float): Percentage of LR to set weight decay to.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, lr_scale: float, wd_pct: float = 0.0):
|
32 |
+
self.lr_scale = lr_scale
|
33 |
+
self.wd_pct = wd_pct
|
34 |
+
|
35 |
+
def fit_start(self, state: State, logger: Logger) -> None:
|
36 |
+
del logger # unused
|
37 |
+
|
38 |
+
if hasattr(state, 'optimizer') and state.optimizers is None:
|
39 |
+
raise Exception('No optimizers defined')
|
40 |
+
for optimizer in state.optimizers:
|
41 |
+
for group in optimizer.param_groups:
|
42 |
+
group['lr'] *= self.lr_scale
|
43 |
+
group['weight_decay'] = group['lr'] * self.wd_pct
|
44 |
+
if 'initial_lr' in group:
|
45 |
+
group['initial_lr'] *= self.lr_scale
|
46 |
+
log.info(
|
47 |
+
f"Set LR and WD to {group['lr']}, {group['weight_decay']}")
|
48 |
+
|
49 |
+
for scheduler in state.schedulers:
|
50 |
+
scheduler.base_lrs = [
|
51 |
+
self.lr_scale * lr for lr in scheduler.base_lrs
|
52 |
+
]
|
53 |
+
|
54 |
+
|
55 |
+
class LayerFreezing(Callback):
|
56 |
+
"""LayerFreezing.
|
57 |
+
|
58 |
+
This callback can be applied upon resuming a model checkpoint. Upon
|
59 |
+
fit_start it freeze the layers specified in `layer_names`. If using
|
60 |
+
activation checkpointing, please set the
|
61 |
+
`activation_checkpointing_reentrant` flag in `fsdp_config` to false.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
layer_names (float): Names of layers to freeze.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, layer_names: List[str]):
|
68 |
+
self.layer_names = set(layer_names)
|
69 |
+
|
70 |
+
def fit_start(self, state: State, logger: Logger) -> None:
|
71 |
+
del logger # unused
|
72 |
+
|
73 |
+
model_layers = set(name for name, _ in state.model.named_parameters())
|
74 |
+
for layer in self.layer_names:
|
75 |
+
if layer not in model_layers:
|
76 |
+
raise Exception(
|
77 |
+
f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}'
|
78 |
+
)
|
79 |
+
|
80 |
+
successful_freeze = False
|
81 |
+
for name, p in state.model.named_parameters():
|
82 |
+
if p.requires_grad and name in self.layer_names:
|
83 |
+
p.requires_grad = False
|
84 |
+
log.debug(f'Froze layer: {name}\nParam: {p}')
|
85 |
+
successful_freeze = True
|
86 |
+
|
87 |
+
if not successful_freeze:
|
88 |
+
raise Exception(
|
89 |
+
f"Tried to run LayerFreezing but didn't freeze any layers")
|
Perceptrix/finetune/build/lib/llmfoundry/callbacks/scheduled_gc_callback.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import gc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from composer.core import Callback, State
|
8 |
+
from composer.loggers import Logger
|
9 |
+
|
10 |
+
|
11 |
+
def gc_cuda():
|
12 |
+
"""Garbage collect Torch (CUDA) memory."""
|
13 |
+
gc.collect()
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
torch.cuda.empty_cache()
|
16 |
+
|
17 |
+
|
18 |
+
class ScheduledGarbageCollector(Callback):
|
19 |
+
"""Disable automatic garbage collection and collect garbage at interval.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
batch_interval (int): Number of batches between checkpoints call to gc.collect()
|
23 |
+
eval_keep_disabled (bool): keep gc disabled during eval (default: False)
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
batch_interval: int,
|
29 |
+
eval_keep_disabled: bool = False,
|
30 |
+
):
|
31 |
+
self.batch_interval = batch_interval
|
32 |
+
self.eval_keep_disabled = eval_keep_disabled
|
33 |
+
self.gc_init_state = None
|
34 |
+
|
35 |
+
def fit_start(self, state: State, logger: Logger) -> None:
|
36 |
+
del state, logger # unused
|
37 |
+
|
38 |
+
# cache if automatic garbage collection is enabled; reset at fit_end
|
39 |
+
self.gc_init_state = gc.isenabled()
|
40 |
+
|
41 |
+
# disable automatic garbage collection
|
42 |
+
gc.disable()
|
43 |
+
gc_cuda()
|
44 |
+
|
45 |
+
def fit_end(self, state: State, logger: Logger) -> None:
|
46 |
+
del state, logger # unused
|
47 |
+
|
48 |
+
gc_cuda()
|
49 |
+
|
50 |
+
# reset automatic garbage collection at fit_end
|
51 |
+
if self.gc_init_state:
|
52 |
+
gc.enable()
|
53 |
+
else:
|
54 |
+
gc.disable()
|
55 |
+
|
56 |
+
def before_dataloader(self, state: State, logger: Logger) -> None:
|
57 |
+
del logger # unused
|
58 |
+
|
59 |
+
if state.timestamp.batch.value % self.batch_interval == 0:
|
60 |
+
gc_cuda()
|
61 |
+
|
62 |
+
def eval_start(self, state: State, logger: Logger) -> None:
|
63 |
+
del state, logger # unused
|
64 |
+
|
65 |
+
gc_cuda()
|
66 |
+
if not self.eval_keep_disabled:
|
67 |
+
gc.enable()
|
68 |
+
|
69 |
+
def eval_end(self, state: State, logger: Logger) -> None:
|
70 |
+
del state, logger # unused
|
71 |
+
|
72 |
+
if not self.eval_keep_disabled:
|
73 |
+
gc.disable()
|
74 |
+
|
75 |
+
gc_cuda()
|
Perceptrix/finetune/build/lib/llmfoundry/data/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
|
5 |
+
from llmfoundry.data.denoising import (MixtureOfDenoisersCollator,
|
6 |
+
build_text_denoising_dataloader)
|
7 |
+
from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator,
|
8 |
+
build_finetuning_dataloader)
|
9 |
+
from llmfoundry.data.text_data import (StreamingTextDataset,
|
10 |
+
build_text_dataloader)
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'MixtureOfDenoisersCollator',
|
14 |
+
'build_text_denoising_dataloader',
|
15 |
+
'Seq2SeqFinetuningCollator',
|
16 |
+
'build_finetuning_dataloader',
|
17 |
+
'StreamingTextDataset',
|
18 |
+
'build_text_dataloader',
|
19 |
+
'NoConcatDataset',
|
20 |
+
'ConcatTokensDataset',
|
21 |
+
]
|
Perceptrix/finetune/build/lib/llmfoundry/data/data.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Datasets for converting to MDS Shards."""
|
5 |
+
import os
|
6 |
+
import warnings
|
7 |
+
from typing import Dict, Iterable, Union
|
8 |
+
|
9 |
+
import datasets as hf_datasets
|
10 |
+
import numpy as np
|
11 |
+
from torch.utils.data import IterableDataset
|
12 |
+
from transformers import PreTrainedTokenizerBase
|
13 |
+
|
14 |
+
|
15 |
+
class NoConcatDataset(IterableDataset):
|
16 |
+
"""An IterableDataset that returns text samples for MDSWriter.
|
17 |
+
|
18 |
+
Returns dicts of {'text': bytes}
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset,
|
22 |
+
hf_datasets.Dataset]):
|
23 |
+
self.hf_dataset = hf_dataset
|
24 |
+
|
25 |
+
def __iter__(self) -> Iterable[Dict[str, bytes]]:
|
26 |
+
for sample in self.hf_dataset:
|
27 |
+
# convert to bytes to store in MDS binary format
|
28 |
+
yield {'text': sample['text'].encode('utf-8')}
|
29 |
+
|
30 |
+
|
31 |
+
class ConcatTokensDataset(IterableDataset):
|
32 |
+
"""An IterableDataset that returns token samples for MDSWriter.
|
33 |
+
|
34 |
+
Returns dicts of {'tokens': bytes}
|
35 |
+
|
36 |
+
To use data created by this class and written to MDS format:
|
37 |
+
|
38 |
+
```python
|
39 |
+
import torch
|
40 |
+
from streaming.base import StreamingDataset
|
41 |
+
from transformers import AutoTokenizer
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
|
44 |
+
ds = StreamingDataset(local='mds-data-folder', split='val')
|
45 |
+
|
46 |
+
# note, you need to copy the numpy array because the original is non-writeable
|
47 |
+
# and torch does not support non-writeable tensors, so you get a scary warning and
|
48 |
+
# if you do try to write to the tensor you get undefined behavior
|
49 |
+
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
|
50 |
+
print(tokenizer.decode(tokens))
|
51 |
+
```
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
|
57 |
+
tokenizer: PreTrainedTokenizerBase,
|
58 |
+
max_length: int,
|
59 |
+
bos_text: str,
|
60 |
+
eos_text: str,
|
61 |
+
no_wrap: bool,
|
62 |
+
):
|
63 |
+
self.hf_dataset = hf_dataset
|
64 |
+
self.tokenizer = tokenizer
|
65 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
66 |
+
self.max_length = max_length
|
67 |
+
self.bos_text = bos_text
|
68 |
+
self.eos_text = eos_text
|
69 |
+
self.should_wrap = not no_wrap
|
70 |
+
|
71 |
+
self.bos_tokens = self.tokenizer(self.bos_text,
|
72 |
+
truncation=False,
|
73 |
+
padding=False,
|
74 |
+
add_special_tokens=False)['input_ids']
|
75 |
+
if len(self.bos_tokens) > 1:
|
76 |
+
warnings.warn(
|
77 |
+
f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token\
|
78 |
+
, instead we got {self.bos_tokens}. Quit if this was in error.')
|
79 |
+
|
80 |
+
self.eos_tokens = self.tokenizer(self.eos_text,
|
81 |
+
truncation=False,
|
82 |
+
padding=False,
|
83 |
+
add_special_tokens=False)['input_ids']
|
84 |
+
if len(self.eos_tokens) > 1:
|
85 |
+
warnings.warn(
|
86 |
+
f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\
|
87 |
+
, instead we got {self.eos_tokens}. Quit if this was in error.')
|
88 |
+
|
89 |
+
eos_text_provided = self.eos_text != ''
|
90 |
+
bos_text_provided = self.bos_text != ''
|
91 |
+
test_text = self.tokenizer('')
|
92 |
+
if len(test_text['input_ids']) > 0 and (eos_text_provided or
|
93 |
+
bos_text_provided):
|
94 |
+
message = 'both eos and bos' if eos_text_provided and bos_text_provided else (
|
95 |
+
'eos_text' if eos_text_provided else 'bos_text')
|
96 |
+
warnings.warn(
|
97 |
+
f'The provided tokenizer adds special tokens, but you also specified {message}. This may result '
|
98 |
+
+
|
99 |
+
'in duplicated special tokens. Please be sure this is what you intend.'
|
100 |
+
)
|
101 |
+
|
102 |
+
def __iter__(self) -> Iterable[Dict[str, bytes]]:
|
103 |
+
|
104 |
+
buffer = []
|
105 |
+
for sample in self.hf_dataset:
|
106 |
+
encoded = self.tokenizer(sample['text'],
|
107 |
+
truncation=False,
|
108 |
+
padding=False)
|
109 |
+
iids = encoded['input_ids']
|
110 |
+
buffer = buffer + self.bos_tokens + iids + self.eos_tokens
|
111 |
+
while len(buffer) >= self.max_length:
|
112 |
+
concat_sample = buffer[:self.max_length]
|
113 |
+
buffer = buffer[self.max_length:] if self.should_wrap else []
|
114 |
+
yield {
|
115 |
+
# convert to bytes to store in MDS binary format
|
116 |
+
'tokens': np.asarray(concat_sample).tobytes()
|
117 |
+
}
|
Perceptrix/finetune/build/lib/llmfoundry/data/denoising.py
ADDED
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Streaming dataloader for (mixture of) denoising task(s)."""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from omegaconf import DictConfig
|
14 |
+
from omegaconf import OmegaConf as om
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from transformers import PreTrainedTokenizerBase
|
17 |
+
|
18 |
+
from llmfoundry.data.packing import BinPackWrapper
|
19 |
+
from llmfoundry.data.text_data import StreamingTextDataset
|
20 |
+
from llmfoundry.models import utils
|
21 |
+
|
22 |
+
__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
|
23 |
+
|
24 |
+
log = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
# HuggingFace hardcodes the ignore index to -100
|
27 |
+
_HF_IGNORE_INDEX = -100
|
28 |
+
|
29 |
+
# Required signature of any `prefix_function` (see below)
|
30 |
+
PREFIX_FUNCTION = Callable[[float, Optional[float], PreTrainedTokenizerBase],
|
31 |
+
Sequence[int]]
|
32 |
+
|
33 |
+
|
34 |
+
def ul2_prefix_function(
|
35 |
+
mask_ratio: float,
|
36 |
+
mean_length: Optional[float],
|
37 |
+
tokenizer: PreTrainedTokenizerBase,
|
38 |
+
) -> Sequence[int]:
|
39 |
+
"""Generates prefixes based on UL2 paper.
|
40 |
+
|
41 |
+
See: http://arxiv.org/abs/2205.05131
|
42 |
+
"""
|
43 |
+
if mean_length is None:
|
44 |
+
# This is the case for "sequence to sequence"
|
45 |
+
prefix = '[S2S]' if mask_ratio < 1.0 else '[CLM]'
|
46 |
+
elif mean_length >= 12 or mask_ratio >= 0.3:
|
47 |
+
# UL2 tags this corruption rate "extreme"
|
48 |
+
prefix = '[NLG]'
|
49 |
+
else:
|
50 |
+
# UL2 tags this corruption rate as "regular"
|
51 |
+
prefix = '[NLU]'
|
52 |
+
return tokenizer(prefix, add_special_tokens=False).input_ids
|
53 |
+
|
54 |
+
|
55 |
+
class MixtureOfDenoisersCollator:
|
56 |
+
"""Data collator for mixture of span-corruption denoisers, as in UL2.
|
57 |
+
|
58 |
+
This collator supports a variety of tasks used to pre-train an
|
59 |
+
encoder-decoder model or a (prefix LM) decoder-only model. This is meant
|
60 |
+
to be used with a dataset that yields tokenized text sequences. It is not
|
61 |
+
required that the token sequences are already padded or truncate, as this
|
62 |
+
collator will internally truncate and pad as needed.
|
63 |
+
|
64 |
+
For the denoising mixture recommended in the original UL2 paper,
|
65 |
+
http://arxiv.org/abs/2205.05131, use:
|
66 |
+
.. python:
|
67 |
+
MixtureOfDenoisersCollator(
|
68 |
+
...,
|
69 |
+
span_mean_lengths_and_ratios=[
|
70 |
+
[3, .15],
|
71 |
+
[8, .15],
|
72 |
+
[3, .50],
|
73 |
+
[8, .50],
|
74 |
+
[64, .15],
|
75 |
+
[64, .50],
|
76 |
+
],
|
77 |
+
sequence_mask_ratios=0.25
|
78 |
+
)
|
79 |
+
|
80 |
+
Args:
|
81 |
+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
|
82 |
+
prepare the data from raw text. Any missing sentinel tokens will
|
83 |
+
be added by the collator.
|
84 |
+
max_seq_length (int): The maximum length of sequences produced by this
|
85 |
+
collator. Incoming sequences may be truncated to accommodate this
|
86 |
+
limit.
|
87 |
+
Note that when formatting for decoder-only models, the context
|
88 |
+
tokens and target tokens are concatenated, and max_seq_length
|
89 |
+
applies to their combined length. For encoder-decoder models, both
|
90 |
+
the encoder and decoder will see up to max_seq_length tokens.
|
91 |
+
decoder_only_format (bool, optional): Whether to format the batches
|
92 |
+
for a decoder-only model (i.e. a prefix LM) or, if ``False``, an
|
93 |
+
encoder-decoder model. Default: ``False``.
|
94 |
+
span_mean_lengths_and_rations (optional): A length-2 list of a
|
95 |
+
``[mean_length, mask_ratio]`` pair, or a list of such pairs. Each
|
96 |
+
pair adds a span corruption denoising task to the task mixture. For
|
97 |
+
example, ``[3, 0.15]`` adds the original span corruption task used
|
98 |
+
for pre-training a T5 model as in http://arxiv.org/abs/1910.10683,
|
99 |
+
which trained with a single span corruption task that used a mean
|
100 |
+
span length of 3 and a mask ratio of 15%.
|
101 |
+
Default: ``None`` does not add any span corruption tasks.
|
102 |
+
sequence_mask_ratios (optional): A float or list of floats, one for each
|
103 |
+
sequence corruption denoising task to add to the task mixture. Each
|
104 |
+
sequence mask ratio must be greater than 0.0 and at most 1.0.
|
105 |
+
This type of task is a special instance of span corruption, with
|
106 |
+
exactly one masked span take from the end of the sequence. The
|
107 |
+
length of the span is sampled uniformly such that the average
|
108 |
+
portion of masked tokens equals sequence_mask_ratio.
|
109 |
+
Note: A value of 1.0 essentially yields causal LM examples.
|
110 |
+
Default: ``None` does not add any sequence corruption tasks.
|
111 |
+
allow_pad_trimming (bool, optional): Whether to allow the collator to
|
112 |
+
trim away sequence regions that are entirely padding (i.e. padding
|
113 |
+
for each example in the batch). If ``True``, shorter sequences may
|
114 |
+
improve throughput but at a potentially higher memory cost
|
115 |
+
owing to variable sequence lengths from batch to batch.
|
116 |
+
Default: ``False`` yields batches that are always padded to
|
117 |
+
max_seq_length.
|
118 |
+
prefix_function (callable, optional): A function that maps denoising
|
119 |
+
task parameters (e.g. mean_length=3, mask_ratio=0.15) to a prefix
|
120 |
+
that will be added to sequences when the associated "noiser" is
|
121 |
+
applied.
|
122 |
+
To disable these prefixes, use a value of ``None``.
|
123 |
+
Default: :func:`ul2_prefix_function` applies the prefix scheme
|
124 |
+
suggested in the UL2 paper: http://arxiv.org/abs/2205.05131.
|
125 |
+
context_eos (bool, optional): Whether to attach an EOS token to the end
|
126 |
+
of the context sequence, marking the transition from context to
|
127 |
+
target sequence. Only applicable if decoder_only_format is True.
|
128 |
+
Context EOS tokens are always added for encoder-decoder format.
|
129 |
+
Default: ``False`` does not attach context EOS.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
tokenizer: PreTrainedTokenizerBase,
|
135 |
+
max_seq_length: int,
|
136 |
+
decoder_only_format: bool = False,
|
137 |
+
span_mean_lengths_and_ratios: Optional[List] = None,
|
138 |
+
sequence_mask_ratios: Optional[Union[List[float], float]] = None,
|
139 |
+
allow_pad_trimming: bool = False,
|
140 |
+
prefix_function: Optional[PREFIX_FUNCTION] = ul2_prefix_function,
|
141 |
+
context_eos: Optional[bool] = None,
|
142 |
+
):
|
143 |
+
# Prepare the tokenizer for denoising tasks
|
144 |
+
utils.adapt_tokenizer_for_denoising(tokenizer)
|
145 |
+
|
146 |
+
self.tokenizer = tokenizer
|
147 |
+
self.max_seq_length = max_seq_length
|
148 |
+
self.decoder_only_format = decoder_only_format
|
149 |
+
self._sentinel_token_ids = np.array(self.tokenizer.sentinel_token_ids)
|
150 |
+
|
151 |
+
# Trimming will always be skipped on at least the first __call__
|
152 |
+
self._allow_pad_trimming = allow_pad_trimming
|
153 |
+
self._seen_first_batch = False
|
154 |
+
|
155 |
+
self.context_eos = bool(context_eos) if decoder_only_format else True
|
156 |
+
|
157 |
+
# Process the span_mean_lengths_and_ratios argument
|
158 |
+
if span_mean_lengths_and_ratios is None:
|
159 |
+
# In this case, there are no span corruption tasks
|
160 |
+
self.span_mean_lengths_and_ratios = []
|
161 |
+
elif isinstance(span_mean_lengths_and_ratios[0], (int, float)):
|
162 |
+
# In this case, there is one span corruption task
|
163 |
+
if not len(span_mean_lengths_and_ratios) == 2:
|
164 |
+
raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
|
165 |
+
'pair of [mean_length, mask_ratio], a list ' + \
|
166 |
+
f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
|
167 |
+
self.span_mean_lengths_and_ratios = [span_mean_lengths_and_ratios]
|
168 |
+
else:
|
169 |
+
# In this case, there are one or more span corruption tasks
|
170 |
+
span_mean_lengths_and_ratios = list(span_mean_lengths_and_ratios)
|
171 |
+
for spec_pair in span_mean_lengths_and_ratios:
|
172 |
+
if len(spec_pair) != 2:
|
173 |
+
raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
|
174 |
+
'pair of [mean_length, mask_ratio], a list ' + \
|
175 |
+
f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
|
176 |
+
self.span_mean_lengths_and_ratios = span_mean_lengths_and_ratios
|
177 |
+
|
178 |
+
# Process the sequence_mask_ratios argument
|
179 |
+
if sequence_mask_ratios is None:
|
180 |
+
# In this case, there are no sequence corruption tasks
|
181 |
+
self.sequence_mask_ratios = []
|
182 |
+
elif isinstance(sequence_mask_ratios, float):
|
183 |
+
# In this case, there is one sequence corruption task
|
184 |
+
self.sequence_mask_ratios = [sequence_mask_ratios]
|
185 |
+
else:
|
186 |
+
# In this case, there is one or more sequence corruption tasks
|
187 |
+
for ratio in sequence_mask_ratios:
|
188 |
+
if not (0 < ratio <= 1.0):
|
189 |
+
raise ValueError('`sequence_mask_ratios` must be a float (or list '+\
|
190 |
+
'of floats) that are each >0.0 and <=1.0, or None. '+\
|
191 |
+
f'Got {sequence_mask_ratios}.')
|
192 |
+
self.sequence_mask_ratios = sequence_mask_ratios
|
193 |
+
|
194 |
+
# Populate the noisers so we can learn to denoise them!
|
195 |
+
self._noisers = []
|
196 |
+
self._smallest_max_raw_length = self.max_seq_length * 100
|
197 |
+
self._largest_max_raw_length = 0
|
198 |
+
self._uses_span_corruption = False
|
199 |
+
|
200 |
+
# Add "noisers" for any span corruption denoising tasks
|
201 |
+
# Each mean_length / mask_ratio combo becomes one of the span
|
202 |
+
# corruption denoising tasks
|
203 |
+
for span_mean_length, span_mask_ratio in self.span_mean_lengths_and_ratios:
|
204 |
+
self._uses_span_corruption = True
|
205 |
+
if span_mean_length < 0:
|
206 |
+
raise ValueError('All span mean lengths must be positive.')
|
207 |
+
if not 0 < span_mask_ratio < 1.0:
|
208 |
+
raise ValueError(
|
209 |
+
'All span masking ratios must be between 0.0 and 1.0.')
|
210 |
+
|
211 |
+
if prefix_function is not None:
|
212 |
+
prefix_tokens = prefix_function(span_mask_ratio,
|
213 |
+
span_mean_length,
|
214 |
+
self.tokenizer)
|
215 |
+
else:
|
216 |
+
prefix_tokens = None
|
217 |
+
|
218 |
+
max_raw_length = _get_max_starting_length(
|
219 |
+
max_length=self.max_seq_length,
|
220 |
+
mask_ratio=span_mask_ratio,
|
221 |
+
mean_span_length=span_mean_length,
|
222 |
+
n_prefix_tokens=len(prefix_tokens or []),
|
223 |
+
decoder_only_format=self.decoder_only_format,
|
224 |
+
context_eos=self.context_eos)
|
225 |
+
if max_raw_length < self._smallest_max_raw_length:
|
226 |
+
self._smallest_max_raw_length = max_raw_length
|
227 |
+
if max_raw_length > self._largest_max_raw_length:
|
228 |
+
self._largest_max_raw_length = max_raw_length
|
229 |
+
|
230 |
+
kwargs = {
|
231 |
+
'mean_span_length': span_mean_length,
|
232 |
+
'mask_ratio': span_mask_ratio,
|
233 |
+
'prefix_tokens': prefix_tokens,
|
234 |
+
'max_raw_length': max_raw_length,
|
235 |
+
}
|
236 |
+
self._noisers.append(kwargs)
|
237 |
+
|
238 |
+
# Add "noisers" for any sequential denoising tasks
|
239 |
+
for sequence_mask_ratio in self.sequence_mask_ratios:
|
240 |
+
if prefix_function is not None:
|
241 |
+
prefix_tokens = prefix_function(sequence_mask_ratio, None,
|
242 |
+
self.tokenizer)
|
243 |
+
else:
|
244 |
+
prefix_tokens = None
|
245 |
+
|
246 |
+
max_raw_length = self.max_seq_length - len(prefix_tokens or []) - 1
|
247 |
+
if decoder_only_format and self.context_eos:
|
248 |
+
max_raw_length = max_raw_length - 1
|
249 |
+
|
250 |
+
if not self._uses_span_corruption and (
|
251 |
+
max_raw_length < self._smallest_max_raw_length):
|
252 |
+
# We choose not to count sequence denoising in the smallest
|
253 |
+
# unless there is only sequence denoising.
|
254 |
+
self._smallest_max_raw_length = max_raw_length
|
255 |
+
if max_raw_length > self._largest_max_raw_length:
|
256 |
+
self._largest_max_raw_length = max_raw_length
|
257 |
+
|
258 |
+
kwargs = {
|
259 |
+
'mean_span_length': None,
|
260 |
+
'mask_ratio': sequence_mask_ratio,
|
261 |
+
'prefix_tokens': prefix_tokens,
|
262 |
+
'max_raw_length': max_raw_length,
|
263 |
+
}
|
264 |
+
self._noisers.append(kwargs)
|
265 |
+
|
266 |
+
if not self._noisers:
|
267 |
+
raise ValueError(
|
268 |
+
'No denoising tasks were included. Make sure to set ' + \
|
269 |
+
'`span_mean_lengths_and_ratios` and/or `sequence_mask_ratios`.')
|
270 |
+
|
271 |
+
@property
|
272 |
+
def smallest_max_raw_length(self) -> int:
|
273 |
+
return int(self._smallest_max_raw_length)
|
274 |
+
|
275 |
+
@property
|
276 |
+
def largest_max_raw_length(self) -> int:
|
277 |
+
return int(self._largest_max_raw_length)
|
278 |
+
|
279 |
+
def __call__(self, examples: List[Dict[str,
|
280 |
+
Any]]) -> Dict[str, torch.Tensor]:
|
281 |
+
"""Batch examples processed by the span corrupter."""
|
282 |
+
processed_examples = []
|
283 |
+
for example in examples:
|
284 |
+
# Randomly pick a "noiser" to apply to this example
|
285 |
+
noiser = random.choice(self._noisers)
|
286 |
+
# Apply it
|
287 |
+
processed_examples.append(
|
288 |
+
noise_token_sequence(
|
289 |
+
example,
|
290 |
+
mask_ratio=noiser['mask_ratio'],
|
291 |
+
mean_span_length=noiser['mean_span_length'],
|
292 |
+
prefix_tokens=noiser['prefix_tokens'],
|
293 |
+
max_raw_length=noiser['max_raw_length'],
|
294 |
+
max_seq_length=self.max_seq_length,
|
295 |
+
tokenizer=self.tokenizer,
|
296 |
+
sentinel_token_ids=self._sentinel_token_ids,
|
297 |
+
decoder_only_format=self.decoder_only_format,
|
298 |
+
context_eos=self.context_eos))
|
299 |
+
batch = self.tokenizer.pad(processed_examples)
|
300 |
+
|
301 |
+
# This logic prevents trimming on at least the first batch
|
302 |
+
if not (self._allow_pad_trimming and self._seen_first_batch):
|
303 |
+
self._seen_first_batch = True
|
304 |
+
return batch
|
305 |
+
self._seen_first_batch = True
|
306 |
+
|
307 |
+
# Truncate portions of the inputs that are purely padding
|
308 |
+
# (up to a multiple of 8)
|
309 |
+
multiple_of = 8
|
310 |
+
n_examples_per_length = batch['attention_mask'].sum(0)
|
311 |
+
keep_tokens = torch.sum(n_examples_per_length > 0)
|
312 |
+
keep_tokens = int(multiple_of * torch.ceil(keep_tokens / multiple_of))
|
313 |
+
|
314 |
+
# Note: EncDec formatting will always produce a right-padded batch
|
315 |
+
if self.tokenizer.padding_side == 'left' and self.decoder_only_format:
|
316 |
+
batch['input_ids'] = batch['input_ids'][:, -keep_tokens:]
|
317 |
+
batch['attention_mask'] = batch['attention_mask'][:, -keep_tokens:]
|
318 |
+
else:
|
319 |
+
batch['input_ids'] = batch['input_ids'][:, :keep_tokens]
|
320 |
+
batch['attention_mask'] = batch['attention_mask'][:, :keep_tokens]
|
321 |
+
|
322 |
+
if self.decoder_only_format:
|
323 |
+
if self.tokenizer.padding_side == 'left':
|
324 |
+
batch['labels'] = batch['labels'][:, -keep_tokens:]
|
325 |
+
batch['bidirectional_mask'] = batch[
|
326 |
+
'bidirectional_mask'][:, -keep_tokens:]
|
327 |
+
else:
|
328 |
+
batch['labels'] = batch['labels'][:, :keep_tokens]
|
329 |
+
batch['bidirectional_mask'] = batch[
|
330 |
+
'bidirectional_mask'][:, :keep_tokens]
|
331 |
+
|
332 |
+
else:
|
333 |
+
# Truncate portions of the decoder inputs that are purely padding
|
334 |
+
n_examples_per_length = batch['decoder_attention_mask'].sum(0)
|
335 |
+
keep_tokens = torch.sum(n_examples_per_length > 0)
|
336 |
+
keep_tokens = int(multiple_of *
|
337 |
+
torch.ceil(keep_tokens / multiple_of))
|
338 |
+
|
339 |
+
batch['labels'] = batch['labels'][:, :keep_tokens]
|
340 |
+
batch['decoder_attention_mask'] = batch[
|
341 |
+
'decoder_attention_mask'][:, :keep_tokens]
|
342 |
+
batch['decoder_input_ids'] = batch[
|
343 |
+
'decoder_input_ids'][:, :keep_tokens]
|
344 |
+
|
345 |
+
# This slicing can produce non-contiguous tensors, so use .contiguous
|
346 |
+
# to prevent related problems
|
347 |
+
batch = {k: v.contiguous() for k, v in batch.items()}
|
348 |
+
|
349 |
+
return batch
|
350 |
+
|
351 |
+
|
352 |
+
def build_text_denoising_dataloader(
|
353 |
+
cfg: DictConfig,
|
354 |
+
tokenizer: PreTrainedTokenizerBase,
|
355 |
+
device_batch_size: int,
|
356 |
+
) -> DataLoader[Dict]:
|
357 |
+
"""Constructor function for a Mixture of Denoisers dataloader.
|
358 |
+
|
359 |
+
This function constructs a dataloader that can be used to train an
|
360 |
+
encoder-decoder model or a (prefix LM) decoder-only model on a text
|
361 |
+
denoising task mixture (e.g. span corruption, or UL2).
|
362 |
+
|
363 |
+
The underlying dataset is a :class:`StreamingTextDataset`, allowing you to
|
364 |
+
stream raw text data or pre-tokenized text data.
|
365 |
+
|
366 |
+
The dataloader uses a :class:`MixtureOfDenoisersCollator` to prepare the
|
367 |
+
tokenized examples into training batches.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
cfg (DictConfig): An omegaconf dictionary used to configure the loader:
|
371 |
+
cfg.name (str): The type of dataloader to build. Must = "text_denoising".
|
372 |
+
---
|
373 |
+
cfg.dataset.max_seq_len (int): The maximum length of sequences
|
374 |
+
in the batch. See :class:`MixtureOfDenoisersCollator` docstring
|
375 |
+
for details.
|
376 |
+
cfg.dataset.packing_ratio (float, optional): If provided, this invokes
|
377 |
+
a collator wrapper that packs device_batch_size*packing_ratio
|
378 |
+
raw examples into device_batch_size packed examples. This helps
|
379 |
+
minimize padding while preserving sequence integrity.
|
380 |
+
This adds `sequence_id` to the batch, which indicates which unique
|
381 |
+
sequence each token belongs to.
|
382 |
+
Note: Using this feature will not change device_batch_size but it
|
383 |
+
will determine the number of raw examples consumed by the dataloader
|
384 |
+
per batch. Some examples may be discarded if they do not fit when
|
385 |
+
packing.
|
386 |
+
Select packing_ratio **carefully** based on the dataset
|
387 |
+
statistics, max_seq_len, and tolerance for discarding samples!
|
388 |
+
The packing code in `./packing.py` provides a script that can help
|
389 |
+
you choose the best packing_ratio.
|
390 |
+
See :class:`StreamingTextDataset` for info on other standard config
|
391 |
+
options within `cfg.dataset`.
|
392 |
+
---
|
393 |
+
cfg.mixture_of_denoisers.decoder_only_format (bool): Whether the
|
394 |
+
batches should use the format required for training a decoder-only
|
395 |
+
model (if ``True``) or an encoder-decoder model (if ``False``).
|
396 |
+
cfg.mixture_of_denoisers.span_mean_lengths_and_ratios (optional): The
|
397 |
+
parameters for any span corruption denoising tasks to include in
|
398 |
+
the task mixture.
|
399 |
+
See :class:`MixtureOfDenoisersCollator` docstring for details.
|
400 |
+
cfg.mixture_of_denoisers.sequence_mask_ratios (optional): The
|
401 |
+
parameters for any sequence denoising tasks to include in the
|
402 |
+
task mixture.
|
403 |
+
See :class:`MixtureOfDenoisersCollator` docstring for details.
|
404 |
+
cfg.mixture_of_denoisers.allow_pad_trimming (optional): Whether to
|
405 |
+
allow the collator to trim padding when possible (if ``True``).
|
406 |
+
Defaults to ``False``.
|
407 |
+
cfg.mixture_of_denoisers.prefix_function (optional): Set to ``None``
|
408 |
+
to disable the UL2-style prefixes that will be automatically
|
409 |
+
added by default.
|
410 |
+
---
|
411 |
+
See :class:`DataLoader` for standard argument options to the pytorch
|
412 |
+
dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
|
413 |
+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
|
414 |
+
prepare the data from raw text. Any missing sentinel tokens will
|
415 |
+
be added by the collator.
|
416 |
+
device_batch_size (int): The size of the batches (number of examples)
|
417 |
+
that the dataloader will produce.
|
418 |
+
|
419 |
+
Note:
|
420 |
+
You can run the script inside `./packing.py` to quickly test the
|
421 |
+
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
|
422 |
+
given a starting workload YAML.
|
423 |
+
"""
|
424 |
+
assert cfg.name == 'text_denoising', f'Tried to build_denoising text dataloader with cfg.name={cfg.name}'
|
425 |
+
|
426 |
+
collate_fn = MixtureOfDenoisersCollator(
|
427 |
+
tokenizer=tokenizer,
|
428 |
+
max_seq_length=cfg.dataset.max_seq_len,
|
429 |
+
decoder_only_format=cfg.mixture_of_denoisers.decoder_only_format,
|
430 |
+
span_mean_lengths_and_ratios=cfg.mixture_of_denoisers.get(
|
431 |
+
'span_mean_lengths_and_ratios'),
|
432 |
+
sequence_mask_ratios=cfg.mixture_of_denoisers.get(
|
433 |
+
'sequence_mask_ratios'),
|
434 |
+
allow_pad_trimming=cfg.mixture_of_denoisers.get('allow_pad_trimming',
|
435 |
+
False),
|
436 |
+
prefix_function=cfg.mixture_of_denoisers.get('prefix_function',
|
437 |
+
ul2_prefix_function),
|
438 |
+
context_eos=cfg.mixture_of_denoisers.get('context_eos'))
|
439 |
+
|
440 |
+
truncate_to = cfg.mixture_of_denoisers.get('truncate_raw_tokens_to')
|
441 |
+
if truncate_to is None:
|
442 |
+
# By default, truncate to the largest max raw length of the denoisers
|
443 |
+
truncate_to = collate_fn.largest_max_raw_length
|
444 |
+
elif isinstance(truncate_to, str):
|
445 |
+
if truncate_to.lower() == 'min':
|
446 |
+
# Truncate to the smallest max raw length of the denoisers
|
447 |
+
truncate_to = collate_fn.smallest_max_raw_length
|
448 |
+
elif truncate_to.lower() == 'max':
|
449 |
+
# Truncate to the largest max raw length of the denoisers
|
450 |
+
truncate_to = collate_fn.largest_max_raw_length
|
451 |
+
else:
|
452 |
+
raise ValueError(
|
453 |
+
f'truncate_raw_tokens_to(="{truncate_to.lower()}") must be "min", "max", a positive int, or None.'
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
if not isinstance(truncate_to, int):
|
457 |
+
ValueError(
|
458 |
+
f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
|
459 |
+
)
|
460 |
+
if truncate_to < 0:
|
461 |
+
ValueError(
|
462 |
+
f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
|
463 |
+
)
|
464 |
+
|
465 |
+
dataset = StreamingTextDataset(
|
466 |
+
local=cfg.dataset.local,
|
467 |
+
tokenizer=tokenizer,
|
468 |
+
max_seq_len=truncate_to,
|
469 |
+
remote=cfg.dataset.get('remote'),
|
470 |
+
split=cfg.dataset.get('split'),
|
471 |
+
shuffle=cfg.dataset.get('shuffle', False),
|
472 |
+
predownload=cfg.dataset.get('predownload', 100_000),
|
473 |
+
keep_zip=cfg.dataset.get('keep_zip', False),
|
474 |
+
download_retry=cfg.dataset.get('download_retry', 2),
|
475 |
+
download_timeout=cfg.dataset.get('download_timeout', 60),
|
476 |
+
validate_hash=cfg.dataset.get('validate_hash'),
|
477 |
+
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
|
478 |
+
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', 128),
|
479 |
+
batch_size=device_batch_size,
|
480 |
+
)
|
481 |
+
|
482 |
+
if dataset.tokenizer.pad_token is None:
|
483 |
+
dataset.tokenizer.pad_token = dataset.tokenizer.eos_token
|
484 |
+
|
485 |
+
if cfg.dataset.get('packing_ratio'):
|
486 |
+
n_examples_to_pack = int(device_batch_size * cfg.dataset.packing_ratio)
|
487 |
+
if n_examples_to_pack < device_batch_size:
|
488 |
+
raise ValueError('packing_ratio must be >= 1, if supplied')
|
489 |
+
if not cfg.mixture_of_denoisers.decoder_only_format:
|
490 |
+
raise NotImplementedError(
|
491 |
+
'On-the-fly packing is currently only supported for decoder-only formats.'
|
492 |
+
)
|
493 |
+
collate_fn = BinPackWrapper(
|
494 |
+
collator=collate_fn,
|
495 |
+
target_batch_size=device_batch_size,
|
496 |
+
max_seq_len=cfg.dataset.max_seq_len,
|
497 |
+
pad_token_id=dataset.tokenizer.pad_token_id,
|
498 |
+
padding_side=dataset.tokenizer.padding_side,
|
499 |
+
max_leftover_bins_to_keep=cfg.dataset.get(
|
500 |
+
'max_leftover_bins_to_keep'),
|
501 |
+
)
|
502 |
+
device_batch_size = n_examples_to_pack
|
503 |
+
elif cfg.dataset.get('max_leftover_bins_to_keep') is not None:
|
504 |
+
raise ValueError(
|
505 |
+
'cfg.dataset.max_leftover_bins_to_keep has been defined, ' +\
|
506 |
+
'but cfg.dataset.packing_ratio has not been set. Please set ' +\
|
507 |
+
'the latter to turn on packing or remove the former from the config.')
|
508 |
+
|
509 |
+
return DataLoader(
|
510 |
+
dataset,
|
511 |
+
collate_fn=collate_fn,
|
512 |
+
batch_size=device_batch_size,
|
513 |
+
drop_last=cfg.drop_last,
|
514 |
+
num_workers=cfg.num_workers,
|
515 |
+
pin_memory=cfg.get('pin_memory', True),
|
516 |
+
prefetch_factor=cfg.get('prefetch_factor', 2),
|
517 |
+
persistent_workers=cfg.get('persistent_workers', False),
|
518 |
+
timeout=cfg.get('timeout', 0),
|
519 |
+
)
|
520 |
+
|
521 |
+
|
522 |
+
def noise_token_sequence(
|
523 |
+
example: Union[torch.Tensor, Mapping[str, Any]],
|
524 |
+
mask_ratio: float,
|
525 |
+
mean_span_length: Optional[float],
|
526 |
+
prefix_tokens: Optional[Sequence[int]],
|
527 |
+
max_raw_length: int,
|
528 |
+
max_seq_length: int,
|
529 |
+
tokenizer: PreTrainedTokenizerBase,
|
530 |
+
sentinel_token_ids: np.ndarray,
|
531 |
+
decoder_only_format: bool,
|
532 |
+
context_eos: bool,
|
533 |
+
) -> Dict[str, torch.Tensor]:
|
534 |
+
"""Span corruption applicable to all UL2 denoising tasks."""
|
535 |
+
# Extract the raw text tokens (trim if we need to)
|
536 |
+
if isinstance(example, torch.Tensor):
|
537 |
+
# If the example is a tensor, assume is the raw tokens with no padding
|
538 |
+
tokens = example
|
539 |
+
length = len(tokens)
|
540 |
+
else:
|
541 |
+
tokens = example['input_ids']
|
542 |
+
length = sum(example['attention_mask'])
|
543 |
+
if length > max_raw_length:
|
544 |
+
length = max_raw_length
|
545 |
+
if tokenizer.padding_side == 'left':
|
546 |
+
tokens = tokens[-length:]
|
547 |
+
else:
|
548 |
+
tokens = tokens[:length]
|
549 |
+
|
550 |
+
prefix_tokens = prefix_tokens or []
|
551 |
+
|
552 |
+
if length < 1:
|
553 |
+
raise ValueError('Example cannot be empty but token length <1.')
|
554 |
+
|
555 |
+
# mean_span_length==None is a special case for "sequential" denoising
|
556 |
+
# (where a single span at the end of the sequence is masked)
|
557 |
+
if mean_span_length is None:
|
558 |
+
# This ensures that exactly 1 span will be produced and that
|
559 |
+
# trimming to max_seq_length won't cut off any <EOS> token.
|
560 |
+
# In the decoder-only case, this won't insert new tokens.
|
561 |
+
if mask_ratio <= 0.5:
|
562 |
+
u = np.random.uniform(low=0.0, high=mask_ratio * 2)
|
563 |
+
else:
|
564 |
+
u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0)
|
565 |
+
mean_span_length = float(np.round(1 + u * (length - 1)))
|
566 |
+
mask_ratio = mean_span_length / length
|
567 |
+
use_sentinels = False
|
568 |
+
else:
|
569 |
+
use_sentinels = True
|
570 |
+
|
571 |
+
# Generate the mask
|
572 |
+
# Note: this function can be used for all the UL2 noising functions
|
573 |
+
mask = _sample_mask_array(length, mask_ratio, mean_span_length)
|
574 |
+
# The sequence should always be unmasked at the beginning
|
575 |
+
assert mask[0] == 0
|
576 |
+
|
577 |
+
# Generate the input/label sequences given the raw tokens and the mask
|
578 |
+
tokens_inputs = _apply_mask(tokens,
|
579 |
+
mask,
|
580 |
+
use_sentinels,
|
581 |
+
tokenizer.eos_token_id,
|
582 |
+
sentinel_token_ids,
|
583 |
+
ensure_eos=context_eos)
|
584 |
+
tokens_labels = _apply_mask(tokens,
|
585 |
+
1 - mask,
|
586 |
+
use_sentinels,
|
587 |
+
tokenizer.eos_token_id,
|
588 |
+
sentinel_token_ids,
|
589 |
+
ensure_eos=True)
|
590 |
+
|
591 |
+
# Tag the inputs with any prefix
|
592 |
+
if prefix_tokens:
|
593 |
+
tokens_inputs = np.concatenate([prefix_tokens, tokens_inputs])
|
594 |
+
|
595 |
+
# Trim if necessary
|
596 |
+
if len(tokens_inputs) > max_seq_length:
|
597 |
+
raise ValueError('This should not exceed the max length')
|
598 |
+
if len(tokens_labels) > max_seq_length:
|
599 |
+
raise ValueError('This should not exceed the max length')
|
600 |
+
|
601 |
+
tokens_inputs = torch.LongTensor(tokens_inputs)
|
602 |
+
tokens_labels = torch.LongTensor(tokens_labels)
|
603 |
+
|
604 |
+
if decoder_only_format:
|
605 |
+
return _format_tokens_for_decoder_only(tokens_inputs, tokens_labels,
|
606 |
+
max_seq_length,
|
607 |
+
tokenizer.pad_token_id,
|
608 |
+
tokenizer.padding_side)
|
609 |
+
return _format_tokens_for_encoder_decoder(tokens_inputs, tokens_labels,
|
610 |
+
max_seq_length,
|
611 |
+
tokenizer.pad_token_id)
|
612 |
+
|
613 |
+
|
614 |
+
def _get_max_starting_length(max_length: int, mask_ratio: float,
|
615 |
+
mean_span_length: float, n_prefix_tokens: int,
|
616 |
+
decoder_only_format: bool,
|
617 |
+
context_eos: bool) -> int:
|
618 |
+
"""Get max num raw tokens that will fit max_length."""
|
619 |
+
|
620 |
+
def sequence_stats(length: int):
|
621 |
+
length = np.maximum(length, 2)
|
622 |
+
num_noise_tokens = int(np.round(mask_ratio * float(length)))
|
623 |
+
num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1),
|
624 |
+
length - 1)
|
625 |
+
num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
|
626 |
+
num_noise_spans = np.maximum(num_spans, 1)
|
627 |
+
num_nonnoise_tokens = length - num_noise_tokens
|
628 |
+
# Prefix, sentinel, and EOS added to input for Enc-Dec
|
629 |
+
extra_inp_tokens = n_prefix_tokens + num_noise_spans + int(context_eos)
|
630 |
+
# Sentinel and EOS added to target
|
631 |
+
extra_targ_tokens = num_noise_spans + 1
|
632 |
+
# Sequence totals after corruption
|
633 |
+
total_inp_tokens = num_nonnoise_tokens + extra_inp_tokens
|
634 |
+
total_targ_tokens = num_noise_tokens + extra_targ_tokens
|
635 |
+
return total_inp_tokens, total_targ_tokens
|
636 |
+
|
637 |
+
def length_fits(length: int) -> bool:
|
638 |
+
total_inp_tokens, total_targ_tokens = sequence_stats(length)
|
639 |
+
if decoder_only_format:
|
640 |
+
return (total_inp_tokens + total_targ_tokens) <= max_length
|
641 |
+
return (total_inp_tokens <= max_length) and (total_targ_tokens <=
|
642 |
+
max_length)
|
643 |
+
|
644 |
+
# Start with a definitely too-long sequence and reduce until it fits
|
645 |
+
num_raw_tokens = max_length * 2
|
646 |
+
while num_raw_tokens > 0:
|
647 |
+
if length_fits(num_raw_tokens):
|
648 |
+
return num_raw_tokens
|
649 |
+
num_raw_tokens -= 1
|
650 |
+
raise ValueError(
|
651 |
+
'Unable to find a starting sequence length that can fit given the corruption and max_length parameters.'
|
652 |
+
)
|
653 |
+
|
654 |
+
|
655 |
+
def _sample_mask_array(length: int, mask_ratio: float,
|
656 |
+
mean_span_length: float) -> np.ndarray:
|
657 |
+
"""Samples a span corruption mask."""
|
658 |
+
if mask_ratio == 0.0:
|
659 |
+
return np.zeros(length)
|
660 |
+
# This first block computes the number of noise/non-noise spans and the
|
661 |
+
# total tokens in each. Extra steps are taken to handle edge cases that
|
662 |
+
# cause degeneracy.
|
663 |
+
starting_length = length
|
664 |
+
length = np.maximum(length, 2)
|
665 |
+
num_noise_tokens = int(np.round(mask_ratio * float(length)))
|
666 |
+
num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1), length - 1)
|
667 |
+
num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
|
668 |
+
num_noise_spans = np.maximum(num_spans, 1)
|
669 |
+
num_nonnoise_tokens = length - num_noise_tokens
|
670 |
+
|
671 |
+
# Sample the noise/non-noise span lengths and interleave them to
|
672 |
+
# generate the mask array.
|
673 |
+
# Note: We always start with a non-noise span.
|
674 |
+
def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray:
|
675 |
+
"""Samples lengths of num_spans segments.
|
676 |
+
|
677 |
+
Note: the combined length of segments equals total_tokens.
|
678 |
+
"""
|
679 |
+
span_markers = np.less(np.arange(total_tokens - 1), num_spans -
|
680 |
+
1)[np.random.permutation(total_tokens - 1)]
|
681 |
+
span_start_indicator = np.concatenate([np.array([0]), span_markers])
|
682 |
+
span_id = np.cumsum(span_start_indicator).reshape(-1, 1)
|
683 |
+
spans = np.arange(num_spans).reshape(1, -1)
|
684 |
+
span_lengths = np.sum(span_id == spans, axis=0)
|
685 |
+
return span_lengths
|
686 |
+
|
687 |
+
noise_span_lengths = _sample_span_lengths(num_noise_tokens, num_noise_spans)
|
688 |
+
nonnoise_span_lengths = _sample_span_lengths(num_nonnoise_tokens,
|
689 |
+
num_noise_spans)
|
690 |
+
interleaved_span_lengths = np.reshape(
|
691 |
+
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
|
692 |
+
[num_noise_spans * 2])
|
693 |
+
|
694 |
+
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
695 |
+
span_start_indicator = np.zeros(length)
|
696 |
+
span_start_indicator[span_starts] = 1
|
697 |
+
span_id = np.cumsum(span_start_indicator)
|
698 |
+
is_noise = np.equal(np.mod(span_id, 2), 1)
|
699 |
+
|
700 |
+
mask = is_noise[:starting_length]
|
701 |
+
|
702 |
+
return mask
|
703 |
+
|
704 |
+
|
705 |
+
def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],
|
706 |
+
mask: np.ndarray,
|
707 |
+
use_sentinels: bool,
|
708 |
+
eos_token_id: int,
|
709 |
+
sentinel_token_ids: np.ndarray,
|
710 |
+
ensure_eos: bool = True) -> np.ndarray:
|
711 |
+
"""Remove or replace masked portions from token sequence."""
|
712 |
+
if not use_sentinels:
|
713 |
+
# The logic is simple if we don't use sentinel tokens
|
714 |
+
noised_tokens = np.array(tokens)[np.logical_not(mask)]
|
715 |
+
|
716 |
+
# Ensure there's an end-of-sentence token at the end
|
717 |
+
if ensure_eos and (noised_tokens[-1] != eos_token_id):
|
718 |
+
noised_tokens = np.concatenate(
|
719 |
+
[noised_tokens, np.array([eos_token_id])])
|
720 |
+
|
721 |
+
return noised_tokens
|
722 |
+
|
723 |
+
# Masking at previous token
|
724 |
+
prev_token_mask = np.concatenate([np.array([0]), mask[:-1]])
|
725 |
+
|
726 |
+
# Decompose mask into start-of-span mask and non-start-of-span mask
|
727 |
+
start_of_noise_span_token = np.logical_and(mask,
|
728 |
+
np.logical_not(prev_token_mask))
|
729 |
+
nonstart_noise_span_token = np.logical_and(mask, prev_token_mask)
|
730 |
+
|
731 |
+
# Replace tokens at the start of each noise span with its corresponding
|
732 |
+
# sentinel token
|
733 |
+
sentinel_idx = np.minimum(len(sentinel_token_ids),
|
734 |
+
np.cumsum(start_of_noise_span_token)) - 1
|
735 |
+
tokens = np.where(start_of_noise_span_token,
|
736 |
+
sentinel_token_ids[sentinel_idx], tokens)
|
737 |
+
|
738 |
+
# Remove masked tokens (but preserving the sentinel tokens)
|
739 |
+
noised_tokens = tokens[np.logical_not(nonstart_noise_span_token)]
|
740 |
+
|
741 |
+
# Ensure there's an end-of-sentence token at the end
|
742 |
+
if ensure_eos and (noised_tokens[-1] != eos_token_id):
|
743 |
+
noised_tokens = np.concatenate(
|
744 |
+
[noised_tokens, np.array([eos_token_id])])
|
745 |
+
return noised_tokens
|
746 |
+
|
747 |
+
|
748 |
+
def _format_tokens_for_encoder_decoder(
|
749 |
+
tokens_inputs: torch.LongTensor,
|
750 |
+
tokens_labels: torch.LongTensor,
|
751 |
+
max_seq_length: int,
|
752 |
+
pad_token_id: int,
|
753 |
+
) -> Dict[str, torch.Tensor]:
|
754 |
+
"""Package the input/label sequence for an EncDec model."""
|
755 |
+
example = {}
|
756 |
+
# Re-populate with an empty, padded example
|
757 |
+
example['input_ids'] = torch.full((max_seq_length,),
|
758 |
+
pad_token_id,
|
759 |
+
dtype=torch.int32)
|
760 |
+
example['labels'] = torch.full((max_seq_length,),
|
761 |
+
_HF_IGNORE_INDEX,
|
762 |
+
dtype=torch.int32)
|
763 |
+
example['attention_mask'] = torch.zeros_like(example['input_ids'])
|
764 |
+
example['decoder_attention_mask'] = torch.zeros_like(example['labels'])
|
765 |
+
|
766 |
+
# Fill in with processed results (Note: EncDec format is right-padded)
|
767 |
+
example['input_ids'][:len(tokens_inputs)] = tokens_inputs
|
768 |
+
example['labels'][:len(tokens_labels)] = tokens_labels
|
769 |
+
example['attention_mask'][:len(tokens_inputs)] = 1
|
770 |
+
example['decoder_attention_mask'][:len(tokens_labels)] = 1
|
771 |
+
|
772 |
+
# Best practice is to include decoder_input_ids (= right-shifted labels)
|
773 |
+
example['decoder_input_ids'] = torch.full_like(example['labels'],
|
774 |
+
pad_token_id)
|
775 |
+
example['decoder_input_ids'][1:len(tokens_labels)] = tokens_labels[:-1]
|
776 |
+
return example
|
777 |
+
|
778 |
+
|
779 |
+
def _format_tokens_for_decoder_only(
|
780 |
+
tokens_inputs: torch.LongTensor,
|
781 |
+
tokens_labels: torch.LongTensor,
|
782 |
+
max_seq_length: int,
|
783 |
+
pad_token_id: int,
|
784 |
+
padding_side: str,
|
785 |
+
) -> Dict[str, torch.Tensor]:
|
786 |
+
"""Package the input/label sequence for an decoder-only model."""
|
787 |
+
example = {}
|
788 |
+
# Re-populate with an empty, padded example
|
789 |
+
example['input_ids'] = torch.full((max_seq_length,),
|
790 |
+
pad_token_id,
|
791 |
+
dtype=torch.int32)
|
792 |
+
example['labels'] = torch.full((max_seq_length,),
|
793 |
+
_HF_IGNORE_INDEX,
|
794 |
+
dtype=torch.int32)
|
795 |
+
example['attention_mask'] = torch.full((max_seq_length,),
|
796 |
+
0,
|
797 |
+
dtype=torch.bool)
|
798 |
+
example['bidirectional_mask'] = torch.full((max_seq_length,),
|
799 |
+
0,
|
800 |
+
dtype=torch.bool)
|
801 |
+
|
802 |
+
n_input = len(tokens_inputs)
|
803 |
+
n_label = len(tokens_labels)
|
804 |
+
n_concat = n_input + n_label
|
805 |
+
assert n_concat <= max_seq_length, f'{n_concat=}, {n_input=}, {n_label=}'
|
806 |
+
|
807 |
+
tokens_concat = torch.concat([tokens_inputs, tokens_labels], dim=0)
|
808 |
+
|
809 |
+
# Fill in with the processed results
|
810 |
+
if padding_side == 'left':
|
811 |
+
example['input_ids'][-n_concat:] = tokens_concat
|
812 |
+
# `labels` copies `input_ids` but with -100 at
|
813 |
+
# non-loss-generating tokens. `labels` will be shifted in the
|
814 |
+
# model code when computing loss.
|
815 |
+
example['labels'][-n_concat:] = tokens_concat
|
816 |
+
example['labels'][-n_concat:-n_label] = _HF_IGNORE_INDEX
|
817 |
+
example['attention_mask'][-n_concat:] = 1
|
818 |
+
example['bidirectional_mask'][-n_concat:-n_label] = 1
|
819 |
+
else:
|
820 |
+
example['input_ids'][:n_concat] = tokens_concat
|
821 |
+
# See above comment regarding `labels`
|
822 |
+
example['labels'][:n_concat] = tokens_concat
|
823 |
+
example['labels'][:n_input] = _HF_IGNORE_INDEX
|
824 |
+
example['attention_mask'][:n_concat] = 1
|
825 |
+
example['bidirectional_mask'][:n_input] = 1
|
826 |
+
return example
|
827 |
+
|
828 |
+
|
829 |
+
# Helpful to test if your dataloader is working locally
|
830 |
+
# Run `python denoising.py [local] [remote, optional]` and verify that batches
|
831 |
+
# are printed out
|
832 |
+
if __name__ == '__main__':
|
833 |
+
from llmfoundry.utils.builders import build_tokenizer
|
834 |
+
|
835 |
+
local = sys.argv[1]
|
836 |
+
if len(sys.argv) > 2:
|
837 |
+
remote = sys.argv[2]
|
838 |
+
else:
|
839 |
+
remote = local
|
840 |
+
print(f'Reading val split from {remote} -> {local}')
|
841 |
+
|
842 |
+
decoder_only = True
|
843 |
+
|
844 |
+
cfg = {
|
845 |
+
'name': 'text_denoising',
|
846 |
+
'dataset': {
|
847 |
+
'local': local,
|
848 |
+
'remote': remote,
|
849 |
+
'split': 'val', # 'val_small',
|
850 |
+
'shuffle': False,
|
851 |
+
'max_seq_len': 2048 if decoder_only else 1024,
|
852 |
+
'packing_ratio': 4.5,
|
853 |
+
'predownload': 1000,
|
854 |
+
'keep_zip': True, # in case we need compressed files after testing
|
855 |
+
},
|
856 |
+
'mixture_of_denoisers': {
|
857 |
+
'decoder_only_format': decoder_only,
|
858 |
+
'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
|
859 |
+
'sequence_mask_ratios': 0.25,
|
860 |
+
},
|
861 |
+
'drop_last': False,
|
862 |
+
'num_workers': 0,
|
863 |
+
}
|
864 |
+
cfg = om.create(cfg)
|
865 |
+
device_batch_size = 2
|
866 |
+
|
867 |
+
tokenizer_name = 'EleutherAI/gpt-neox-20b' if decoder_only else 't5-base'
|
868 |
+
tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
|
869 |
+
tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
|
870 |
+
tokenizer_kwargs=tokenizer_kwargs)
|
871 |
+
|
872 |
+
loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
|
873 |
+
assert isinstance(loader.dataset, StreamingTextDataset)
|
874 |
+
|
875 |
+
print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
|
876 |
+
|
877 |
+
packing = cfg.dataset.get('packing_ratio') is not None
|
878 |
+
if packing:
|
879 |
+
tokenizer = loader.collate_fn.base_collator.tokenizer
|
880 |
+
else:
|
881 |
+
tokenizer = loader.collate_fn.tokenizer
|
882 |
+
batch_ix = 0
|
883 |
+
for batch in loader:
|
884 |
+
if batch_ix >= 50:
|
885 |
+
batch_ix += 1
|
886 |
+
break
|
887 |
+
if batch_ix >= 5:
|
888 |
+
if not packing:
|
889 |
+
break
|
890 |
+
batch_ix += 1
|
891 |
+
continue
|
892 |
+
print('\n')
|
893 |
+
print('#' * 20, f'Batch {batch_ix}', '#' * 20)
|
894 |
+
for k, v in batch.items():
|
895 |
+
print(k, v.shape, v.dtype)
|
896 |
+
for sample_ix, token_sample in enumerate(batch['input_ids']):
|
897 |
+
if cfg.mixture_of_denoisers.decoder_only_format:
|
898 |
+
labels = batch['labels'][sample_ix]
|
899 |
+
attn_inputs = batch['bidirectional_mask'][sample_ix].to(
|
900 |
+
torch.bool)
|
901 |
+
attn_full = batch['attention_mask'][sample_ix].to(torch.bool)
|
902 |
+
attn_labels = torch.logical_xor(attn_inputs, attn_full)
|
903 |
+
print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
|
904 |
+
if packing:
|
905 |
+
for subseq in range(
|
906 |
+
int(batch['sequence_id'][sample_ix].max()) + 1):
|
907 |
+
is_subseq = batch['sequence_id'][sample_ix] == subseq
|
908 |
+
print(
|
909 |
+
'\033[93m{}\033[00m\n'.format('Input: '),
|
910 |
+
tokenizer.decode(token_sample[torch.logical_and(
|
911 |
+
is_subseq, attn_inputs)]))
|
912 |
+
print(
|
913 |
+
'\033[92m{}\033[00m\n'.format('Target: '),
|
914 |
+
tokenizer.decode(labels[torch.logical_and(
|
915 |
+
is_subseq, attn_labels)]))
|
916 |
+
else:
|
917 |
+
print('\033[91m{}\033[00m\n'.format('Full: '),
|
918 |
+
tokenizer.decode(token_sample[attn_full]))
|
919 |
+
print('\033[93m{}\033[00m\n'.format('Input: '),
|
920 |
+
tokenizer.decode(token_sample[attn_inputs]))
|
921 |
+
print('\033[92m{}\033[00m\n'.format('Target: '),
|
922 |
+
tokenizer.decode(labels[attn_labels]))
|
923 |
+
else:
|
924 |
+
labels = batch['labels'][sample_ix]
|
925 |
+
attn_inputs = batch['attention_mask'][sample_ix].to(torch.bool)
|
926 |
+
attn_labels = batch['decoder_attention_mask'][sample_ix].to(
|
927 |
+
torch.bool)
|
928 |
+
print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
|
929 |
+
print('\033[93m{}\033[00m\n'.format('Input: '),
|
930 |
+
tokenizer.decode(token_sample[attn_inputs]))
|
931 |
+
print('\033[92m{}\033[00m\n'.format('Target: '),
|
932 |
+
tokenizer.decode(labels[attn_labels]))
|
933 |
+
batch_ix += 1
|
934 |
+
|
935 |
+
if packing:
|
936 |
+
print(f'Padding = {100*(1-loader.collate_fn.efficiency):5.2f}%')
|
937 |
+
print(f'Waste = {100*loader.collate_fn.waste:5.2f}%')
|
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
|
5 |
+
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
|
6 |
+
|
7 |
+
__all__ = ['Seq2SeqFinetuningCollator', 'build_finetuning_dataloader']
|
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/collator.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import warnings
|
6 |
+
from typing import Any, Dict, List, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
10 |
+
|
11 |
+
log = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# HuggingFace hardcodes the ignore index to -100
|
14 |
+
_HF_IGNORE_INDEX = -100
|
15 |
+
|
16 |
+
|
17 |
+
class Seq2SeqFinetuningCollator:
|
18 |
+
"""A general-purpose collator for sequence-to-sequence training/evaluation.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
|
22 |
+
max_seq_len (int): The maximum sequence length of the combined
|
23 |
+
context/target sequence (decoder-only format) or of each the
|
24 |
+
context sequence and target sequence (encoder-decoder format).
|
25 |
+
decoder_only_format (bool): Whether to format the batches for a
|
26 |
+
decoder-only model (if True) or an encoder-decoder model (if False).
|
27 |
+
allow_pad_trimming (bool, optional): Whether to allow the collator
|
28 |
+
to trim padding, which may result in smaller but inconsistent batch
|
29 |
+
sizes. Default: ``False`` ensures that all sequences are max_seq_len.
|
30 |
+
separator_text (str | bool, optional): If a string is provided, it will
|
31 |
+
be used to separate the context and target sequences (appended to end
|
32 |
+
of context). If ``True``, will use the tokenizer's sep_token, which must
|
33 |
+
be defined. Only applicable for decoder-only formatting.
|
34 |
+
format_for_generation (bool, optional): Whether to format the batch such
|
35 |
+
that context and target sequences remain separated, which is useful
|
36 |
+
when using the context to generate text which should be compared to the
|
37 |
+
target (e.g., during evaluation). Default: ``False``.
|
38 |
+
batch_metadata (dict, optional): A dictionary of metadata which will be added
|
39 |
+
to the batch.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
45 |
+
max_seq_len: int,
|
46 |
+
decoder_only_format: bool,
|
47 |
+
allow_pad_trimming: bool = False,
|
48 |
+
separator_text: Optional[Union[str, bool]] = None,
|
49 |
+
format_for_generation: bool = False,
|
50 |
+
batch_metadata: Optional[Dict[str, Any]] = None,
|
51 |
+
):
|
52 |
+
self.tokenizer = tokenizer
|
53 |
+
self.max_seq_len = max_seq_len
|
54 |
+
self.decoder_only_format = decoder_only_format
|
55 |
+
self.format_for_generation = format_for_generation
|
56 |
+
self.batch_metadata = batch_metadata or {}
|
57 |
+
|
58 |
+
# Trimming will always be skipped on at least the first __call__
|
59 |
+
self._allow_pad_trimming = allow_pad_trimming
|
60 |
+
self._seen_first_batch = False
|
61 |
+
|
62 |
+
illegal_keys = [
|
63 |
+
'input_ids', 'labels', 'attention_mask', 'decoder_input_ids',
|
64 |
+
'decoder_attention_mask', 'generate_output'
|
65 |
+
]
|
66 |
+
found_keys = []
|
67 |
+
for illegal_key in illegal_keys:
|
68 |
+
if illegal_key in self.batch_metadata:
|
69 |
+
found_keys.append(illegal_key)
|
70 |
+
if found_keys:
|
71 |
+
raise ValueError(
|
72 |
+
f'The following keys are in batch_metadata but are not allowed: {", ".join(found_keys)}.\n' +\
|
73 |
+
f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' +\
|
74 |
+
f'{", ".join(illegal_keys)}'
|
75 |
+
)
|
76 |
+
if self.format_for_generation:
|
77 |
+
self.batch_metadata['generate_output'] = True
|
78 |
+
|
79 |
+
if (max_seq_len % 8) != 0:
|
80 |
+
log.warning(
|
81 |
+
'For performance, a max_seq_len as a multiple of 8 is recommended.'
|
82 |
+
)
|
83 |
+
|
84 |
+
if self.tokenizer.pad_token_id is None:
|
85 |
+
raise ValueError(
|
86 |
+
f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None'
|
87 |
+
)
|
88 |
+
|
89 |
+
self.separator_tokens = []
|
90 |
+
if separator_text and decoder_only_format:
|
91 |
+
if separator_text == True:
|
92 |
+
# Use the tokenizer's sep token or throw an error if undefined
|
93 |
+
if self.tokenizer.sep_token_id is None:
|
94 |
+
raise ValueError(
|
95 |
+
'Setting separator_text=True requires that the tokenizer has sep_token_id but it has not been set. ' +\
|
96 |
+
'Please pass a string argument for separator_text or set sep_token_id in the tokenizer.'
|
97 |
+
)
|
98 |
+
self.separator_tokens = [self.tokenizer.sep_token_id]
|
99 |
+
else:
|
100 |
+
# Convert the string separator_text into token(s)
|
101 |
+
self.separator_tokens = tokenizer(
|
102 |
+
separator_text, add_special_tokens=False).input_ids
|
103 |
+
|
104 |
+
self._warned_context = False
|
105 |
+
self._warned_target = False
|
106 |
+
|
107 |
+
def __call__(self, examples: List[Dict[str,
|
108 |
+
Any]]) -> Dict[str, torch.Tensor]:
|
109 |
+
for check_key in ['input_ids', 'labels', 'attention_mask']:
|
110 |
+
if check_key not in examples[0]:
|
111 |
+
raise KeyError(
|
112 |
+
f'Examples returned by dataset do not include required key: {check_key}'
|
113 |
+
)
|
114 |
+
|
115 |
+
if self.decoder_only_format:
|
116 |
+
batch = self._process_and_batch_decoder_only(examples)
|
117 |
+
else:
|
118 |
+
batch = self._process_and_batch_encoder_decoder(examples)
|
119 |
+
|
120 |
+
# Add any batch_metadata
|
121 |
+
batch_size = batch['input_ids'].shape[0]
|
122 |
+
batch.update({
|
123 |
+
k: torch.tensor([v] * batch_size)
|
124 |
+
for k, v in self.batch_metadata.items()
|
125 |
+
})
|
126 |
+
|
127 |
+
return batch
|
128 |
+
|
129 |
+
def _process_and_batch_decoder_only(
|
130 |
+
self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
131 |
+
# Steps explained in comments
|
132 |
+
processed_examples = []
|
133 |
+
for example in examples:
|
134 |
+
context = ensure_list(example['input_ids'])
|
135 |
+
target = ensure_list(example['labels'])
|
136 |
+
# First, get rid of any padding tokens
|
137 |
+
context = [t for t in context if t != self.tokenizer.pad_token_id]
|
138 |
+
target = [t for t in target if t != self.tokenizer.pad_token_id]
|
139 |
+
# Second, append any separator tokens to the context tokens
|
140 |
+
if self.separator_tokens:
|
141 |
+
context = context + self.separator_tokens
|
142 |
+
# Third, ensure that the target text ends with an eos tag
|
143 |
+
if target[-1] != self.tokenizer.eos_token_id:
|
144 |
+
target = target + [self.tokenizer.eos_token_id]
|
145 |
+
|
146 |
+
n_context = len(context)
|
147 |
+
n_target = len(target)
|
148 |
+
|
149 |
+
if n_context >= self.max_seq_len:
|
150 |
+
if not self._warned_context:
|
151 |
+
warnings.warn(
|
152 |
+
f'Skipping example because CONTEXT length={n_context} leaves no room ' +\
|
153 |
+
f'for TARGET tokens because max_seq_len={self.max_seq_len}. ' +\
|
154 |
+
f'If this causes downstream issues because of inconsistent batch sizes, ' +\
|
155 |
+
f'consider increasing max_seq_len or using example packing.'
|
156 |
+
)
|
157 |
+
self._warned_context = True
|
158 |
+
continue
|
159 |
+
|
160 |
+
if self.format_for_generation:
|
161 |
+
# When formatting for generation, we need to keep input_ids and
|
162 |
+
# labels separate. The input_ids (context) will be fed into the
|
163 |
+
# generator and the labels will be used by the eval metric.
|
164 |
+
input_ids = context[-self.max_seq_len:]
|
165 |
+
n_context = len(input_ids)
|
166 |
+
attention_mask = [1] * n_context
|
167 |
+
bidirectional_mask = [1] * n_context
|
168 |
+
# Annoyingly, we need to pad the everything but input_ids
|
169 |
+
# and attention_mask ourselves
|
170 |
+
i_pad = [self.tokenizer.pad_token_id
|
171 |
+
] * (self.max_seq_len - n_target)
|
172 |
+
z_pad = [0] * (self.max_seq_len - n_context)
|
173 |
+
if self.tokenizer.padding_side == 'left':
|
174 |
+
labels = i_pad + target
|
175 |
+
bidirectional_mask = z_pad + bidirectional_mask
|
176 |
+
else:
|
177 |
+
labels = target + i_pad
|
178 |
+
bidirectional_mask = bidirectional_mask + z_pad
|
179 |
+
|
180 |
+
else:
|
181 |
+
# We need to concatenate the context and target to get the
|
182 |
+
# full input sequence, cutting off any excess tokens from the
|
183 |
+
# end of the target
|
184 |
+
if n_context + n_target > self.max_seq_len:
|
185 |
+
old_n_target = int(n_target)
|
186 |
+
n_target = self.max_seq_len - n_context
|
187 |
+
if not self._warned_target:
|
188 |
+
warnings.warn(
|
189 |
+
f'Truncating TARGET sequence of length={old_n_target} to length={n_target}, ' +\
|
190 |
+
f'so context+target fit max_seq_len={self.max_seq_len}. If truncation is ' +\
|
191 |
+
f'a problem, consider increasing max_seq_len.')
|
192 |
+
self._warned_target = True
|
193 |
+
target = target[-n_target:]
|
194 |
+
target[-1] = self.tokenizer.eos_token_id
|
195 |
+
n_total = n_context + n_target
|
196 |
+
|
197 |
+
input_ids = context + target
|
198 |
+
labels = ([_HF_IGNORE_INDEX] * n_context) + target
|
199 |
+
attention_mask = [1] * n_total
|
200 |
+
# bidirectional_mask is used by our prefix lm model variants
|
201 |
+
bidirectional_mask = ([1] * n_context) + ([0] * n_target)
|
202 |
+
|
203 |
+
# Annoyingly, we need to pad the everything but input_ids
|
204 |
+
# and attention_mask ourselves
|
205 |
+
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
|
206 |
+
z_pad = [0] * (self.max_seq_len - n_total)
|
207 |
+
if self.tokenizer.padding_side == 'left':
|
208 |
+
labels = i_pad + labels
|
209 |
+
bidirectional_mask = z_pad + bidirectional_mask
|
210 |
+
else:
|
211 |
+
labels = labels + i_pad
|
212 |
+
bidirectional_mask = bidirectional_mask + z_pad
|
213 |
+
|
214 |
+
# Update the example
|
215 |
+
example['input_ids'] = input_ids
|
216 |
+
example['labels'] = labels
|
217 |
+
example['attention_mask'] = attention_mask
|
218 |
+
example['bidirectional_mask'] = bidirectional_mask
|
219 |
+
|
220 |
+
processed_examples.append(example)
|
221 |
+
|
222 |
+
batch = self.tokenizer.pad(
|
223 |
+
processed_examples,
|
224 |
+
padding='max_length',
|
225 |
+
max_length=self.max_seq_len,
|
226 |
+
return_tensors='pt',
|
227 |
+
)
|
228 |
+
|
229 |
+
# This logic prevents trimming on at least the first batch
|
230 |
+
if not (self._allow_pad_trimming and self._seen_first_batch):
|
231 |
+
self._seen_first_batch = True
|
232 |
+
return batch
|
233 |
+
self._seen_first_batch = True
|
234 |
+
|
235 |
+
# The batch is ready, but we can trim padding for efficiency
|
236 |
+
multiple_of = 8
|
237 |
+
|
238 |
+
n_non_padding = batch['attention_mask'].sum(dim=1).max()
|
239 |
+
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
|
240 |
+
for k, v in batch.items():
|
241 |
+
if len(v.shape) < 2:
|
242 |
+
continue
|
243 |
+
if k == 'labels' and self.format_for_generation:
|
244 |
+
continue
|
245 |
+
if self.tokenizer.padding_side == 'left':
|
246 |
+
batch[k] = v[:, -keep_tokens:].contiguous()
|
247 |
+
else:
|
248 |
+
batch[k] = v[:, :keep_tokens].contiguous()
|
249 |
+
|
250 |
+
return batch
|
251 |
+
|
252 |
+
def _process_and_batch_encoder_decoder(
|
253 |
+
self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
254 |
+
# The encoder-decoder case is has some gotchas.
|
255 |
+
# Steps are explained in comments.
|
256 |
+
processed_examples = []
|
257 |
+
for example in examples:
|
258 |
+
context = ensure_list(example['input_ids'])
|
259 |
+
target = ensure_list(example['labels'])
|
260 |
+
# ... first, get rid of any padding that was already applied
|
261 |
+
context = [t for t in context if t != self.tokenizer.pad_token_id]
|
262 |
+
target = [t for t in target if t != self.tokenizer.pad_token_id]
|
263 |
+
# ... second, ensure that the target text ends with an eos tag
|
264 |
+
if target[-1] != self.tokenizer.eos_token_id:
|
265 |
+
target = target + [self.tokenizer.eos_token_id]
|
266 |
+
# ... third, we need to pad labels ourselves. Because HF.
|
267 |
+
if len(target) < self.max_seq_len:
|
268 |
+
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
|
269 |
+
target = target + i_pad
|
270 |
+
else:
|
271 |
+
if not self._warned_target:
|
272 |
+
warnings.warn(
|
273 |
+
f'Truncating TARGET sequence of length={len(target)} ' +\
|
274 |
+
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
|
275 |
+
f'a problem, consider increasing max_seq_len.')
|
276 |
+
self._warned_target = True
|
277 |
+
target = target[:self.max_seq_len -
|
278 |
+
1] + [self.tokenizer.eos_token_id]
|
279 |
+
|
280 |
+
# We might need to truncate the context. Preserve the beginning.
|
281 |
+
if len(context) > self.max_seq_len:
|
282 |
+
if not self._warned_context:
|
283 |
+
warnings.warn(
|
284 |
+
f'Truncating CONTEXT sequence of length={len(context)} ' +\
|
285 |
+
f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
|
286 |
+
f'a problem, consider increasing max_seq_len.')
|
287 |
+
self._warned_context = True
|
288 |
+
context = context[:self.max_seq_len -
|
289 |
+
1] + [self.tokenizer.eos_token_id]
|
290 |
+
|
291 |
+
# Back into the example
|
292 |
+
example['input_ids'] = context
|
293 |
+
example['attention_mask'] = [1] * len(context)
|
294 |
+
example['labels'] = target
|
295 |
+
|
296 |
+
processed_examples.append(example)
|
297 |
+
|
298 |
+
# Batch examples into a single dict (this also pads)
|
299 |
+
batch = self.tokenizer.pad(
|
300 |
+
processed_examples,
|
301 |
+
padding='max_length',
|
302 |
+
max_length=self.max_seq_len,
|
303 |
+
return_tensors='pt',
|
304 |
+
)
|
305 |
+
# We're still missing decoder_input_ids and decoder_attention_mask
|
306 |
+
batch['decoder_input_ids'] = torch.cat([
|
307 |
+
torch.full((len(processed_examples), 1),
|
308 |
+
self.tokenizer.pad_token_id), batch['labels'][:, :-1]
|
309 |
+
],
|
310 |
+
dim=1)
|
311 |
+
batch['decoder_input_ids'].masked_fill_(
|
312 |
+
batch['decoder_input_ids'] == _HF_IGNORE_INDEX,
|
313 |
+
self.tokenizer.pad_token_id)
|
314 |
+
batch['decoder_attention_mask'] = torch.not_equal(
|
315 |
+
batch['labels'], _HF_IGNORE_INDEX)
|
316 |
+
|
317 |
+
# This logic prevents trimming on at least the first batch
|
318 |
+
if not (self._allow_pad_trimming and self._seen_first_batch):
|
319 |
+
self._seen_first_batch = True
|
320 |
+
return batch
|
321 |
+
self._seen_first_batch = True
|
322 |
+
|
323 |
+
# The batch is now valid, but we can trim padding for efficiency
|
324 |
+
multiple_of = 8
|
325 |
+
# (first for the encoder)
|
326 |
+
n_non_padding = batch['attention_mask'].sum(dim=1).max()
|
327 |
+
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
|
328 |
+
for k in ['input_ids', 'attention_mask']:
|
329 |
+
batch[k] = batch[k][:, :keep_tokens].contiguous()
|
330 |
+
# (then for the decoder)
|
331 |
+
n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
|
332 |
+
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
|
333 |
+
for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
|
334 |
+
batch[k] = batch[k][:, :keep_tokens].contiguous()
|
335 |
+
|
336 |
+
return batch
|
337 |
+
|
338 |
+
|
339 |
+
def ensure_list(x: Union[List, torch.Tensor]) -> List:
|
340 |
+
if isinstance(x, torch.Tensor):
|
341 |
+
x = list(x.flatten())
|
342 |
+
assert isinstance(x, list)
|
343 |
+
return x
|
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/dataloader.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
from typing import Tuple, Union
|
6 |
+
|
7 |
+
import datasets as hf_datasets
|
8 |
+
import torch
|
9 |
+
from composer.utils import dist, get_file, parse_uri
|
10 |
+
from omegaconf import DictConfig
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from transformers import PreTrainedTokenizerBase
|
13 |
+
|
14 |
+
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
|
15 |
+
from llmfoundry.data.finetuning.tasks import dataset_constructor
|
16 |
+
from llmfoundry.data.packing import BinPackWrapper
|
17 |
+
|
18 |
+
log = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
# HuggingFace hardcodes the ignore index to -100
|
21 |
+
_HF_IGNORE_INDEX = -100
|
22 |
+
|
23 |
+
|
24 |
+
def build_finetuning_dataloader(cfg: DictConfig,
|
25 |
+
tokenizer: PreTrainedTokenizerBase,
|
26 |
+
device_batch_size: int) -> DataLoader:
|
27 |
+
"""Builds a finetuning dataloader for training or evaluating.
|
28 |
+
|
29 |
+
The underlying dataset can be built through one of two code paths:
|
30 |
+
1. As a HuggingFace dataset, via `datasets.load_dataset(...)`
|
31 |
+
2. As a streaming dataset
|
32 |
+
You will need to set slightly different dataset config fields depending
|
33 |
+
on which you intend to use, as explained below.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
cfg (DictConfig): An omegaconf dictionary used to configure the loader:
|
37 |
+
cfg.name (str): The type of dataloader to build. Must = "finetuning".
|
38 |
+
---
|
39 |
+
*** HuggingFace dataset config fields ***
|
40 |
+
cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset
|
41 |
+
to use. Can also be a remote http(s) directory or object store bucket
|
42 |
+
containing the file {split}.jsonl in the format (prompt, response),
|
43 |
+
in which case the builder will create a HuggingFace dataset.
|
44 |
+
cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to
|
45 |
+
pass to `datasets.load_dataset`, which can be used to load
|
46 |
+
a dataset from local files.
|
47 |
+
cfg.dataset.preprocessing_fn (str, optional): The name/import path of
|
48 |
+
the preprocessing function to use for formatting the data examples.
|
49 |
+
If ``None`` (default), the builder will use the preprocessing function
|
50 |
+
registered under `hf_name` (see `tasks.py`), if one exists,
|
51 |
+
otherwise it will skip preprocessing.
|
52 |
+
If `preprocessing_fn` corresponds to a registered preprocessing
|
53 |
+
function in `tasks.py`, the builder will use that.
|
54 |
+
Otherwise, it will interpret `preprocessing_fn` as a
|
55 |
+
"import.path:function_name" import path; e.g., it will call
|
56 |
+
`from import.path import function_name` and use the imported
|
57 |
+
function as the preprocessing function.
|
58 |
+
*** Streaming dataset config fields ***
|
59 |
+
cfg.dataset.remote (str, optional): Location of a MDS-formatted
|
60 |
+
streaming dataset to use. Setting this will tell the builder
|
61 |
+
to create a streaming dataset rather than a HuggingFace dataset.
|
62 |
+
cfg.dataset.local (str, optional): Local path where remote data
|
63 |
+
will be streamed to. Only valid if `cfg.dataset.remote` has
|
64 |
+
also been set.
|
65 |
+
*** Shared dataset configs fields ***
|
66 |
+
cfg.dataset.max_seq_len (int): The maximum length of sequences
|
67 |
+
in the batch. See :class:`Seq2SeqFinetuningCollator` docstring
|
68 |
+
for details.
|
69 |
+
cfg.dataset.decoder_only_format (bool): Whether to format the
|
70 |
+
examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator`
|
71 |
+
docstring for details.
|
72 |
+
cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow
|
73 |
+
the collator to trim padding. See :class:`Seq2SeqFinetuningCollator`
|
74 |
+
docstring for details. Default: ``False``.
|
75 |
+
cfg.dataset.packing_ratio (float, optional): If provided, this invokes
|
76 |
+
a collator wrapper that packs `device_batch_size*packing_ratio`
|
77 |
+
raw examples into `device_batch_size` packed examples. This helps
|
78 |
+
minimize padding while preserving sequence integrity.
|
79 |
+
This adds `sequence_id` to the batch, which indicates which unique
|
80 |
+
sequence each token belongs to.
|
81 |
+
Note: Using this feature will not change device_batch_size but it
|
82 |
+
will determine the number of raw examples consumed by the dataloader
|
83 |
+
per batch. Some examples may be discarded if they do not fit when
|
84 |
+
packing.
|
85 |
+
Select `packing_ratio` **carefully** based on the dataset
|
86 |
+
statistics, `max_seq_len`, and tolerance for discarding samples!
|
87 |
+
The packing code in `../packing.py` provides a script that can help
|
88 |
+
you choose the best `packing_ratio`.
|
89 |
+
cfg.dataset.shuffle (bool): Whether to shuffle the dataset.
|
90 |
+
___
|
91 |
+
See :class:`StreamingFinetuningDataset` for info on other standard config
|
92 |
+
options within `cfg.dataset` that will be passed as kwargs if
|
93 |
+
using the streaming codepath.
|
94 |
+
---
|
95 |
+
See :class:`DataLoader` for standard argument options to the pytorch
|
96 |
+
dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
|
97 |
+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
|
98 |
+
prepare the data from raw text. Any missing sentinel tokens will
|
99 |
+
be added by the collator.
|
100 |
+
device_batch_size (int): The size of the batches (number of examples)
|
101 |
+
that the dataloader will produce.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
A pytorch dataloader
|
105 |
+
|
106 |
+
Note:
|
107 |
+
You can run the script inside `../packing.py` to quickly test the
|
108 |
+
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
|
109 |
+
given a starting workload YAML.
|
110 |
+
"""
|
111 |
+
_validate_config(cfg.dataset)
|
112 |
+
|
113 |
+
# Use EOS as the pad token if none exists
|
114 |
+
if tokenizer.pad_token is None:
|
115 |
+
tokenizer.pad_token = tokenizer.eos_token
|
116 |
+
|
117 |
+
dataset = None # for pyright
|
118 |
+
if cfg.dataset.get('remote') is not None:
|
119 |
+
dataset = dataset_constructor.build_from_streaming(
|
120 |
+
tokenizer=tokenizer,
|
121 |
+
local=cfg.dataset.local,
|
122 |
+
remote=cfg.dataset.get('remote', None),
|
123 |
+
split=cfg.dataset.get('split', None),
|
124 |
+
download_retry=cfg.dataset.get('download_retry', 2),
|
125 |
+
download_timeout=cfg.dataset.get('download_timeout', 60),
|
126 |
+
validate_hash=cfg.dataset.get('validate_hash', None),
|
127 |
+
keep_zip=cfg.dataset.get('keep_zip', False),
|
128 |
+
epoch_size=cfg.dataset.get('epoch_size', None),
|
129 |
+
predownload=cfg.dataset.get('predownload', None),
|
130 |
+
cache_limit=cfg.dataset.get('cache_limit', None),
|
131 |
+
partition_algo=cfg.dataset.get('partition_algo', 'orig'),
|
132 |
+
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
|
133 |
+
batch_size=device_batch_size,
|
134 |
+
shuffle=cfg.dataset.get('shuffle', False),
|
135 |
+
shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'),
|
136 |
+
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
|
137 |
+
shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18),
|
138 |
+
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
|
139 |
+
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
|
140 |
+
batching_method=cfg.dataset.get('batching_method', 'random'),
|
141 |
+
)
|
142 |
+
|
143 |
+
collate_fn, dataloader_batch_size = _build_collate_fn(
|
144 |
+
cfg.dataset, tokenizer, device_batch_size)
|
145 |
+
|
146 |
+
return DataLoader(
|
147 |
+
dataset,
|
148 |
+
collate_fn=collate_fn,
|
149 |
+
batch_size=dataloader_batch_size,
|
150 |
+
drop_last=cfg.drop_last,
|
151 |
+
num_workers=cfg.num_workers,
|
152 |
+
pin_memory=cfg.get('pin_memory', True),
|
153 |
+
prefetch_factor=cfg.get('prefetch_factor', 2),
|
154 |
+
persistent_workers=cfg.get('persistent_workers', True),
|
155 |
+
timeout=cfg.get('timeout', 0),
|
156 |
+
)
|
157 |
+
|
158 |
+
else:
|
159 |
+
backend, _, _ = parse_uri(cfg.dataset.hf_name)
|
160 |
+
if backend not in ['', None]:
|
161 |
+
if cfg.dataset.get('split') is None:
|
162 |
+
raise ValueError(
|
163 |
+
'When using a HuggingFace dataset from a URL, you must set the ' + \
|
164 |
+
'`split` key in the dataset config.'
|
165 |
+
)
|
166 |
+
dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
|
167 |
+
else:
|
168 |
+
dataset = dataset_constructor.build_from_hf(
|
169 |
+
cfg.dataset,
|
170 |
+
max_seq_len=cfg.dataset.max_seq_len,
|
171 |
+
tokenizer=tokenizer,
|
172 |
+
)
|
173 |
+
|
174 |
+
collate_fn, dataloader_batch_size = _build_collate_fn(
|
175 |
+
cfg.dataset, tokenizer, device_batch_size)
|
176 |
+
|
177 |
+
if cfg.drop_last:
|
178 |
+
world_size = dist.get_world_size()
|
179 |
+
minimum_dataset_size = world_size * dataloader_batch_size
|
180 |
+
if hasattr(dataset, '__len__'):
|
181 |
+
full_dataset_size = len(dataset)
|
182 |
+
if full_dataset_size < minimum_dataset_size:
|
183 |
+
raise ValueError(
|
184 |
+
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
|
185 |
+
+
|
186 |
+
f'has {full_dataset_size} samples, but your minimum batch size '
|
187 |
+
+
|
188 |
+
f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
|
189 |
+
+
|
190 |
+
f'your per device batch size is {dataloader_batch_size}. Please increase the number '
|
191 |
+
+
|
192 |
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
|
193 |
+
)
|
194 |
+
|
195 |
+
assert dataset is not None
|
196 |
+
return DataLoader(
|
197 |
+
dataset,
|
198 |
+
collate_fn=collate_fn,
|
199 |
+
batch_size=dataloader_batch_size,
|
200 |
+
drop_last=cfg.drop_last,
|
201 |
+
sampler=dist.get_sampler(dataset,
|
202 |
+
drop_last=cfg.drop_last,
|
203 |
+
shuffle=cfg.dataset.shuffle),
|
204 |
+
num_workers=cfg.num_workers,
|
205 |
+
pin_memory=cfg.get('pin_memory', True),
|
206 |
+
prefetch_factor=cfg.get('prefetch_factor', 2),
|
207 |
+
persistent_workers=cfg.get('persistent_workers', True),
|
208 |
+
timeout=cfg.get('timeout', 0),
|
209 |
+
)
|
210 |
+
|
211 |
+
|
212 |
+
def _validate_config(dataset_cfg: DictConfig) -> None:
|
213 |
+
"""Validates the dataset configuration.
|
214 |
+
|
215 |
+
Makes sure that the dataset is properly configured for either
|
216 |
+
a HuggingFace dataset or a streaming dataset. Must be valid for one or
|
217 |
+
the other.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
dataset_cfg (DictConfig): The dataset configuration to be validated.
|
221 |
+
|
222 |
+
Raises:
|
223 |
+
ValueError: If the dataset configuration does not meet the requirements.
|
224 |
+
"""
|
225 |
+
if dataset_cfg.get('hf_name') is not None:
|
226 |
+
# Using the HuggingFace dataset codepath
|
227 |
+
illegal_keys = ['local', 'remote']
|
228 |
+
discovered_illegal_keys = []
|
229 |
+
for key in illegal_keys:
|
230 |
+
if dataset_cfg.get(key) is not None:
|
231 |
+
discovered_illegal_keys.append('`' + key + '`')
|
232 |
+
if discovered_illegal_keys:
|
233 |
+
raise ValueError(
|
234 |
+
'The dataset config sets a value for `hf_name` as well as the ' +\
|
235 |
+
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
|
236 |
+
'Those keys are used when building from a streaming dataset, but ' +\
|
237 |
+
'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.'
|
238 |
+
)
|
239 |
+
elif dataset_cfg.get('remote') is not None:
|
240 |
+
# Using the streaming dataset codepath
|
241 |
+
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn']
|
242 |
+
discovered_illegal_keys = []
|
243 |
+
for key in illegal_keys:
|
244 |
+
if dataset_cfg.get(key) is not None:
|
245 |
+
discovered_illegal_keys.append('`' + key + '`')
|
246 |
+
if discovered_illegal_keys:
|
247 |
+
raise ValueError(
|
248 |
+
'The dataset config sets a value for `remote` as well as the ' +\
|
249 |
+
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
|
250 |
+
'Those keys are used when building from a HuggingFace dataset, but ' +\
|
251 |
+
'setting `remote` instructs the dataset to build from a streaming dataset.'
|
252 |
+
)
|
253 |
+
if dataset_cfg.get('local') is None:
|
254 |
+
raise ValueError(
|
255 |
+
'Using a streaming dataset requires setting both `remote` and `local`, ' +\
|
256 |
+
'but dataset.local is None.'
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
raise ValueError(
|
260 |
+
'In the dataset config, you must set either `hf_name` to use a ' +\
|
261 |
+
'HuggingFace dataset or set `remote` to use a streaming ' +\
|
262 |
+
'dataset, but both were None.'
|
263 |
+
)
|
264 |
+
|
265 |
+
|
266 |
+
def _build_hf_dataset_from_remote(
|
267 |
+
cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
|
268 |
+
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
|
269 |
+
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
|
270 |
+
"""Builds a dataset from a remote object store.
|
271 |
+
|
272 |
+
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
|
273 |
+
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
|
274 |
+
dataset.
|
275 |
+
|
276 |
+
The function also ensures synchronicity across multiple processes during the file download. It creates a signal
|
277 |
+
file that is used to synchronize the start of the download across different processes. Once the download is
|
278 |
+
completed, the function removes the signal file.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
|
282 |
+
This includes:
|
283 |
+
- dataset.hf_name: The path of the HuggingFace dataset to download.
|
284 |
+
- dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
|
285 |
+
- dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.
|
286 |
+
|
287 |
+
tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
|
291 |
+
|
292 |
+
Raises:
|
293 |
+
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
|
294 |
+
"""
|
295 |
+
supported_extensions = ['jsonl', 'csv', 'parquet']
|
296 |
+
# HF datasets does not support a split with dashes, so we replace dashes
|
297 |
+
# with underscores in the destination split.
|
298 |
+
destination_split = cfg.dataset.split.replace('-', '_')
|
299 |
+
finetune_dir = os.path.join(
|
300 |
+
os.path.dirname(
|
301 |
+
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
|
302 |
+
'downloaded_finetuning',
|
303 |
+
destination_split if destination_split != 'data' else 'data_not',
|
304 |
+
)
|
305 |
+
os.makedirs(finetune_dir, exist_ok=True)
|
306 |
+
for extension in supported_extensions:
|
307 |
+
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
|
308 |
+
destination = str(
|
309 |
+
os.path.abspath(
|
310 |
+
os.path.join(
|
311 |
+
finetune_dir, 'data',
|
312 |
+
f'{destination_split}-00000-of-00001.{extension}')))
|
313 |
+
|
314 |
+
# Since we don't know exactly what the extension will be, since it is one of a list
|
315 |
+
# use a signal file to wait for instead of the desired file
|
316 |
+
signal_file_path = os.path.join(
|
317 |
+
finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
|
318 |
+
if dist.get_local_rank() == 0:
|
319 |
+
try:
|
320 |
+
get_file(path=name, destination=destination, overwrite=True)
|
321 |
+
except FileNotFoundError as e:
|
322 |
+
if extension == supported_extensions[-1]:
|
323 |
+
files_searched = [
|
324 |
+
f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}'
|
325 |
+
for ext in supported_extensions
|
326 |
+
]
|
327 |
+
raise FileNotFoundError(
|
328 |
+
f'Could not find a file with any of ' + \
|
329 |
+
f'the supported extensions: {supported_extensions}\n' + \
|
330 |
+
f'at {files_searched}'
|
331 |
+
) from e
|
332 |
+
else:
|
333 |
+
log.debug(
|
334 |
+
f'Could not find {name}, looking for another extension')
|
335 |
+
continue
|
336 |
+
|
337 |
+
os.makedirs(os.path.dirname(signal_file_path), exist_ok=True)
|
338 |
+
with open(signal_file_path, 'wb') as f:
|
339 |
+
f.write(b'local_rank0_completed_download')
|
340 |
+
|
341 |
+
# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
|
342 |
+
# so that we don't timeout for large downloads. This syncs all processes on the node
|
343 |
+
with dist.local_rank_zero_download_and_wait(signal_file_path):
|
344 |
+
# Then, wait to ensure every node has finished downloading the checkpoint
|
345 |
+
dist.barrier()
|
346 |
+
|
347 |
+
# clean up signal file
|
348 |
+
if dist.get_local_rank() == 0:
|
349 |
+
os.remove(signal_file_path)
|
350 |
+
dist.barrier()
|
351 |
+
|
352 |
+
cfg.dataset.hf_name = finetune_dir
|
353 |
+
log.info(cfg.dataset)
|
354 |
+
dataset = dataset_constructor.build_from_hf(
|
355 |
+
cfg.dataset,
|
356 |
+
max_seq_len=cfg.dataset.max_seq_len,
|
357 |
+
tokenizer=tokenizer,
|
358 |
+
)
|
359 |
+
return dataset
|
360 |
+
|
361 |
+
|
362 |
+
def _build_collate_fn(
|
363 |
+
dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
|
364 |
+
device_batch_size: int
|
365 |
+
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
|
366 |
+
collate_fn = Seq2SeqFinetuningCollator(
|
367 |
+
tokenizer=tokenizer,
|
368 |
+
max_seq_len=dataset_cfg.max_seq_len,
|
369 |
+
decoder_only_format=dataset_cfg.decoder_only_format,
|
370 |
+
allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False),
|
371 |
+
)
|
372 |
+
|
373 |
+
packing_ratio = dataset_cfg.get('packing_ratio')
|
374 |
+
if packing_ratio is None:
|
375 |
+
if dataset_cfg.get('max_leftover_bins_to_keep') is not None:
|
376 |
+
raise ValueError(
|
377 |
+
'dataset.max_leftover_bins_to_keep has been defined, ' +\
|
378 |
+
'but dataset.packing_ratio has not been set. Please set ' +\
|
379 |
+
'the latter to turn on packing or remove the former from the config.')
|
380 |
+
return collate_fn, device_batch_size
|
381 |
+
|
382 |
+
if packing_ratio == 1.0:
|
383 |
+
return collate_fn, device_batch_size
|
384 |
+
elif packing_ratio < 1.0:
|
385 |
+
raise ValueError('packing_ratio must be >= 1, if supplied')
|
386 |
+
|
387 |
+
if not dataset_cfg.decoder_only_format:
|
388 |
+
raise NotImplementedError(
|
389 |
+
'On-the-fly packing is currently only supported for decoder-only formats.'
|
390 |
+
)
|
391 |
+
|
392 |
+
collate_fn = BinPackWrapper(
|
393 |
+
collator=collate_fn,
|
394 |
+
target_batch_size=device_batch_size,
|
395 |
+
max_seq_len=dataset_cfg.max_seq_len,
|
396 |
+
pad_token_id=tokenizer.pad_token_id,
|
397 |
+
padding_side=tokenizer.padding_side,
|
398 |
+
max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'),
|
399 |
+
)
|
400 |
+
n_examples_to_pack = int(device_batch_size * packing_ratio)
|
401 |
+
return collate_fn, n_examples_to_pack
|
402 |
+
|
403 |
+
|
404 |
+
if __name__ == '__main__':
|
405 |
+
import torch
|
406 |
+
from omegaconf import OmegaConf as om
|
407 |
+
|
408 |
+
from llmfoundry.utils import build_tokenizer
|
409 |
+
cfg = om.create({
|
410 |
+
'dataset': {
|
411 |
+
'hf_name':
|
412 |
+
'tatsu-lab/alpaca',
|
413 |
+
'preprocessing_fn':
|
414 |
+
'llmfoundry.data.finetuning.tasks:alpaca_preprocessing_function',
|
415 |
+
'split':
|
416 |
+
'train',
|
417 |
+
'packing_ratio':
|
418 |
+
18.0,
|
419 |
+
'max_seq_len':
|
420 |
+
2048,
|
421 |
+
'decoder_only_format':
|
422 |
+
True,
|
423 |
+
'separator_text':
|
424 |
+
False,
|
425 |
+
'allow_pad_trimming':
|
426 |
+
False,
|
427 |
+
'num_canonical_nodes':
|
428 |
+
472,
|
429 |
+
'shuffle':
|
430 |
+
True,
|
431 |
+
},
|
432 |
+
'drop_last': False,
|
433 |
+
'num_workers': 0,
|
434 |
+
'pin_memory': False,
|
435 |
+
'prefetch_factor': 2,
|
436 |
+
'persistent_workers': False,
|
437 |
+
'timeout': 0
|
438 |
+
})
|
439 |
+
|
440 |
+
tokenizer_name = 'EleutherAI/gpt-neox-20b'
|
441 |
+
tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
|
442 |
+
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
|
443 |
+
|
444 |
+
device_batch_size = 2
|
445 |
+
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
|
446 |
+
|
447 |
+
packing = cfg.dataset.get('packing_ratio') is not None
|
448 |
+
|
449 |
+
for i, batch in enumerate(dataloader):
|
450 |
+
if i >= 5:
|
451 |
+
break
|
452 |
+
print(f'-----Batch {i}-----')
|
453 |
+
for k, v in batch.items():
|
454 |
+
if isinstance(v, torch.Tensor):
|
455 |
+
print(k, v.shape)
|
456 |
+
else:
|
457 |
+
print(k, v)
|
458 |
+
for j in range(device_batch_size):
|
459 |
+
print(f'--- Sample {j} ---')
|
460 |
+
if cfg.dataset.decoder_only_format:
|
461 |
+
if packing:
|
462 |
+
for subseq in range(int(batch['sequence_id'][j].max()) + 1):
|
463 |
+
is_subseq = batch['sequence_id'][j] == subseq
|
464 |
+
print(
|
465 |
+
'\033[93m{}\033[00m\n'.format('INPUT IDS:'),
|
466 |
+
tokenizer.decode(batch['input_ids'][
|
467 |
+
j,
|
468 |
+
torch.logical_and(
|
469 |
+
is_subseq, batch['attention_mask'][j] ==
|
470 |
+
1)],
|
471 |
+
skip_special_tokens=False))
|
472 |
+
print(
|
473 |
+
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
|
474 |
+
tokenizer.decode(batch['input_ids'][
|
475 |
+
j,
|
476 |
+
torch.logical_and(
|
477 |
+
is_subseq, batch['bidirectional_mask'][j] ==
|
478 |
+
1)],
|
479 |
+
skip_special_tokens=False))
|
480 |
+
print(
|
481 |
+
'\033[91m{}\033[00m\n'.format('TARGET: '),
|
482 |
+
tokenizer.decode(batch['input_ids'][
|
483 |
+
j,
|
484 |
+
torch.logical_and(
|
485 |
+
is_subseq,
|
486 |
+
batch['labels'][j] != _HF_IGNORE_INDEX)],
|
487 |
+
skip_special_tokens=False))
|
488 |
+
else:
|
489 |
+
print(
|
490 |
+
'\033[93m{}\033[00m\n'.format('INPUT IDS:'),
|
491 |
+
tokenizer.decode(
|
492 |
+
batch['input_ids'][j,
|
493 |
+
batch['attention_mask'][j] == 1],
|
494 |
+
skip_special_tokens=False))
|
495 |
+
print(
|
496 |
+
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
|
497 |
+
tokenizer.decode(batch['input_ids'][
|
498 |
+
j, batch['bidirectional_mask'][j] == 1],
|
499 |
+
skip_special_tokens=False))
|
500 |
+
print(
|
501 |
+
'\033[91m{}\033[00m\n'.format('TARGET: '),
|
502 |
+
tokenizer.decode(batch['input_ids'][
|
503 |
+
j, batch['labels'][j] != _HF_IGNORE_INDEX],
|
504 |
+
skip_special_tokens=False))
|
505 |
+
else:
|
506 |
+
print(
|
507 |
+
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
|
508 |
+
tokenizer.decode(
|
509 |
+
batch['input_ids'][j, batch['attention_mask'][j] == 1],
|
510 |
+
skip_special_tokens=False))
|
511 |
+
print(
|
512 |
+
'\033[91m{}\033[00m\n'.format('TARGET: '),
|
513 |
+
tokenizer.decode(batch['labels'][
|
514 |
+
j, batch['decoder_attention_mask'][j] == 1],
|
515 |
+
skip_special_tokens=False))
|
516 |
+
print(' ')
|
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/tasks.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Includes code for task-specific seq-to-seq data formatting.
|
5 |
+
|
6 |
+
This file provides some templates/examples of preprocessing functions
|
7 |
+
that format examples for use in seq-to-seq finetuning tasks.
|
8 |
+
These preprocessing functions take individual examples that contain raw
|
9 |
+
text and process them into formatted examples.
|
10 |
+
|
11 |
+
These functions have this basic structure:
|
12 |
+
|
13 |
+
def preprocessing_fn(example: Dict) -> Dict[str, str]:
|
14 |
+
# code to extract prompt/response from `example`
|
15 |
+
...
|
16 |
+
return {
|
17 |
+
'prompt': <prompt>,
|
18 |
+
'response': <response>,
|
19 |
+
}
|
20 |
+
|
21 |
+
where `<prompt>` is a placeholder for the prompt text string that you
|
22 |
+
extracted from the input example, and '<response>' is a placeholder for
|
23 |
+
the response text string.
|
24 |
+
|
25 |
+
Just to be clear, "prompt" represents the text you would give the model
|
26 |
+
at inference time, and "response" represents the text you are training
|
27 |
+
it to produce given the prompt.
|
28 |
+
|
29 |
+
The key requirement of these functions is that they return a dictionary
|
30 |
+
with "prompt" and "response" keys, and that the values associated with
|
31 |
+
those keys are strings (i.e. text).
|
32 |
+
"""
|
33 |
+
|
34 |
+
import importlib
|
35 |
+
import logging
|
36 |
+
import os
|
37 |
+
import warnings
|
38 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
39 |
+
|
40 |
+
import datasets as hf_datasets
|
41 |
+
from omegaconf import DictConfig
|
42 |
+
from streaming import StreamingDataset
|
43 |
+
from transformers import PreTrainedTokenizerBase
|
44 |
+
|
45 |
+
log = logging.getLogger(__name__)
|
46 |
+
|
47 |
+
__all__ = ['dataset_constructor']
|
48 |
+
|
49 |
+
|
50 |
+
def _tokenize_formatted_example(
|
51 |
+
example: Dict[str, Any],
|
52 |
+
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
|
53 |
+
if ('prompt' not in example) or ('response' not in example):
|
54 |
+
raise KeyError(
|
55 |
+
'Unable to tokenize example because it has not been properly formatted. ' +\
|
56 |
+
'"prompt" and "response" are required keys but at least one was missing ' +\
|
57 |
+
f'from {example=}.'
|
58 |
+
)
|
59 |
+
if not isinstance(example['prompt'], str):
|
60 |
+
raise TypeError(
|
61 |
+
f'Unable to tokenize example because "prompt" was not a string. {example=}'
|
62 |
+
)
|
63 |
+
if not isinstance(example['response'], str):
|
64 |
+
raise TypeError(
|
65 |
+
f'Unable to tokenize example because "response" was not a string. {example=}'
|
66 |
+
)
|
67 |
+
return tokenizer(text=example['prompt'], text_target=example['response'])
|
68 |
+
|
69 |
+
|
70 |
+
class StreamingFinetuningDataset(StreamingDataset):
|
71 |
+
"""Finetuning dataset with flexible tokenization using StreamingDataset.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
|
75 |
+
tokenize samples.
|
76 |
+
local (str): Local dataset directory where shards are cached by split.
|
77 |
+
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
|
78 |
+
its data must exist locally. StreamingDataset uses either ``streams`` or
|
79 |
+
``remote``/``local``. Defaults to ``None``.
|
80 |
+
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
|
81 |
+
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
|
82 |
+
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
|
83 |
+
download_timeout (float): Number of seconds to wait for a shard to download before raising
|
84 |
+
an exception. Defaults to ``60``.
|
85 |
+
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
|
86 |
+
shards. Defaults to ``None``.
|
87 |
+
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
|
88 |
+
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
|
89 |
+
`False``.
|
90 |
+
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
|
91 |
+
streams. If ``None``, takes its value from the total number of underlying samples.
|
92 |
+
Provide this field if you are weighting streams relatively to target a larger or
|
93 |
+
smaller epoch size. Defaults to ``None``.
|
94 |
+
predownload (int, optional): Target number of samples ahead to download the shards of while
|
95 |
+
iterating. Defaults to ``100_000``.
|
96 |
+
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
|
97 |
+
shard cache. Before downloading a shard, the least recently used resident shard(s) may
|
98 |
+
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
|
99 |
+
to disable shard eviction. Supports integer bytes as well as string human-readable
|
100 |
+
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
|
101 |
+
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
|
102 |
+
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
|
103 |
+
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
|
104 |
+
initial run.
|
105 |
+
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
|
106 |
+
partitioned over the workers. Defaults to ``None``.
|
107 |
+
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
|
108 |
+
``False``.
|
109 |
+
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
|
110 |
+
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
|
111 |
+
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
|
112 |
+
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
|
113 |
+
Defaults to ``balanced``.
|
114 |
+
sampling_granularity (int): When picking samples for a stream's final partial repeat,
|
115 |
+
how many samples to pick from the same shard at a time (``1`` for evenly balanced
|
116 |
+
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
|
117 |
+
Defaults to ``1``.
|
118 |
+
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
|
119 |
+
``per_stream``. Defaults to ``random``.
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self,
|
123 |
+
tokenizer: PreTrainedTokenizerBase,
|
124 |
+
local: str,
|
125 |
+
remote: Optional[str] = None,
|
126 |
+
split: Optional[str] = None,
|
127 |
+
download_retry: int = 2,
|
128 |
+
download_timeout: float = 60,
|
129 |
+
validate_hash: Optional[str] = None,
|
130 |
+
keep_zip: bool = False,
|
131 |
+
epoch_size: Optional[int] = None,
|
132 |
+
predownload: Optional[int] = None,
|
133 |
+
cache_limit: Optional[Union[int, str]] = None,
|
134 |
+
partition_algo: str = 'orig',
|
135 |
+
num_canonical_nodes: Optional[int] = None,
|
136 |
+
batch_size: Optional[int] = None,
|
137 |
+
shuffle: bool = False,
|
138 |
+
shuffle_algo: str = 'py1b',
|
139 |
+
shuffle_seed: int = 9176,
|
140 |
+
shuffle_block_size: int = 1 << 18,
|
141 |
+
sampling_method: str = 'balanced',
|
142 |
+
sampling_granularity: int = 1,
|
143 |
+
batching_method: str = 'random',
|
144 |
+
**kwargs: Any):
|
145 |
+
|
146 |
+
if len(kwargs) > 0:
|
147 |
+
raise ValueError(
|
148 |
+
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}'
|
149 |
+
)
|
150 |
+
|
151 |
+
if remote is None or (local == remote):
|
152 |
+
if os.path.isdir(local):
|
153 |
+
contents = set(os.listdir(local))
|
154 |
+
if split not in contents:
|
155 |
+
raise ValueError(
|
156 |
+
f'local directory {local} does not contain split {split}'
|
157 |
+
)
|
158 |
+
|
159 |
+
# Build Dataset
|
160 |
+
super().__init__(
|
161 |
+
local=local,
|
162 |
+
remote=remote,
|
163 |
+
split=split,
|
164 |
+
download_retry=download_retry,
|
165 |
+
download_timeout=download_timeout,
|
166 |
+
validate_hash=validate_hash,
|
167 |
+
keep_zip=keep_zip,
|
168 |
+
epoch_size=epoch_size,
|
169 |
+
predownload=predownload,
|
170 |
+
cache_limit=cache_limit,
|
171 |
+
partition_algo=partition_algo,
|
172 |
+
num_canonical_nodes=num_canonical_nodes,
|
173 |
+
batch_size=batch_size,
|
174 |
+
shuffle=shuffle,
|
175 |
+
shuffle_algo=shuffle_algo,
|
176 |
+
shuffle_seed=shuffle_seed,
|
177 |
+
shuffle_block_size=shuffle_block_size,
|
178 |
+
sampling_method=sampling_method,
|
179 |
+
sampling_granularity=sampling_granularity,
|
180 |
+
batching_method=batching_method,
|
181 |
+
)
|
182 |
+
|
183 |
+
self.tokenizer = tokenizer
|
184 |
+
|
185 |
+
# How to process a sample
|
186 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
187 |
+
sample = super().__getitem__(idx)
|
188 |
+
return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
|
189 |
+
|
190 |
+
|
191 |
+
class DatasetConstructor:
|
192 |
+
|
193 |
+
def __init__(self):
|
194 |
+
self._task_preprocessing_registry: Dict[str, Callable] = {}
|
195 |
+
|
196 |
+
def register(self, *names: str) -> Callable[[Callable], Callable]:
|
197 |
+
"""Decorator for registering preprocessing functions."""
|
198 |
+
|
199 |
+
def _register_func(name: str, func: Callable) -> None:
|
200 |
+
if name in self._task_preprocessing_registry:
|
201 |
+
raise ValueError(
|
202 |
+
f'A tokenization function has already been registered with {name=}.'
|
203 |
+
)
|
204 |
+
self._task_preprocessing_registry[name] = func
|
205 |
+
return
|
206 |
+
|
207 |
+
def wrapper(func: Callable) -> Callable:
|
208 |
+
for name in names:
|
209 |
+
_register_func(name, func)
|
210 |
+
return func
|
211 |
+
|
212 |
+
return wrapper
|
213 |
+
|
214 |
+
def print_registered_tasks(self) -> None:
|
215 |
+
tasks = sorted(self._task_preprocessing_registry.keys())
|
216 |
+
print('\n'.join(tasks))
|
217 |
+
|
218 |
+
def get_preprocessing_fn_from_dict(
|
219 |
+
self, mapping: Union[Dict, DictConfig]
|
220 |
+
) -> Callable[[Dict[str, Any]], Dict[str, str]]:
|
221 |
+
"""Get a preprocessing function from a dictionary.
|
222 |
+
|
223 |
+
The dictionary maps column names in the dataset to "prompt" and "response".
|
224 |
+
For example,
|
225 |
+
```yaml
|
226 |
+
preprocessing_fn:
|
227 |
+
prompt: text
|
228 |
+
response: summary
|
229 |
+
```
|
230 |
+
would map the `text` column as to prompt and the `summary` column as the response.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
mapping (dict): A dictionary mapping column names to "prompt" and "response".
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
Callable: The preprocessing function.
|
237 |
+
|
238 |
+
Raises:
|
239 |
+
ValueError: If the mapping does not have keys "prompt" and "response".
|
240 |
+
"""
|
241 |
+
|
242 |
+
def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:
|
243 |
+
if list(mapping.keys()) != ['prompt', 'response']:
|
244 |
+
raise ValueError(
|
245 |
+
f'Expected {mapping=} to have keys "prompt" and "response".'
|
246 |
+
)
|
247 |
+
return {
|
248 |
+
'prompt': example[mapping['prompt']],
|
249 |
+
'response': example[mapping['response']]
|
250 |
+
}
|
251 |
+
|
252 |
+
return _preprocessor
|
253 |
+
|
254 |
+
def get_preprocessing_fn_from_str(
|
255 |
+
self,
|
256 |
+
preprocessor: Optional[str],
|
257 |
+
dataset_name: Optional[str] = None
|
258 |
+
) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
|
259 |
+
"""Get a preprocessing function from a string.
|
260 |
+
|
261 |
+
String can be either a registered function or an import path.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
preprocessor (Optional[str]): The name of the preprocessing function, or an import path.
|
265 |
+
dataset_name (Optional[str]): The dataset name to look up in the registry.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
Callable: The preprocessing function or None if not found.
|
269 |
+
|
270 |
+
Raises:
|
271 |
+
ValueError: If the preprocessing function import from the provided string fails.
|
272 |
+
"""
|
273 |
+
if preprocessor is None:
|
274 |
+
if dataset_name is None:
|
275 |
+
return None
|
276 |
+
if dataset_name in self._task_preprocessing_registry:
|
277 |
+
log.info(
|
278 |
+
f'Re-formatting dataset with "{dataset_name}" preprocessing function.'
|
279 |
+
)
|
280 |
+
return self._task_preprocessing_registry[dataset_name]
|
281 |
+
else:
|
282 |
+
log.info('No preprocessor was supplied and no preprocessing function ' +\
|
283 |
+
f'is registered for dataset name "{dataset_name}". No additional ' +\
|
284 |
+
'preprocessing will be applied. If the dataset is already formatted ' +\
|
285 |
+
'correctly, you can ignore this message.')
|
286 |
+
return None
|
287 |
+
if preprocessor in self._task_preprocessing_registry:
|
288 |
+
log.info(
|
289 |
+
f'Re-formatting dataset with "{preprocessor}" preprocessing function.'
|
290 |
+
)
|
291 |
+
return self._task_preprocessing_registry[preprocessor]
|
292 |
+
|
293 |
+
try:
|
294 |
+
import_path, function_name = preprocessor.split(':', maxsplit=1)
|
295 |
+
module = importlib.import_module(import_path)
|
296 |
+
preprocessing_fn = getattr(module, function_name)
|
297 |
+
except Exception as e:
|
298 |
+
raise ValueError(
|
299 |
+
f'Failed to import preprocessing function from string = {preprocessor}.'
|
300 |
+
) from e
|
301 |
+
|
302 |
+
return preprocessing_fn
|
303 |
+
|
304 |
+
def build_from_hf(
|
305 |
+
self, cfg: DictConfig, max_seq_len: int,
|
306 |
+
tokenizer: PreTrainedTokenizerBase
|
307 |
+
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
|
308 |
+
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
|
309 |
+
"""Load a HuggingFace Datasets, preprocess, and tokenize.
|
310 |
+
|
311 |
+
Note: This function will drop examples where the prompt is longer than the max_seq_len
|
312 |
+
|
313 |
+
Args:
|
314 |
+
cfg (DictConfig): The dataset configuration.
|
315 |
+
max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped.
|
316 |
+
tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
Dataset: The tokenized dataset.
|
320 |
+
"""
|
321 |
+
dataset_name = cfg.hf_name
|
322 |
+
# HF datasets does not support a split with dashes,so we replace split
|
323 |
+
# dashes with underscore.
|
324 |
+
split = cfg.split.replace('-', '_')
|
325 |
+
kwargs = cfg.get('hf_kwargs', {})
|
326 |
+
proto_preprocessing_fn = cfg.get('preprocessing_fn')
|
327 |
+
if isinstance(proto_preprocessing_fn, dict) or isinstance(
|
328 |
+
proto_preprocessing_fn, DictConfig):
|
329 |
+
preprocessing_fn = self.get_preprocessing_fn_from_dict(
|
330 |
+
proto_preprocessing_fn)
|
331 |
+
else:
|
332 |
+
preprocessing_fn = self.get_preprocessing_fn_from_str(
|
333 |
+
proto_preprocessing_fn, dataset_name)
|
334 |
+
|
335 |
+
dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)
|
336 |
+
|
337 |
+
def dataset_mapper(example: Dict):
|
338 |
+
if preprocessing_fn is not None:
|
339 |
+
example = preprocessing_fn(example)
|
340 |
+
return _tokenize_formatted_example(example, tokenizer)
|
341 |
+
|
342 |
+
columns_to_remove = list(dataset[0].keys())
|
343 |
+
tokenized_dataset = dataset.map(
|
344 |
+
dataset_mapper,
|
345 |
+
batched=False,
|
346 |
+
remove_columns=columns_to_remove,
|
347 |
+
)
|
348 |
+
prompt_length_filtered_dataset = tokenized_dataset.filter(
|
349 |
+
lambda example: len(example['input_ids']) < max_seq_len)
|
350 |
+
|
351 |
+
examples_removed = len(tokenized_dataset) - len(
|
352 |
+
prompt_length_filtered_dataset)
|
353 |
+
if examples_removed > 0:
|
354 |
+
warnings.warn(
|
355 |
+
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
|
356 |
+
)
|
357 |
+
|
358 |
+
empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
|
359 |
+
lambda example: len(example['input_ids']) > 0 and len(example[
|
360 |
+
'labels']) > 0 and any(token_id != tokenizer.pad_token_id
|
361 |
+
for token_id in example['labels']))
|
362 |
+
empty_examples_removed = len(prompt_length_filtered_dataset) - len(
|
363 |
+
empty_examples_dropped_dataset)
|
364 |
+
if empty_examples_removed > 0:
|
365 |
+
warnings.warn(
|
366 |
+
f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
|
367 |
+
+ 'or the response was only padding tokens.')
|
368 |
+
|
369 |
+
return empty_examples_dropped_dataset
|
370 |
+
|
371 |
+
def build_from_streaming(self, *args: Any,
|
372 |
+
**kwargs: Any) -> StreamingFinetuningDataset:
|
373 |
+
return StreamingFinetuningDataset(*args, **kwargs)
|
374 |
+
|
375 |
+
|
376 |
+
dataset_constructor = DatasetConstructor()
|
377 |
+
|
378 |
+
|
379 |
+
@dataset_constructor.register('tatsu-lab/alpaca')
|
380 |
+
def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
|
381 |
+
"""Split out prompt/response from text."""
|
382 |
+
try:
|
383 |
+
prompt, response = inp['text'].split('### Response:')
|
384 |
+
prompt += '### Response:'
|
385 |
+
except Exception as e:
|
386 |
+
raise ValueError(
|
387 |
+
f"Unable to extract prompt/response from 'text'={inp['text']}"
|
388 |
+
) from e
|
389 |
+
return {'prompt': prompt, 'response': response}
|
390 |
+
|
391 |
+
|
392 |
+
@dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
|
393 |
+
def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
|
394 |
+
"""Format the text string."""
|
395 |
+
PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
|
396 |
+
try:
|
397 |
+
if inp['input'] != '':
|
398 |
+
instruction = inp['instruction'] + '\n' + inp['input']
|
399 |
+
else:
|
400 |
+
instruction = inp['instruction']
|
401 |
+
prompt = PROMPT_FORMAT.format(instruction=instruction)
|
402 |
+
response = inp['output']
|
403 |
+
except Exception as e:
|
404 |
+
raise ValueError(
|
405 |
+
f'Unable to extract prompt/response from {inp=}') from e
|
406 |
+
return {'prompt': prompt, 'response': response}
|
407 |
+
|
408 |
+
|
409 |
+
@dataset_constructor.register('bigscience/P3')
|
410 |
+
def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
|
411 |
+
"""Format the already-split example."""
|
412 |
+
return {
|
413 |
+
'prompt': inp['inputs'] + ':',
|
414 |
+
'response': inp['targets'],
|
415 |
+
}
|
416 |
+
|
417 |
+
|
418 |
+
# Muennighoff's P3 and flan datasets share a similar convention
|
419 |
+
@dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
|
420 |
+
def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
|
421 |
+
"""Format the already-split example."""
|
422 |
+
try:
|
423 |
+
prompt: str = inp['inputs']
|
424 |
+
response: str = inp['targets']
|
425 |
+
# Put a space before the response if needed
|
426 |
+
transitions = (' ', '\n', '\t')
|
427 |
+
if not (prompt.endswith(transitions) or
|
428 |
+
response.startswith(transitions)):
|
429 |
+
response = ' ' + response
|
430 |
+
except Exception as e:
|
431 |
+
raise ValueError(
|
432 |
+
f'Unable to process prompt/response from {inp=}') from e
|
433 |
+
return {'prompt': prompt, 'response': response}
|
Perceptrix/finetune/build/lib/llmfoundry/data/packing.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import os
|
5 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from omegaconf import DictConfig
|
10 |
+
from transformers import PreTrainedTokenizerBase
|
11 |
+
|
12 |
+
|
13 |
+
class BinPackWrapper:
|
14 |
+
"""Utility collator for packing to reduce padding."""
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
collator: Callable,
|
18 |
+
target_batch_size: int,
|
19 |
+
max_seq_len: int,
|
20 |
+
pad_token_id: int,
|
21 |
+
padding_side: Literal['left', 'right'],
|
22 |
+
max_leftover_bins_to_keep: Optional[int] = None):
|
23 |
+
self.base_collator = collator
|
24 |
+
self.out_size = int(target_batch_size)
|
25 |
+
self.max_seq_len = int(max_seq_len)
|
26 |
+
self.pad_token_id = int(pad_token_id)
|
27 |
+
self.padding_side = padding_side
|
28 |
+
|
29 |
+
if self.out_size <= 0:
|
30 |
+
raise ValueError(f'{target_batch_size=} must be >0.')
|
31 |
+
if self.max_seq_len <= 0:
|
32 |
+
raise ValueError(f'{max_seq_len=} must be >0.')
|
33 |
+
if self.pad_token_id < 0:
|
34 |
+
raise ValueError(f'{pad_token_id=} must be >=0.')
|
35 |
+
|
36 |
+
if max_leftover_bins_to_keep is None:
|
37 |
+
self.max_leftover_bins_to_keep = int(10 * self.out_size)
|
38 |
+
elif max_leftover_bins_to_keep < 0:
|
39 |
+
raise ValueError(
|
40 |
+
f'{max_leftover_bins_to_keep=} must be >=0 or None.')
|
41 |
+
else:
|
42 |
+
self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep)
|
43 |
+
|
44 |
+
self.n_packed_tokens = 0
|
45 |
+
self.n_total_tokens = 0
|
46 |
+
self.n_packed_examples = 0
|
47 |
+
|
48 |
+
self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []
|
49 |
+
|
50 |
+
@property
|
51 |
+
def waste(self) -> float:
|
52 |
+
return 1 - (self.n_packed_tokens / self.n_total_tokens)
|
53 |
+
|
54 |
+
@property
|
55 |
+
def efficiency(self) -> float:
|
56 |
+
return self.n_packed_tokens / (self.max_seq_len *
|
57 |
+
self.n_packed_examples)
|
58 |
+
|
59 |
+
def __call__(
|
60 |
+
self,
|
61 |
+
examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
62 |
+
batch = self.base_collator(examples)
|
63 |
+
|
64 |
+
assert 'attention_mask' in batch
|
65 |
+
assert 'input_ids' in batch
|
66 |
+
|
67 |
+
for key in batch.keys():
|
68 |
+
assert key in [
|
69 |
+
'input_ids',
|
70 |
+
'labels',
|
71 |
+
'attention_mask',
|
72 |
+
'bidirectional_mask',
|
73 |
+
]
|
74 |
+
|
75 |
+
# Cut everything down to size
|
76 |
+
sizes, trimmed_examples = [], []
|
77 |
+
for idx in range(batch['attention_mask'].shape[0]):
|
78 |
+
size, trimmed_example = extract_trim_batch_idx(batch, idx)
|
79 |
+
sizes.append(size)
|
80 |
+
trimmed_examples.append(trimmed_example)
|
81 |
+
|
82 |
+
# Apply our CS 101 bin packing algorithm.
|
83 |
+
packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing(
|
84 |
+
sizes=sizes,
|
85 |
+
examples=trimmed_examples,
|
86 |
+
num_bins=self.out_size,
|
87 |
+
max_bin_size=self.max_seq_len,
|
88 |
+
existing_bins=self._leftover_bins,
|
89 |
+
)
|
90 |
+
self.n_packed_tokens += n_packed_tokens
|
91 |
+
self.n_total_tokens += n_total_tokens
|
92 |
+
self.n_packed_examples += self.out_size
|
93 |
+
self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
|
94 |
+
|
95 |
+
# Re-pad to max_seq_len and batch
|
96 |
+
batch = repad(packed_examples,
|
97 |
+
max_seq_len=self.max_seq_len,
|
98 |
+
pad_token_id=self.pad_token_id,
|
99 |
+
padding_side=self.padding_side)
|
100 |
+
return batch
|
101 |
+
|
102 |
+
|
103 |
+
def extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
|
104 |
+
idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
|
105 |
+
example = {k: v[idx] for k, v in batch.items()}
|
106 |
+
|
107 |
+
keep = example['attention_mask'] == 1
|
108 |
+
size = int(keep.sum())
|
109 |
+
trim_example = {k: v[keep] for k, v in example.items()}
|
110 |
+
trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])
|
111 |
+
|
112 |
+
return size, trim_example
|
113 |
+
|
114 |
+
|
115 |
+
def combine_in_place(
|
116 |
+
example: Dict[str, torch.Tensor],
|
117 |
+
add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
118 |
+
if 'labels' in add_on:
|
119 |
+
# Prevents the last token in example from being trained to
|
120 |
+
# predict the first token in add_on, which would make no sense.
|
121 |
+
add_on['labels'][0] = -100
|
122 |
+
|
123 |
+
for k in example.keys():
|
124 |
+
if k == 'sequence_id':
|
125 |
+
example[k] = torch.cat(
|
126 |
+
[example[k], add_on[k] + 1 + torch.max(example[k])])
|
127 |
+
else:
|
128 |
+
example[k] = torch.cat([example[k], add_on[k]])
|
129 |
+
return example
|
130 |
+
|
131 |
+
|
132 |
+
def first_fit_bin_packing(
|
133 |
+
sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int,
|
134 |
+
max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]
|
135 |
+
) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[
|
136 |
+
str, torch.Tensor]]]]:
|
137 |
+
|
138 |
+
# Will contain tuples (bin_size_size, packed_example)
|
139 |
+
bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins
|
140 |
+
|
141 |
+
starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])
|
142 |
+
|
143 |
+
sizes_and_examples = [
|
144 |
+
(size, example) for size, example in zip(sizes, examples)
|
145 |
+
]
|
146 |
+
sorted_sizes_and_examples = sorted(sizes_and_examples,
|
147 |
+
key=lambda x: x[0],
|
148 |
+
reverse=True)
|
149 |
+
|
150 |
+
required_num_examples = max(0, num_bins - len(bins))
|
151 |
+
num_examples = len(sizes)
|
152 |
+
if num_examples < required_num_examples:
|
153 |
+
for size, example in sorted_sizes_and_examples:
|
154 |
+
# Can't keep packing. All remaining items get their own bin.
|
155 |
+
bins.append((size, example))
|
156 |
+
|
157 |
+
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
|
158 |
+
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
|
159 |
+
total_example_sizes = sum(sizes)
|
160 |
+
if total_new_bin_sizes != total_example_sizes:
|
161 |
+
raise AssertionError(
|
162 |
+
f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
|
163 |
+
)
|
164 |
+
|
165 |
+
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
|
166 |
+
bin_sizes, packed_examples = [], []
|
167 |
+
for bin_size, packed_example in sorted_bins:
|
168 |
+
bin_sizes.append(bin_size)
|
169 |
+
packed_examples.append(packed_example)
|
170 |
+
|
171 |
+
# Return:
|
172 |
+
# - the num_bins largest packed examples
|
173 |
+
# - the total tokens in those examples
|
174 |
+
# - the total size of all new examples
|
175 |
+
# - leftover bins
|
176 |
+
return packed_examples[:num_bins], sum(
|
177 |
+
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]
|
178 |
+
|
179 |
+
# Go through each item from longest to shortest.
|
180 |
+
# Note: all items will either go into an existing or new bin.
|
181 |
+
for i, (size, example) in enumerate(sorted_sizes_and_examples):
|
182 |
+
# If we can't keep packing, all remaining items get their own bin.
|
183 |
+
required_num_examples = max(0, num_bins - len(bins))
|
184 |
+
n_remaining = num_examples - i
|
185 |
+
assert n_remaining >= required_num_examples
|
186 |
+
if n_remaining == required_num_examples:
|
187 |
+
# Can't keep packing. All remaining items get their own bin.
|
188 |
+
bins.append((size, example))
|
189 |
+
continue
|
190 |
+
|
191 |
+
# Add it to the first bin it fits in
|
192 |
+
added = False
|
193 |
+
for bidx in range(len(bins)):
|
194 |
+
if bins[bidx][0] + size <= max_bin_size:
|
195 |
+
bin_size, packed_example = bins.pop(bidx)
|
196 |
+
bin_size = bin_size + size
|
197 |
+
packed_example = combine_in_place(packed_example, example)
|
198 |
+
bins.append((bin_size, packed_example))
|
199 |
+
added = True
|
200 |
+
break
|
201 |
+
# If it didn't fit anywhere, open a new bin
|
202 |
+
if not added:
|
203 |
+
bins.append((size, example))
|
204 |
+
|
205 |
+
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
|
206 |
+
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
|
207 |
+
total_example_sizes = sum(sizes)
|
208 |
+
if total_new_bin_sizes != total_example_sizes:
|
209 |
+
raise AssertionError(
|
210 |
+
f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
|
211 |
+
)
|
212 |
+
|
213 |
+
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
|
214 |
+
bin_sizes, packed_examples = [], []
|
215 |
+
for bin_size, packed_example in sorted_bins:
|
216 |
+
bin_sizes.append(bin_size)
|
217 |
+
packed_examples.append(packed_example)
|
218 |
+
|
219 |
+
# Return:
|
220 |
+
# - the num_bins largest packed examples
|
221 |
+
# - the total tokens in those examples
|
222 |
+
# - the total size of all new examples
|
223 |
+
# - leftover bins
|
224 |
+
return packed_examples[:num_bins], sum(
|
225 |
+
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]
|
226 |
+
|
227 |
+
|
228 |
+
def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
|
229 |
+
pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
|
230 |
+
|
231 |
+
def pad_tensor(tensor: torch.Tensor, pad_value: int):
|
232 |
+
if len(tensor) == max_seq_len:
|
233 |
+
return tensor
|
234 |
+
t = torch.full((max_seq_len,),
|
235 |
+
pad_value,
|
236 |
+
dtype=tensor.dtype,
|
237 |
+
device=tensor.device)
|
238 |
+
if padding_side == 'left':
|
239 |
+
t[-len(tensor):] = tensor
|
240 |
+
elif padding_side == 'right':
|
241 |
+
t[:len(tensor)] = tensor
|
242 |
+
else:
|
243 |
+
raise ValueError(f'Unknown {padding_side=}')
|
244 |
+
return t
|
245 |
+
|
246 |
+
pad_vals = {
|
247 |
+
'input_ids': pad_token_id,
|
248 |
+
'labels': -100,
|
249 |
+
'attention_mask': 0,
|
250 |
+
'bidirectional_mask': 0,
|
251 |
+
'sequence_id': -1,
|
252 |
+
}
|
253 |
+
keys = packed_examples[0].keys()
|
254 |
+
batch = {}
|
255 |
+
for key in keys:
|
256 |
+
batch[key] = torch.stack([
|
257 |
+
pad_tensor(example[key], pad_vals[key])
|
258 |
+
for example in packed_examples
|
259 |
+
])
|
260 |
+
return batch
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == '__main__':
|
264 |
+
from argparse import ArgumentParser, Namespace
|
265 |
+
|
266 |
+
from omegaconf import OmegaConf as om
|
267 |
+
|
268 |
+
from llmfoundry import (build_finetuning_dataloader,
|
269 |
+
build_text_denoising_dataloader)
|
270 |
+
from llmfoundry.data import build_text_dataloader
|
271 |
+
from llmfoundry.utils import build_tokenizer
|
272 |
+
|
273 |
+
def parse_args() -> Namespace:
|
274 |
+
"""Parse commandline arguments."""
|
275 |
+
parser = ArgumentParser(
|
276 |
+
description=
|
277 |
+
'Profile packing_ratio choices for a particular workload.')
|
278 |
+
parser.add_argument(
|
279 |
+
'--yaml-path',
|
280 |
+
type=str,
|
281 |
+
required=True,
|
282 |
+
help='Path to the YAML that defines the workload to profile.')
|
283 |
+
parser.add_argument('--num-devices',
|
284 |
+
type=int,
|
285 |
+
default=None,
|
286 |
+
help='How many devices your run will use.')
|
287 |
+
parser.add_argument('--min',
|
288 |
+
type=float,
|
289 |
+
required=True,
|
290 |
+
help='Smallest packing_ratio to test. Must be >=1.')
|
291 |
+
parser.add_argument(
|
292 |
+
'--max',
|
293 |
+
type=float,
|
294 |
+
required=True,
|
295 |
+
help='Largest packing_ratio to test. Must be larger than `min`.')
|
296 |
+
parser.add_argument(
|
297 |
+
'--num-packing-ratios',
|
298 |
+
type=int,
|
299 |
+
default=10,
|
300 |
+
help=
|
301 |
+
'Number of packing_ratio values (spaced between `min` and `max) to try.'
|
302 |
+
)
|
303 |
+
|
304 |
+
args = parser.parse_args()
|
305 |
+
|
306 |
+
if not os.path.isfile(args.yaml_path):
|
307 |
+
raise FileNotFoundError(
|
308 |
+
'`yaml_path` does not correspond to any existing file.')
|
309 |
+
if args.num_devices < 1:
|
310 |
+
raise ValueError('`num_devices` must be a positive integer.')
|
311 |
+
if args.min < 1.0:
|
312 |
+
raise ValueError('`min` must be >=1.0.')
|
313 |
+
if args.max < args.min:
|
314 |
+
raise ValueError('`max` cannot be less than `min`.')
|
315 |
+
if args.num_packing_ratios < 1:
|
316 |
+
raise ValueError('`num_packing_ratios` must be a positive integer.')
|
317 |
+
return args
|
318 |
+
|
319 |
+
def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
|
320 |
+
device_batch_size: int):
|
321 |
+
if cfg.name == 'text':
|
322 |
+
return build_text_dataloader(cfg, tokenizer, device_batch_size)
|
323 |
+
elif cfg.name == 'text_denoising':
|
324 |
+
return build_text_denoising_dataloader(cfg, tokenizer,
|
325 |
+
device_batch_size)
|
326 |
+
elif cfg.name == 'finetuning':
|
327 |
+
return build_finetuning_dataloader(cfg, tokenizer,
|
328 |
+
device_batch_size)
|
329 |
+
else:
|
330 |
+
raise ValueError(
|
331 |
+
f'Not sure how to build dataloader with config: {cfg}')
|
332 |
+
|
333 |
+
args = parse_args()
|
334 |
+
|
335 |
+
with open(args.yaml_path) as f:
|
336 |
+
cfg = om.load(f)
|
337 |
+
if 'parameters' in cfg:
|
338 |
+
cfg = om.to_container(cfg.parameters)
|
339 |
+
cfg = om.create(cfg)
|
340 |
+
device_batch_size = cfg.global_train_batch_size // args.num_devices
|
341 |
+
|
342 |
+
# Determine the packing_ratio values we'll try
|
343 |
+
packing_ratios, raw_batch_sizes = [], []
|
344 |
+
for packing_ratio in np.linspace(args.min,
|
345 |
+
args.max,
|
346 |
+
args.num_packing_ratios,
|
347 |
+
endpoint=True):
|
348 |
+
packing_ratio = np.round(10 * packing_ratio) / 10
|
349 |
+
raw_batch_size = int(packing_ratio * device_batch_size)
|
350 |
+
if raw_batch_size not in raw_batch_sizes:
|
351 |
+
packing_ratios.append(packing_ratio)
|
352 |
+
raw_batch_sizes.append(raw_batch_size)
|
353 |
+
|
354 |
+
# Fetch a bunch of raw examples once, which we'll re-use
|
355 |
+
if 'train_loader' not in cfg:
|
356 |
+
raise ValueError('config must define train_loader')
|
357 |
+
dataloader_cfg = cfg.train_loader
|
358 |
+
|
359 |
+
max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep',
|
360 |
+
None)
|
361 |
+
|
362 |
+
# build tokenizer
|
363 |
+
if 'tokenizer' not in cfg:
|
364 |
+
raise ValueError('config must define tokenizer')
|
365 |
+
|
366 |
+
resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True)
|
367 |
+
if not isinstance(resolved_tokenizer_cfg, Dict):
|
368 |
+
raise ValueError(
|
369 |
+
'tokenizer config needs to be resolved by omegaconf into a Dict.')
|
370 |
+
tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg
|
371 |
+
|
372 |
+
tokenizer_name = tokenizer_cfg['name']
|
373 |
+
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
|
374 |
+
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
|
375 |
+
|
376 |
+
# Turn off packing for the dataloader (we want raw, pre-packed examples)
|
377 |
+
dataloader_cfg.dataset.packing_ratio = None
|
378 |
+
dataloader_cfg.dataset.max_leftovers_to_keep = None
|
379 |
+
train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
|
380 |
+
max(raw_batch_sizes) * 100)
|
381 |
+
|
382 |
+
# Get a bunch of raw examples
|
383 |
+
big_batch = next(iter(train_dataloader))
|
384 |
+
|
385 |
+
def split_big_batch(raw_batch_size: int) -> List:
|
386 |
+
input_ids = big_batch['input_ids'].split(raw_batch_size)
|
387 |
+
batches = [{'input_ids': x} for x in input_ids]
|
388 |
+
|
389 |
+
for key in big_batch.keys():
|
390 |
+
if key == 'input_ids':
|
391 |
+
continue
|
392 |
+
for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
|
393 |
+
batches[idx].update({key: split})
|
394 |
+
return batches
|
395 |
+
|
396 |
+
def profile_packing(raw_batch_size: int) -> Tuple[float, float]:
|
397 |
+
packer = BinPackWrapper(
|
398 |
+
collator=lambda x: x,
|
399 |
+
target_batch_size=device_batch_size,
|
400 |
+
max_seq_len=dataloader_cfg.dataset.max_seq_len,
|
401 |
+
pad_token_id=0, # <-- Doesn't need to be correct for profiling
|
402 |
+
padding_side='left', # <-- Doesn't need to be correct for profiling
|
403 |
+
max_leftover_bins_to_keep=max_leftovers_to_keep)
|
404 |
+
|
405 |
+
# Simulate feeding the packing collator a bunch of data
|
406 |
+
for batch in split_big_batch(raw_batch_size):
|
407 |
+
if batch['input_ids'].shape[0] < device_batch_size:
|
408 |
+
continue
|
409 |
+
_ = packer(batch)
|
410 |
+
|
411 |
+
# Return the padding / waste stats over that bunch of data
|
412 |
+
padding_percent = 100 * (1 - packer.efficiency)
|
413 |
+
waste_percent = 100 * packer.waste
|
414 |
+
return padding_percent, waste_percent
|
415 |
+
|
416 |
+
header = '\n\n\n packing_ratio | % PADDING | % WASTE'
|
417 |
+
fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%'
|
418 |
+
|
419 |
+
print(header)
|
420 |
+
print('-' * len(header))
|
421 |
+
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
|
422 |
+
padding, waste = profile_packing(raw_batch_size)
|
423 |
+
print(fstr.format(packing_ratio, padding, waste))
|
Perceptrix/finetune/build/lib/llmfoundry/data/text_data.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Build a StreamingTextDataset dataset and dataloader for training."""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from itertools import islice
|
8 |
+
from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence,
|
9 |
+
Union, cast)
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import transformers
|
14 |
+
from omegaconf import DictConfig
|
15 |
+
from omegaconf import OmegaConf as om
|
16 |
+
from streaming import Stream, StreamingDataset
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from transformers import PreTrainedTokenizerBase
|
19 |
+
|
20 |
+
|
21 |
+
class StreamingTextDataset(StreamingDataset):
|
22 |
+
"""Generic text dataset using MosaicML's StreamingDataset.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
tokenizer (Tokenizer): HuggingFace tokenizer to
|
26 |
+
tokenize samples.
|
27 |
+
max_seq_len (int): The max sequence length of each sample.
|
28 |
+
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
|
29 |
+
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
|
30 |
+
``remote``/``local``. Defaults to ``None``.
|
31 |
+
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
|
32 |
+
its data must exist locally. StreamingDataset uses either ``streams`` or
|
33 |
+
``remote``/``local``. Defaults to ``None``.
|
34 |
+
local (str, optional): Local working directory to download shards to. This is where shards
|
35 |
+
are cached while they are being used. Uses a temp directory if not set.
|
36 |
+
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
|
37 |
+
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
|
38 |
+
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
|
39 |
+
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
|
40 |
+
download_timeout (float): Number of seconds to wait for a shard to download before raising
|
41 |
+
an exception. Defaults to ``60``.
|
42 |
+
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
|
43 |
+
shards. Defaults to ``None``.
|
44 |
+
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
|
45 |
+
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
|
46 |
+
`False``.
|
47 |
+
epoch_size (int, optional): Number of samples to draw per epoch balanced across all
|
48 |
+
streams. If ``None``, takes its value from the total number of underlying samples.
|
49 |
+
Provide this field if you are weighting streams relatively to target a larger or
|
50 |
+
smaller epoch size. Defaults to ``None``.
|
51 |
+
predownload (int, optional): Target number of samples ahead to download the shards of while
|
52 |
+
iterating. Defaults to ``100_000``.
|
53 |
+
cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
|
54 |
+
shard cache. Before downloading a shard, the least recently used resident shard(s) may
|
55 |
+
be evicted (deleted from the local cache) in order to stay under the limit. Set to None
|
56 |
+
to disable shard eviction. Supports integer bytes as well as string human-readable
|
57 |
+
bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
|
58 |
+
partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
|
59 |
+
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
|
60 |
+
resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
|
61 |
+
initial run.
|
62 |
+
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
|
63 |
+
partitioned over the workers. Defaults to ``None``.
|
64 |
+
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
|
65 |
+
``False``.
|
66 |
+
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
|
67 |
+
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
|
68 |
+
shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
|
69 |
+
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
|
70 |
+
Defaults to ``balanced``.
|
71 |
+
sampling_granularity (int): When picking samples for a stream's final partial repeat,
|
72 |
+
how many samples to pick from the same shard at a time (``1`` for evenly balanced
|
73 |
+
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
|
74 |
+
Defaults to ``1``.
|
75 |
+
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
|
76 |
+
``per_stream``. Defaults to ``random``.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self,
|
80 |
+
tokenizer: PreTrainedTokenizerBase,
|
81 |
+
max_seq_len: int,
|
82 |
+
streams: Optional[Sequence[Stream]] = None,
|
83 |
+
remote: Optional[str] = None,
|
84 |
+
local: Optional[str] = None,
|
85 |
+
split: Optional[str] = None,
|
86 |
+
download_retry: int = 2,
|
87 |
+
download_timeout: float = 60,
|
88 |
+
validate_hash: Optional[str] = None,
|
89 |
+
keep_zip: bool = False,
|
90 |
+
epoch_size: Optional[int] = None,
|
91 |
+
predownload: int = 100_000,
|
92 |
+
cache_limit: Optional[Union[int, str]] = None,
|
93 |
+
partition_algo: str = 'orig',
|
94 |
+
num_canonical_nodes: Optional[int] = None,
|
95 |
+
batch_size: Optional[int] = None,
|
96 |
+
shuffle: bool = False,
|
97 |
+
shuffle_algo: str = 'py1b',
|
98 |
+
shuffle_seed: int = 9176,
|
99 |
+
shuffle_block_size: int = 1 << 18,
|
100 |
+
sampling_method: str = 'balanced',
|
101 |
+
sampling_granularity: int = 1,
|
102 |
+
batching_method: str = 'random',
|
103 |
+
**kwargs: Any):
|
104 |
+
|
105 |
+
group_method = kwargs.pop('group_method', None)
|
106 |
+
if group_method is not None:
|
107 |
+
raise NotImplementedError(
|
108 |
+
'group_method is deprecated and has been removed.\nTo ' +
|
109 |
+
'concatenate, use the --concat_tokens ' +
|
110 |
+
'argument when creating your MDS dataset with concat_c4.py')
|
111 |
+
|
112 |
+
if len(kwargs) > 0:
|
113 |
+
raise ValueError(
|
114 |
+
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}'
|
115 |
+
)
|
116 |
+
|
117 |
+
if local is not None and (remote is None or (local == remote)):
|
118 |
+
if os.path.isdir(local):
|
119 |
+
contents = set(os.listdir(local))
|
120 |
+
if split not in contents:
|
121 |
+
raise ValueError(
|
122 |
+
f'local directory {local} does not contain split {split}'
|
123 |
+
)
|
124 |
+
|
125 |
+
# TODO: discover where yamls are being converted incorrect, but temporary workaround
|
126 |
+
if isinstance(shuffle_block_size, float):
|
127 |
+
shuffle_block_size = int(shuffle_block_size)
|
128 |
+
|
129 |
+
# Build Dataset
|
130 |
+
super().__init__(
|
131 |
+
streams=streams,
|
132 |
+
remote=remote,
|
133 |
+
local=local,
|
134 |
+
split=split,
|
135 |
+
download_retry=download_retry,
|
136 |
+
download_timeout=download_timeout,
|
137 |
+
validate_hash=validate_hash,
|
138 |
+
keep_zip=keep_zip,
|
139 |
+
epoch_size=epoch_size,
|
140 |
+
predownload=predownload,
|
141 |
+
cache_limit=cache_limit,
|
142 |
+
partition_algo=partition_algo,
|
143 |
+
num_canonical_nodes=num_canonical_nodes,
|
144 |
+
batch_size=batch_size,
|
145 |
+
shuffle=shuffle,
|
146 |
+
shuffle_algo=shuffle_algo,
|
147 |
+
shuffle_seed=shuffle_seed,
|
148 |
+
shuffle_block_size=shuffle_block_size,
|
149 |
+
sampling_method=sampling_method,
|
150 |
+
sampling_granularity=sampling_granularity,
|
151 |
+
batching_method=batching_method,
|
152 |
+
)
|
153 |
+
self.tokenizer = tokenizer
|
154 |
+
self.max_seq_len = max_seq_len
|
155 |
+
|
156 |
+
# How to tokenize a text sample to a token sample
|
157 |
+
def _tokenize(self, text_sample: Mapping) -> Dict[str, List[int]]:
|
158 |
+
if self.tokenizer._pad_token is None:
|
159 |
+
# Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
|
160 |
+
raise RuntimeError(
|
161 |
+
'If tokenizing on-the-fly, tokenizer must have a pad_token_id')
|
162 |
+
|
163 |
+
return self.tokenizer(text_sample['text'],
|
164 |
+
truncation=True,
|
165 |
+
padding='max_length',
|
166 |
+
max_length=self.max_seq_len)
|
167 |
+
|
168 |
+
def _read_binary_tokenized_sample(self, sample: Dict[str,
|
169 |
+
Any]) -> torch.Tensor:
|
170 |
+
return torch.from_numpy(
|
171 |
+
np.frombuffer(sample['tokens'],
|
172 |
+
dtype=np.int64)[:self.max_seq_len].copy())
|
173 |
+
|
174 |
+
# How to process a sample
|
175 |
+
def __getitem__(self,
|
176 |
+
idx: int) -> Union[Dict[str, List[int]], torch.Tensor]:
|
177 |
+
sample = super().__getitem__(idx)
|
178 |
+
if 'text' in sample:
|
179 |
+
token_sample = self._tokenize(sample)
|
180 |
+
elif 'tokens' in sample:
|
181 |
+
token_sample = self._read_binary_tokenized_sample(sample)
|
182 |
+
else:
|
183 |
+
raise RuntimeError(
|
184 |
+
'StreamingTextDataset needs samples to have a `text` or `tokens` column'
|
185 |
+
)
|
186 |
+
return token_sample
|
187 |
+
|
188 |
+
|
189 |
+
class ConcatenatedSequenceCollatorWrapper:
|
190 |
+
"""Collator wrapper to add sequence_id to batch."""
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
base_collator: Callable,
|
195 |
+
eos_token_id: Optional[int] = None,
|
196 |
+
bos_token_id: Optional[int] = None,
|
197 |
+
):
|
198 |
+
self.base_collator = base_collator
|
199 |
+
if (eos_token_id is None) and (bos_token_id is None):
|
200 |
+
raise ValueError(
|
201 |
+
'Must supply a value for either eos_token_id or bos_token_id, but got None for both.'
|
202 |
+
)
|
203 |
+
if (eos_token_id is not None) and (bos_token_id is not None):
|
204 |
+
raise ValueError(
|
205 |
+
'Cannot use *both* EOS and BOS tokens for detecting sequence boundaries. ' +\
|
206 |
+
'Please supply `eos_token_id` if sequences end with an EOS token, or use ' +\
|
207 |
+
'`bos_token_id` if sequences start with a BOS token.'
|
208 |
+
)
|
209 |
+
|
210 |
+
if eos_token_id is None:
|
211 |
+
self.split_token_id = cast(int, bos_token_id)
|
212 |
+
self.bos_mode = True
|
213 |
+
else:
|
214 |
+
self.split_token_id = eos_token_id
|
215 |
+
self.bos_mode = False
|
216 |
+
|
217 |
+
def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
|
218 |
+
batch = self.base_collator(examples)
|
219 |
+
batch['sequence_id'] = self.get_sequence_id_from_batch(batch)
|
220 |
+
return batch
|
221 |
+
|
222 |
+
def get_sequence_id_from_batch(
|
223 |
+
self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
224 |
+
is_separator = torch.eq(batch['input_ids'], self.split_token_id)
|
225 |
+
cumulative_sep = torch.cumsum(is_separator,
|
226 |
+
dim=1).to(batch['input_ids'].dtype)
|
227 |
+
# If separator token is bos, we're already done
|
228 |
+
if self.bos_mode:
|
229 |
+
return cumulative_sep
|
230 |
+
|
231 |
+
# If separator token is eos, right shift 1 space
|
232 |
+
left_zeros = cumulative_sep.new_zeros((cumulative_sep.shape[0], 1))
|
233 |
+
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)
|
234 |
+
|
235 |
+
|
236 |
+
def build_text_dataloader(
|
237 |
+
cfg: DictConfig,
|
238 |
+
tokenizer: PreTrainedTokenizerBase,
|
239 |
+
device_batch_size: int,
|
240 |
+
) -> DataLoader:
|
241 |
+
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
|
242 |
+
if cfg.dataset.get('group_method', None) is not None:
|
243 |
+
raise NotImplementedError(
|
244 |
+
'group_method is deprecated and has been removed.\nTo ' +
|
245 |
+
'concatenate, use the --concat_tokens ' +
|
246 |
+
'argument when creating your MDS dataset with convert_dataset_hf.py'
|
247 |
+
)
|
248 |
+
|
249 |
+
# get kwargs
|
250 |
+
streams_dict = cfg.dataset.pop('streams', None)
|
251 |
+
mlm_probability = cfg.dataset.pop('mlm_probability', None)
|
252 |
+
eos_token_id = cfg.dataset.pop('eos_token_id', None)
|
253 |
+
bos_token_id = cfg.dataset.pop('bos_token_id', None)
|
254 |
+
|
255 |
+
# build streams
|
256 |
+
streams = None
|
257 |
+
if streams_dict is not None:
|
258 |
+
streams = []
|
259 |
+
for _, stream in streams_dict.items():
|
260 |
+
# stream is the streams kwargs
|
261 |
+
# fwd all kwargs with **stream allows streaming to check args
|
262 |
+
streams.append(Stream(**stream))
|
263 |
+
|
264 |
+
# build dataset potentially with streams
|
265 |
+
dataset = StreamingTextDataset(
|
266 |
+
tokenizer=tokenizer,
|
267 |
+
streams=streams,
|
268 |
+
batch_size=device_batch_size,
|
269 |
+
**cfg.dataset,
|
270 |
+
)
|
271 |
+
|
272 |
+
collate_fn = transformers.DataCollatorForLanguageModeling(
|
273 |
+
tokenizer=dataset.tokenizer,
|
274 |
+
mlm=mlm_probability is not None,
|
275 |
+
mlm_probability=mlm_probability)
|
276 |
+
|
277 |
+
if (eos_token_id is not None) or (bos_token_id is not None):
|
278 |
+
# Note: Will raise an error if both are non-None
|
279 |
+
collate_fn = ConcatenatedSequenceCollatorWrapper(
|
280 |
+
base_collator=collate_fn,
|
281 |
+
eos_token_id=eos_token_id,
|
282 |
+
bos_token_id=bos_token_id)
|
283 |
+
|
284 |
+
return DataLoader(
|
285 |
+
dataset,
|
286 |
+
collate_fn=collate_fn,
|
287 |
+
batch_size=device_batch_size,
|
288 |
+
drop_last=cfg.drop_last,
|
289 |
+
num_workers=cfg.num_workers,
|
290 |
+
pin_memory=cfg.get('pin_memory', True),
|
291 |
+
prefetch_factor=cfg.get('prefetch_factor', 2),
|
292 |
+
persistent_workers=cfg.get('persistent_workers', True),
|
293 |
+
timeout=cfg.get('timeout', 0),
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
# Helpful to test if your dataloader is working locally
|
298 |
+
# Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out
|
299 |
+
if __name__ == '__main__':
|
300 |
+
import argparse
|
301 |
+
|
302 |
+
from llmfoundry.utils.builders import build_tokenizer
|
303 |
+
|
304 |
+
parser = argparse.ArgumentParser()
|
305 |
+
parser.add_argument('--tokenizer',
|
306 |
+
type=str,
|
307 |
+
default='EleutherAI/gpt-neox-20b',
|
308 |
+
help='the name of the tokenizer to use')
|
309 |
+
parser.add_argument('--local_path',
|
310 |
+
type=str,
|
311 |
+
required=True,
|
312 |
+
help='the path to the local copy of the dataset')
|
313 |
+
parser.add_argument(
|
314 |
+
'--remote_path',
|
315 |
+
type=str,
|
316 |
+
default=None,
|
317 |
+
help='the path to the remote copy to stream from (optional)')
|
318 |
+
parser.add_argument('--split',
|
319 |
+
type=str,
|
320 |
+
default='val',
|
321 |
+
help='which split of the dataset to use')
|
322 |
+
parser.add_argument('--max_seq_len',
|
323 |
+
type=int,
|
324 |
+
default=32,
|
325 |
+
help='max sequence length to test')
|
326 |
+
|
327 |
+
args = parser.parse_args()
|
328 |
+
|
329 |
+
if args.remote_path is not None:
|
330 |
+
print(
|
331 |
+
f'Reading {args.split} split from {args.local_path} <- streamed from <- {args.remote_path}'
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
print(f'Reading {args.split} split from {args.local_path}')
|
335 |
+
|
336 |
+
cfg = {
|
337 |
+
'name': 'text',
|
338 |
+
'dataset': {
|
339 |
+
'local': args.local_path,
|
340 |
+
'remote': args.remote_path,
|
341 |
+
'split': args.split,
|
342 |
+
'shuffle': False,
|
343 |
+
'max_seq_len': args.max_seq_len,
|
344 |
+
'keep_zip': True, # in case we need compressed files after testing
|
345 |
+
},
|
346 |
+
'drop_last': False,
|
347 |
+
'num_workers': 4,
|
348 |
+
}
|
349 |
+
cfg = om.create(cfg)
|
350 |
+
device_batch_size = 2
|
351 |
+
|
352 |
+
tokenizer_name = args.tokenizer
|
353 |
+
tokenizer_kwargs = {'model_max_length': args.max_seq_len}
|
354 |
+
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
|
355 |
+
|
356 |
+
loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
|
357 |
+
assert isinstance(loader.dataset, StreamingTextDataset)
|
358 |
+
tokenizer = loader.dataset.tokenizer
|
359 |
+
|
360 |
+
for batch_ix, batch in enumerate(islice(loader, 5)):
|
361 |
+
print('\n')
|
362 |
+
print('#' * 20, f'Batch {batch_ix}', '#' * 20)
|
363 |
+
for k, v in batch.items():
|
364 |
+
print(k, v.shape, v.dtype)
|
365 |
+
for sample_ix, token_sample in enumerate(batch['input_ids']):
|
366 |
+
print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
|
367 |
+
print(tokenizer.decode(token_sample))
|
Perceptrix/finetune/build/lib/llmfoundry/models/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
|
5 |
+
ComposerHFT5)
|
6 |
+
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
|
7 |
+
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'ComposerHFCausalLM',
|
11 |
+
'ComposerHFPrefixLM',
|
12 |
+
'ComposerHFT5',
|
13 |
+
'MPTConfig',
|
14 |
+
'MPTPreTrainedModel',
|
15 |
+
'MPTModel',
|
16 |
+
'MPTForCausalLM',
|
17 |
+
'ComposerMPTCausalLM',
|
18 |
+
]
|
Perceptrix/finetune/build/lib/llmfoundry/models/hf/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM
|
5 |
+
from llmfoundry.models.hf.hf_fsdp import (prepare_hf_causal_lm_model_for_fsdp,
|
6 |
+
prepare_hf_enc_dec_model_for_fsdp,
|
7 |
+
prepare_hf_model_for_fsdp)
|
8 |
+
from llmfoundry.models.hf.hf_prefix_lm import ComposerHFPrefixLM
|
9 |
+
from llmfoundry.models.hf.hf_t5 import ComposerHFT5
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'ComposerHFCausalLM',
|
13 |
+
'ComposerHFPrefixLM',
|
14 |
+
'ComposerHFT5',
|
15 |
+
'prepare_hf_causal_lm_model_for_fsdp',
|
16 |
+
'prepare_hf_enc_dec_model_for_fsdp',
|
17 |
+
'prepare_hf_model_for_fsdp',
|
18 |
+
]
|
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_causal_lm.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
from typing import Mapping, Union
|
9 |
+
|
10 |
+
# required for loading a python model into composer
|
11 |
+
import transformers
|
12 |
+
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
|
13 |
+
InContextLearningLMAccuracy,
|
14 |
+
InContextLearningLMExpectedCalibrationError,
|
15 |
+
InContextLearningMCExpectedCalibrationError,
|
16 |
+
InContextLearningMultipleChoiceAccuracy,
|
17 |
+
InContextLearningQAAccuracy,
|
18 |
+
LanguageCrossEntropy, LanguagePerplexity)
|
19 |
+
from composer.utils import dist
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
from torch import nn
|
22 |
+
from transformers import (AutoConfig, AutoModelForCausalLM,
|
23 |
+
PreTrainedTokenizerBase)
|
24 |
+
|
25 |
+
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
|
26 |
+
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
|
27 |
+
from llmfoundry.models.layers.llama_attention_monkeypatch import \
|
28 |
+
get_llama_attention_patch_fn
|
29 |
+
from llmfoundry.models.utils import init_empty_weights
|
30 |
+
|
31 |
+
try:
|
32 |
+
from peft.peft_model import PeftModel
|
33 |
+
model_types = PeftModel, transformers.PreTrainedModel
|
34 |
+
|
35 |
+
except ImportError:
|
36 |
+
model_types = transformers.PreTrainedModel
|
37 |
+
|
38 |
+
__all__ = ['ComposerHFCausalLM']
|
39 |
+
|
40 |
+
log = logging.getLogger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
|
44 |
+
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library.
|
48 |
+
if DictConfig, the following keys are required:
|
49 |
+
cfg.pretrained_model_name_or_path (str): The name of or local path to
|
50 |
+
the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel).
|
51 |
+
cfg.config_overrides (dict, optional): An optional dictionary of keyword
|
52 |
+
arguments that override the default configuration associated with
|
53 |
+
cfg.pretrained_model_name_or_path.
|
54 |
+
cfg.pretrained (bool): Whether to instantiate the model with pre-trained
|
55 |
+
weights coming from cfg.pretrained_model_name_or_path. If ``True``,
|
56 |
+
cfg.config_overrides must be compatible with the pre-trained weights.
|
57 |
+
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
|
58 |
+
initialize the model on. Currently, `meta` is only supported when
|
59 |
+
cfg.pretrained is ``False``. Default: ``'cpu'``.
|
60 |
+
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, om_model_config: Union[DictConfig,
|
64 |
+
transformers.PreTrainedModel,
|
65 |
+
nn.Module],
|
66 |
+
tokenizer: PreTrainedTokenizerBase):
|
67 |
+
# set up training and eval metrics
|
68 |
+
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
|
69 |
+
eval_metrics = [
|
70 |
+
LanguageCrossEntropy(),
|
71 |
+
LanguagePerplexity(),
|
72 |
+
InContextLearningLMAccuracy(),
|
73 |
+
InContextLearningMultipleChoiceAccuracy(),
|
74 |
+
InContextLearningQAAccuracy(),
|
75 |
+
InContextLearningCodeEvalAccuracy(),
|
76 |
+
InContextLearningLMExpectedCalibrationError(),
|
77 |
+
InContextLearningMCExpectedCalibrationError()
|
78 |
+
]
|
79 |
+
|
80 |
+
# if we are passed a DictConfig, we need to instantiate the model
|
81 |
+
if isinstance(om_model_config, DictConfig):
|
82 |
+
if not om_model_config.get('trust_remote_code',
|
83 |
+
True) and om_model_config.get(
|
84 |
+
'pretrained_model_name_or_path',
|
85 |
+
None).startswith('mosaicml/mpt'):
|
86 |
+
raise ValueError(
|
87 |
+
'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
|
88 |
+
+
|
89 |
+
'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
|
90 |
+
)
|
91 |
+
|
92 |
+
if not om_model_config.get('use_train_metrics', True):
|
93 |
+
train_metrics = []
|
94 |
+
|
95 |
+
# load the model config
|
96 |
+
trust_remote_code = om_model_config.get('trust_remote_code', True)
|
97 |
+
use_auth_token = om_model_config.get('use_auth_token', False)
|
98 |
+
config = AutoConfig.from_pretrained(
|
99 |
+
om_model_config.pretrained_model_name_or_path,
|
100 |
+
trust_remote_code=trust_remote_code,
|
101 |
+
use_auth_token=use_auth_token,
|
102 |
+
)
|
103 |
+
|
104 |
+
# set config overrides
|
105 |
+
for k, v in om_model_config.get('config_overrides', {}).items():
|
106 |
+
if not hasattr(config, k):
|
107 |
+
raise ValueError(
|
108 |
+
f'config does not have attribute "{k}" to override ({k}: {v}).'
|
109 |
+
)
|
110 |
+
|
111 |
+
attr = getattr(config, k)
|
112 |
+
# attempt to disallow typos in nested configs
|
113 |
+
if isinstance(attr, Mapping):
|
114 |
+
extra_keys = [
|
115 |
+
_k for _k in v.keys() if _k not in attr.keys()
|
116 |
+
]
|
117 |
+
if extra_keys:
|
118 |
+
raise ValueError(
|
119 |
+
f'Config dict override got unknown keys. ' +
|
120 |
+
f'Extra keys: {extra_keys}. ' +
|
121 |
+
f'Expected (a subset of) keys: {list(attr.keys())}.'
|
122 |
+
)
|
123 |
+
getattr(config, k).update(v)
|
124 |
+
# necessary case to allow for rope_scaling to be overriden in llama config
|
125 |
+
elif attr is None and isinstance(v, Mapping):
|
126 |
+
setattr(config, k, {})
|
127 |
+
getattr(config, k).update(v)
|
128 |
+
else:
|
129 |
+
setattr(config, k, v)
|
130 |
+
|
131 |
+
load_in_8bit = om_model_config.get('load_in_8bit', False)
|
132 |
+
|
133 |
+
# below we set up the device to initialize the model on
|
134 |
+
init_device = om_model_config.get('init_device', 'cpu')
|
135 |
+
|
136 |
+
# Get the device we want to initialize, and use the
|
137 |
+
# reolved version to initialize the HF model
|
138 |
+
resolved_init_device = hf_get_init_device(init_device)
|
139 |
+
|
140 |
+
# We need to have all non-zero local ranks be not-pretrained
|
141 |
+
# Rank 0 will still be pretrained, and distribute the weights appropriately
|
142 |
+
if dist.get_local_rank() != 0 and init_device == 'mixed':
|
143 |
+
om_model_config.pretrained = False
|
144 |
+
|
145 |
+
# initialize the model on the correct device
|
146 |
+
if resolved_init_device == 'cpu':
|
147 |
+
if om_model_config.pretrained:
|
148 |
+
model = AutoModelForCausalLM.from_pretrained(
|
149 |
+
om_model_config.pretrained_model_name_or_path,
|
150 |
+
trust_remote_code=trust_remote_code,
|
151 |
+
use_auth_token=use_auth_token,
|
152 |
+
load_in_8bit=load_in_8bit,
|
153 |
+
config=config)
|
154 |
+
else:
|
155 |
+
model = AutoModelForCausalLM.from_config(
|
156 |
+
config,
|
157 |
+
trust_remote_code=trust_remote_code,
|
158 |
+
)
|
159 |
+
elif resolved_init_device == 'meta':
|
160 |
+
if om_model_config.pretrained:
|
161 |
+
raise ValueError(
|
162 |
+
'Setting cfg.pretrained=True is not supported when init_device="meta".'
|
163 |
+
)
|
164 |
+
with init_empty_weights(include_buffers=False):
|
165 |
+
model = AutoModelForCausalLM.from_config(
|
166 |
+
config,
|
167 |
+
trust_remote_code=trust_remote_code,
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
raise ValueError(
|
171 |
+
f'init_device="{init_device}" must be either "cpu" or "meta".'
|
172 |
+
)
|
173 |
+
|
174 |
+
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
|
175 |
+
if dist.get_local_rank() == 0:
|
176 |
+
with open(signal_file_path, 'wb') as f:
|
177 |
+
f.write(b'local_rank0_completed_download')
|
178 |
+
|
179 |
+
# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
|
180 |
+
# so that we don't timeout for large downloads. This syncs all processes on the node
|
181 |
+
with dist.local_rank_zero_download_and_wait(signal_file_path):
|
182 |
+
# Then, wait to ensure every node has finished downloading the checkpoint
|
183 |
+
dist.barrier()
|
184 |
+
|
185 |
+
if dist.get_local_rank() == 0:
|
186 |
+
os.remove(signal_file_path)
|
187 |
+
|
188 |
+
z_loss = om_model_config.get('z_loss', 0.0)
|
189 |
+
|
190 |
+
attention_patch_type = om_model_config.get('attention_patch_type',
|
191 |
+
None)
|
192 |
+
if attention_patch_type is not None:
|
193 |
+
if model.config.model_type != 'llama':
|
194 |
+
raise ValueError(
|
195 |
+
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
|
196 |
+
)
|
197 |
+
|
198 |
+
log.debug(
|
199 |
+
f'Patching llama attention with {attention_patch_type} attention'
|
200 |
+
)
|
201 |
+
from transformers.models.llama.modeling_llama import \
|
202 |
+
LlamaAttention
|
203 |
+
LlamaAttention.forward = get_llama_attention_patch_fn(
|
204 |
+
attention_patch_type)
|
205 |
+
model.config.use_cache = False
|
206 |
+
|
207 |
+
# elif the model is either a PeftModel or a PreTrainedModel
|
208 |
+
elif isinstance(om_model_config, model_types):
|
209 |
+
model = om_model_config
|
210 |
+
init_device = 'cpu'
|
211 |
+
z_loss = 0.0
|
212 |
+
|
213 |
+
# else, unsupported type
|
214 |
+
else:
|
215 |
+
raise ValueError(
|
216 |
+
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
|
217 |
+
)
|
218 |
+
|
219 |
+
composer_model = super().__init__(model=model,
|
220 |
+
shift_labels=True,
|
221 |
+
tokenizer=tokenizer,
|
222 |
+
metrics=train_metrics,
|
223 |
+
eval_metrics=eval_metrics,
|
224 |
+
z_loss=z_loss,
|
225 |
+
init_device=init_device)
|
226 |
+
|
227 |
+
return composer_model
|
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_fsdp.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# helper functions from https://github.com/CarperAI/trlx/blob/main/trlx/utils/modeling.py
|
5 |
+
# which is MIT licensed
|
6 |
+
|
7 |
+
import functools
|
8 |
+
from typing import Any, Iterable, List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from transformers import PreTrainedModel
|
12 |
+
from transformers.models.opt.modeling_opt import OPTDecoder
|
13 |
+
|
14 |
+
|
15 |
+
# helper functions
|
16 |
+
def rhasattr(obj: Any, attr: str) -> bool:
|
17 |
+
"""A chain-able attribute version of hasattr.
|
18 |
+
|
19 |
+
For example, to check if
|
20 |
+
`obj` has the attribute `foo.bar.baz`, you can use:
|
21 |
+
`rhasattr(obj, "foo.bar.baz")`
|
22 |
+
Reference: https://stackoverflow.com/a/67303315
|
23 |
+
"""
|
24 |
+
_nested_attrs = attr.split('.')
|
25 |
+
_curr_obj = obj
|
26 |
+
for _a in _nested_attrs[:-1]:
|
27 |
+
if hasattr(_curr_obj, _a):
|
28 |
+
_curr_obj = getattr(_curr_obj, _a)
|
29 |
+
else:
|
30 |
+
return False
|
31 |
+
return hasattr(_curr_obj, _nested_attrs[-1])
|
32 |
+
|
33 |
+
|
34 |
+
def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
|
35 |
+
"""A chain-able attribute version of getattr.
|
36 |
+
|
37 |
+
For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
|
38 |
+
`rgetattr(obj, "foo.bar.baz")`
|
39 |
+
Reference: https://stackoverflow.com/a/31174427
|
40 |
+
"""
|
41 |
+
|
42 |
+
def _getattr(obj: Any, attr: str):
|
43 |
+
return getattr(obj, attr, *args)
|
44 |
+
|
45 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
46 |
+
|
47 |
+
|
48 |
+
def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
|
49 |
+
for attr in attrs:
|
50 |
+
if rhasattr(obj, attr):
|
51 |
+
return rgetattr(obj, attr)
|
52 |
+
return None
|
53 |
+
|
54 |
+
|
55 |
+
def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
|
56 |
+
"""Returns the causal decoder backbone of the specified HuggingFace model.
|
57 |
+
|
58 |
+
Newer HF models have a `self.get_decoder()` method. Older models do not.
|
59 |
+
|
60 |
+
NOTE: Different model configurations have different causal decoder attribute
|
61 |
+
names.
|
62 |
+
- transformer: (GPT2LMHeadModel, GPTJConfig)
|
63 |
+
- model.decoder: (OPTConfig, BloomConfig)
|
64 |
+
- gpt_neox: (GPTNeoXConfig)
|
65 |
+
"""
|
66 |
+
if hasattr(model, 'get_decoder'):
|
67 |
+
return model.get_decoder()
|
68 |
+
|
69 |
+
decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox')
|
70 |
+
causal_base_model = findattr(model, decoder_attrs)
|
71 |
+
if causal_base_model is None:
|
72 |
+
raise ValueError(
|
73 |
+
f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.'
|
74 |
+
)
|
75 |
+
return causal_base_model
|
76 |
+
|
77 |
+
|
78 |
+
def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
|
79 |
+
"""Returns the hidden layers of the specified model.
|
80 |
+
|
81 |
+
NOTE: Different model configurations have different hidden layer attribute names.
|
82 |
+
- transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
|
83 |
+
- model.decoder.layers: (OPTForCausalLM)
|
84 |
+
- gpt_neox.layers: (GPTNeoXForCausalLM)
|
85 |
+
- model.layers: (LlaMaForCausalLM)
|
86 |
+
- transformer.blocks: (MPTForCausalLM)
|
87 |
+
"""
|
88 |
+
hidden_layers_attrs = (
|
89 |
+
'transformer.h', # BLOOM, GPT2, GPTJ
|
90 |
+
'model.decoder.layers', # OPT
|
91 |
+
'gpt_neox.layers', # GPTNeoX
|
92 |
+
'block', # T5, BART, Pegasus (from encoder)
|
93 |
+
'layers', # ProphetNet, Marian (from encoder)
|
94 |
+
'model.layers', # LLaMa
|
95 |
+
'transformer.blocks', # MPT
|
96 |
+
)
|
97 |
+
layers = findattr(model, hidden_layers_attrs)
|
98 |
+
if layers is None:
|
99 |
+
raise ValueError(
|
100 |
+
f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}'
|
101 |
+
)
|
102 |
+
return layers
|
103 |
+
|
104 |
+
|
105 |
+
def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
|
106 |
+
"""Returns the appropriate device to initialize models."""
|
107 |
+
from composer.utils import dist
|
108 |
+
if init_device == 'mixed':
|
109 |
+
if dist.get_local_rank() == 0:
|
110 |
+
return 'cpu'
|
111 |
+
return 'meta'
|
112 |
+
return init_device
|
113 |
+
|
114 |
+
|
115 |
+
# /end helper functions
|
116 |
+
|
117 |
+
|
118 |
+
def prepare_hf_model_for_fsdp(model: PreTrainedModel,
|
119 |
+
init_device: Optional[str]) -> None:
|
120 |
+
"""FSDP wrap a HuggingFace model.
|
121 |
+
|
122 |
+
Call specific functions
|
123 |
+
"""
|
124 |
+
if model.config.is_encoder_decoder:
|
125 |
+
prepare_hf_enc_dec_model_for_fsdp(model, init_device)
|
126 |
+
else:
|
127 |
+
# many common decoder-only model do not set the flag
|
128 |
+
# model.config.is_decoder, so we can't trust it
|
129 |
+
prepare_hf_causal_lm_model_for_fsdp(model, init_device)
|
130 |
+
|
131 |
+
|
132 |
+
def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
|
133 |
+
init_device: Optional[str]) -> None:
|
134 |
+
"""FSDP wrap a HuggingFace decoder.
|
135 |
+
|
136 |
+
Wrap any model for FSDP which follows one of the 3 existing conventions from
|
137 |
+
HuggingFace for decoder-only LLMs.
|
138 |
+
"""
|
139 |
+
causal_base_model = hf_get_causal_base_model(model)
|
140 |
+
|
141 |
+
# OPT has an extra layer of wrapping, so special case here
|
142 |
+
if isinstance(causal_base_model, OPTDecoder):
|
143 |
+
model.model._fsdp_wrap = False
|
144 |
+
model_block = hf_get_hidden_layers(model)
|
145 |
+
lm_head = model.get_output_embeddings()
|
146 |
+
# some models (OPT) implement .get_input_embeddings for the causal subclass
|
147 |
+
# but all of them implement it for the base model
|
148 |
+
tied_embeddings = causal_base_model.get_input_embeddings()
|
149 |
+
modules = {
|
150 |
+
'base_model': causal_base_model,
|
151 |
+
'model_block': model_block,
|
152 |
+
'lm_head': lm_head,
|
153 |
+
'tied_embeddings': tied_embeddings
|
154 |
+
}
|
155 |
+
|
156 |
+
for mod_name, module in modules.items():
|
157 |
+
if module is None:
|
158 |
+
raise ValueError(
|
159 |
+
f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
|
160 |
+
'follow common layer/weight naming conventions.')
|
161 |
+
block_type = type(model_block[0])
|
162 |
+
if init_device == 'mixed':
|
163 |
+
# For FSDP with models with different device initializations, `mixed`, which
|
164 |
+
# initializes the model on rank 0 on `cpu` and on all other ranks on `meta,``
|
165 |
+
# we need to tag all child modules that are torch.nn.Modules with `_fsdp_wrap`.
|
166 |
+
for child in model.children():
|
167 |
+
if isinstance(child, type(causal_base_model)):
|
168 |
+
continue
|
169 |
+
if isinstance(child, torch.nn.Module):
|
170 |
+
child._fsdp_wrap = True
|
171 |
+
|
172 |
+
for child in causal_base_model.children():
|
173 |
+
if isinstance(child, torch.nn.ModuleList):
|
174 |
+
continue
|
175 |
+
if isinstance(child, torch.nn.Module):
|
176 |
+
child._fsdp_wrap = True
|
177 |
+
|
178 |
+
if model.config.tie_word_embeddings and not model.config.model_type == 'mpt':
|
179 |
+
raise ValueError(
|
180 |
+
'The passed in HuggingFaceModel has tied word embeddings ' +
|
181 |
+
'and the passed in initialization device is `mixed.` ' +
|
182 |
+
'In order to support this initialization scheme, we would need to break '
|
183 |
+
+
|
184 |
+
'the weight tying. As a result, either use a different initialization scheme '
|
185 |
+
+ 'or in the model config set `tie_word_embeddings=False.`')
|
186 |
+
else:
|
187 |
+
# When using the HF LM models,
|
188 |
+
# the weights of the self.lm_head and self.transformer.wte are tied.
|
189 |
+
# This tying occurs inside the `self.post_init()` function.
|
190 |
+
# This is a hurdle for FSDP because they need to be in the same FSDP block
|
191 |
+
# These lines ensures that both modules stay together in the top-most block when
|
192 |
+
# the model has this tying enabled (almost all do; this property defaults to True)
|
193 |
+
if model.config.tie_word_embeddings:
|
194 |
+
causal_base_model._fsdp_wrap = False
|
195 |
+
tied_embeddings._fsdp_wrap = False
|
196 |
+
lm_head._fsdp_wrap = False
|
197 |
+
|
198 |
+
# FSDP Wrap and Activation Checkpoint every model block
|
199 |
+
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
|
200 |
+
model.activation_checkpointing_fn = lambda module: isinstance(
|
201 |
+
module, block_type)
|
202 |
+
|
203 |
+
|
204 |
+
def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel,
|
205 |
+
init_device: Optional[str]) -> None:
|
206 |
+
"""Wrap an encoder/decoder HF model.
|
207 |
+
|
208 |
+
This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
|
209 |
+
You have model.shared, model.encoder, model.decoder and model.lm_head, where
|
210 |
+
model.shared are the embeddings which are tied to model.lm_head, and
|
211 |
+
model.shared == model.encoder.embed_tokens and model.shared ==
|
212 |
+
model.decoder.embed_tokens
|
213 |
+
"""
|
214 |
+
tied_embeddings = model.get_input_embeddings()
|
215 |
+
encoder = model.get_encoder()
|
216 |
+
decoder = model.get_decoder()
|
217 |
+
lm_head = model.get_output_embeddings()
|
218 |
+
# some encoder/decoders have different layers for encoder vs decoder
|
219 |
+
encoder_block = hf_get_hidden_layers(encoder)
|
220 |
+
decoder_block = hf_get_hidden_layers(decoder)
|
221 |
+
|
222 |
+
modules = {
|
223 |
+
'encoder': encoder,
|
224 |
+
'decoder': decoder,
|
225 |
+
'encoder_block': encoder_block,
|
226 |
+
'decoder_block': decoder_block,
|
227 |
+
'lm_head': lm_head,
|
228 |
+
'tied_embeddings': tied_embeddings
|
229 |
+
}
|
230 |
+
|
231 |
+
for mod_name, module in modules.items():
|
232 |
+
if module is None:
|
233 |
+
raise ValueError(
|
234 |
+
f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
|
235 |
+
'follow common layer/weight naming conventions.')
|
236 |
+
decoder_block_type = type(decoder_block[0])
|
237 |
+
encoder_block_type = type(encoder_block[0])
|
238 |
+
|
239 |
+
if model.config.tie_word_embeddings:
|
240 |
+
# it is possible to train an enc/dec without tied embeddings, hence the check
|
241 |
+
tied_embeddings._fsdp_wrap = False
|
242 |
+
encoder._fsdp_wrap = False
|
243 |
+
decoder._fsdp_wrap = False
|
244 |
+
lm_head._fsdp_wrap = False
|
245 |
+
|
246 |
+
# FSDP Wrap and Activation Checkpoint every decoder block
|
247 |
+
model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
|
248 |
+
model.activation_checkpointing_fn = lambda module: isinstance(
|
249 |
+
module, decoder_block_type)
|
250 |
+
|
251 |
+
if encoder_block_type == decoder_block_type:
|
252 |
+
return
|
253 |
+
|
254 |
+
# need to wrap encoder blocks separately for ProhpetNet and Marian
|
255 |
+
model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
|
256 |
+
model.activation_checkpointing_fn = lambda module: isinstance(
|
257 |
+
module, encoder_block_type)
|
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_prefix_lm.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Implements a Hugging Prefix LM wrapped inside a :class:`.ComposerModel`."""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
from typing import Mapping, MutableMapping
|
9 |
+
|
10 |
+
from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
|
11 |
+
from composer.utils import dist
|
12 |
+
from omegaconf import DictConfig
|
13 |
+
from transformers import (AutoConfig, AutoModelForCausalLM,
|
14 |
+
PreTrainedTokenizerBase)
|
15 |
+
|
16 |
+
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
|
17 |
+
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
|
18 |
+
from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
|
19 |
+
add_bidirectional_mask_if_missing,
|
20 |
+
convert_hf_causal_lm_to_prefix_lm,
|
21 |
+
init_empty_weights)
|
22 |
+
|
23 |
+
__all__ = ['ComposerHFPrefixLM']
|
24 |
+
|
25 |
+
# HuggingFace hardcodes the ignore index to -100
|
26 |
+
_HF_IGNORE_INDEX = -100
|
27 |
+
|
28 |
+
|
29 |
+
class ComposerHFPrefixLM(HuggingFaceModelWithZLoss):
|
30 |
+
"""Configures a :class:`.HuggingFaceModel` around a Prefix LM.
|
31 |
+
|
32 |
+
Note: HuggingFace does not natively support Prefix LM-style models. This function uses
|
33 |
+
`transformers.AutoModelForCausalLM` to instantiate a Causal LM, then uses a conversion utility
|
34 |
+
to turn the model into a Prefix LM. Currently, that conversion utility only supports the
|
35 |
+
following HuggingFace Causal LM types:
|
36 |
+
- `GPT2LMHeadModel`
|
37 |
+
- `GPTNeoForCausalLM`
|
38 |
+
- `GPTNeoXForCausalLM`
|
39 |
+
- `GPTJForCausalLM`
|
40 |
+
- `BloomForCausalLM`
|
41 |
+
- `OPTForCausalLM`
|
42 |
+
|
43 |
+
Args:
|
44 |
+
cfg (DictConfig): An omegaconf dictionary used to configure the model:
|
45 |
+
cfg.pretrained_model_name_or_path (str): The name of or local path to
|
46 |
+
the HF model (e.g., `gpt2` to instantiate a GPT2LMHeadModel). The model
|
47 |
+
will be converted to a Prefix LM during initialization.
|
48 |
+
cfg.config_overrides (dict, optional): An optional dictionary of keyword
|
49 |
+
arguments that override the default configuration associated with
|
50 |
+
cfg.pretrained_model_name_or_path. Default: ``{}``.
|
51 |
+
cfg.pretrained (bool): Whether to instantiate the model with pre-trained
|
52 |
+
weights coming from cfg.pretrained_model_name_or_path. If ``True``,
|
53 |
+
cfg.config_overrides must be compatible with the pre-trained weights.
|
54 |
+
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
|
55 |
+
initialize the model on. Currently, `meta` is only supported when
|
56 |
+
cfg.pretrained is ``False``. Default: ``'cpu'``.
|
57 |
+
cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
|
58 |
+
the z-loss will be multiplied by this value before being added to the
|
59 |
+
standard loss term. Default: ``0.0``.
|
60 |
+
cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
|
61 |
+
of the model/tokenizer to include sentinel tokens that are used in denoising
|
62 |
+
tasks like Span Corruption. If you intend to load from an existing Composer
|
63 |
+
checkpoint that was trained on such a task, set this to ``True`` to ensure
|
64 |
+
that the model vocab size matches your checkpoint's vocab size when loading
|
65 |
+
the weights. Default: ``False``.
|
66 |
+
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, om_model_config: DictConfig,
|
70 |
+
tokenizer: PreTrainedTokenizerBase):
|
71 |
+
config = AutoConfig.from_pretrained(
|
72 |
+
om_model_config.pretrained_model_name_or_path,
|
73 |
+
trust_remote_code=om_model_config.get('trust_remote_code', True),
|
74 |
+
use_auth_token=om_model_config.get('use_auth_token', False),
|
75 |
+
)
|
76 |
+
|
77 |
+
# set config overrides
|
78 |
+
for k, v in om_model_config.get('config_overrides', {}).items():
|
79 |
+
if not hasattr(config, k):
|
80 |
+
raise ValueError(
|
81 |
+
f'config does not have attribute "{k}" to override ({k}: {v}).'
|
82 |
+
)
|
83 |
+
|
84 |
+
attr = getattr(config, k)
|
85 |
+
if isinstance(attr, Mapping):
|
86 |
+
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
|
87 |
+
if extra_keys:
|
88 |
+
raise ValueError(
|
89 |
+
f'Config dict override got unknown keys. ' +
|
90 |
+
f'Extra keys: {extra_keys}. ' +
|
91 |
+
f'Expected (a subset of) keys: {list(attr.keys())}.')
|
92 |
+
getattr(config, k).update(v)
|
93 |
+
else:
|
94 |
+
setattr(config, k, v)
|
95 |
+
|
96 |
+
# Set up the tokenizer (add tokens for denoising sentinels if needed)
|
97 |
+
if om_model_config.get('adapt_vocab_for_denoising', False):
|
98 |
+
adapt_tokenizer_for_denoising(tokenizer)
|
99 |
+
|
100 |
+
init_device = om_model_config.get('init_device', 'cpu')
|
101 |
+
|
102 |
+
# Get the device we want to initialize, and use the
|
103 |
+
# resolved version to initialize the HF model
|
104 |
+
resolved_init_device = hf_get_init_device(init_device)
|
105 |
+
|
106 |
+
# We need to have all non-zero local ranks be not-pretrained
|
107 |
+
# Rank 0 will still be pretrained, and distribute the weights appropriately
|
108 |
+
if dist.get_local_rank() != 0 and init_device == 'mixed':
|
109 |
+
om_model_config.pretrained = False
|
110 |
+
|
111 |
+
if resolved_init_device == 'cpu':
|
112 |
+
if om_model_config.pretrained:
|
113 |
+
model = AutoModelForCausalLM.from_pretrained(
|
114 |
+
om_model_config.pretrained_model_name_or_path,
|
115 |
+
config=config)
|
116 |
+
else:
|
117 |
+
model = AutoModelForCausalLM.from_config(config)
|
118 |
+
elif resolved_init_device == 'meta':
|
119 |
+
if om_model_config.pretrained:
|
120 |
+
raise ValueError(
|
121 |
+
'Setting cfg.pretrained=True is not supported when init_device="meta".'
|
122 |
+
)
|
123 |
+
with init_empty_weights(include_buffers=False):
|
124 |
+
model = AutoModelForCausalLM.from_config(config)
|
125 |
+
else:
|
126 |
+
raise ValueError(
|
127 |
+
f'init_device="{init_device}" must be either "cpu" or "meta".')
|
128 |
+
|
129 |
+
# Convert the Causal LM into a Prefix LM via our custom wrapper
|
130 |
+
model = convert_hf_causal_lm_to_prefix_lm(model)
|
131 |
+
|
132 |
+
metrics = [
|
133 |
+
LanguageCrossEntropy(ignore_index=_HF_IGNORE_INDEX),
|
134 |
+
MaskedAccuracy(ignore_index=_HF_IGNORE_INDEX)
|
135 |
+
]
|
136 |
+
|
137 |
+
composer_model = super().__init__(model=model,
|
138 |
+
shift_labels=True,
|
139 |
+
tokenizer=tokenizer,
|
140 |
+
metrics=metrics,
|
141 |
+
z_loss=om_model_config.get(
|
142 |
+
'z_loss', 0.0),
|
143 |
+
init_device=init_device)
|
144 |
+
|
145 |
+
return composer_model
|
146 |
+
|
147 |
+
def forward(self, batch: MutableMapping):
|
148 |
+
# Add bidirectional_mask if it is missing and can be constructed
|
149 |
+
add_bidirectional_mask_if_missing(batch)
|
150 |
+
return super().forward(batch)
|
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_t5.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Implements a Hugging Face T5 wrapped inside a :class:`.ComposerModel`."""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
from typing import Mapping
|
9 |
+
|
10 |
+
from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
|
11 |
+
from composer.utils import dist
|
12 |
+
from omegaconf import DictConfig
|
13 |
+
from transformers import (AutoConfig, PreTrainedTokenizerBase,
|
14 |
+
T5ForConditionalGeneration)
|
15 |
+
|
16 |
+
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
|
17 |
+
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
|
18 |
+
from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
|
19 |
+
init_empty_weights)
|
20 |
+
|
21 |
+
__all__ = ['ComposerHFT5']
|
22 |
+
|
23 |
+
# HuggingFace hardcodes the ignore index to -100
|
24 |
+
_HF_IGNORE_INDEX = -100
|
25 |
+
|
26 |
+
|
27 |
+
class ComposerHFT5(HuggingFaceModelWithZLoss):
|
28 |
+
"""Configures a :class:`.HuggingFaceModel` around a T5.
|
29 |
+
|
30 |
+
Note: This function uses `transformers.T5ForConditionalGeneration`. Future releases
|
31 |
+
will expand support to more general classes of HF Encoder-Decoder models.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
cfg (DictConfig): An omegaconf dictionary used to configure the model:
|
35 |
+
cfg.pretrained_model_name_or_path (str): The name of or local path to
|
36 |
+
the HF model (e.g., `t5-base` to instantiate a T5 using the base config).
|
37 |
+
cfg.config_overrides (dict, optional): An optional dictionary of keyword
|
38 |
+
arguments that override the default configuration associated with
|
39 |
+
cfg.pretrained_model_name_or_path. Default: ``{}``.
|
40 |
+
cfg.pretrained (bool): Whether to instantiate the model with pre-trained
|
41 |
+
weights coming from cfg.pretrained_model_name_or_path. If ``True``,
|
42 |
+
cfg.config_overrides must be compatible with the pre-trained weights.
|
43 |
+
cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
|
44 |
+
initialize the model on. Currently, `meta` is only supported when
|
45 |
+
cfg.pretrained is ``False``. Default: ``'cpu'``.
|
46 |
+
cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
|
47 |
+
the z-loss will be multiplied by this value before being added to the
|
48 |
+
standard loss term. Default: ``0.0``.
|
49 |
+
cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
|
50 |
+
of the model/tokenizer to include sentinel tokens that are used in denoising
|
51 |
+
tasks like Span Corruption. If you intend to load from an existing Composer
|
52 |
+
checkpoint that was trained on such a task, set this to ``True`` to ensure
|
53 |
+
that the model vocab size matches your checkpoint's vocab size when loading
|
54 |
+
the weights. Default: ``False``.
|
55 |
+
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, om_model_config: DictConfig,
|
59 |
+
tokenizer: PreTrainedTokenizerBase):
|
60 |
+
config = AutoConfig.from_pretrained(
|
61 |
+
om_model_config.pretrained_model_name_or_path,
|
62 |
+
trust_remote_code=om_model_config.get('trust_remote_code', True),
|
63 |
+
use_auth_token=om_model_config.get('use_auth_token', False),
|
64 |
+
)
|
65 |
+
|
66 |
+
# set config overrides
|
67 |
+
for k, v in om_model_config.get('config_overrides', {}).items():
|
68 |
+
if not hasattr(config, k):
|
69 |
+
raise ValueError(
|
70 |
+
f'config does not have attribute "{k}" to override ({k}: {v}).'
|
71 |
+
)
|
72 |
+
|
73 |
+
attr = getattr(config, k)
|
74 |
+
if isinstance(attr, Mapping):
|
75 |
+
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
|
76 |
+
if extra_keys:
|
77 |
+
raise ValueError(
|
78 |
+
f'Config dict override got unknown keys. ' +
|
79 |
+
f'Extra keys: {extra_keys}. ' +
|
80 |
+
f'Expected (a subset of) keys: {list(attr.keys())}.')
|
81 |
+
getattr(config, k).update(v)
|
82 |
+
else:
|
83 |
+
setattr(config, k, v)
|
84 |
+
|
85 |
+
if not config.is_encoder_decoder:
|
86 |
+
raise ValueError(f'Model type "hf_t5" currently only supports T5 models ' +\
|
87 |
+
f'using configs where `is_encoder_decoder` is ``True``.')
|
88 |
+
|
89 |
+
# Set up the tokenizer (add tokens for denoising sentinels if needed)
|
90 |
+
if om_model_config.get('adapt_vocab_for_denoising', False):
|
91 |
+
adapt_tokenizer_for_denoising(tokenizer)
|
92 |
+
|
93 |
+
init_device = om_model_config.get('init_device', 'cpu')
|
94 |
+
|
95 |
+
# Get the device we want to initialize, and use the
|
96 |
+
# resolved version to initialize the HF model
|
97 |
+
resolved_init_device = hf_get_init_device(init_device)
|
98 |
+
|
99 |
+
# We need to have all non-zero local ranks be not-pretrained
|
100 |
+
# Rank 0 will still be pretrained, and distribute the weights appropriately
|
101 |
+
if dist.get_local_rank() != 0 and init_device == 'mixed':
|
102 |
+
om_model_config.pretrained = False
|
103 |
+
|
104 |
+
if resolved_init_device == 'cpu':
|
105 |
+
if om_model_config.pretrained:
|
106 |
+
model = T5ForConditionalGeneration.from_pretrained(
|
107 |
+
om_model_config.pretrained_model_name_or_path,
|
108 |
+
config=config)
|
109 |
+
else:
|
110 |
+
model = T5ForConditionalGeneration(config)
|
111 |
+
elif resolved_init_device == 'meta':
|
112 |
+
if om_model_config.pretrained:
|
113 |
+
raise ValueError(
|
114 |
+
'Setting cfg.pretrained=True is not supported when init_device="meta".'
|
115 |
+
)
|
116 |
+
with init_empty_weights(include_buffers=False):
|
117 |
+
model = T5ForConditionalGeneration(config)
|
118 |
+
else:
|
119 |
+
raise ValueError(
|
120 |
+
f'init_device="{init_device}" must be either "cpu" or "meta".')
|
121 |
+
|
122 |
+
metrics = [
|
123 |
+
LanguageCrossEntropy(ignore_index=_HF_IGNORE_INDEX),
|
124 |
+
MaskedAccuracy(ignore_index=_HF_IGNORE_INDEX)
|
125 |
+
]
|
126 |
+
|
127 |
+
composer_model = super().__init__(model=model,
|
128 |
+
tokenizer=tokenizer,
|
129 |
+
metrics=metrics,
|
130 |
+
z_loss=om_model_config.get(
|
131 |
+
'z_loss', 0.0),
|
132 |
+
init_device=init_device)
|
133 |
+
|
134 |
+
return composer_model
|
Perceptrix/finetune/build/lib/llmfoundry/models/hf/model_wrapper.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Re-usable :class:`.ComposerModel` for LLM HF Models."""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
import inspect
|
9 |
+
from collections import UserDict
|
10 |
+
from typing import List, Mapping, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import transformers
|
14 |
+
from composer.models.huggingface import HuggingFaceModel
|
15 |
+
from torchmetrics import Metric
|
16 |
+
from transformers import PreTrainedTokenizerBase
|
17 |
+
from transformers.utils.generic import ModelOutput
|
18 |
+
|
19 |
+
from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp
|
20 |
+
|
21 |
+
# HuggingFace hardcodes the ignore index to -100
|
22 |
+
_HF_IGNORE_INDEX = -100
|
23 |
+
|
24 |
+
|
25 |
+
class HuggingFaceModelWithZLoss(HuggingFaceModel):
|
26 |
+
"""Wrapper around HuggingFaceModel.
|
27 |
+
|
28 |
+
This adds z-loss, which is used in some training contexts,
|
29 |
+
and is a convenient way to patch features that are generically
|
30 |
+
useful for HF models.
|
31 |
+
See use of z_loss in PaLM: https://arxiv.org/abs/2204.02311v3, Section 5.
|
32 |
+
Also, from https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666:
|
33 |
+
Two uses of z_loss are:
|
34 |
+
- To keep the logits from drifting too far from zero, which can cause
|
35 |
+
unacceptable roundoff errors in bfloat16.
|
36 |
+
- To encourage the logits to be normalized log-probabilities.
|
37 |
+
|
38 |
+
Handles preparation for FSDP wrapping.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
model: transformers.PreTrainedModel,
|
43 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
44 |
+
metrics: Optional[List[Metric]] = None,
|
45 |
+
eval_metrics: Optional[List[Metric]] = None,
|
46 |
+
z_loss: float = 0.0,
|
47 |
+
shift_labels: bool = False,
|
48 |
+
init_device: Optional[str] = None):
|
49 |
+
super().__init__(model,
|
50 |
+
tokenizer,
|
51 |
+
use_logits=True,
|
52 |
+
metrics=metrics,
|
53 |
+
eval_metrics=eval_metrics,
|
54 |
+
shift_labels=shift_labels)
|
55 |
+
self.z_loss = float(z_loss)
|
56 |
+
if self.z_loss < 0.0:
|
57 |
+
raise ValueError(f'z_loss(={z_loss}) cannot be negative.')
|
58 |
+
|
59 |
+
self.model_forward_args = inspect.getfullargspec(
|
60 |
+
self.model.forward).args
|
61 |
+
# inspect.getfullargspec HuggingFace quantized model could not return args correctly
|
62 |
+
if not self.model_forward_args:
|
63 |
+
self.model_forward_args = inspect.signature(
|
64 |
+
self.model.forward).parameters.keys()
|
65 |
+
|
66 |
+
# Note: We need to add the FSDP related attributes to the model AFTER the super init,
|
67 |
+
# so that the (possible) embedding resizing doesn't destroy them
|
68 |
+
prepare_hf_model_for_fsdp(self.model, init_device)
|
69 |
+
|
70 |
+
# This provides support for meta initialization when using FSDP
|
71 |
+
self.model.param_init_fn = lambda module: self.model._init_weights(
|
72 |
+
module)
|
73 |
+
|
74 |
+
def forward(self, batch: Mapping):
|
75 |
+
if isinstance(batch, dict) or isinstance(batch, UserDict):
|
76 |
+
# Further input validation is left to the huggingface forward call
|
77 |
+
batch = {
|
78 |
+
k: v for k, v in batch.items() if k in self.model_forward_args
|
79 |
+
}
|
80 |
+
output = self.model(**batch) # type: ignore (thirdparty)
|
81 |
+
else:
|
82 |
+
raise ValueError(
|
83 |
+
'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model'
|
84 |
+
)
|
85 |
+
return output
|
86 |
+
|
87 |
+
def loss(self, outputs: ModelOutput, batch: Mapping):
|
88 |
+
if self.config.use_return_dict:
|
89 |
+
loss, logits = outputs['loss'], outputs['logits']
|
90 |
+
else:
|
91 |
+
# loss is at index 0 in the output tuple, logits are at index 1
|
92 |
+
loss, logits = outputs[:2]
|
93 |
+
if self.z_loss == 0.0:
|
94 |
+
return loss
|
95 |
+
|
96 |
+
# Add a z_loss to the standard loss
|
97 |
+
logits_flat = logits.view(-1, logits.size(-1))
|
98 |
+
labels_flat = batch['labels'].view(-1)
|
99 |
+
log_z = torch.logsumexp(logits_flat[labels_flat != _HF_IGNORE_INDEX],
|
100 |
+
dim=1)
|
101 |
+
log_z2 = log_z**2
|
102 |
+
z_loss = log_z2.mean() * self.z_loss
|
103 |
+
if self.config.use_return_dict:
|
104 |
+
outputs['loss'] += z_loss
|
105 |
+
return outputs['loss']
|
106 |
+
else:
|
107 |
+
outputs[0] += z_loss
|
108 |
+
return outputs[0]
|
Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from llmfoundry.models.inference_api_wrapper.interface import \
|
5 |
+
InferenceAPIEvalWrapper
|
6 |
+
from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
|
7 |
+
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper)
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'OpenAICausalLMEvalWrapper',
|
11 |
+
'OpenAIChatAPIEvalWrapper',
|
12 |
+
'InferenceAPIEvalWrapper',
|
13 |
+
]
|
Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/interface.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Dict, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from composer.core.types import Batch
|
8 |
+
from composer.metrics import InContextLearningMetric
|
9 |
+
from composer.metrics.nlp import (InContextLearningLMAccuracy,
|
10 |
+
InContextLearningLMExpectedCalibrationError,
|
11 |
+
InContextLearningMCExpectedCalibrationError,
|
12 |
+
InContextLearningMultipleChoiceAccuracy,
|
13 |
+
InContextLearningQAAccuracy,
|
14 |
+
LanguageCrossEntropy, LanguagePerplexity)
|
15 |
+
from composer.models import ComposerModel
|
16 |
+
from torchmetrics import Metric
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
|
20 |
+
class InferenceAPIEvalWrapper(ComposerModel):
|
21 |
+
|
22 |
+
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
|
23 |
+
self.tokenizer = tokenizer
|
24 |
+
self.labels = None
|
25 |
+
# set up training and eval metrics
|
26 |
+
eval_metrics = [
|
27 |
+
LanguageCrossEntropy(),
|
28 |
+
LanguagePerplexity(),
|
29 |
+
InContextLearningLMAccuracy(),
|
30 |
+
InContextLearningMultipleChoiceAccuracy(),
|
31 |
+
InContextLearningQAAccuracy(),
|
32 |
+
InContextLearningLMExpectedCalibrationError(),
|
33 |
+
InContextLearningMCExpectedCalibrationError()
|
34 |
+
]
|
35 |
+
self.eval_metrics = {
|
36 |
+
metric.__class__.__name__: metric for metric in eval_metrics
|
37 |
+
}
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
def get_metrics(self, is_train: bool = False):
|
41 |
+
if is_train:
|
42 |
+
raise NotImplementedError(
|
43 |
+
'You cannot use inference wrappers for training')
|
44 |
+
else:
|
45 |
+
metrics = self.eval_metrics
|
46 |
+
|
47 |
+
return metrics if metrics else {}
|
48 |
+
|
49 |
+
def get_next_token_logit_tensor(self,
|
50 |
+
prompt: str) -> Optional[torch.Tensor]:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
def rebatch(self, batch: Batch):
|
54 |
+
# default is a no-op, but Chat API modifies these
|
55 |
+
return batch
|
56 |
+
|
57 |
+
def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
|
58 |
+
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
|
59 |
+
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
|
60 |
+
# be returned from eval_forward
|
61 |
+
output_logits_batch = []
|
62 |
+
for tokens, cont_idxs in zip(batch['input_ids'],
|
63 |
+
batch['continuation_indices']):
|
64 |
+
|
65 |
+
seqlen = tokens.shape[0]
|
66 |
+
tokens = tokens.tolist()
|
67 |
+
cont_idxs = cont_idxs.tolist()
|
68 |
+
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
|
69 |
+
output_logits = torch.nn.functional.one_hot(
|
70 |
+
torch.tensor(tokens[1:cont_idxs[0]]),
|
71 |
+
num_classes=self.tokenizer.vocab_size)
|
72 |
+
for i in range(len(expected_cont_tokens)):
|
73 |
+
# decode one token at a time
|
74 |
+
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] +
|
75 |
+
expected_cont_tokens[0:i])
|
76 |
+
next_logit_tensor = self.get_next_token_logit_tensor(prompt)
|
77 |
+
if next_logit_tensor is None:
|
78 |
+
continue
|
79 |
+
output_logits = torch.cat(
|
80 |
+
[output_logits,
|
81 |
+
next_logit_tensor.reshape(1, -1)])
|
82 |
+
padding = torch.nn.functional.one_hot(
|
83 |
+
torch.full((seqlen - output_logits.shape[0],),
|
84 |
+
self.tokenizer.pad_token_id),
|
85 |
+
num_classes=self.tokenizer.vocab_size)
|
86 |
+
output_logits = torch.cat([output_logits, padding])
|
87 |
+
output_logits_batch.append(output_logits)
|
88 |
+
|
89 |
+
return torch.stack(output_logits_batch).to(batch['input_ids'].device)
|
90 |
+
|
91 |
+
def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
|
92 |
+
batch = self.rebatch(batch)
|
93 |
+
self.labels = batch.pop('labels')
|
94 |
+
self.labels[:, :-1] = self.labels[:, 1:].clone()
|
95 |
+
self.labels[:, -1] = -100
|
96 |
+
if isinstance(metric, InContextLearningMetric) and batch.get(
|
97 |
+
'mode', None) == 'icl_task':
|
98 |
+
assert self.labels is not None
|
99 |
+
metric.update(batch, outputs, self.labels)
|
100 |
+
else:
|
101 |
+
raise NotImplementedError(
|
102 |
+
'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task'
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self):
|
106 |
+
raise NotImplementedError(
|
107 |
+
"Inference API wrapper doesn't support forward")
|
108 |
+
|
109 |
+
def loss(self):
|
110 |
+
raise NotImplementedError("Inference API wrapper doesn't support loss")
|
Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Implements a OpenAI chat and causal LM inference API wrappers."""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
from time import sleep
|
9 |
+
from typing import Any, Dict, List, Optional, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from composer.core.types import Batch
|
13 |
+
from composer.utils.import_helpers import MissingConditionalImportError
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
log = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
from llmfoundry.models.inference_api_wrapper.interface import \
|
19 |
+
InferenceAPIEvalWrapper
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
'OpenAICausalLMEvalWrapper',
|
23 |
+
'OpenAIChatAPIEvalWrapper',
|
24 |
+
]
|
25 |
+
|
26 |
+
MAX_RETRIES = 10
|
27 |
+
|
28 |
+
|
29 |
+
class OpenAIEvalInterface(InferenceAPIEvalWrapper):
|
30 |
+
|
31 |
+
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
|
32 |
+
super().__init__(model_cfg, tokenizer)
|
33 |
+
try:
|
34 |
+
import openai
|
35 |
+
except ImportError as e:
|
36 |
+
raise MissingConditionalImportError(
|
37 |
+
extra_deps_group='openai',
|
38 |
+
conda_package='openai',
|
39 |
+
conda_channel='conda-forge') from e
|
40 |
+
openai.api_key = os.getenv('OPENAI_API_KEY')
|
41 |
+
self.model_name = model_cfg['version']
|
42 |
+
|
43 |
+
def generate_completion(self, prompt: str, num_tokens: int):
|
44 |
+
raise NotImplementedError()
|
45 |
+
|
46 |
+
def process_result(self, completion: Optional[dict]):
|
47 |
+
raise NotImplementedError()
|
48 |
+
|
49 |
+
def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):
|
50 |
+
completion = self.try_generate_completion(prompt, num_tokens)
|
51 |
+
return self.process_result(completion)
|
52 |
+
|
53 |
+
def try_generate_completion(self, prompt: str, num_tokens: int):
|
54 |
+
try:
|
55 |
+
from openai.error import RateLimitError
|
56 |
+
except ImportError as e:
|
57 |
+
raise MissingConditionalImportError(
|
58 |
+
extra_deps_group='openai',
|
59 |
+
conda_package='openai',
|
60 |
+
conda_channel='conda-forge') from e
|
61 |
+
tries = 0
|
62 |
+
completion = None
|
63 |
+
while tries < MAX_RETRIES:
|
64 |
+
tries += 1
|
65 |
+
try:
|
66 |
+
|
67 |
+
completion = self.generate_completion(prompt, num_tokens)
|
68 |
+
break
|
69 |
+
except RateLimitError as e:
|
70 |
+
if 'You exceeded your current quota' in str(e._message):
|
71 |
+
raise e
|
72 |
+
sleep(60)
|
73 |
+
continue
|
74 |
+
except Exception:
|
75 |
+
continue
|
76 |
+
return completion
|
77 |
+
|
78 |
+
|
79 |
+
class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):
|
80 |
+
|
81 |
+
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
|
82 |
+
super().__init__(model_cfg, tokenizer)
|
83 |
+
try:
|
84 |
+
import openai
|
85 |
+
except ImportError as e:
|
86 |
+
raise MissingConditionalImportError(
|
87 |
+
extra_deps_group='openai',
|
88 |
+
conda_package='openai',
|
89 |
+
conda_channel='conda-forge') from e
|
90 |
+
|
91 |
+
self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create(
|
92 |
+
self.model_name,
|
93 |
+
messages=[{
|
94 |
+
'role': 'user',
|
95 |
+
'content': prompt
|
96 |
+
}],
|
97 |
+
max_tokens=num_tokens,
|
98 |
+
temperature=0.0)
|
99 |
+
|
100 |
+
def retokenize(self, tokens: List[int], cont_idxs: List[int]):
|
101 |
+
"""Chat API will never respond with a word-initial space.
|
102 |
+
|
103 |
+
If the continuation tokens begin with a word initial space, we need to
|
104 |
+
re-tokenize with the space removed.
|
105 |
+
"""
|
106 |
+
original_len = len(tokens)
|
107 |
+
retokenized_continuation = self.tokenizer(
|
108 |
+
self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] +
|
109 |
+
1]).strip())['input_ids']
|
110 |
+
|
111 |
+
# replace the original continuation with the retokenized continuation + padding
|
112 |
+
padding = [tokens[-1]] * (
|
113 |
+
len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation))
|
114 |
+
tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding
|
115 |
+
|
116 |
+
if len(tokens) > original_len:
|
117 |
+
# this only happens if we were already at max seq len and the continuation got LARGER
|
118 |
+
tokens = tokens[-original_len:]
|
119 |
+
cont_idxs = list(
|
120 |
+
range(original_len - len(retokenized_continuation),
|
121 |
+
original_len))
|
122 |
+
else:
|
123 |
+
cont_idxs = list(
|
124 |
+
range(cont_idxs[0],
|
125 |
+
cont_idxs[0] + len(retokenized_continuation)))
|
126 |
+
return torch.tensor(tokens), torch.tensor(cont_idxs)
|
127 |
+
|
128 |
+
def rebatch(self, batch: Batch):
|
129 |
+
"""Chat API tokenization has different behavior than GPT3.
|
130 |
+
|
131 |
+
Model responses will never begin with spaces even if the continuation is
|
132 |
+
expected to, so we need to retokenize the input to account for that.
|
133 |
+
"""
|
134 |
+
new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {
|
135 |
+
'input_ids': [],
|
136 |
+
'continuation_indices': [],
|
137 |
+
'labels': []
|
138 |
+
}
|
139 |
+
for tokens, cont_idxs in zip(batch['input_ids'],
|
140 |
+
batch['continuation_indices']):
|
141 |
+
tokens, cont_idxs = self.retokenize(tokens.tolist(),
|
142 |
+
cont_idxs.tolist())
|
143 |
+
|
144 |
+
assert isinstance(new_batch['input_ids'], list)
|
145 |
+
new_batch['input_ids'].append(tokens)
|
146 |
+
assert isinstance(new_batch['labels'], list)
|
147 |
+
new_batch['labels'].append(tokens)
|
148 |
+
assert isinstance(new_batch['continuation_indices'], list)
|
149 |
+
new_batch['continuation_indices'].append(cont_idxs)
|
150 |
+
|
151 |
+
new_batch.update({
|
152 |
+
k: torch.stack(new_batch[k]) # pyright: ignore
|
153 |
+
for k in ['input_ids', 'labels']
|
154 |
+
})
|
155 |
+
|
156 |
+
new_batch.update({k: v for k, v in batch.items() if k not in new_batch})
|
157 |
+
|
158 |
+
return new_batch
|
159 |
+
|
160 |
+
def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
|
161 |
+
# Override the base class because Chat's API always strips spacing from model outputs resulting in different tokens
|
162 |
+
# than what the continuation would expect.
|
163 |
+
# Get around this issue by retokenizing the batch to remove spacing from the continuation as well as
|
164 |
+
# decoding the whole continuation at once.
|
165 |
+
output_logits_batch = []
|
166 |
+
batch = self.rebatch(batch)
|
167 |
+
for tokens, cont_idxs in zip(batch['input_ids'],
|
168 |
+
batch['continuation_indices']):
|
169 |
+
|
170 |
+
seqlen = tokens.shape[0]
|
171 |
+
tokens = tokens.tolist()
|
172 |
+
cont_idxs = cont_idxs.tolist()
|
173 |
+
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
|
174 |
+
output_logits = torch.nn.functional.one_hot(
|
175 |
+
torch.tensor(tokens[1:cont_idxs[0]]),
|
176 |
+
num_classes=self.tokenizer.vocab_size)
|
177 |
+
|
178 |
+
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]])
|
179 |
+
next_logit_tensor = self.get_next_token_logit_tensor(
|
180 |
+
prompt, num_tokens=len(expected_cont_tokens))
|
181 |
+
|
182 |
+
if next_logit_tensor is not None:
|
183 |
+
output_logits = torch.cat([output_logits, next_logit_tensor])
|
184 |
+
padding = torch.nn.functional.one_hot(
|
185 |
+
torch.full((seqlen - output_logits.shape[0],),
|
186 |
+
self.tokenizer.pad_token_id),
|
187 |
+
num_classes=self.tokenizer.vocab_size)
|
188 |
+
output_logits = torch.cat([output_logits, padding])
|
189 |
+
output_logits_batch.append(output_logits)
|
190 |
+
|
191 |
+
return torch.stack(output_logits_batch).to(batch['input_ids'].device)
|
192 |
+
|
193 |
+
def process_result(self, completion: Optional[dict]):
|
194 |
+
assert isinstance(completion, dict)
|
195 |
+
if len(completion['choices']) > 0:
|
196 |
+
tensors = []
|
197 |
+
for t in self.tokenizer(completion['choices'][0]['message']
|
198 |
+
['content'])['input_ids']:
|
199 |
+
tensors.append(
|
200 |
+
self.tokenizer.construct_logit_tensor(
|
201 |
+
{self.tokenizer.decode([t]): 0.0}))
|
202 |
+
|
203 |
+
if len(tensors) == 0:
|
204 |
+
return None
|
205 |
+
return torch.stack(tensors)
|
206 |
+
else:
|
207 |
+
# the model sometimes stops early even though we are still requesting tokens!
|
208 |
+
# not sure if there's a fix
|
209 |
+
return None
|
210 |
+
|
211 |
+
|
212 |
+
class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):
|
213 |
+
|
214 |
+
def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
|
215 |
+
super().__init__(model_cfg, tokenizer)
|
216 |
+
try:
|
217 |
+
import openai
|
218 |
+
except ImportError as e:
|
219 |
+
raise MissingConditionalImportError(
|
220 |
+
extra_deps_group='openai',
|
221 |
+
conda_package='openai',
|
222 |
+
conda_channel='conda-forge') from e
|
223 |
+
|
224 |
+
self.generate_completion = lambda prompt, num_tokens: openai.Completion.create(
|
225 |
+
engine=self.model_name,
|
226 |
+
prompt=prompt,
|
227 |
+
max_tokens=1,
|
228 |
+
logprobs=5,
|
229 |
+
temperature=0.0)
|
230 |
+
|
231 |
+
def process_result(self, completion: Optional[dict]):
|
232 |
+
if completion is None:
|
233 |
+
raise ValueError("Couldn't generate model output")
|
234 |
+
|
235 |
+
assert isinstance(completion, dict)
|
236 |
+
if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0:
|
237 |
+
tensor = self.tokenizer.construct_logit_tensor(
|
238 |
+
dict(completion['choices'][0]['logprobs']['top_logprobs'][0]))
|
239 |
+
return tensor
|
240 |
+
else:
|
241 |
+
# the model sometimes stops early even though we are still requesting tokens!
|
242 |
+
# not sure if there's a fix
|
243 |
+
return None
|
Perceptrix/finetune/build/lib/llmfoundry/models/layers/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from llmfoundry.models.layers.attention import (
|
5 |
+
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
|
6 |
+
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
|
7 |
+
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
|
8 |
+
from llmfoundry.models.layers.blocks import MPTBlock
|
9 |
+
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
|
10 |
+
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
|
11 |
+
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
|
12 |
+
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
'scaled_multihead_dot_product_attention',
|
16 |
+
'flash_attn_fn',
|
17 |
+
'triton_flash_attn_fn',
|
18 |
+
'MultiheadAttention',
|
19 |
+
'MultiQueryAttention',
|
20 |
+
'attn_bias_shape',
|
21 |
+
'build_attn_bias',
|
22 |
+
'build_alibi_bias',
|
23 |
+
'ATTN_CLASS_REGISTRY',
|
24 |
+
'MPTMLP',
|
25 |
+
'MPTBlock',
|
26 |
+
'NORM_CLASS_REGISTRY',
|
27 |
+
'LPLayerNorm',
|
28 |
+
'FC_CLASS_REGISTRY',
|
29 |
+
'SharedEmbedding',
|
30 |
+
'FFN_CLASS_REGISTRY',
|
31 |
+
'build_ffn',
|
32 |
+
]
|
Perceptrix/finetune/build/lib/llmfoundry/models/layers/attention.py
ADDED
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Attention layers."""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import warnings
|
8 |
+
from typing import Any, List, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from einops import rearrange
|
13 |
+
from packaging import version
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
|
17 |
+
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
18 |
+
|
19 |
+
|
20 |
+
def is_flash_v2_installed():
|
21 |
+
try:
|
22 |
+
import flash_attn as flash_attn
|
23 |
+
except:
|
24 |
+
return False
|
25 |
+
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
|
26 |
+
|
27 |
+
|
28 |
+
def is_flash_v1_installed():
|
29 |
+
try:
|
30 |
+
import flash_attn as flash_attn
|
31 |
+
except:
|
32 |
+
return False
|
33 |
+
return version.parse(flash_attn.__version__) < version.parse('2.0.0')
|
34 |
+
|
35 |
+
|
36 |
+
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
|
37 |
+
original_is_causal: bool) -> bool:
|
38 |
+
# disable causal when it is not needed
|
39 |
+
# necessary for flash & triton for generation with kv_cache
|
40 |
+
if original_is_causal and num_query_tokens != num_key_tokens:
|
41 |
+
if num_query_tokens != 1:
|
42 |
+
raise NotImplementedError(
|
43 |
+
'MPT does not support query and key with different number of tokens, unless number of query tokens is 1.'
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
return False
|
47 |
+
return original_is_causal
|
48 |
+
|
49 |
+
|
50 |
+
def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
|
51 |
+
"""Perform repeat of kv heads along a particular dimension.
|
52 |
+
|
53 |
+
hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
|
54 |
+
n_rep: amount of repetitions of kv_n_heads
|
55 |
+
Unlike torch.repeat_interleave, this function avoids allocating new memory.
|
56 |
+
"""
|
57 |
+
if n_rep == 1:
|
58 |
+
return hidden
|
59 |
+
|
60 |
+
b, s, kv_n_heads, d = hidden.shape
|
61 |
+
|
62 |
+
hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
|
63 |
+
|
64 |
+
return hidden.reshape(b, s, kv_n_heads * n_rep, d)
|
65 |
+
|
66 |
+
|
67 |
+
def scaled_multihead_dot_product_attention(
|
68 |
+
query: torch.Tensor,
|
69 |
+
key: torch.Tensor,
|
70 |
+
value: torch.Tensor,
|
71 |
+
n_heads: int,
|
72 |
+
kv_n_heads: Optional[int] = None,
|
73 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
74 |
+
softmax_scale: Optional[float] = None,
|
75 |
+
attn_bias: Optional[torch.Tensor] = None,
|
76 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
77 |
+
is_causal: bool = False,
|
78 |
+
dropout_p: float = 0.0,
|
79 |
+
training: bool = False,
|
80 |
+
needs_weights: bool = False,
|
81 |
+
multiquery: bool = False,
|
82 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
|
83 |
+
torch.Tensor]]]:
|
84 |
+
|
85 |
+
if multiquery:
|
86 |
+
warnings.warn(
|
87 |
+
DeprecationWarning(
|
88 |
+
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
|
89 |
+
))
|
90 |
+
kv_n_heads = 1
|
91 |
+
elif kv_n_heads is None:
|
92 |
+
warnings.warn(
|
93 |
+
DeprecationWarning(
|
94 |
+
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
|
95 |
+
))
|
96 |
+
kv_n_heads = n_heads
|
97 |
+
|
98 |
+
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
99 |
+
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
100 |
+
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
101 |
+
|
102 |
+
if past_key_value is not None:
|
103 |
+
# attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
|
104 |
+
# kv_cache is therefore stored using that shape.
|
105 |
+
# attn_impl: torch stores the kv_cache in the ordering which is most advantageous
|
106 |
+
# for its attn computation ie
|
107 |
+
# keys are stored as tensors with shape [b, h, d_head, s] and
|
108 |
+
# values are stored as tensors with shape [b, h, s, d_head]
|
109 |
+
if len(past_key_value) != 0:
|
110 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
111 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
112 |
+
|
113 |
+
past_key_value = (k, v)
|
114 |
+
|
115 |
+
b, _, s_q, d = q.shape
|
116 |
+
s_k = k.size(-1)
|
117 |
+
|
118 |
+
# grouped query case
|
119 |
+
if kv_n_heads > 1 and kv_n_heads < n_heads:
|
120 |
+
# necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
|
121 |
+
k = repeat_kv_for_gqa(k.transpose(1, 2),
|
122 |
+
n_heads // kv_n_heads).transpose(1, 2)
|
123 |
+
v = repeat_kv_for_gqa(v.transpose(1, 2),
|
124 |
+
n_heads // kv_n_heads).transpose(1, 2)
|
125 |
+
|
126 |
+
if softmax_scale is None:
|
127 |
+
softmax_scale = 1 / math.sqrt(d)
|
128 |
+
|
129 |
+
attn_weight = q.matmul(k) * softmax_scale
|
130 |
+
|
131 |
+
if attn_bias is not None:
|
132 |
+
# clamp to 0 necessary for torch 2.0 compile()
|
133 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
134 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
135 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
136 |
+
|
137 |
+
if (attn_bias.size(-1) != 1 and
|
138 |
+
attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
|
139 |
+
attn_bias.size(-2) != s_q):
|
140 |
+
raise RuntimeError(
|
141 |
+
f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
|
142 |
+
)
|
143 |
+
attn_weight = attn_weight + attn_bias
|
144 |
+
|
145 |
+
min_val = torch.finfo(q.dtype).min
|
146 |
+
|
147 |
+
if key_padding_mask is not None:
|
148 |
+
if attn_bias is not None:
|
149 |
+
warnings.warn(
|
150 |
+
'Propagating key_padding_mask to the attention module ' +\
|
151 |
+
'and applying it within the attention module can cause ' +\
|
152 |
+
'unnecessary computation/memory usage. Consider integrating ' +\
|
153 |
+
'into attn_bias once and passing that to each attention ' +\
|
154 |
+
'module instead.'
|
155 |
+
)
|
156 |
+
attn_weight = attn_weight.masked_fill(
|
157 |
+
~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
158 |
+
|
159 |
+
if is_causal and (not q.size(2) == 1):
|
160 |
+
s = max(s_q, s_k)
|
161 |
+
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
|
162 |
+
causal_mask = causal_mask.tril()
|
163 |
+
causal_mask = causal_mask.to(torch.bool)
|
164 |
+
causal_mask = ~causal_mask
|
165 |
+
causal_mask = causal_mask[-s_q:, -s_k:]
|
166 |
+
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
|
167 |
+
min_val)
|
168 |
+
|
169 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
170 |
+
|
171 |
+
if dropout_p:
|
172 |
+
attn_weight = torch.nn.functional.dropout(attn_weight,
|
173 |
+
p=dropout_p,
|
174 |
+
training=training,
|
175 |
+
inplace=True)
|
176 |
+
|
177 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
178 |
+
out = rearrange(out, 'b h s d -> b s (h d)')
|
179 |
+
|
180 |
+
if needs_weights:
|
181 |
+
return out, attn_weight, past_key_value
|
182 |
+
return out, None, past_key_value
|
183 |
+
|
184 |
+
|
185 |
+
def check_valid_inputs(*tensors: torch.Tensor,
|
186 |
+
valid_dtypes: Optional[List[torch.dtype]] = None):
|
187 |
+
if valid_dtypes is None:
|
188 |
+
valid_dtypes = [torch.float16, torch.bfloat16]
|
189 |
+
for tensor in tensors:
|
190 |
+
if tensor.dtype not in valid_dtypes:
|
191 |
+
raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
|
192 |
+
if not tensor.is_cuda:
|
193 |
+
raise TypeError(f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
|
194 |
+
|
195 |
+
|
196 |
+
def flash_attn_fn(
|
197 |
+
query: torch.Tensor,
|
198 |
+
key: torch.Tensor,
|
199 |
+
value: torch.Tensor,
|
200 |
+
n_heads: int,
|
201 |
+
kv_n_heads: Optional[int] = None,
|
202 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
203 |
+
softmax_scale: Optional[float] = None,
|
204 |
+
attn_bias: Optional[torch.Tensor] = None,
|
205 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
206 |
+
is_causal: bool = False,
|
207 |
+
dropout_p: float = 0.0,
|
208 |
+
training: bool = False,
|
209 |
+
needs_weights: bool = False,
|
210 |
+
multiquery: bool = False,
|
211 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
|
212 |
+
torch.Tensor]]]:
|
213 |
+
try:
|
214 |
+
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
|
215 |
+
except:
|
216 |
+
raise RuntimeError(
|
217 |
+
'Please install flash-attn==1.0.9 or flash-attn==2.3.2')
|
218 |
+
|
219 |
+
check_valid_inputs(query, key, value)
|
220 |
+
|
221 |
+
if multiquery:
|
222 |
+
warnings.warn(
|
223 |
+
DeprecationWarning(
|
224 |
+
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
|
225 |
+
))
|
226 |
+
kv_n_heads = 1
|
227 |
+
elif kv_n_heads is None:
|
228 |
+
warnings.warn(
|
229 |
+
DeprecationWarning(
|
230 |
+
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
|
231 |
+
))
|
232 |
+
kv_n_heads = n_heads
|
233 |
+
|
234 |
+
if past_key_value is not None:
|
235 |
+
if len(past_key_value) != 0:
|
236 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
237 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
238 |
+
|
239 |
+
past_key_value = (key, value)
|
240 |
+
|
241 |
+
if attn_bias is not None:
|
242 |
+
# clamp to 0 necessary for torch 2.0 compile()
|
243 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
244 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
245 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
246 |
+
|
247 |
+
if attn_bias is not None:
|
248 |
+
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
249 |
+
|
250 |
+
batch_size, seqlen = query.shape[:2]
|
251 |
+
|
252 |
+
if key_padding_mask is None:
|
253 |
+
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
254 |
+
query_padding_mask = key_padding_mask[:, -query.size(1):]
|
255 |
+
|
256 |
+
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
|
257 |
+
query, query_padding_mask)
|
258 |
+
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
|
259 |
+
|
260 |
+
key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
|
261 |
+
key, key_padding_mask)
|
262 |
+
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
263 |
+
|
264 |
+
value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
|
265 |
+
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
|
266 |
+
|
267 |
+
# multi-query case
|
268 |
+
if kv_n_heads == 1:
|
269 |
+
# Expanding a tensor does not allocate new memory, but only creates a new
|
270 |
+
# view on the existing tensor where a dimension of size one is expanded
|
271 |
+
# to a larger size by setting the stride to 0.
|
272 |
+
# - pytorch docs
|
273 |
+
#
|
274 |
+
# hopefully the kernels can utilize this and we're jot just wasting BW here
|
275 |
+
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
|
276 |
+
key_unpad.size(-1))
|
277 |
+
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
|
278 |
+
value_unpad.size(-1))
|
279 |
+
# grouped query case
|
280 |
+
elif kv_n_heads < n_heads:
|
281 |
+
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
|
282 |
+
# We repeat each kv head by the group size number to use the underlying MHA kernels
|
283 |
+
|
284 |
+
# since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
|
285 |
+
# we use .view to modify {key, value}_unpad appropriately
|
286 |
+
|
287 |
+
key_unpad = repeat_kv_for_gqa(
|
288 |
+
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
|
289 |
+
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
|
290 |
+
value_unpad = repeat_kv_for_gqa(
|
291 |
+
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
|
292 |
+
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
|
293 |
+
|
294 |
+
dropout_p = dropout_p if training else 0.0
|
295 |
+
|
296 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
297 |
+
|
298 |
+
if is_flash_v1_installed():
|
299 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
300 |
+
q=query_unpad,
|
301 |
+
k=key_unpad,
|
302 |
+
v=value_unpad,
|
303 |
+
cu_seqlens_q=cu_seqlens_q,
|
304 |
+
cu_seqlens_k=cu_seqlens_k,
|
305 |
+
max_seqlen_q=max_seqlen_q,
|
306 |
+
max_seqlen_k=max_seqlen_k,
|
307 |
+
dropout_p=dropout_p,
|
308 |
+
softmax_scale=softmax_scale,
|
309 |
+
causal=reset_is_causal,
|
310 |
+
return_attn_probs=needs_weights)
|
311 |
+
elif is_flash_v2_installed():
|
312 |
+
output_unpad = flash_attn_interface.flash_attn_varlen_func(
|
313 |
+
q=query_unpad,
|
314 |
+
k=key_unpad,
|
315 |
+
v=value_unpad,
|
316 |
+
cu_seqlens_q=cu_seqlens_q,
|
317 |
+
cu_seqlens_k=cu_seqlens_k,
|
318 |
+
max_seqlen_q=max_seqlen_q,
|
319 |
+
max_seqlen_k=max_seqlen_k,
|
320 |
+
dropout_p=dropout_p,
|
321 |
+
softmax_scale=softmax_scale,
|
322 |
+
causal=reset_is_causal,
|
323 |
+
return_attn_probs=needs_weights)
|
324 |
+
else:
|
325 |
+
raise RuntimeError(
|
326 |
+
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
|
327 |
+
|
328 |
+
output = bert_padding.pad_input(
|
329 |
+
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
|
330 |
+
seqlen)
|
331 |
+
return output, None, past_key_value
|
332 |
+
|
333 |
+
|
334 |
+
def triton_flash_attn_fn(
|
335 |
+
query: torch.Tensor,
|
336 |
+
key: torch.Tensor,
|
337 |
+
value: torch.Tensor,
|
338 |
+
n_heads: int,
|
339 |
+
kv_n_heads: Optional[int] = None,
|
340 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
341 |
+
softmax_scale: Optional[float] = None,
|
342 |
+
attn_bias: Optional[torch.Tensor] = None,
|
343 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
344 |
+
is_causal: bool = False,
|
345 |
+
dropout_p: float = 0.0,
|
346 |
+
training: bool = False,
|
347 |
+
needs_weights: bool = False,
|
348 |
+
multiquery: bool = False,
|
349 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
|
350 |
+
torch.Tensor]]]:
|
351 |
+
try:
|
352 |
+
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
|
353 |
+
except:
|
354 |
+
_installed = False
|
355 |
+
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
356 |
+
_installed = True
|
357 |
+
# if torch1.13.1 revert to using triton flash attn from HazyResearch
|
358 |
+
# with flash-attn==1.0.9 and triton==2.0.0.dev20221202
|
359 |
+
try:
|
360 |
+
from flash_attn.flash_attn_triton import flash_attn_func
|
361 |
+
except:
|
362 |
+
_installed = False
|
363 |
+
if not _installed:
|
364 |
+
# installing triton-pre-mlir works for both torch1.13.1 and torch2.0+
|
365 |
+
# default recommendation is to install this variant
|
366 |
+
raise RuntimeError(
|
367 |
+
'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
|
368 |
+
+
|
369 |
+
'and `pip install .[gpu]` if installing from llm-foundry source or '
|
370 |
+
+
|
371 |
+
'`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
|
372 |
+
+
|
373 |
+
'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
|
374 |
+
+
|
375 |
+
'Note: (1) requires you have CMake and PyTorch already installed.'
|
376 |
+
)
|
377 |
+
|
378 |
+
check_valid_inputs(query, key, value)
|
379 |
+
|
380 |
+
if multiquery:
|
381 |
+
warnings.warn(
|
382 |
+
DeprecationWarning(
|
383 |
+
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
|
384 |
+
))
|
385 |
+
kv_n_heads = 1
|
386 |
+
elif kv_n_heads is None:
|
387 |
+
warnings.warn(
|
388 |
+
DeprecationWarning(
|
389 |
+
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
|
390 |
+
))
|
391 |
+
kv_n_heads = n_heads
|
392 |
+
|
393 |
+
if past_key_value is not None:
|
394 |
+
if len(past_key_value) != 0:
|
395 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
396 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
397 |
+
|
398 |
+
past_key_value = (key, value)
|
399 |
+
|
400 |
+
if attn_bias is not None:
|
401 |
+
# clamp to 0 necessary for torch 2.0 compile()
|
402 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
403 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
404 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
405 |
+
|
406 |
+
if dropout_p:
|
407 |
+
raise NotImplementedError(
|
408 |
+
f'Dropout not implemented for attn_impl: triton.')
|
409 |
+
dropout_p = dropout_p if training else 0.0
|
410 |
+
|
411 |
+
if needs_weights:
|
412 |
+
raise NotImplementedError(
|
413 |
+
f'attn_impl: triton cannot return attn weights.')
|
414 |
+
|
415 |
+
if key_padding_mask is not None:
|
416 |
+
warnings.warn(
|
417 |
+
'Propagating key_padding_mask to the attention module ' +\
|
418 |
+
'and applying it within the attention module can cause ' +\
|
419 |
+
'unnecessary computation/memory usage. Consider integrating ' +\
|
420 |
+
'into attn_bias once and passing that to each attention ' +\
|
421 |
+
'module instead.'
|
422 |
+
)
|
423 |
+
b_size, s_k = key_padding_mask.shape[:2]
|
424 |
+
|
425 |
+
if attn_bias is None:
|
426 |
+
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
427 |
+
|
428 |
+
attn_bias = attn_bias.masked_fill(
|
429 |
+
~key_padding_mask.view((b_size, 1, 1, s_k)),
|
430 |
+
torch.finfo(query.dtype).min)
|
431 |
+
|
432 |
+
query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
|
433 |
+
key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
|
434 |
+
value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
|
435 |
+
|
436 |
+
# multi-query case
|
437 |
+
if kv_n_heads == 1:
|
438 |
+
# necessary to repeat instead of expand tensor because
|
439 |
+
# output contains NaN in edge cases such as with head dimension = 8
|
440 |
+
key = key.repeat(1, 1, n_heads, 1)
|
441 |
+
value = value.repeat(1, 1, n_heads, 1)
|
442 |
+
# grouped query case
|
443 |
+
elif kv_n_heads < n_heads:
|
444 |
+
# Each query belong to a group of kv heads of group size n_heads // kv_n_heads
|
445 |
+
# We repeat each kv head by the group size number to use the underlying MHA kernels
|
446 |
+
key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
|
447 |
+
value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
|
448 |
+
|
449 |
+
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
450 |
+
attn_output = flash_attn_func( # type: ignore
|
451 |
+
query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
452 |
+
|
453 |
+
output = attn_output.view(*attn_output.shape[:2], -1) # type: ignore
|
454 |
+
|
455 |
+
return output, None, past_key_value
|
456 |
+
|
457 |
+
|
458 |
+
class GroupedQueryAttention(nn.Module):
|
459 |
+
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
|
460 |
+
|
461 |
+
and Multi-query attention (MQA).
|
462 |
+
|
463 |
+
This allows the user to set a variable of number of kv_n_heads, rather than
|
464 |
+
just n_heads or 1, as in MHA and MQA. Using torch or triton attention
|
465 |
+
implementation enables user to also use additive bias.
|
466 |
+
"""
|
467 |
+
|
468 |
+
def __init__(
|
469 |
+
self,
|
470 |
+
d_model: int,
|
471 |
+
n_heads: int,
|
472 |
+
kv_n_heads: int,
|
473 |
+
attn_impl: str = 'triton',
|
474 |
+
clip_qkv: Optional[float] = None,
|
475 |
+
qk_ln: bool = False,
|
476 |
+
softmax_scale: Optional[float] = None,
|
477 |
+
attn_pdrop: float = 0.0,
|
478 |
+
norm_type: str = 'low_precision_layernorm',
|
479 |
+
fc_type: str = 'torch',
|
480 |
+
device: Optional[str] = None,
|
481 |
+
bias: bool = True,
|
482 |
+
):
|
483 |
+
super().__init__()
|
484 |
+
|
485 |
+
self.attn_impl = attn_impl
|
486 |
+
self.clip_qkv = clip_qkv
|
487 |
+
self.qk_ln = qk_ln
|
488 |
+
|
489 |
+
self.d_model = d_model
|
490 |
+
self.n_heads = n_heads
|
491 |
+
self.kv_n_heads = kv_n_heads
|
492 |
+
|
493 |
+
self.head_dim = d_model // n_heads
|
494 |
+
|
495 |
+
if self.kv_n_heads <= 0:
|
496 |
+
raise ValueError('kv_n_heads should be greater than zero.')
|
497 |
+
|
498 |
+
if self.kv_n_heads > self.n_heads:
|
499 |
+
raise ValueError(
|
500 |
+
'The number of KV heads should be less than or equal to Q heads.'
|
501 |
+
)
|
502 |
+
|
503 |
+
if self.n_heads % self.kv_n_heads != 0:
|
504 |
+
raise ValueError(
|
505 |
+
'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.'
|
506 |
+
)
|
507 |
+
|
508 |
+
self.softmax_scale = softmax_scale
|
509 |
+
if self.softmax_scale is None:
|
510 |
+
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
511 |
+
self.attn_dropout_p = attn_pdrop
|
512 |
+
|
513 |
+
fc_kwargs: dict[str, Any] = {
|
514 |
+
'bias': bias,
|
515 |
+
}
|
516 |
+
if fc_type != 'te':
|
517 |
+
fc_kwargs['device'] = device
|
518 |
+
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
|
519 |
+
self.d_model,
|
520 |
+
self.d_model + 2 * self.kv_n_heads * self.head_dim,
|
521 |
+
**fc_kwargs,
|
522 |
+
)
|
523 |
+
# for param init fn; enables shape based init of fused layers
|
524 |
+
fuse_splits = [
|
525 |
+
i * self.head_dim
|
526 |
+
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
|
527 |
+
]
|
528 |
+
self.Wqkv._fused = (0, fuse_splits)
|
529 |
+
|
530 |
+
if self.qk_ln:
|
531 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
532 |
+
self.q_ln = norm_class(self.d_model, device=device)
|
533 |
+
self.k_ln = norm_class(self.kv_n_heads * self.head_dim,
|
534 |
+
device=device)
|
535 |
+
|
536 |
+
if self.attn_impl == 'flash':
|
537 |
+
self.attn_fn = flash_attn_fn
|
538 |
+
elif self.attn_impl == 'triton':
|
539 |
+
self.attn_fn = triton_flash_attn_fn
|
540 |
+
elif self.attn_impl == 'torch':
|
541 |
+
self.attn_fn = scaled_multihead_dot_product_attention
|
542 |
+
else:
|
543 |
+
raise ValueError(f'{attn_impl=} is an invalid setting.')
|
544 |
+
|
545 |
+
self.out_proj = FC_CLASS_REGISTRY[fc_type](
|
546 |
+
self.d_model,
|
547 |
+
self.d_model,
|
548 |
+
**fc_kwargs,
|
549 |
+
)
|
550 |
+
self.out_proj._is_residual = True
|
551 |
+
|
552 |
+
def forward(
|
553 |
+
self,
|
554 |
+
x: torch.Tensor,
|
555 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
556 |
+
attn_bias: Optional[torch.Tensor] = None,
|
557 |
+
attention_mask: Optional[torch.Tensor] = None,
|
558 |
+
is_causal: bool = True,
|
559 |
+
needs_weights: bool = False,
|
560 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
|
561 |
+
torch.Tensor, torch.Tensor]]]:
|
562 |
+
qkv = self.Wqkv(x)
|
563 |
+
|
564 |
+
if self.clip_qkv:
|
565 |
+
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
566 |
+
|
567 |
+
query, key, value = qkv.split(
|
568 |
+
[
|
569 |
+
self.d_model,
|
570 |
+
self.kv_n_heads * self.head_dim,
|
571 |
+
self.kv_n_heads * self.head_dim,
|
572 |
+
],
|
573 |
+
dim=2,
|
574 |
+
)
|
575 |
+
|
576 |
+
key_padding_mask = attention_mask
|
577 |
+
|
578 |
+
if self.qk_ln:
|
579 |
+
# Applying layernorm to qk
|
580 |
+
dtype = query.dtype
|
581 |
+
query = self.q_ln(query).to(dtype)
|
582 |
+
key = self.k_ln(key).to(dtype)
|
583 |
+
|
584 |
+
context, attn_weights, past_key_value = self.attn_fn(
|
585 |
+
query,
|
586 |
+
key,
|
587 |
+
value,
|
588 |
+
self.n_heads,
|
589 |
+
self.kv_n_heads,
|
590 |
+
past_key_value=past_key_value,
|
591 |
+
softmax_scale=self.softmax_scale,
|
592 |
+
attn_bias=attn_bias,
|
593 |
+
key_padding_mask=key_padding_mask,
|
594 |
+
is_causal=is_causal,
|
595 |
+
dropout_p=self.attn_dropout_p,
|
596 |
+
training=self.training,
|
597 |
+
needs_weights=needs_weights,
|
598 |
+
)
|
599 |
+
|
600 |
+
return self.out_proj(context), attn_weights, past_key_value
|
601 |
+
|
602 |
+
|
603 |
+
class MultiheadAttention(GroupedQueryAttention):
|
604 |
+
"""Multi-head self attention.
|
605 |
+
|
606 |
+
Using torch or triton attention implementation enables user to also use
|
607 |
+
additive bias.
|
608 |
+
"""
|
609 |
+
|
610 |
+
def __init__(
|
611 |
+
self,
|
612 |
+
d_model: int,
|
613 |
+
n_heads: int,
|
614 |
+
attn_impl: str = 'triton',
|
615 |
+
clip_qkv: Optional[float] = None,
|
616 |
+
qk_ln: bool = False,
|
617 |
+
softmax_scale: Optional[float] = None,
|
618 |
+
attn_pdrop: float = 0.0,
|
619 |
+
norm_type: str = 'low_precision_layernorm',
|
620 |
+
fc_type: str = 'torch',
|
621 |
+
device: Optional[str] = None,
|
622 |
+
bias: bool = True,
|
623 |
+
):
|
624 |
+
super().__init__(
|
625 |
+
d_model=d_model,
|
626 |
+
n_heads=n_heads,
|
627 |
+
kv_n_heads=n_heads, # for MHA, same # heads as kv groups
|
628 |
+
attn_impl=attn_impl,
|
629 |
+
clip_qkv=clip_qkv,
|
630 |
+
qk_ln=qk_ln,
|
631 |
+
softmax_scale=softmax_scale,
|
632 |
+
attn_pdrop=attn_pdrop,
|
633 |
+
norm_type=norm_type,
|
634 |
+
fc_type=fc_type,
|
635 |
+
device=device,
|
636 |
+
bias=bias,
|
637 |
+
)
|
638 |
+
|
639 |
+
|
640 |
+
class MultiQueryAttention(GroupedQueryAttention):
|
641 |
+
"""Multi-Query self attention.
|
642 |
+
|
643 |
+
Using torch or triton attention implementation enables user to also use
|
644 |
+
additive bias.
|
645 |
+
"""
|
646 |
+
|
647 |
+
def __init__(
|
648 |
+
self,
|
649 |
+
d_model: int,
|
650 |
+
n_heads: int,
|
651 |
+
attn_impl: str = 'triton',
|
652 |
+
clip_qkv: Optional[float] = None,
|
653 |
+
qk_ln: bool = False,
|
654 |
+
softmax_scale: Optional[float] = None,
|
655 |
+
attn_pdrop: float = 0.0,
|
656 |
+
norm_type: str = 'low_precision_layernorm',
|
657 |
+
fc_type: str = 'torch',
|
658 |
+
device: Optional[str] = None,
|
659 |
+
bias: bool = True,
|
660 |
+
):
|
661 |
+
super().__init__(
|
662 |
+
d_model=d_model,
|
663 |
+
n_heads=n_heads,
|
664 |
+
kv_n_heads=1, # for MQA, 1 head
|
665 |
+
attn_impl=attn_impl,
|
666 |
+
clip_qkv=clip_qkv,
|
667 |
+
qk_ln=qk_ln,
|
668 |
+
softmax_scale=softmax_scale,
|
669 |
+
attn_pdrop=attn_pdrop,
|
670 |
+
norm_type=norm_type,
|
671 |
+
fc_type=fc_type,
|
672 |
+
device=device,
|
673 |
+
bias=bias,
|
674 |
+
)
|
675 |
+
|
676 |
+
|
677 |
+
def attn_bias_shape(
|
678 |
+
attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
|
679 |
+
prefix_lm: bool, causal: bool,
|
680 |
+
use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
|
681 |
+
if attn_impl == 'flash':
|
682 |
+
return None
|
683 |
+
elif attn_impl in ['torch', 'triton']:
|
684 |
+
if alibi:
|
685 |
+
if (prefix_lm or not causal) or use_sequence_id:
|
686 |
+
return (1, n_heads, seq_len, seq_len)
|
687 |
+
return (1, n_heads, 1, seq_len)
|
688 |
+
elif prefix_lm or use_sequence_id:
|
689 |
+
return (1, 1, seq_len, seq_len)
|
690 |
+
return None
|
691 |
+
else:
|
692 |
+
raise ValueError(f'{attn_impl=} is an invalid setting.')
|
693 |
+
|
694 |
+
|
695 |
+
def build_attn_bias(
|
696 |
+
attn_impl: str,
|
697 |
+
attn_bias: torch.Tensor,
|
698 |
+
n_heads: int,
|
699 |
+
seq_len: int,
|
700 |
+
causal: bool = False,
|
701 |
+
alibi: bool = False,
|
702 |
+
alibi_bias_max: int = 8,
|
703 |
+
) -> Optional[torch.Tensor]:
|
704 |
+
if attn_impl == 'flash':
|
705 |
+
return None
|
706 |
+
elif attn_impl in ['torch', 'triton']:
|
707 |
+
if alibi:
|
708 |
+
# in place add alibi to attn bias
|
709 |
+
device, dtype = attn_bias.device, attn_bias.dtype
|
710 |
+
attn_bias = attn_bias.add(
|
711 |
+
build_alibi_bias(
|
712 |
+
n_heads,
|
713 |
+
seq_len,
|
714 |
+
full=not causal,
|
715 |
+
alibi_bias_max=alibi_bias_max,
|
716 |
+
device=device,
|
717 |
+
dtype=dtype,
|
718 |
+
))
|
719 |
+
return attn_bias
|
720 |
+
else:
|
721 |
+
raise ValueError(f'{attn_impl=} is an invalid setting.')
|
722 |
+
|
723 |
+
|
724 |
+
def gen_slopes(n_heads: int,
|
725 |
+
alibi_bias_max: int = 8,
|
726 |
+
device: Optional[torch.device] = None) -> torch.Tensor:
|
727 |
+
_n_heads = 2**math.ceil(math.log2(n_heads))
|
728 |
+
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
729 |
+
m = m.mul(alibi_bias_max / _n_heads)
|
730 |
+
slopes = (1. / torch.pow(2, m))
|
731 |
+
|
732 |
+
if _n_heads != n_heads:
|
733 |
+
# if n_heads is not a power of two,
|
734 |
+
# Huggingface and FasterTransformer calculate slopes normally,
|
735 |
+
# then return this strided concatenation of slopes
|
736 |
+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
737 |
+
|
738 |
+
return slopes.view(1, n_heads, 1, 1)
|
739 |
+
|
740 |
+
|
741 |
+
def build_alibi_bias(
|
742 |
+
n_heads: int,
|
743 |
+
seq_len: int,
|
744 |
+
full: bool = False,
|
745 |
+
alibi_bias_max: int = 8,
|
746 |
+
device: Optional[torch.device] = None,
|
747 |
+
dtype: Optional[torch.dtype] = None,
|
748 |
+
) -> torch.Tensor:
|
749 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32,
|
750 |
+
device=device).view(1, 1, 1, seq_len)
|
751 |
+
if full:
|
752 |
+
# generate 1 x Heads x SeqLen x SeqLen alibi bias mask
|
753 |
+
# otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
|
754 |
+
alibi_bias = alibi_bias - torch.arange(
|
755 |
+
1 - seq_len, 1, dtype=torch.int32, device=device).view(
|
756 |
+
1, 1, seq_len, 1)
|
757 |
+
alibi_bias = alibi_bias.abs().mul(-1)
|
758 |
+
|
759 |
+
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
760 |
+
alibi_bias = alibi_bias * slopes
|
761 |
+
return alibi_bias.to(dtype=dtype)
|
762 |
+
|
763 |
+
|
764 |
+
ATTN_CLASS_REGISTRY = {
|
765 |
+
'multihead_attention': MultiheadAttention,
|
766 |
+
'multiquery_attention': MultiQueryAttention,
|
767 |
+
'grouped_query_attention': GroupedQueryAttention
|
768 |
+
}
|
Perceptrix/finetune/build/lib/llmfoundry/models/layers/blocks.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""GPT Blocks used for the GPT Model."""
|
5 |
+
|
6 |
+
from typing import Any, Dict, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
|
12 |
+
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
|
13 |
+
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
14 |
+
|
15 |
+
|
16 |
+
class MPTBlock(nn.Module):
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
d_model: int,
|
21 |
+
n_heads: int,
|
22 |
+
expansion_ratio: int,
|
23 |
+
attn_config: Optional[Dict] = None,
|
24 |
+
ffn_config: Optional[Dict] = None,
|
25 |
+
resid_pdrop: float = 0.0,
|
26 |
+
norm_type: str = 'low_precision_layernorm',
|
27 |
+
fc_type: str = 'torch',
|
28 |
+
device: Optional[str] = None,
|
29 |
+
no_bias: bool = False,
|
30 |
+
**kwargs: Any,
|
31 |
+
):
|
32 |
+
if attn_config is None:
|
33 |
+
attn_config = {
|
34 |
+
'attn_type': 'multihead_attention',
|
35 |
+
'attn_pdrop': 0.0,
|
36 |
+
'attn_impl': 'triton',
|
37 |
+
'qk_ln': False,
|
38 |
+
'clip_qkv': None,
|
39 |
+
'softmax_scale': None,
|
40 |
+
'prefix_lm': False,
|
41 |
+
'attn_uses_sequence_id': False,
|
42 |
+
'alibi': False,
|
43 |
+
'alibi_bias_max': 8,
|
44 |
+
}
|
45 |
+
|
46 |
+
if ffn_config is None:
|
47 |
+
ffn_config = {
|
48 |
+
'ffn_type': 'mptmlp',
|
49 |
+
}
|
50 |
+
|
51 |
+
del kwargs # unused, just to capture any extra args from the config
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
55 |
+
assert isinstance(attn_config['attn_type'], str)
|
56 |
+
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
57 |
+
|
58 |
+
# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
|
59 |
+
args_to_exclude_in_attn_class = {
|
60 |
+
'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
|
61 |
+
'alibi_bias_max'
|
62 |
+
}
|
63 |
+
attn_config_subset_for_attn_class = {
|
64 |
+
k: v
|
65 |
+
for k, v in attn_config.items()
|
66 |
+
if k not in args_to_exclude_in_attn_class
|
67 |
+
}
|
68 |
+
|
69 |
+
self.norm_1 = norm_class(d_model, device=device)
|
70 |
+
self.attn = attn_class(
|
71 |
+
d_model=d_model,
|
72 |
+
n_heads=n_heads,
|
73 |
+
fc_type=fc_type,
|
74 |
+
device=device,
|
75 |
+
**attn_config_subset_for_attn_class,
|
76 |
+
bias=not no_bias,
|
77 |
+
)
|
78 |
+
self.norm_2 = None
|
79 |
+
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
|
80 |
+
False):
|
81 |
+
self.norm_2 = norm_class(d_model, device=device)
|
82 |
+
self.ffn = build_ffn(
|
83 |
+
d_model=d_model,
|
84 |
+
expansion_ratio=expansion_ratio,
|
85 |
+
device=device,
|
86 |
+
bias=not no_bias,
|
87 |
+
**ffn_config,
|
88 |
+
)
|
89 |
+
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
90 |
+
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
x: torch.Tensor,
|
95 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
96 |
+
attn_bias: Optional[torch.Tensor] = None,
|
97 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
98 |
+
is_causal: bool = True,
|
99 |
+
output_attentions: bool = False,
|
100 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
|
101 |
+
torch.Tensor, torch.Tensor]]]:
|
102 |
+
a = self.norm_1(x)
|
103 |
+
b, attn_weights, past_key_value = self.attn(
|
104 |
+
a,
|
105 |
+
past_key_value=past_key_value,
|
106 |
+
attn_bias=attn_bias,
|
107 |
+
attention_mask=attention_mask,
|
108 |
+
is_causal=is_causal,
|
109 |
+
needs_weights=output_attentions,
|
110 |
+
)
|
111 |
+
x = x + self.resid_attn_dropout(b)
|
112 |
+
m = x
|
113 |
+
if self.norm_2 is not None:
|
114 |
+
m = self.norm_2(x)
|
115 |
+
n = self.ffn(m)
|
116 |
+
x = x + self.resid_ffn_dropout(n)
|
117 |
+
return x, attn_weights, past_key_value
|
Perceptrix/finetune/build/lib/llmfoundry/models/layers/custom_embedding.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML LLM Foundry authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
|
9 |
+
class SharedEmbedding(nn.Embedding):
|
10 |
+
|
11 |
+
def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
|
12 |
+
if unembed:
|
13 |
+
return F.linear(input, self.weight)
|
14 |
+
return super().forward(input)
|