crystal-technologies commited on
Commit
de4ade4
·
1 Parent(s): 82c3d93

Upload 303 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Perceptrix/__init__.py +2 -0
  2. Perceptrix/chat.py +125 -0
  3. Perceptrix/create_data/interface.py +152 -0
  4. Perceptrix/create_data/static/style.css +154 -0
  5. Perceptrix/create_data/templates/index.html +80 -0
  6. Perceptrix/engine.py +213 -0
  7. Perceptrix/finetune/Dockerfile +13 -0
  8. Perceptrix/finetune/Makefile +23 -0
  9. Perceptrix/finetune/README.md +265 -0
  10. Perceptrix/finetune/build/lib/inference/__init__.py +4 -0
  11. Perceptrix/finetune/build/lib/inference/convert_composer_mpt_to_ft.py +232 -0
  12. Perceptrix/finetune/build/lib/inference/convert_composer_to_hf.py +290 -0
  13. Perceptrix/finetune/build/lib/inference/convert_hf_mpt_to_ft.py +154 -0
  14. Perceptrix/finetune/build/lib/inference/convert_hf_to_onnx.py +229 -0
  15. Perceptrix/finetune/build/lib/inference/hf_chat.py +389 -0
  16. Perceptrix/finetune/build/lib/inference/hf_generate.py +372 -0
  17. Perceptrix/finetune/build/lib/inference/run_mpt_with_ft.py +480 -0
  18. Perceptrix/finetune/build/lib/llmfoundry/__init__.py +71 -0
  19. Perceptrix/finetune/build/lib/llmfoundry/callbacks/__init__.py +31 -0
  20. Perceptrix/finetune/build/lib/llmfoundry/callbacks/eval_gauntlet_callback.py +177 -0
  21. Perceptrix/finetune/build/lib/llmfoundry/callbacks/fdiff_callback.py +67 -0
  22. Perceptrix/finetune/build/lib/llmfoundry/callbacks/generate_callback.py +30 -0
  23. Perceptrix/finetune/build/lib/llmfoundry/callbacks/hf_checkpointer.py +167 -0
  24. Perceptrix/finetune/build/lib/llmfoundry/callbacks/model_gauntlet_callback.py +21 -0
  25. Perceptrix/finetune/build/lib/llmfoundry/callbacks/monolithic_ckpt_callback.py +115 -0
  26. Perceptrix/finetune/build/lib/llmfoundry/callbacks/resumption_callbacks.py +89 -0
  27. Perceptrix/finetune/build/lib/llmfoundry/callbacks/scheduled_gc_callback.py +75 -0
  28. Perceptrix/finetune/build/lib/llmfoundry/data/__init__.py +21 -0
  29. Perceptrix/finetune/build/lib/llmfoundry/data/data.py +117 -0
  30. Perceptrix/finetune/build/lib/llmfoundry/data/denoising.py +937 -0
  31. Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/__init__.py +7 -0
  32. Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/collator.py +343 -0
  33. Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/dataloader.py +516 -0
  34. Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/tasks.py +433 -0
  35. Perceptrix/finetune/build/lib/llmfoundry/data/packing.py +423 -0
  36. Perceptrix/finetune/build/lib/llmfoundry/data/text_data.py +367 -0
  37. Perceptrix/finetune/build/lib/llmfoundry/models/__init__.py +18 -0
  38. Perceptrix/finetune/build/lib/llmfoundry/models/hf/__init__.py +18 -0
  39. Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_causal_lm.py +227 -0
  40. Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_fsdp.py +257 -0
  41. Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_prefix_lm.py +150 -0
  42. Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_t5.py +134 -0
  43. Perceptrix/finetune/build/lib/llmfoundry/models/hf/model_wrapper.py +108 -0
  44. Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/__init__.py +13 -0
  45. Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/interface.py +110 -0
  46. Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +243 -0
  47. Perceptrix/finetune/build/lib/llmfoundry/models/layers/__init__.py +32 -0
  48. Perceptrix/finetune/build/lib/llmfoundry/models/layers/attention.py +768 -0
  49. Perceptrix/finetune/build/lib/llmfoundry/models/layers/blocks.py +117 -0
  50. Perceptrix/finetune/build/lib/llmfoundry/models/layers/custom_embedding.py +14 -0
Perceptrix/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from Perceptrix.engine import robotix, identify_objects_from_text, search_keyword
2
+ # from Perceptrix.chat import perceptrix
Perceptrix/chat.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig
2
+ from Perceptrix.streamer import TextStreamer
3
+ from utils import setup_device
4
+ import torch
5
+ import time
6
+ import os
7
+
8
+ model_name = os.environ.get('CHAT_MODEL')
9
+
10
+ model_path = "models/CRYSTAL-chat" if model_name == None else model_name
11
+ config = AutoConfig.from_pretrained(
12
+ model_path, trust_remote_code=True)
13
+
14
+ device = setup_device()
15
+ device = "mps"
16
+
17
+ bnb_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_compute_dtype=torch.bfloat16,
21
+ )
22
+
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ model_path,
25
+ config=config,
26
+ low_cpu_mem_usage=True,
27
+ trust_remote_code=True,
28
+ device_map="auto",
29
+ torch_dtype=torch.float16,
30
+ # quantization_config=bnb_config,
31
+ offload_folder="offloads",
32
+ )
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(
35
+ model_path,
36
+ trust_remote_code=True,
37
+ )
38
+
39
+ if tokenizer.pad_token_id is None:
40
+ tokenizer.pad_token = tokenizer.eos_token
41
+
42
+ tokenizer.padding_side = "left"
43
+ tokenizer = tokenizer
44
+ model.eval()
45
+
46
+ streamer = TextStreamer(tokenizer, skip_prompt=True,
47
+ skip_special_tokens=True, save_file="reply.txt")
48
+
49
+ def evaluate(
50
+ prompt='',
51
+ temperature=0.4,
52
+ top_p=0.65,
53
+ top_k=35,
54
+ repetition_penalty=1.1,
55
+ max_new_tokens=512,
56
+ **kwargs,
57
+ ):
58
+ inputs = tokenizer(prompt, return_tensors="pt")
59
+ input_ids = inputs["input_ids"].to(device)
60
+ generation_config = GenerationConfig(
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ top_k=top_k,
64
+ repetition_penalty=repetition_penalty,
65
+ **kwargs,
66
+ )
67
+
68
+ with torch.no_grad():
69
+ generation_output = model.generate(
70
+ input_ids=input_ids,
71
+ generation_config=generation_config,
72
+ return_dict_in_generate=True,
73
+ output_scores=True,
74
+ max_new_tokens=max_new_tokens,
75
+ eos_token_id=tokenizer.eos_token_id,
76
+ pad_token_id=tokenizer.pad_token_id,
77
+ streamer=streamer,
78
+ )
79
+ s = generation_output.sequences[0]
80
+ output = tokenizer.decode(s, skip_special_tokens=True)
81
+ yield output.split("### Response:")[-1].strip()
82
+
83
+
84
+ def predict(
85
+ inputs,
86
+ temperature=0.4,
87
+ top_p=0.65,
88
+ top_k=35,
89
+ repetition_penalty=1.1,
90
+ max_new_tokens=512,
91
+ ):
92
+ now_prompt = inputs
93
+
94
+ response = evaluate(
95
+ now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True
96
+ )
97
+
98
+ for i in response:
99
+ print(i)
100
+ response = i
101
+
102
+ return response
103
+
104
+
105
+ instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user."
106
+
107
+ def perceptrix(prompt):
108
+ prompt = instructions+"\n"+prompt
109
+ response = predict(
110
+ inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512
111
+ )
112
+ spl_tokens = ["<|im_start|>", "<|im_end|>"]
113
+ clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
114
+ return response[len(clean_prompt):]
115
+
116
+
117
+ if __name__ == "__main__":
118
+ history = ""
119
+ while True:
120
+ user_input = input("User: ")
121
+ start = time.time()
122
+ user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n"
123
+ result = perceptrix(history+user_input)
124
+ history += user_input + result + "<|im_end|>\n"
125
+ print("Answer completed in ~", round(time.time()-start), "seconds")
Perceptrix/create_data/interface.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ import random
3
+ import json
4
+
5
+ app = Flask(__name__)
6
+
7
+ data_file_path = "Perceptrix/finetune/finetune-data/crystal-finetune.json"
8
+
9
+ with open(data_file_path, 'r') as data_file:
10
+ data = json.loads(data_file.read())
11
+ file_data = data_file.read()
12
+
13
+
14
+ room_descriptions = ["The surroundings include a bed, a chair, and a dog. The bed is made up with a white blanket, and there are two people sitting on it, likely a man and a woman. The chair is positioned next to the bed, and the dog is sitting on the bed as well. The room appears to be a bedroom, and the atmosphere seems to be cozy and comfortable.",
15
+ "In the image, there is a living room with a couch, a chair, a coffee table, and a window. The room is well-decorated and filled with furniture, including a bed, a desk, and a dining table. The living room is situated next to the bedroom, and there is a window in the living room. The overall atmosphere of the room is cozy and inviting.",
16
+ "The surroundings include a living room with a couch, a chair, and a coffee table. There is also a television in the room.",
17
+ "In the image, there is a group of people sitting in chairs, likely in a classroom or a meeting room. They are engaged in a discussion or a presentation, with some of them looking at a screen.",
18
+ "The surroundings include a living room with a couch, a chair, and a window. The room is well-lit, and there are several potted plants in the space.",
19
+ "In the image, a woman is lying on a bed surrounded by a variety of stuffed animals. There are at least ten stuffed animals of different colors and sizes, including teddy bears, dolls, and other toys. The scene appears to be cozy and comfortable, with the woman resting peacefully in her bed.",
20
+ "The surroundings include a living room with a yellow couch, a table, and a potted plant. There are three people sitting on the couch, and a laptop is placed on the table.",
21
+ "The surroundings include a living room with a couch, a coffee table, and a TV. The room is filled with people, and they are sitting on the couch, engaging in a conversation.",
22
+ "The surroundings include a living room with a couch, a chair, and a window. There is also a woman standing in the room, possibly near the window. The room appears to be clean and well-maintained.",
23
+ "The surroundings include a living room with a bed, a couch, a chair, and a TV. There are also various items scattered around, such as a book, a bottle, and a cup. The room appears to be messy and disorganized, with some items like the book and bottle being placed on the floor.",
24
+ "In the image, there is a bedroom with a bed, a chair, and a window. A woman is sitting on the bed, and a dog is nearby. The woman is wearing a white shirt and appears to be engaged in a conversation with the AI assistant. The room appears to be clean and well-organized.",
25
+ "I am in a living room, sitting on a couch, and using a laptop.",
26
+ "The surroundings include a couch, a chair, a window, and a potted plant. There is also a person sitting on the couch, and a baby is laying on the person's lap.",
27
+ "In the image, there is a group of people standing in a living room, with a bed and a couch visible in the background. The room appears to be clean and well-organized.",
28
+ "In the image, there is a living room with a white couch, a chair, and a window. The room is well-lit and appears to be clean and organized.",
29
+ "The surroundings consist of a large, empty room with a hardwood floor. There is a man sitting on the floor, possibly in a corner or a cubicle, and he is holding a remote control.",
30
+ "The surroundings include a living room with a woman standing in front of a door, which is open. The room appears to be dimly lit, creating a somewhat dark atmosphere.",
31
+ "The surroundings include a living room with a couch, a chair, and a TV. There are three people sitting on the couch, and a baby is present. The living room appears to be a comfortable and cozy space for the family to spend time together.",
32
+ "In the image, there is a person sitting on a bed in a bedroom. The bed is surrounded by a colorful blanket, and there is a laptop on the bed. The room appears to be a small bedroom, and the bed is positioned near a window.",
33
+ "The surroundings include a bedroom with a bed, a nightstand, and a window. The bed is neatly made and has a white and gray color scheme. There are also potted plants in the room, adding a touch of greenery and a sense of freshness.",
34
+ "In the image, there is a large bedroom with a bed, a nightstand, and a window. The room is clean and well-organized, with a white color scheme and a minimalist design. The bed is neatly made, and there are pillows on it. The room also has a chair and a potted plant, adding a touch of warmth and natural elements to the space. The window provides natural light, and the room appears to be well-lit and inviting.",
35
+ "In the image, there is a living room with a couch, a chair, and a coffee table. The room is well-decorated and features a dark color scheme, with a black couch and a black coffee table. There is also a potted plant in the room, adding a touch of greenery to the space. The living room is well-lit, and there are several books scattered around the room, suggesting that the occupants enjoy reading.",
36
+ "In the image, there is a living room with a large window, a couch, and a chair. The room is filled with furniture, including a coffee table, a dining table, and a potted plant. The living room has a modern and clean design, with a white color scheme. The large window allows for natural light to enter the room, creating a bright and inviting atmosphere.",
37
+ "In the image, there is a bedroom with a large bed, a nightstand, and a window. The room is well-lit and clean, creating a comfortable and inviting atmosphere.",
38
+ "The surroundings include a living room with a couch, a coffee table, and a television. The room is filled with various items, such as books, a vase, and a potted plant. The living room is well-lit, and there is a window in the room. Additionally, there is a dining table and chairs, which suggests that the living room and dining area are combined.",
39
+ "In the image, I am surrounded by a living room with a piano, a couch, and a chair. The living room has a modern design, and the furniture is arranged in a way that creates a comfortable and inviting atmosphere.",
40
+ "The surroundings include a living room with a couch, a coffee table, and a vase. The living room is well-decorated and has a clean and organized appearance.",
41
+ "The surroundings include a large bedroom with a white color scheme, a bed with a white comforter, and a window. There is also a ceiling fan, which is a white fan, and a chair in the room. The room appears to be clean and well-maintained.",
42
+ "In the image, there is a green couch, a green chair, and a green ottoman in a living room. The room is filled with books, suggesting that it is a cozy and well-read space.",
43
+ "The surroundings include a large bedroom with a large bed, a chair, and a desk. The room also has a window, a lamp, and a mirror. The bed is neatly made, and there are pillows on it. The room is well-lit, with a lamp providing illumination.",
44
+ "In the image, there is a living room with a couch, a chair, and a table. The room has a modern design, featuring a large window and a chandelier. The living room is filled with furniture, including a couch, a chair, and a table. The room also has a potted plant, which adds a touch of greenery and a sense of freshness to the space. The living room is well-lit, with the large window providing ample natural light, and the chandelier adding a touch of elegance and sophistication.",
45
+ "I am in a living room with a fireplace, a couch, a chair, a dining table, and a potted plant. The room is filled with furniture and decorations, creating a cozy and inviting atmosphere.",
46
+ "In the image, there is a living room with a couch, two chairs, and a coffee table. The living room is well-lit and has a view of the city, which adds to the ambiance of the space. The room also features a potted plant and a vase, adding a touch of greenery and decoration to the area.",
47
+ "In the image, there is a neatly made bed in a bedroom, with a white comforter and a red blanket. The bed is situated next to a window, which allows natural light to enter the room. The room also has a nightstand with a lamp, providing additional lighting. The overall atmosphere of the room is clean and inviting.",
48
+ "In the image, I am surrounded by a large, clean living room with white walls, a fireplace, and a comfortable couch. There are also several chairs and a dining table in the room. The space is well-lit, and the furniture is arranged to create a cozy and inviting atmosphere.",
49
+ "The surroundings in the image include a living room with a couch, chairs, and a coffee table. The living room is filled with furniture, and there are multiple lamps and potted plants scattered throughout the space. The room also has a window, which allows natural light to enter the room.",
50
+ "The surroundings include a living room with a couch, a coffee table, and a lamp. The room also has a large window, which allows for natural light to enter. There are several chairs and a dining table in the room, suggesting that it is a multi-purpose space for relaxation and dining. The living room is well-lit and furnished with comfortable seating options, creating a welcoming atmosphere for people to gather and socialize.",
51
+ "In the image, I am surrounded by a living room filled with furniture, including a couch, chairs, and a coffee table. The living room is well-decorated, and there are several books and a vase present. The room also features a rug, which adds to the overall aesthetic and comfort of the space.",
52
+ "The surroundings include a messy room with a bed, a desk, and a chair. The room is filled with clothes, shoes, and other items, creating a cluttered and disorganized space.",
53
+ "The surroundings in the image include a cluttered room with a desk, a bed, and various items scattered around. The room appears to be messy and disorganized, with clothes and other belongings scattered on the floor.",
54
+ "The surroundings include a group of people sitting on a bed, with a laptop and a cell phone visible. The room appears to be a bedroom, and the individuals are engaged in a conversation.",
55
+ "I am in a living room, surrounded by several people sitting on a couch. They are all engaged in various activities, such as watching TV, using their cell phones, and possibly playing video games. The room is filled with furniture, including a couch, chairs, and a TV. The atmosphere appears to be casual and relaxed, with the people enjoying their time together in the living room.",
56
+ "The surroundings include a group of people sitting on a couch, with a wooden table in the background. The room appears to be a living room, and there is a window nearby.",
57
+ "In the image, there are several people sitting on a couch, using their cell phones. The couch is located in a living room, and the people are engaged in various activities on their devices.",
58
+ "The surroundings include a living room with a fireplace, where a group of people is sitting on couches and chairs. There are multiple books scattered around the room, suggesting that the individuals might be engaged in reading or studying. The room also has a dining table and a potted plant, which adds to the cozy atmosphere of the space.",
59
+ "The surroundings include a living room with a couch, a dining table, and a pizza on the table. The people in the room are sitting and enjoying their meal together.",
60
+ "In the image, there are four people sitting on a bed, with two of them facing the camera. They are all wearing blue shirts and are engaged in a conversation. The scene takes place in a bedroom, which is a comfortable and familiar setting for the group.",
61
+ "The surroundings include a group of people sitting on a couch, with some of them holding pizza boxes. The room appears to be a living room, and there is a book nearby.",
62
+ "In the image, there are four people sitting on a couch in a living room. They are all engaged in using their cell phones, with one of them holding a book. The room has a wooden floor and a table, and there are chairs nearby. The atmosphere appears to be casual and relaxed, with the group of friends enjoying their time together while using their devices.",
63
+ "I am in a living room, which is filled with furniture such as a couch, a chair, and a table. The room is well-lit and appears to be a comfortable space for relaxation and socializing.",
64
+ "The surroundings include a living room with a couch, a table, and a chair. There are people sitting on the couch and a man standing in the room.",
65
+ "The surroundings include a group of people sitting on a couch, likely in a living room or a similar space. They are engaged in a conversation or enjoying each other's company.",
66
+ "The surroundings include a living room with a couch, a chair, and a table. There are several people sitting on the couch and chairs, engaging in conversation and enjoying each other's company. The room appears to be well-lit, creating a comfortable atmosphere for socializing.",
67
+ "In the image, there is a group of people sitting around a dining table in a room. The table is covered with various items, including cups, bowls, and a vase. The room appears to be a living room or a dining area, with a couch and chairs nearby. The scene suggests a casual and comfortable setting where people are gathered for a meal or a social event.",
68
+ "The surroundings include a living room with a couch, a dining table, and a TV. The room is filled with people, some of whom are sitting on the couch, while others are standing around the table. There are also several bottles and cups, which might be used for drinking. The atmosphere appears to be relaxed and social, with people enjoying each other's company and engaging in conversations.",
69
+ "The surroundings include a living room with a couch, a coffee table, and a window. There are three people sitting on the couch, engaging in a conversation.",
70
+ "The surroundings include a living room with a couch, a table, and a couple of chairs. There are also several bottles and cups on the table, suggesting that the room is set up for a casual gathering or a social event.",
71
+ "The surroundings include a living room with a couch, chairs, and a coffee table. There are also books scattered around the room, suggesting that the group of people might be engaged in a discussion or reading. The room appears to be cozy and comfortable, with a relaxed atmosphere.",
72
+ "The surroundings include a living room with a couch, a television, and a group of people sitting together."]
73
+
74
+ cities = ["Wake Forest, NC", "Rocky Mount, NC", "San Francisco, CA", "New York City, NY", "Trenton, NJ", "Philadelphia, PA",
75
+ "Vikas Puri, New Delhi", "Jiugong, Beijing", "Les Halles, Paris", "Diemen, Amsterdam", "Al Shamkhah, Abu Dhabi",
76
+ "Cairo", "Idore, Madhya Pradesh", "Bangalore, Karnataka", "Toronoto, Ontario", "Brixton, London", "Charlotte, NC",
77
+ "Los Angeles, CA", "Las Vegas, NV", "Cupertino, CA", "Silicon Valley, CA", "Sham Shui Po, Hong Kong", "Danilovsky District, Moscow",
78
+ "Rochester, NY", "Manhattan, NY"]
79
+ weather_names = ["Sunny", "Rainy", "Windy", "Cloudy",
80
+ "Mostly Cloudy", "Partly Cloudy", "Light Rain", "Sleet"]
81
+ days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
82
+ months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun",
83
+ "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
84
+
85
+ dummy_time = f"{random.choice(days)} {random.randint(1, 30)} {random.choice(months)} {random.randint(2020, 2024)} 0{random.randint(1, 9)}:{random.randint(0, 59)} {random.choice(['AM', 'PM'])}"
86
+ dummy_weather = f"{random.choice(cities)} is {random.choice(weather_names)} with {random.randint(60, 95)}°F and Precipitation: {random.randint(0, 100)}%, Humidity: {random.randint(0, 100)}%, Wind: {random.randint(1, 15)} mph"
87
+ dummy_current_events = random.choice(room_descriptions)
88
+
89
+
90
+ @app.route("/")
91
+ def home():
92
+ vqa = file_data.count("> VQA")
93
+ robot = file_data.count("> VQA")
94
+ internet = file_data.count("> VQA")
95
+ cli = file_data.count("> VQA")
96
+ note = file_data.count("> VQA")
97
+ home_automation = file_data.count("> VQA")
98
+ header = f"""VQA: {vqa}
99
+ ROBOT: {robot}
100
+ INTERNET: {internet}
101
+ CLI: {cli}
102
+ NOTE: {note}
103
+ HOME AUTOMATION: {home_automation}
104
+
105
+ """
106
+
107
+ return render_template("index.html", full_data=header + json.dumps(data, indent=4), data_entries=len(data), dummy_time=dummy_time, dummy_weather=dummy_weather, dummy_current_events=dummy_current_events[:-1])
108
+
109
+
110
+ @app.route('/record', methods=['GET', 'POST'])
111
+ def record():
112
+ if request.method == "POST":
113
+ with open(data_file_path, 'r') as data:
114
+ data = json.loads(data.read())
115
+ dummy_time = f"{random.choice(days)} {random.randint(1, 30)} {random.choice(months)} {random.randint(2020, 2024)} 0{random.randint(1, 9)}:{random.randint(0, 59)} {random.choice(['AM', 'PM'])}"
116
+ dummy_weather = f"{random.choice(cities)} is {random.choice(weather_names)} with {random.randint(60, 95)}°F and Precipitation: {random.randint(0, 100)}%, Humidity: {random.randint(0, 100)}%, Wind: {random.randint(1, 15)} mph"
117
+ dummy_current_events = random.choice(room_descriptions)
118
+
119
+ entry = request.form["current-data-preview"]
120
+ input_field = request.form["input"]
121
+ entry = {
122
+ "prompt": entry.split(entry.split(input_field)[-1])[0],
123
+ "response": entry.split(input_field)[-1][2:],
124
+ }
125
+ data.append(entry)
126
+ with open(data_file_path, 'w+') as file:
127
+ file.write(str(json.dumps(data, indent=4)).replace("\r\n", "\n"))
128
+
129
+ with open(data_file_path, "r") as data_file:
130
+ file_data = data_file.read()
131
+
132
+ vqa = file_data.count("> VQA")
133
+ robot = file_data.count("> Robot")
134
+ internet = file_data.count("> Internet")
135
+ cli = file_data.count("> CLI")
136
+ note = file_data.count("> Note")
137
+ home_automation = file_data.count("> Home Automation")
138
+
139
+ header = f"""VQA: {vqa}
140
+ ROBOT: {robot}
141
+ INTERNET: {internet}
142
+ CLI: {cli}
143
+ NOTE: {note}
144
+ HOME AUTOMATION: {home_automation}
145
+
146
+ """
147
+
148
+ return render_template("index.html", full_data=header + json.dumps(data, indent=4), data_entries=len(data), dummy_time=dummy_time, dummy_weather=dummy_weather, dummy_current_events=dummy_current_events[:-1])
149
+
150
+
151
+ if __name__ == "__main__":
152
+ app.run(host="0.0.0.0")
Perceptrix/create_data/static/style.css ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.cdnfonts.com/css/bitsumishi');
2
+
3
+ h1 {
4
+ font-family: "Bitsumishi", sans-serif;
5
+ font-size: 50px;
6
+ color: rgb(255, 255, 255);
7
+ margin: 10px;
8
+ letter-spacing: 3px;
9
+ }
10
+
11
+ body {
12
+ text-align: center;
13
+ background-color: rgb(0, 0, 0);
14
+ display: flex;
15
+ flex-direction: column;
16
+ justify-content: center;
17
+ margin: auto;
18
+ margin-top: 20px;
19
+ overflow: auto;
20
+ font-family: "Bitsumishi", sans-serif;
21
+ }
22
+
23
+ .entries {
24
+ display: flex;
25
+ flex-direction: column;
26
+ align-items: center;
27
+ justify-content: space-evenly;
28
+ height: 90vh;
29
+ width: fit-content;
30
+ }
31
+
32
+ .entry {
33
+ width: 20vw;
34
+ height: 25vh;
35
+ border: 2px solid rgb(32, 161, 236);
36
+ font-family: "Bitsumishi", sans-serif;
37
+ background-color: transparent;
38
+ border-top-left-radius: 30px;
39
+ border-bottom-right-radius: 30px;
40
+ padding: 25px;
41
+ resize: none;
42
+ outline: none;
43
+ color: white;
44
+ font-size: medium;
45
+ letter-spacing: 1px;
46
+ }
47
+
48
+ .main {
49
+ display: flex;
50
+ width: 90vw;
51
+ justify-content: space-between;
52
+ align-items: center;
53
+ margin: auto;
54
+ }
55
+
56
+ #submit {
57
+ background-color: transparent;
58
+ font-family: "Bitsumishi", sans-serif;
59
+ height: fit-content;
60
+ padding: 10px;
61
+ border: 2px solid rgb(32, 161, 236);
62
+ color: rgb(32, 161, 236);
63
+ font-size: medium;
64
+ cursor: pointer;
65
+ border-top-left-radius: 10px;
66
+ border-bottom-right-radius: 10px;
67
+ box-shadow: 0px 0px 100px 5px rgb(0, 0, 0) inset;
68
+ transition: 0.5s;
69
+ margin-top: 10px;
70
+ width: 50%;
71
+ }
72
+
73
+ #submit:hover {
74
+ box-shadow: 0px 0px 20px 5px rgb(33, 77, 255);
75
+ transition: 0.5s;
76
+ }
77
+
78
+ .other-inputs {
79
+ padding: 10px;
80
+ outline: none;
81
+ border: 2px solid rgb(32, 161, 236);
82
+ border-radius: 10px;
83
+ font-size: medium;
84
+ background-color: transparent;
85
+ font-family: "Bitsumishi", sans-serif;
86
+ width: 100%;
87
+ color: white;
88
+ }
89
+
90
+ .control {
91
+ display: flex;
92
+ flex-direction: column;
93
+ justify-content: space-evenly;
94
+ text-align: left;
95
+ margin-left: 100px;
96
+ }
97
+
98
+ .data-preview {
99
+ margin-top: 25px;
100
+ }
101
+
102
+ .data{
103
+ background-color:rgb(22, 22, 22);
104
+ padding: 25px;
105
+ width: 20vw;
106
+ height: 15vw;
107
+ color: rgb(255, 255, 255);
108
+ border-radius: 10px;
109
+ overflow: scroll;
110
+ }
111
+
112
+ .other-fields{
113
+ display: flex;
114
+ flex-direction: column;
115
+ width: 20%;
116
+ justify-content: space-around;
117
+ height: 32vh;
118
+ }
119
+
120
+ .all-inputs{
121
+ display: flex;
122
+ width: 63vw;
123
+ justify-content: space-between;
124
+ font-weight: 100;
125
+ }
126
+
127
+ .center-input{
128
+ padding: 10px;
129
+ outline: none;
130
+ border: 2px solid rgb(32, 161, 236);
131
+ border-radius: 10px;
132
+ font-size: medium;
133
+ background-color: transparent;
134
+ font-family: "Bitsumishi", sans-serif;
135
+ width: 45%;
136
+ color: white;
137
+ }
138
+
139
+ .center-inputs{
140
+ display: flex;
141
+ width: 63vw;
142
+ justify-content: space-between;
143
+ }
144
+
145
+ #output{
146
+ width: 60vw
147
+ }
148
+
149
+ #current-data-preview{
150
+ outline: none;
151
+ border: none;
152
+ resize: none;
153
+ font-size: 14px;
154
+ }
Perceptrix/create_data/templates/index.html ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Fine-tune Data CRYSTAL</title>
8
+ <link rel="stylesheet" href="{{url_for('static',filename='style.css')}}">
9
+ </head>
10
+
11
+ <body>
12
+ <h1>FINETUNE CRYSTAL</h1>
13
+ <form action="/record" method="post">
14
+ <div class="main">
15
+ <div class="entries">
16
+ <div class="all-inputs">
17
+ <textarea class="entry" name="notes" id="notes" placeholder="Enter Additional Notes Here:" oninput="updatePreview()">Notes- </textarea>
18
+ <textarea class="entry" name="input" id="input" placeholder="Enter Input Here:" oninput="updatePreview()"></textarea>
19
+ <div class="other-fields">
20
+ <input type="text" class="other-inputs" name="user" id="user" placeholder="User Name" oninput="updatePreview()">
21
+ <input type="text" class="other-inputs" name="time" id="time" placeholder="Time" oninput="updatePreview()">
22
+ <input type="text" class="other-inputs" name="weather" id="weather" placeholder="Weather" oninput="updatePreview()">
23
+ <input type="text" class="other-inputs" name="action" id="action" placeholder="Action" oninput="updatePreview()">
24
+ </div>
25
+ </div>
26
+ <div class="center-inputs">
27
+ <input type="text" class="center-input" name="current-events" id="current-events" placeholder="Current Events" oninput="updatePreview()">
28
+ <input type="text" class="center-input" name="speak" id="speak" placeholder="Speak" oninput="updatePreview()">
29
+ </div>
30
+ <textarea class="entry" name="output" id="output" placeholder="Enter Output Here:" oninput="updatePreview()"></textarea>
31
+ </div>
32
+ <div class="control">
33
+ <input id="submit" type="submit" value="ADD TO DATABASE">
34
+ <div class="data-preview">
35
+ <label style="color: white; font-size: x-large;" for="preview">Data Preview</label>
36
+ <textarea id="current-data-preview" name="current-data-preview" class="data"></textarea>
37
+ <p style="color: rgb(143, 143, 143); font-size: medium; margin: 5px;"> Data Entries: {{data_entries}}</p>
38
+ <pre class="data">{{full_data}}</pre>
39
+ </div>
40
+ </div>
41
+ </div>
42
+ </form>
43
+ <script>
44
+ var user = document.querySelector('#user');
45
+ var time = document.querySelector('#time');
46
+ var weather = document.querySelector('#weather');
47
+ var action = document.querySelector('#action');
48
+ var current_events = document.querySelector('#current-events');
49
+ var input = document.querySelector('#input');
50
+ var output = document.querySelector('#output');
51
+ var speak = document.querySelector('#speak');
52
+ var notes = document.querySelector('#notes');
53
+ var preview = document.querySelector('#current-data-preview');
54
+
55
+ time.value = "{{dummy_time}}";
56
+ weather.value = "{{dummy_weather}}";
57
+ current_events.value = "{{dummy_current_events}}";
58
+
59
+ function updatePreview() {
60
+ format = "Time- {time}\nWeather- {weather}\nSurroundings- {current_events}\n{notes}\n{user}: {input}\nCRYSTAL:<###CRYSTAL-INTERNAL###> Speak\n{speak}\n<###CRYSTAL-INTERNAL###> {action}\n{output}"
61
+ var formattedText = format.replace('{time}', time.value)
62
+ .replace('{weather}', weather.value)
63
+ .replace('{current_events}', current_events.value)
64
+ .replace('{user}', user.value)
65
+ .replace('{input}', input.value)
66
+ .replace('{action}', action.value)
67
+ .replace('{speak}', speak.value)
68
+ .replace('{notes}', notes.value)
69
+ .replace('{output}', output.value);
70
+
71
+ preview.textContent = formattedText;
72
+ }
73
+
74
+ // Initial preview update
75
+ updatePreview();
76
+
77
+ </script>
78
+ </body>
79
+
80
+ </html>
Perceptrix/engine.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, BitsAndBytesConfig
2
+ from Perceptrix.streamer import TextStreamer
3
+ from utils import setup_device
4
+ import torch
5
+ import tqdm
6
+ import sys
7
+ import os
8
+
9
+ model_name = os.environ.get('LLM_MODEL')
10
+
11
+ model_id = "models/CRYSTAL-instruct" if model_name == None else model_name
12
+
13
+ device = setup_device()
14
+
15
+ bnb_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_use_double_quant=True,
18
+ bnb_4bit_compute_dtype=torch.bfloat16
19
+ )
20
+
21
+ tokenizer = LlamaTokenizer.from_pretrained(
22
+ model_id,
23
+ use_fast=True)
24
+
25
+ model = LlamaForCausalLM.from_pretrained(
26
+ model_id,
27
+ load_in_8bit=False,
28
+ device_map="auto",
29
+ torch_dtype=torch.float16,
30
+ low_cpu_mem_usage=True,
31
+ offload_folder="offload",
32
+ quantization_config=bnb_config,
33
+ )
34
+
35
+ streamer = TextStreamer(tokenizer, skip_prompt=True,
36
+ skip_special_tokens=True, save_file="reply.txt")
37
+
38
+ PROMPT = '''### Instruction:
39
+ {}
40
+ ### Input:
41
+ {}
42
+ ### Response:'''
43
+
44
+ model.config.pad_token_id = tokenizer.pad_token_id = 0
45
+ model.config.bos_token_id = 1
46
+ model.config.eos_token_id = 2
47
+
48
+ model.eval()
49
+ if torch.__version__ >= "2" and sys.platform != "win32":
50
+ model = torch.compile(model)
51
+
52
+
53
+ def evaluate(
54
+ prompt='',
55
+ temperature=0.4,
56
+ top_p=0.65,
57
+ top_k=35,
58
+ repetition_penalty=1.1,
59
+ max_new_tokens=512,
60
+ **kwargs,
61
+ ):
62
+ inputs = tokenizer(prompt, return_tensors="pt")
63
+ input_ids = inputs["input_ids"].to(device)
64
+ generation_config = GenerationConfig(
65
+ temperature=temperature,
66
+ top_p=top_p,
67
+ top_k=top_k,
68
+ repetition_penalty=repetition_penalty,
69
+ **kwargs,
70
+ )
71
+
72
+ with torch.no_grad():
73
+ generation_output = model.generate(
74
+ input_ids=input_ids,
75
+ generation_config=generation_config,
76
+ return_dict_in_generate=True,
77
+ output_scores=True,
78
+ max_new_tokens=max_new_tokens,
79
+ streamer=streamer,
80
+ )
81
+ s = generation_output.sequences[0]
82
+ output = tokenizer.decode(s)
83
+ yield output.split("### Response:")[-1].strip()
84
+
85
+
86
+
87
+ def run_instruction(
88
+ instruction,
89
+ inputs,
90
+ temperature=0.4,
91
+ top_p=0.65,
92
+ top_k=35,
93
+ repetition_penalty=1.1,
94
+ max_new_tokens=512,
95
+ stop_tokens=None,
96
+ ):
97
+ now_prompt = PROMPT.format(instruction+'\n', inputs)
98
+
99
+ response = evaluate(
100
+ now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, stop_tokens=stop_tokens, do_sample=True
101
+ )
102
+
103
+ for i in response:
104
+ yield i
105
+
106
+ def search_keyword(prompt):
107
+ instructions = """Prompt:Time: Fri, 23 August 2023 2:30PM\nWeather: 73F\nHow many friends have I told you about?
108
+ Search Keyword:Friends
109
+ Prompt:Time: Thu, 27 September 2023 3:41PM\nWeather: 62F\nWhat was our very first conversation
110
+ Chat Index:0
111
+ Prompt:Time: Tue, 21 September 2023 2:30PM\nWeather: 67F\nWhat was the last thing I said to you
112
+ Chat Index:-1
113
+ Prompt:Time: Sun, 31 October 2023 7:33AM\nWeather: 59F\nWhat was the last thing I said to you before that
114
+ Chat Index:-2
115
+ Prompt:Time: Sat, 30 October 2023 8:21PM\nWeather: 65F\nDid I ever tell you about my math class?
116
+ Search Keyword:math
117
+ Prompt:Time: Mon, 13 November 2023 4:52PM\nWeather: 55F\nWhat was my 7th grade English teacher's name?
118
+ Search Keyword:English
119
+ Prompt:Time: Wed, 15 May 2023 6:19PM\nWeather: 80F\nWhere did I say my wallet was?
120
+ Search Keyword:Wallet
121
+ Prompt:Time: Fri, 24 June 2023 1:52PM\nWeather: 92F\nWhat did Alex tell you?
122
+ Search Keyword:Alex
123
+ Prompt:Time: Sat, 19 July 2023 2:44PM\nWeather: 91F\nWhat was my first conversation today
124
+ Search Keyword:24 June"""
125
+ answer = ''.join(run_instruction(
126
+ instructions,
127
+ "Prompt:"+prompt+"\n",
128
+ temperature=0.5,
129
+ top_p=0.5,
130
+ top_k=200,
131
+ repetition_penalty=1.1,
132
+ max_new_tokens=256,
133
+ ))
134
+ return answer
135
+
136
+
137
+ def identify_objects_from_text(prompt):
138
+ instructions = """Input:The object that flies in the air from this picture is a toy helicopter
139
+ Output:Toy helicopter
140
+ Input:For the robot to be able to achieve the task, the robot needs to look for a white shirt
141
+ Output:White shirt
142
+ Input:To complete the task, the robot needs to remove the fruits from the wooden basket.
143
+ Output:fruits, wooden basket
144
+ Input:To clean up your desk, you need to gather and organize the various items scattered around it. This includes the laptop, cell phone, scissors, pens, and other objects. By putting these items back in their designated spaces or containers, you can create a more organized and clutter-free workspace.
145
+ Output:Laptop, cell phone, scissors, pens, containers
146
+ Input:The tree with a colorful sky background is the one to be looking for.
147
+ Output:Tree"""
148
+ answer = ''.join(run_instruction(
149
+ instructions,
150
+ prompt,
151
+ temperature=0.5,
152
+ top_p=0.5,
153
+ top_k=200,
154
+ repetition_penalty=1.1,
155
+ max_new_tokens=256,
156
+ ))
157
+ return answer
158
+
159
+
160
+
161
+ def robotix(prompt, stop=None):
162
+ instructions = """#Get me some water
163
+ objects = [['water: 57%', (781, 592)]]
164
+ robot.target((781, 592))
165
+ object_distance = distance()
166
+ if object_distance > 10:
167
+ robot.go("forward", object_distance, track="water")
168
+ robot.grab()
169
+ if object_distance > 10:
170
+ robot.go("back", object_distance)
171
+ robot.release("here")
172
+ #Stand by the table
173
+ objects = [['table: 81%', (1489, 1173)], ['table: 75%', (1971, 1293)]]
174
+ robot.target((1489, 1173))
175
+ if distance() > 10:
176
+ robot.go(forward, distance())
177
+ #Put the apples in the basket
178
+ objects = [['basket: 77%', (89, 112)], ['apples: 72%', (222, 182)]]
179
+ robot.target((281, 189))
180
+ if distance() > 10:
181
+ robot.go("forward", distance(), track="apples")
182
+ robot.grab()
183
+ robot.target(robot.find("basket"))
184
+ robot.release(distance())
185
+ #Go to the sofa
186
+ objects=[['sofa: 81%', (1060, 931)]]
187
+ robot.target((1060, 931))
188
+ if distance() > 10:
189
+ robot.go("forward", distance())
190
+ #Go to that person over there and then come back
191
+ objects=[['person: 85%', (331, 354)]]
192
+ robot.target((331, 354))
193
+ object_distance = distance()
194
+ if object_distance > 10:
195
+ robot.go("forward", object_distance)
196
+ robot.go("backward", object_distance)
197
+ """
198
+
199
+ answer = ''.join(run_instruction(
200
+ instructions,
201
+ prompt,
202
+ temperature=0.2,
203
+ top_p=0.5,
204
+ top_k=300,
205
+ repetition_penalty=1.1,
206
+ max_new_tokens=256,
207
+ stop_tokens=stop,
208
+ ))
209
+ return answer
210
+
211
+
212
+ if __name__ == "__main__":
213
+ print(robotix("#Get me a glass of water\nobjects = [['water: 65%', (695, 234)]]"))
Perceptrix/finetune/Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ ARG BASE_IMAGE
5
+ FROM $BASE_IMAGE
6
+
7
+ ARG DEP_GROUPS
8
+
9
+ # Install and uninstall foundry to cache foundry requirements
10
+ RUN git clone -b main https://github.com/mosaicml/llm-foundry.git
11
+ RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}"
12
+ RUN pip uninstall -y llm-foundry
13
+ RUN rm -rf llm-foundry
Perceptrix/finetune/Makefile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # several pytest settings
2
+ WORLD_SIZE ?= 1 # world size for launcher tests
3
+ MASTER_PORT ?= 26000 # port for distributed tests
4
+ PYTHON ?= python3 # Python command
5
+ PYTEST ?= pytest # Pytest command
6
+ PYRIGHT ?= pyright # Pyright command. Pyright must be installed seperately -- e.g. `node install -g pyright`
7
+ EXTRA_ARGS ?= # extra arguments for pytest
8
+ EXTRA_LAUNCHER_ARGS ?= # extra arguments for the composer cli launcher
9
+
10
+ test:
11
+ LOCAL_WORLD_SIZE=1 $(PYTHON) -m $(PYTEST) $(EXTRA_ARGS)
12
+
13
+ test-gpu:
14
+ LOCAL_WORLD_SIZE=1 $(PYTHON) -m $(PYTEST) -m gpu $(EXTRA_ARGS)
15
+
16
+ # runs tests with the launcher
17
+ test-dist:
18
+ $(PYTHON) -m composer.cli.launcher -n $(WORLD_SIZE) --master_port $(MASTER_PORT) $(EXTRA_LAUNCHER_ARGS) -m $(PYTEST) $(EXTRA_ARGS)
19
+
20
+ test-dist-gpu:
21
+ $(PYTHON) -m composer.cli.launcher -n $(WORLD_SIZE) --master_port $(MASTER_PORT) $(EXTRA_LAUNCHER_ARGS) -m $(PYTEST) -m gpu $(EXTRA_ARGS)
22
+
23
+ .PHONY: test test-gpu test-dist test-dist-gpu
Perceptrix/finetune/README.md ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN -->
2
+ <p align="center">
3
+ <a href="https://github.com/mosaicml/llm-foundry">
4
+ <picture>
5
+ <img alt="LLM Foundry" src="./assets/llm-foundry.png" width="95%">
6
+ </picture>
7
+ </a>
8
+ </p>
9
+ <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END -->
10
+
11
+ <p align="center">
12
+ <a href="https://pypi.org/project/llm-foundry/">
13
+ <img alt="PyPi Version" src="https://img.shields.io/pypi/pyversions/llm-foundry">
14
+ </a>
15
+ <a href="https://pypi.org/project/llm-foundry/">
16
+ <img alt="PyPi Package Version" src="https://img.shields.io/pypi/v/llm-foundry">
17
+ </a>
18
+ <a href="https://mosaicml.me/slack">
19
+ <img alt="Chat @ Slack" src="https://img.shields.io/badge/slack-chat-2eb67d.svg?logo=slack">
20
+ </a>
21
+ <a href="https://github.com/mosaicml/llm-foundry/blob/main/LICENSE">
22
+ <img alt="License" src="https://img.shields.io/badge/License-Apache%202.0-green.svg">
23
+ </a>
24
+ </p>
25
+ <br />
26
+
27
+ # LLM Foundry
28
+
29
+ This repository contains code for training, finetuning, evaluating, and deploying LLMs for inference with [Composer](https://github.com/mosaicml/composer) and the [MosaicML platform](https://forms.mosaicml.com/demo?utm_source=github.com&utm_medium=referral&utm_campaign=llm-foundry). Designed to be easy-to-use, efficient _and_ flexible, this codebase is designed to enable rapid experimentation with the latest techniques.
30
+
31
+ You'll find in this repo:
32
+ * `llmfoundry/` - source code for models, datasets, callbacks, utilities, etc.
33
+ * `scripts/` - scripts to run LLM workloads
34
+ * `data_prep/` - convert text data from original sources to StreamingDataset format
35
+ * `train/` - train or finetune HuggingFace and MPT models from 125M - 70B parameters
36
+ * `train/benchmarking` - profile training throughput and MFU
37
+ * `inference/` - convert models to HuggingFace or ONNX format, and generate responses
38
+ * `inference/benchmarking` - profile inference latency and throughput
39
+ * `eval/` - evaluate LLMs on academic (or custom) in-context-learning tasks
40
+ * `mcli/` - launch any of these workloads using [MCLI](https://docs.mosaicml.com/projects/mcli/en/latest/) and the [MosaicML platform](https://www.mosaicml.com/platform)
41
+ * `TUTORIAL.md` - a deeper dive into the repo, example workflows, and FAQs
42
+
43
+ # MPT
44
+
45
+ Mosaic Pretrained Transformers (MPT) are GPT-style models with some special features -- Flash Attention for efficiency, ALiBi for context length extrapolation, and stability improvements to mitigate loss spikes. As part of MosaicML's Foundation series, we have open-sourced several MPT models:
46
+
47
+
48
+ | Model | Context Length | Download | Demo | Commercial use? |
49
+ | ------------------ | -------------- | -------------------------------------------------- | ----------------------------------------------------------- | --------------- |
50
+ | MPT-30B | 8192 | https://huggingface.co/mosaicml/mpt-30b | | Yes |
51
+ | MPT-30B-Instruct | 8192 | https://huggingface.co/mosaicml/mpt-30b-instruct | | Yes |
52
+ | MPT-30B-Chat | 8192 | https://huggingface.co/mosaicml/mpt-30b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-30b-chat) | No |
53
+ | MPT-7B | 2048 | https://huggingface.co/mosaicml/mpt-7b | | Yes |
54
+ | MPT-7B-Instruct | 2048 | https://huggingface.co/mosaicml/mpt-7b-instruct | | Yes |
55
+ | MPT-7B-Chat | 2048 | https://huggingface.co/mosaicml/mpt-7b-chat | [Demo](https://huggingface.co/spaces/mosaicml/mpt-7b-chat) | No |
56
+ | MPT-7B-StoryWriter | 65536 | https://huggingface.co/mosaicml/mpt-7b-storywriter | | Yes |
57
+
58
+ To try out these models locally, [follow the instructions](https://github.com/mosaicml/llm-foundry/tree/main/scripts/inference#interactive-generation-with-modelgenerate) in `scripts/inference/README.md` to prompt HF models using our [hf_generate.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py) or [hf_chat.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_chat.py) scripts.
59
+
60
+ # MPT Community
61
+
62
+ We've been overwhelmed by all the amazing work the community has put into MPT! Here we provide a few links to some of them:
63
+ * [ReplitLM](https://github.com/replit/replitLM): `replit-code-v1-3b` is a 2.7B Causal Language Model focused on Code Completion. The model has been trained on a subset of the Stack Dedup v1.2 dataset covering 20 languages such as Java, Python, and C++
64
+ * [LLaVa-MPT](https://github.com/haotian-liu/LLaVA#LLaVA-MPT-7b): Visual instruction tuning to get MPT multimodal capabilities
65
+ * [ggml](https://github.com/ggerganov/ggml/tree/master): Optimized MPT version for efficient inference on consumer hardware
66
+ * [GPT4All](https://gpt4all.io/index.html): locally running chat system, now with MPT support!
67
+ * [Q8MPT-Chat](https://huggingface.co/spaces/Intel/Q8-Chat): 8-bit optimized MPT for CPU by our friends at Intel
68
+
69
+ Tutorial videos from the community:
70
+ * [Using MPT-7B with Langchain](https://www.youtube.com/watch?v=DXpk9K7DgMo&t=3s) by [@jamesbriggs](https://www.youtube.com/@jamesbriggs)
71
+ * [MPT-7B StoryWriter Intro](https://www.youtube.com/watch?v=O9Y_ZdsuKWQ) by [AItrepreneur](https://www.youtube.com/@Aitrepreneur)
72
+ * [Fine-tuning MPT-7B on a single GPU](https://www.youtube.com/watch?v=KSlWkrByc0o&t=9s) by [@AIology2022](https://www.youtube.com/@AIology2022)
73
+ * [How to Fine-tune MPT-7B-Instruct on Google Colab](https://youtu.be/3de0Utr9XnI) by [@VRSEN](https://www.youtube.com/@vrsen)
74
+
75
+ Something missing? Contribute with a PR!
76
+
77
+ # Latest News
78
+ * [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b)
79
+ * [Blog: Introducing MPT-7B](https://www.mosaicml.com/blog/mpt-7b)
80
+ * [Blog: Benchmarking LLMs on H100](https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1)
81
+ * [Blog: Blazingly Fast LLM Evaluation](https://www.mosaicml.com/blog/llm-evaluation-for-icl)
82
+ * [Blog: GPT3 Quality for $500k](https://www.mosaicml.com/blog/gpt-3-quality-for-500k)
83
+ * [Blog: Billion parameter GPT training made easy](https://www.mosaicml.com/blog/billion-parameter-gpt-training-made-easy)
84
+
85
+
86
+
87
+ # Hardware and Software Requirements
88
+ This codebase has been tested with PyTorch 1.13.1 and PyTorch 2.0.1 on systems with NVIDIA A100s and H100s.
89
+ This codebase may also work on systems with other devices, such as consumer NVIDIA cards and AMD cards, but we are not actively testing these systems.
90
+ If you have success/failure using LLM Foundry on other systems, please let us know in a Github issue and we will update the support matrix!
91
+
92
+ | Device | Torch Version | Cuda Version | Status |
93
+ | -------------- | ------------- | ------------ | ---------------------------- |
94
+ | A100-40GB/80GB | 1.13.1 | 11.7 | :white_check_mark: Supported |
95
+ | A100-40GB/80GB | 2.0.1 | 11.7, 11.8 | :white_check_mark: Supported |
96
+ | A100-40GB/80GB | 2.1.0 | 11.8, 12.1 | :white_check_mark: Supported |
97
+ | H100-80GB | 1.13.1 | 11.7 | :x: Not Supported |
98
+ | H100-80GB | 2.0.1 | 11.8 | :white_check_mark: Supported |
99
+ | H100-80GB | 2.1.0 | 12.1 | :white_check_mark: Supported |
100
+ | A10-24GB | 1.13.1 | 11.7 | :construction: In Progress |
101
+ | A10-24GB | 2.0.1 | 11.7, 11.8 | :construction: In Progress |
102
+ | MI250 | 2.0.1 | ROCm 5.4 | :construction: In Progress |
103
+
104
+ ## MosaicML Docker Images
105
+ We highly recommend using our prebuilt Docker images. You can find them here: https://hub.docker.com/orgs/mosaicml/repositories.
106
+
107
+ The `mosaicml/pytorch` images are pinned to specific PyTorch and CUDA versions, and are stable and rarely updated.
108
+
109
+ The `mosaicml/llm-foundry` images are built with new tags upon every commit to the `main` branch.
110
+ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117-f678575` or take the latest one using `mosaicml/llm-foundry:1.13.1_cu117-latest`.
111
+
112
+ **Please Note:** The `mosaicml/llm-foundry` images do not come with the `llm-foundry` package preinstalled, just the dependencies. You will still need to `pip install llm-foundry` either from PyPi or from source.
113
+
114
+ | Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
115
+ | ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- |
116
+ | `mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04` | 1.13.1 | 11.7 (Infiniband) | No |
117
+ | `mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04` | 2.0.1 | 11.8 (Infiniband) | No |
118
+ | `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No |
119
+ | `mosaicml/llm-foundry:1.13.1_cu117-latest` | 1.13.1 | 11.7 (Infiniband) | Yes |
120
+ | `mosaicml/llm-foundry:2.0.1_cu118-latest` | 2.0.1 | 11.8 (Infiniband) | Yes |
121
+ | `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1) |
122
+ | `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2) |
123
+ | `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1) |
124
+ | `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2) |
125
+
126
+
127
+ # Installation
128
+
129
+ This assumes you already have PyTorch and CMake installed.
130
+
131
+ To get started, clone the repo and set up your environment. Instructions to do so differ slightly depending on whether you're using Docker.
132
+ ### With Docker (recommended)
133
+
134
+ We *strongly* recommend working with LLM Foundry inside a Docker container (see our recommended Docker image above). If you are doing so, follow these steps to clone the repo and install the requirements.
135
+
136
+ <!--pytest.mark.skip-->
137
+ ```bash
138
+ git clone https://github.com/mosaicml/llm-foundry.git
139
+ cd llm-foundry
140
+ pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU
141
+ ```
142
+
143
+ ### Without Docker (not recommended)
144
+
145
+ If you choose not to use Docker, you should create and use a virtual environment.
146
+
147
+ <!--pytest.mark.skip-->
148
+ ```bash
149
+ git clone https://github.com/mosaicml/llm-foundry.git
150
+ cd llm-foundry
151
+
152
+ # Creating and activate a virtual environment
153
+ python3 -m venv llmfoundry-venv
154
+ source llmfoundry-venv/bin/activate
155
+
156
+ pip install cmake packaging torch # setup.py requires these be installed
157
+
158
+ pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU
159
+ ```
160
+
161
+ ### TransformerEngine and amp_fp8 support
162
+ NVIDIA H100 GPUs have FP8 support; this additionally requires the following installations:
163
+ <!--pytest.mark.skip-->
164
+ ```bash
165
+ pip install flash-attn==1.0.7 --no-build-isolation
166
+ pip install git+https://github.com/NVIDIA/[email protected]
167
+ ```
168
+
169
+ See [here](https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#TransformerEngine-and-amp_fp8-support) for more details on enabling TransformerEngine layers and amp_fp8.
170
+
171
+ ### AMD (BETA support)
172
+
173
+ In [our testing of AMD GPUs](https://www.mosaicml.com/blog/amd-mi250), the env setup includes:
174
+
175
+ <!--pytest.mark.skip-->
176
+ ```bash
177
+ git clone https://github.com/mosaicml/llm-foundry.git
178
+ cd llm-foundry
179
+
180
+ # Creating and activate a virtual environment
181
+ python3 -m venv llmfoundry-venv-amd
182
+ source llmfoundry-venv-amd/bin/activate
183
+
184
+ # installs
185
+ pip install cmake packaging torch
186
+ pip install -e . # This installs some things that are not needed but they don't hurt
187
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2
188
+ ```
189
+ **Lastly**, install the ROCm enabled flash attention (instructions [here](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support)).
190
+
191
+ Notes:
192
+ 1. `attn_impl: triton` does not work.
193
+ 1. We don't yet have a docker img where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.
194
+
195
+ # Quickstart
196
+
197
+ > **Note**
198
+ > Make sure to go through the installation steps above before trying the quickstart!
199
+
200
+ Here is an end-to-end workflow for preparing a subset of the C4 dataset, training an MPT-125M model for 10 batches,
201
+ converting the model to HuggingFace format, evaluating the model on the Winograd challenge, and generating responses to prompts.
202
+
203
+ **(Remember this is a quickstart just to demonstrate the tools -- To get good quality, the LLM must be trained for longer than 10 batches 😄)**
204
+
205
+ <!--pytest.mark.skip-->
206
+ ```bash
207
+ cd scripts
208
+
209
+ # Convert C4 dataset to StreamingDataset format
210
+ python data_prep/convert_dataset_hf.py \
211
+ --dataset c4 --data_subset en \
212
+ --out_root my-copy-c4 --splits train_small val_small \
213
+ --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>'
214
+
215
+ # Train an MPT-125m model for 10 batches
216
+ composer train/train.py \
217
+ train/yamls/pretrain/mpt-125m.yaml \
218
+ data_local=my-copy-c4 \
219
+ train_loader.dataset.split=train_small \
220
+ eval_loader.dataset.split=val_small \
221
+ max_duration=10ba \
222
+ eval_interval=0 \
223
+ save_folder=mpt-125m
224
+
225
+ # Convert the model to HuggingFace format
226
+ python inference/convert_composer_to_hf.py \
227
+ --composer_path mpt-125m/ep0-ba10-rank0.pt \
228
+ --hf_output_path mpt-125m-hf \
229
+ --output_precision bf16 \
230
+ # --hf_repo_for_upload user-org/repo-name
231
+
232
+ # Evaluate the model on a subset of tasks
233
+ composer eval/eval.py \
234
+ eval/yamls/hf_eval.yaml \
235
+ icl_tasks=eval/yamls/copa.yaml \
236
+ model_name_or_path=mpt-125m-hf
237
+
238
+ # Generate responses to prompts
239
+ python inference/hf_generate.py \
240
+ --name_or_path mpt-125m-hf \
241
+ --max_new_tokens 256 \
242
+ --prompts \
243
+ "The answer to life, the universe, and happiness is" \
244
+ "Here's a quick recipe for baking chocolate chip cookies: Start by"
245
+ ```
246
+
247
+ Note: the `composer` command used above to train the model refers to [Composer](https://github.com/mosaicml/composer) library's distributed launcher.
248
+
249
+ If you have a write-enabled [HuggingFace auth token](https://huggingface.co/docs/hub/security-tokens), you can optionally upload your model to the Hub! Just export your token like this:
250
+
251
+ ```bash
252
+ export HUGGING_FACE_HUB_TOKEN=your-auth-token
253
+ ```
254
+
255
+ and uncomment the line containing `--hf_repo_for_upload ...` in the above call to `inference/convert_composer_to_hf.py`.
256
+
257
+ # Learn more about LLM Foundry!
258
+
259
+ Check out [TUTORIAL.md](https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md) to keep learning about working with LLM Foundry. The tutorial highlights example workflows, points you to other resources throughout the repo, and answers frequently asked questions!
260
+
261
+ # Contact Us
262
+
263
+ If you run into any problems with the code, please file Github issues directly to this repo.
264
+
265
+ If you want to train LLMs on the MosaicML platform, reach out to us at [[email protected]](mailto:[email protected])!
Perceptrix/finetune/build/lib/inference/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ __all__ = []
Perceptrix/finetune/build/lib/inference/convert_composer_mpt_to_ft.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Note: This script is specifically for converting MPT Composer checkpoints to FasterTransformer format.
5
+
6
+ import configparser
7
+ import os
8
+ import tempfile
9
+ from argparse import ArgumentParser, Namespace
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Optional, Union
12
+
13
+ import torch
14
+ from composer.utils import get_file, safe_torch_load
15
+ from transformers import PreTrainedTokenizer
16
+
17
+ from llmfoundry.utils import (convert_and_save_ft_weights,
18
+ get_hf_tokenizer_from_composer_state_dict)
19
+
20
+
21
+ def save_ft_config(composer_config: Dict[str, Any],
22
+ tokenizer: PreTrainedTokenizer,
23
+ save_dir: str,
24
+ infer_gpu_num: int = 1,
25
+ weight_data_type: str = 'fp32',
26
+ force: bool = False):
27
+
28
+ config = configparser.ConfigParser()
29
+ config['gpt'] = {}
30
+ try:
31
+ config['gpt']['model_name'] = 'mpt'
32
+ config['gpt']['head_num'] = str(composer_config['n_heads'])
33
+ n_embd = composer_config['d_model']
34
+ config['gpt']['size_per_head'] = str(n_embd //
35
+ composer_config['n_heads'])
36
+ config['gpt']['inter_size'] = str(n_embd * composer_config['mlp_ratio'])
37
+ config['gpt']['max_pos_seq_len'] = str(composer_config['max_seq_len'])
38
+ config['gpt']['num_layer'] = str(composer_config['n_layers'])
39
+ config['gpt']['vocab_size'] = str(composer_config['vocab_size'])
40
+ config['gpt']['start_id'] = str(tokenizer.bos_token_id)
41
+ config['gpt']['end_id'] = str(tokenizer.eos_token_id)
42
+ config['gpt']['weight_data_type'] = weight_data_type
43
+ config['gpt']['tensor_para_size'] = str(infer_gpu_num)
44
+ # nn.LayerNorm default eps is 1e-5
45
+ config['gpt']['layernorm_eps'] = str(1e-5)
46
+ if composer_config['alibi']:
47
+ config['gpt']['has_positional_encoding'] = str(False)
48
+ config['gpt']['use_attention_linear_bias'] = str(True)
49
+ if composer_config['attn_clip_qkv'] and not force:
50
+ raise RuntimeError(
51
+ 'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
52
+ )
53
+ if composer_config['attn_qk_ln'] and not force:
54
+ raise RuntimeError(
55
+ 'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
56
+ )
57
+
58
+ with open(os.path.join(save_dir, 'config.ini'), 'w') as configfile:
59
+ config.write(configfile)
60
+ return config
61
+ except:
62
+ print(f'Failed to save the config in config.ini.')
63
+ raise
64
+
65
+
66
+ def write_ft_checkpoint_from_composer_checkpoint(
67
+ checkpoint_path: Union[Path, str],
68
+ infer_gpu_num: int,
69
+ save_dir: str,
70
+ output_precision: str = 'fp32',
71
+ local_checkpoint_save_location: Optional[Union[Path,
72
+ str]] = None) -> None:
73
+ """Convert a Composer checkpoint to a FasterTransformer checkpoint folder.
74
+
75
+ .. note:: This function may not work properly if you used surgery algorithms when you trained your model. In that case you may need to
76
+ edit the parameter conversion methods to properly convert your custom model.
77
+
78
+ Args:
79
+ checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend
80
+ supported by Composer.
81
+ infer_gpu_num (int): The number of gpus you are planning to use for inference.
82
+ save_dir (str): Path of the directory to save the checkpoint in FT format.
83
+ output_precision (str, optional): The precision of the output weights saved to the FasterTransformer model. Can be either ``fp32`` or ``fp16``.
84
+ local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally.
85
+ If the input ``checkpoint_path`` is already a local path, this will be a symlink.
86
+ Defaults to None, which will use a temporary file.
87
+ """
88
+ dtype = {
89
+ 'fp32': torch.float32,
90
+ 'fp16': torch.float16,
91
+ }[output_precision]
92
+
93
+ # default local path to a tempfile if path is not provided
94
+ if local_checkpoint_save_location is None:
95
+ tmp_dir = tempfile.TemporaryDirectory()
96
+ local_checkpoint_save_location = Path(
97
+ tmp_dir.name) / 'local-composer-checkpoint.pt'
98
+
99
+ # download the checkpoint file
100
+ print(
101
+ f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}'
102
+ )
103
+ get_file(str(checkpoint_path), str(local_checkpoint_save_location))
104
+
105
+ # Load the Composer checkpoint. Use it to get the
106
+ # Composer state dict and weights
107
+ print('Loading checkpoint into CPU RAM...')
108
+ composer_state_dict = safe_torch_load(local_checkpoint_save_location)
109
+
110
+ # Extract Composer config from state dict
111
+ if 'state' not in composer_state_dict:
112
+ raise RuntimeError(
113
+ f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?'
114
+ )
115
+ if 'integrations' not in composer_state_dict[
116
+ 'state'] or 'huggingface' not in composer_state_dict['state'][
117
+ 'integrations']:
118
+ raise RuntimeError(
119
+ 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!'
120
+ )
121
+ composer_config = composer_state_dict['state']['integrations'][
122
+ 'huggingface']['model']['config']['content']
123
+
124
+ # Extract the HF tokenizer
125
+ print('#' * 30)
126
+ print('Extracting HF Tokenizer...')
127
+ hf_tokenizer = get_hf_tokenizer_from_composer_state_dict(
128
+ composer_state_dict)
129
+ if hf_tokenizer is None:
130
+ print('Warning! No HF Tokenizer found!')
131
+
132
+ # Extract the model weights
133
+ weights_state_dict = composer_state_dict['state']['model']
134
+ torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
135
+ weights_state_dict, prefix='model.')
136
+
137
+ # Converting weights to desired dtype
138
+ for k, v in weights_state_dict.items():
139
+ if isinstance(v, torch.Tensor):
140
+ weights_state_dict[k] = v.to(dtype=dtype)
141
+
142
+ # Convert the weights using the config and tokenizer to FasterTransformer format
143
+ print('#' * 30)
144
+ print('Saving FasterTransformer config...')
145
+ save_ft_config(composer_config,
146
+ tokenizer=hf_tokenizer,
147
+ save_dir=save_dir,
148
+ weight_data_type=output_precision)
149
+ print('#' * 30)
150
+ print('Converting weights to FasterTransformer format...')
151
+ convert_and_save_ft_weights(named_params=weights_state_dict,
152
+ config=composer_config,
153
+ infer_gpu_num=infer_gpu_num,
154
+ weight_data_type=output_precision,
155
+ save_dir=save_dir)
156
+
157
+ print('#' * 30)
158
+ print(
159
+ f'FasterTransformer checkpoint folder successfully created at {save_dir}.'
160
+ )
161
+
162
+ print('Done.')
163
+ print('#' * 30)
164
+
165
+
166
+ def parse_args() -> Namespace:
167
+ """Parse commandline arguments."""
168
+ parser = ArgumentParser(
169
+ description=
170
+ 'Convert an MPT Composer checkpoint into a standard FasterTransformer checkpoint folder.'
171
+ )
172
+ parser.add_argument(
173
+ '--composer_path',
174
+ '-i',
175
+ type=str,
176
+ help='Composer checkpoint path. Can be a local file path or cloud URI',
177
+ required=True)
178
+ parser.add_argument(
179
+ '--local_checkpoint_save_location',
180
+ type=str,
181
+ help='If specified, where to save the checkpoint file to locally. \
182
+ If the input ``checkpoint_path`` is already a local path, this will be a symlink. \
183
+ Defaults to None, which will use a temporary file.',
184
+ default=None)
185
+ parser.add_argument(
186
+ '--ft_save_dir',
187
+ '-o',
188
+ type=str,
189
+ help='Directory to save FasterTransformer converted checkpoint in',
190
+ required=True)
191
+ parser.add_argument('--infer_gpu_num',
192
+ '-i_g',
193
+ type=int,
194
+ help='How many gpus for inference?',
195
+ required=True)
196
+ parser.add_argument(
197
+ '--force',
198
+ action='store_true',
199
+ help=
200
+ 'Force conversion to FT even if some features may not work as expected in FT'
201
+ )
202
+ parser.add_argument(
203
+ '--output_precision',
204
+ type=str,
205
+ help=
206
+ 'Data type of weights in the FasterTransformer output model. Input checkpoint weights will be converted to this dtype.',
207
+ choices=['fp32', 'fp16'],
208
+ default='fp32')
209
+
210
+ return parser.parse_args()
211
+
212
+
213
+ if __name__ == '__main__':
214
+ args = parse_args()
215
+ print('\n=============== Argument ===============')
216
+ for key in vars(args):
217
+ print('{}: {}'.format(key, vars(args)[key]))
218
+ print('========================================')
219
+
220
+ save_dir = os.path.join(args.ft_save_dir, f'{args.infer_gpu_num}-gpu')
221
+
222
+ if os.path.exists(save_dir) == False:
223
+ os.makedirs(save_dir)
224
+ else:
225
+ raise RuntimeError(f'Output path {save_dir} already exists!')
226
+
227
+ write_ft_checkpoint_from_composer_checkpoint(
228
+ checkpoint_path=args.composer_path,
229
+ infer_gpu_num=args.infer_gpu_num,
230
+ save_dir=save_dir,
231
+ output_precision=args.output_precision,
232
+ local_checkpoint_save_location=args.local_checkpoint_save_location)
Perceptrix/finetune/build/lib/inference/convert_composer_to_hf.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ import tempfile
6
+ from argparse import ArgumentParser, Namespace
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import transformers
12
+ from composer.models.huggingface import get_hf_config_from_composer_state_dict
13
+ from composer.utils import (get_file, maybe_create_object_store_from_uri,
14
+ parse_uri, safe_torch_load)
15
+ from transformers import PretrainedConfig, PreTrainedTokenizerBase
16
+
17
+ from llmfoundry import MPTConfig, MPTForCausalLM
18
+ from llmfoundry.utils import get_hf_tokenizer_from_composer_state_dict
19
+ from llmfoundry.utils.huggingface_hub_utils import \
20
+ edit_files_for_hf_compatibility
21
+
22
+
23
+ def write_huggingface_pretrained_from_composer_checkpoint(
24
+ checkpoint_path: Union[Path, str],
25
+ output_path: Union[Path, str],
26
+ output_precision: str = 'fp32',
27
+ local_checkpoint_save_location: Optional[Union[Path, str]] = None
28
+ ) -> Tuple[PretrainedConfig, Optional[PreTrainedTokenizerBase]]:
29
+ """Convert a Composer checkpoint to a pretrained HF checkpoint folder.
30
+
31
+ Write a ``config.json`` and ``pytorch_model.bin``, like
32
+ :meth:`transformers.PreTrainedModel.from_pretrained` expects, from a
33
+ composer checkpoint.
34
+
35
+ .. note:: This function will not work properly if you used surgery algorithms when you trained your model. In that case you will want to
36
+ load the model weights using the Composer :class:`~composer.Trainer` with the ``load_path`` argument.
37
+ .. testsetup::
38
+ import torch
39
+ dataset = RandomTextClassificationDataset(size=16, use_keys=True)
40
+ train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
41
+ eval_dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
42
+ import transformers
43
+ from composer.models import HuggingFaceModel
44
+ from composer.trainer import Trainer
45
+ hf_model = transformers.AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-tiny', num_labels=2)
46
+ hf_tokenizer = transformers.AutoTokenizer.from_pretrained('prajjwal1/bert-tiny')
47
+ composer_model = HuggingFaceModel(hf_model, tokenizer=hf_tokenizer, metrics=[], use_logits=True)
48
+ trainer = Trainer(model=composer_model,
49
+ train_dataloader=train_dataloader,
50
+ save_filename='composer-hf-checkpoint.pt',
51
+ max_duration='1ep',
52
+ save_folder='./')
53
+ trainer.fit()
54
+ trainer.close()
55
+
56
+ Example:
57
+ .. testcode::
58
+ from composer.models import write_huggingface_pretrained_from_composer_checkpoint
59
+ write_huggingface_pretrained_from_composer_checkpoint('composer-hf-checkpoint.pt', './hf-save-pretrained-output')
60
+ loaded_model = transformers.AutoModelForSequenceClassification.from_pretrained('./hf-save-pretrained-output')
61
+
62
+ Args:
63
+ checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend
64
+ supported by :meth:`composer.utils.maybe_create_object_store_from_uri`.
65
+ output_path (Union[Path, str]): Path to the folder to write the output to.
66
+ output_precision (str, optional): The precision of the output weights saved to `pytorch_model.bin`. Can be one of ``fp32``, ``fp16``, or ``bf16``.
67
+ local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally.
68
+ If the input ``checkpoint_path`` is already a local path, this will be a symlink.
69
+ Defaults to None, which will use a temporary file.
70
+ """
71
+ dtype = {
72
+ 'fp32': torch.float32,
73
+ 'fp16': torch.float16,
74
+ 'bf16': torch.bfloat16,
75
+ }[output_precision]
76
+
77
+ # default local path to a tempfile if path is not provided
78
+ if local_checkpoint_save_location is None:
79
+ tmp_dir = tempfile.TemporaryDirectory()
80
+ local_checkpoint_save_location = Path(
81
+ tmp_dir.name) / 'local-composer-checkpoint.pt'
82
+
83
+ # create folder
84
+ os.makedirs(output_path)
85
+
86
+ # download the checkpoint file
87
+ print(
88
+ f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}'
89
+ )
90
+ get_file(str(checkpoint_path), str(local_checkpoint_save_location))
91
+
92
+ # Load the Composer checkpoint state dict
93
+ print('Loading checkpoint into CPU RAM...')
94
+ composer_state_dict = safe_torch_load(local_checkpoint_save_location)
95
+
96
+ if 'state' not in composer_state_dict:
97
+ raise RuntimeError(
98
+ f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?'
99
+ )
100
+
101
+ # Build and save HF Config
102
+ print('#' * 30)
103
+ print('Saving HF Model Config...')
104
+ hf_config = get_hf_config_from_composer_state_dict(composer_state_dict)
105
+ hf_config.torch_dtype = dtype
106
+ hf_config.save_pretrained(output_path)
107
+ print(hf_config)
108
+
109
+ # Extract and save the HF tokenizer
110
+ print('#' * 30)
111
+ print('Saving HF Tokenizer...')
112
+ hf_tokenizer = get_hf_tokenizer_from_composer_state_dict(
113
+ composer_state_dict)
114
+ if hf_tokenizer is not None:
115
+ hf_tokenizer.save_pretrained(output_path)
116
+ print(hf_tokenizer)
117
+ else:
118
+ print('Warning! No HF Tokenizer found!')
119
+
120
+ # Extract the HF model weights
121
+ print('#' * 30)
122
+ print('Saving HF Model Weights...')
123
+ weights_state_dict = composer_state_dict
124
+ if 'state' in weights_state_dict:
125
+ weights_state_dict = weights_state_dict['state']['model']
126
+ torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
127
+ weights_state_dict, prefix='model.')
128
+
129
+ # Convert weights to desired dtype
130
+ for k, v in weights_state_dict.items():
131
+ if isinstance(v, torch.Tensor):
132
+ weights_state_dict[k] = v.to(dtype=dtype)
133
+
134
+ # Save weights
135
+ torch.save(weights_state_dict, Path(output_path) / 'pytorch_model.bin')
136
+
137
+ print('#' * 30)
138
+ print(f'HF checkpoint folder successfully created at {output_path}.')
139
+
140
+ return hf_config, hf_tokenizer
141
+
142
+
143
+ def parse_args() -> Namespace:
144
+ """Parse commandline arguments."""
145
+ parser = ArgumentParser(
146
+ description=
147
+ 'Convert a HuggingFace causal LM in a Composer checkpoint into a standard HuggingFace checkpoint folder, and optionally upload to the hub.'
148
+ )
149
+ parser.add_argument('--composer_path', type=str, required=True)
150
+ parser.add_argument('--hf_output_path', type=str, required=True)
151
+ parser.add_argument('--local_checkpoint_save_location',
152
+ type=str,
153
+ default=None)
154
+ parser.add_argument('--output_precision',
155
+ type=str,
156
+ choices=['fp32', 'fp16', 'bf16'],
157
+ default='fp32')
158
+ parser.add_argument('--hf_repo_for_upload', type=str, default=None)
159
+ parser.add_argument('--test_uploaded_model', action='store_true')
160
+
161
+ return parser.parse_args()
162
+
163
+
164
+ def convert_composer_to_hf(args: Namespace) -> None:
165
+ print()
166
+ print('#' * 30)
167
+ print('Converting Composer checkpoint to HuggingFace checkpoint format...')
168
+
169
+ # Register MPT auto classes so that this script works with MPT
170
+ # This script will not work without modification for other custom models,
171
+ # but will work for other HuggingFace causal LMs
172
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
173
+ CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
174
+ MPTConfig.register_for_auto_class()
175
+ MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
176
+
177
+ _, _, local_folder_path = parse_uri(args.hf_output_path)
178
+
179
+ config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint(
180
+ checkpoint_path=args.composer_path,
181
+ output_path=local_folder_path,
182
+ output_precision=args.output_precision,
183
+ local_checkpoint_save_location=args.local_checkpoint_save_location)
184
+
185
+ dtype = {
186
+ 'fp32': torch.float32,
187
+ 'fp16': torch.float16,
188
+ 'bf16': torch.bfloat16,
189
+ }[args.output_precision]
190
+
191
+ print(f'Loading model from {local_folder_path}')
192
+ if config.model_type == 'mpt':
193
+ config.attn_config['attn_impl'] = 'torch'
194
+ config.init_device = 'cpu'
195
+
196
+ if config.model_type == 'mpt':
197
+ loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path,
198
+ config=config,
199
+ torch_dtype=dtype)
200
+ else:
201
+ loaded_hf_model = transformers.AutoModelForCausalLM.from_pretrained(
202
+ local_folder_path, config=config, torch_dtype=dtype)
203
+
204
+ delattr(loaded_hf_model.config, '_name_or_path')
205
+
206
+ loaded_hf_model.save_pretrained(local_folder_path)
207
+
208
+ print(f'Loading tokenizer from {local_folder_path}')
209
+ tokenizer = transformers.AutoTokenizer.from_pretrained(local_folder_path)
210
+ tokenizer.save_pretrained(local_folder_path)
211
+
212
+ # Only need to edit files for MPT because it has custom code
213
+ if config.model_type == 'mpt':
214
+ print('Editing files for HF compatibility...')
215
+ edit_files_for_hf_compatibility(local_folder_path)
216
+
217
+ object_store = maybe_create_object_store_from_uri(str(args.hf_output_path))
218
+
219
+ if object_store is not None:
220
+ print(
221
+ f'Uploading HF checkpoint folder from {local_folder_path} -> {args.hf_output_path}'
222
+ )
223
+ for file in os.listdir(local_folder_path):
224
+ remote_file = os.path.join(local_folder_path, file)
225
+ local_file = os.path.join(local_folder_path, file)
226
+ object_store.upload_object(remote_file, local_file)
227
+
228
+ if args.hf_repo_for_upload is not None:
229
+ from huggingface_hub import HfApi
230
+ api = HfApi()
231
+
232
+ print(
233
+ f'Uploading {args.hf_output_path} to HuggingFace Hub at {args.hf_repo_for_upload}'
234
+ )
235
+ api.create_repo(repo_id=args.hf_repo_for_upload,
236
+ use_auth_token=True,
237
+ repo_type='model',
238
+ private=True,
239
+ exist_ok=True)
240
+ print('Repo created.')
241
+
242
+ # ignore the full checkpoint file if we now have sharded checkpoint files
243
+ ignore_patterns = []
244
+ if any(
245
+ f.startswith('pytorch_model-00001')
246
+ for f in os.listdir(args.hf_output_path)):
247
+ ignore_patterns.append('pytorch_model.bin')
248
+
249
+ api.upload_folder(folder_path=args.hf_output_path,
250
+ repo_id=args.hf_repo_for_upload,
251
+ use_auth_token=True,
252
+ repo_type='model',
253
+ ignore_patterns=ignore_patterns)
254
+ print('Folder uploaded.')
255
+
256
+ if args.test_uploaded_model:
257
+ print('Testing uploaded model...')
258
+ hub_model = transformers.AutoModelForCausalLM.from_pretrained(
259
+ args.hf_repo_for_upload,
260
+ trust_remote_code=True,
261
+ use_auth_token=True,
262
+ torch_dtype=dtype)
263
+ hub_tokenizer = transformers.AutoTokenizer.from_pretrained(
264
+ args.hf_repo_for_upload,
265
+ trust_remote_code=True,
266
+ use_auth_token=True)
267
+
268
+ assert sum(p.numel() for p in hub_model.parameters()) == sum(
269
+ p.numel() for p in loaded_hf_model.parameters())
270
+ assert all(
271
+ str(type(module1)).split('.')[-2:] == str(type(module2)).split(
272
+ '.')[-2:] for module1, module2 in zip(
273
+ hub_model.modules(), loaded_hf_model.modules()))
274
+
275
+ assert next(
276
+ hub_model.parameters()
277
+ ).dtype == dtype, f'Expected model dtype to be {dtype}, but got {next(hub_model.parameters()).dtype}'
278
+ print(
279
+ hub_tokenizer.batch_decode(
280
+ hub_model.generate(hub_tokenizer(
281
+ 'MosaicML is', return_tensors='pt').input_ids,
282
+ max_new_tokens=10)))
283
+
284
+ print(
285
+ 'Composer checkpoint successfully converted to HuggingFace checkpoint format.'
286
+ )
287
+
288
+
289
+ if __name__ == '__main__':
290
+ convert_composer_to_hf(parse_args())
Perceptrix/finetune/build/lib/inference/convert_hf_mpt_to_ft.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Convert MPT model checkpoint to FT format.
19
+
20
+ It's a modified version of
21
+ https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/utils/huggingface_gpt_convert.py
22
+ """
23
+
24
+ import argparse
25
+ import configparser
26
+ import os
27
+
28
+ import transformers
29
+
30
+ from llmfoundry.utils import convert_and_save_ft_weights
31
+
32
+
33
+ def convert_mpt_to_ft(model_name_or_path: str,
34
+ output_dir: str,
35
+ infer_gpu_num: int = 1,
36
+ weight_data_type: str = 'fp32',
37
+ force: bool = False) -> None:
38
+ """Convert an MPT checkpoint to a FasterTransformer compatible format.
39
+
40
+ Args:
41
+ model_name_or_path (str): The HF hub name of the model (e.g., mosaicml/mpt-7b) or the path of a directory
42
+ containing an MPT checkpoint in a local dir.
43
+ output_dir (str): Path of the directory to save the checkpoint in FT format. The directory must not already exist.
44
+ infer_gpu_num (int): The number of gpus you are planning to use for inference.
45
+ weight_data_type (str): Data type of the weights in the input checkpoint.
46
+ force (bool): force conversion even with unsupported features in FT.
47
+ """
48
+ save_dir = os.path.join(output_dir, f'{infer_gpu_num}-gpu')
49
+
50
+ if (os.path.exists(save_dir) == False):
51
+ os.makedirs(save_dir)
52
+ else:
53
+ raise RuntimeError(f'Output path {save_dir} already exists!')
54
+
55
+ # do conversion on cpu
56
+ torch_device = 'cpu'
57
+
58
+ model = transformers.AutoModelForCausalLM.from_pretrained(
59
+ model_name_or_path, trust_remote_code=True).to(torch_device)
60
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
61
+ model_name_or_path, trust_remote_code=True)
62
+
63
+ hf_config = vars(model.config)
64
+
65
+ config = configparser.ConfigParser()
66
+ config['gpt'] = {}
67
+ try:
68
+ config['gpt']['model_name'] = 'mpt' if hf_config[
69
+ '_name_or_path'] == '' else hf_config['_name_or_path']
70
+ config['gpt']['head_num'] = str(hf_config['n_heads'])
71
+ n_embd = hf_config['d_model']
72
+ config['gpt']['size_per_head'] = str(n_embd // hf_config['n_heads'])
73
+ config['gpt']['inter_size'] = str(n_embd * hf_config['expansion_ratio'])
74
+ config['gpt']['max_pos_seq_len'] = str(hf_config['max_seq_len'])
75
+ config['gpt']['num_layer'] = str(hf_config['n_layers'])
76
+ config['gpt']['vocab_size'] = str(hf_config['vocab_size'])
77
+ config['gpt']['start_id'] = str(
78
+ hf_config['bos_token_id']
79
+ ) if hf_config['bos_token_id'] != None else str(tokenizer.bos_token_id)
80
+ config['gpt']['end_id'] = str(
81
+ hf_config['eos_token_id']
82
+ ) if hf_config['eos_token_id'] != None else str(tokenizer.eos_token_id)
83
+ config['gpt']['weight_data_type'] = weight_data_type
84
+ config['gpt']['tensor_para_size'] = str(infer_gpu_num)
85
+ # nn.LayerNorm default eps is 1e-5
86
+ config['gpt']['layernorm_eps'] = str(1e-5)
87
+ if hf_config['attn_config']['alibi']:
88
+ config['gpt']['has_positional_encoding'] = str(False)
89
+ config['gpt']['use_attention_linear_bias'] = str(True)
90
+ if hf_config['attn_config']['clip_qkv'] and not force:
91
+ raise RuntimeError(
92
+ 'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
93
+ )
94
+ if hf_config['attn_config']['qk_ln'] and not force:
95
+ raise RuntimeError(
96
+ 'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.'
97
+ )
98
+
99
+ with open(os.path.join(save_dir, 'config.ini'), 'w') as configfile:
100
+ config.write(configfile)
101
+ except:
102
+ print(f'Failed to save the config in config.ini.')
103
+ raise
104
+
105
+ named_params_dict = {
106
+ name: param for name, param in model.named_parameters()
107
+ }
108
+ convert_and_save_ft_weights(named_params=named_params_dict,
109
+ config=hf_config,
110
+ infer_gpu_num=infer_gpu_num,
111
+ weight_data_type=weight_data_type,
112
+ save_dir=save_dir)
113
+
114
+
115
+ if __name__ == '__main__':
116
+ parser = argparse.ArgumentParser(
117
+ formatter_class=argparse.RawTextHelpFormatter)
118
+ parser.add_argument('--save_dir',
119
+ '-o',
120
+ type=str,
121
+ help='Directory to save converted checkpoint in',
122
+ required=True)
123
+ parser.add_argument(
124
+ '--name_or_dir',
125
+ '-i',
126
+ type=str,
127
+ help=
128
+ 'HF hub Model name (e.g., mosaicml/mpt-7b) or local dir path to load checkpoint from',
129
+ required=True)
130
+ parser.add_argument('--infer_gpu_num',
131
+ '-i_g',
132
+ type=int,
133
+ help='How many gpus for inference?',
134
+ required=True)
135
+ parser.add_argument(
136
+ '--force',
137
+ action='store_true',
138
+ help=
139
+ 'Force conversion to FT even if some features may not work as expected in FT'
140
+ )
141
+ parser.add_argument('--weight_data_type',
142
+ type=str,
143
+ help='Data type of weights in the input checkpoint',
144
+ default='fp32',
145
+ choices=['fp32', 'fp16'])
146
+
147
+ args = parser.parse_args()
148
+ print('\n=============== Argument ===============')
149
+ for key in vars(args):
150
+ print('{}: {}'.format(key, vars(args)[key]))
151
+ print('========================================')
152
+
153
+ convert_mpt_to_ft(args.name_or_dir, args.save_dir, args.infer_gpu_num,
154
+ args.weight_data_type, args.force)
Perceptrix/finetune/build/lib/inference/convert_hf_to_onnx.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Basic HuggingFace -> ONNX export script.
5
+
6
+ This scripts show a basic HuggingFace -> ONNX export workflow. This works for a MPT model
7
+ that has been saved using `MPT.save_pretrained`. For more details and examples
8
+ of exporting and working with HuggingFace models with ONNX, see https://huggingface.co/docs/transformers/serialization#export-to-onnx.
9
+
10
+ Example usage:
11
+
12
+ 1) Local export
13
+
14
+ python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder
15
+
16
+ 2) Remote export
17
+
18
+ python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder s3://bucket/remote/folder
19
+
20
+ 3) Verify the exported model
21
+
22
+ python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder --verify_export
23
+
24
+ 4) Change the batch size or max sequence length
25
+
26
+ python convert_hf_to_onnx.py --pretrained_model_name_or_path local/path/to/huggingface/folder --output_folder local/folder --export_batch_size 1 --max_seq_len 32000
27
+ """
28
+
29
+ import argparse
30
+ import os
31
+ from argparse import ArgumentTypeError
32
+ from pathlib import Path
33
+ from typing import Any, Dict, Optional, Union
34
+
35
+ import torch
36
+ from composer.utils import (maybe_create_object_store_from_uri, parse_uri,
37
+ reproducibility)
38
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
39
+
40
+
41
+ def str2bool(v: Union[str, bool]):
42
+ if isinstance(v, bool):
43
+ return v
44
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
45
+ return True
46
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
47
+ return False
48
+ else:
49
+ raise ArgumentTypeError('Boolean value expected.')
50
+
51
+
52
+ def str_or_bool(v: Union[str, bool]):
53
+ if isinstance(v, bool):
54
+ return v
55
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
56
+ return True
57
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
58
+ return False
59
+ else:
60
+ return v
61
+
62
+
63
+ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int):
64
+ # generate input batch of random data
65
+ batch = {
66
+ 'input_ids':
67
+ torch.randint(
68
+ low=0,
69
+ high=vocab_size,
70
+ size=(batch_size, max_seq_len),
71
+ dtype=torch.int64,
72
+ ),
73
+ 'attention_mask':
74
+ torch.ones(size=(batch_size, max_seq_len), dtype=torch.bool)
75
+ }
76
+ return batch
77
+
78
+
79
+ def export_to_onnx(
80
+ pretrained_model_name_or_path: str,
81
+ output_folder: str,
82
+ export_batch_size: int,
83
+ max_seq_len: Optional[int],
84
+ verify_export: bool,
85
+ from_pretrained_kwargs: Dict[str, Any],
86
+ ):
87
+ reproducibility.seed_all(42)
88
+ save_object_store = maybe_create_object_store_from_uri(output_folder)
89
+ _, _, parsed_save_path = parse_uri(output_folder)
90
+
91
+ print('Loading HF config/model/tokenizer...')
92
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
93
+ **from_pretrained_kwargs)
94
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
95
+ **from_pretrained_kwargs)
96
+
97
+ # specifically for MPT, switch to the torch version of attention for ONNX export
98
+ if hasattr(config, 'attn_config'):
99
+ config.attn_config['attn_impl'] = 'torch'
100
+
101
+ model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,
102
+ config=config,
103
+ **from_pretrained_kwargs)
104
+ model.eval()
105
+
106
+ if max_seq_len is None and not hasattr(model.config, 'max_seq_len'):
107
+ raise ValueError(
108
+ 'max_seq_len must be specified in either the model config or as an argument to this function.'
109
+ )
110
+ elif max_seq_len is None:
111
+ max_seq_len = model.config.max_seq_len
112
+
113
+ assert isinstance(max_seq_len, int) # pyright
114
+
115
+ print('Creating random batch...')
116
+ sample_input = gen_random_batch(
117
+ export_batch_size,
118
+ len(tokenizer),
119
+ max_seq_len,
120
+ )
121
+
122
+ with torch.no_grad():
123
+ model(**sample_input)
124
+
125
+ output_file = Path(parsed_save_path) / 'model.onnx'
126
+ os.makedirs(parsed_save_path, exist_ok=True)
127
+ print('Exporting the model with ONNX...')
128
+ torch.onnx.export(
129
+ model,
130
+ (sample_input,),
131
+ str(output_file),
132
+ input_names=['input_ids', 'attention_mask'],
133
+ output_names=['output'],
134
+ opset_version=16,
135
+ )
136
+
137
+ if verify_export:
138
+ with torch.no_grad():
139
+ orig_out = model(**sample_input)
140
+
141
+ import onnx
142
+ import onnx.checker
143
+ import onnxruntime as ort
144
+
145
+ _ = onnx.load(str(output_file))
146
+
147
+ onnx.checker.check_model(str(output_file))
148
+
149
+ ort_session = ort.InferenceSession(str(output_file))
150
+
151
+ for key, value in sample_input.items():
152
+ sample_input[key] = value.cpu().numpy()
153
+
154
+ loaded_model_out = ort_session.run(None, sample_input)
155
+
156
+ torch.testing.assert_close(
157
+ orig_out.logits.detach().numpy(),
158
+ loaded_model_out[0],
159
+ rtol=1e-2,
160
+ atol=1e-2,
161
+ msg=f'output mismatch between the orig and onnx exported model',
162
+ )
163
+ print('exported model ouptut matches with unexported model!!')
164
+
165
+ if save_object_store is not None:
166
+ print('Uploading files to object storage...')
167
+ for filename in os.listdir(parsed_save_path):
168
+ full_path = str(Path(parsed_save_path) / filename)
169
+ save_object_store.upload_object(full_path, full_path)
170
+
171
+
172
+ def parse_args():
173
+ parser = argparse.ArgumentParser(description='Convert HF model to ONNX',)
174
+ parser.add_argument(
175
+ '--pretrained_model_name_or_path',
176
+ type=str,
177
+ required=True,
178
+ )
179
+ parser.add_argument(
180
+ '--output_folder',
181
+ type=str,
182
+ required=True,
183
+ )
184
+ parser.add_argument(
185
+ '--export_batch_size',
186
+ type=int,
187
+ default=8,
188
+ )
189
+ parser.add_argument(
190
+ '--max_seq_len',
191
+ type=int,
192
+ default=None,
193
+ )
194
+ parser.add_argument(
195
+ '--verify_export',
196
+ action='store_true',
197
+ )
198
+ parser.add_argument('--trust_remote_code',
199
+ type=str2bool,
200
+ nargs='?',
201
+ const=True,
202
+ default=True)
203
+ parser.add_argument('--use_auth_token',
204
+ type=str_or_bool,
205
+ nargs='?',
206
+ const=True,
207
+ default=None)
208
+ parser.add_argument('--revision', type=str, default=None)
209
+ return parser.parse_args()
210
+
211
+
212
+ def main(args: argparse.Namespace):
213
+ from_pretrained_kwargs = {
214
+ 'use_auth_token': args.use_auth_token,
215
+ 'trust_remote_code': args.trust_remote_code,
216
+ 'revision': args.revision,
217
+ }
218
+
219
+ export_to_onnx(
220
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
221
+ output_folder=args.output_folder,
222
+ export_batch_size=args.export_batch_size,
223
+ max_seq_len=args.max_seq_len,
224
+ verify_export=args.verify_export,
225
+ from_pretrained_kwargs=from_pretrained_kwargs)
226
+
227
+
228
+ if __name__ == '__main__':
229
+ main(parse_args())
Perceptrix/finetune/build/lib/inference/hf_chat.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import time
5
+ import warnings
6
+ from argparse import ArgumentParser, ArgumentTypeError, Namespace
7
+ from contextlib import nullcontext
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import torch
11
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
12
+ PreTrainedModel, PreTrainedTokenizerBase,
13
+ StoppingCriteria, StoppingCriteriaList, TextStreamer)
14
+
15
+
16
+ class ChatFormatter:
17
+ """A class for formatting the chat history.
18
+
19
+ Args:
20
+ system: The system prompt. If None, a default ChatML-formatted prompt is used.
21
+ user: The user prompt. If None, a default ChatML value is used.
22
+ assistant: The assistant prompt. If None, a default ChatML value is used.
23
+
24
+ Attributes:
25
+ system: The system prompt.
26
+ user: The user prompt.
27
+ assistant: The assistant prompt.
28
+ response_prefix: The response prefix (anything before {} in the assistant format string)
29
+ """
30
+
31
+ def __init__(self, system: str, user: str, assistant: str) -> None:
32
+ self.system = system if system else '<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.<|im_end|>\n'
33
+ self.user = user if user else '<|im_start|>user\n{}<|im_end|>\n'
34
+ self.assistant = assistant if assistant else '<|im_start|>assistant\n{}<|im_end|>\n'
35
+ self.response_prefix = self.assistant.split('{}')[0]
36
+
37
+
38
+ class Conversation:
39
+ """A class for interacting with a chat-tuned LLM.
40
+
41
+ Args:
42
+ model: The model to use for inference.
43
+ tokenizer: The tokenizer to use for inference.
44
+ chat_format: The chat format to use for the conversation.
45
+ generate_kwargs: The keyword arguments to pass to `model.generate`.
46
+ stop_tokens: The tokens to stop generation on.
47
+
48
+ Attributes:
49
+ model: The model to use for inference.
50
+ tokenizer: The tokenizer to use for inference.
51
+ chat_format: The chat format to use for the conversation.
52
+ streamer: The streamer to use for inference.
53
+ generate_kwargs: The keyword arguments to pass to `model.generate`.
54
+ history: The conversation history.
55
+ cli_instructions: The instructions to display to the user.
56
+ """
57
+
58
+ def __init__(self,
59
+ model: PreTrainedModel,
60
+ tokenizer: PreTrainedTokenizerBase,
61
+ chat_format: ChatFormatter,
62
+ generate_kwargs: Dict[str, Any],
63
+ stop_tokens: Optional[List[str]] = None) -> None:
64
+ if stop_tokens is None:
65
+ stop_tokens = ['<|endoftext|>', '<|im_end|>']
66
+ self.model = model
67
+ self.tokenizer = tokenizer
68
+ self.chat_format = chat_format
69
+
70
+ stop_token_ids = self.tokenizer.convert_tokens_to_ids(stop_tokens)
71
+ if len(stop_token_ids) != len(stop_tokens):
72
+ warnings.warn(
73
+ f'Not all stop tokens were found in the tokenizer vocabulary: {stop_tokens}\n'
74
+ + 'Generation may stop or continue unexpectedly.')
75
+
76
+ class StopOnTokens(StoppingCriteria):
77
+
78
+ def __call__(self, input_ids: torch.LongTensor,
79
+ scores: torch.FloatTensor, **kwargs: Any) -> bool:
80
+ del kwargs # unused
81
+ for stop_id in stop_token_ids:
82
+ if input_ids[0][-1] == stop_id:
83
+ return True
84
+ return False
85
+
86
+ self.streamer = TextStreamer(tokenizer,
87
+ skip_prompt=True,
88
+ skip_special_tokens=True)
89
+ self.generate_kwargs = {
90
+ **generate_kwargs,
91
+ 'stopping_criteria':
92
+ StoppingCriteriaList([StopOnTokens()]),
93
+ 'streamer':
94
+ self.streamer,
95
+ }
96
+ self.history = []
97
+ self.cli_instructions = (
98
+ 'Enter your message below.\n- Hit return twice to send input to the model\n'
99
+ +
100
+ "- Type 'clear' to restart the conversation\n- Type 'history' to see the conversation\n"
101
+ +
102
+ "- Type 'quit' to end\n- Type 'system' to change the system prompt\n"
103
+ )
104
+
105
+ def _history_as_formatted_str(self) -> str:
106
+ text = self.chat_format.system + ''.join([
107
+ '\n'.join([
108
+ self.chat_format.user.format(item[0]),
109
+ self.chat_format.assistant.format(item[1]),
110
+ ]) for item in self.history[:-1]
111
+ ])
112
+ text += self.chat_format.user.format(self.history[-1][0])
113
+ text += self.chat_format.response_prefix
114
+ return text
115
+
116
+ def turn(self, user_inp: str) -> None:
117
+ self.history.append([user_inp, ''])
118
+ conversation = self._history_as_formatted_str()
119
+ input_ids = self.tokenizer(conversation, return_tensors='pt').input_ids
120
+ input_ids = input_ids.to(self.model.device)
121
+ # also stream to stdout
122
+ maybe_synchronize()
123
+ start = time.time()
124
+ print('Assistant:')
125
+ gkwargs = {**self.generate_kwargs, 'input_ids': input_ids}
126
+ # this will stream to stdout, but we need to keep track of the output_ids for saving history
127
+ output_ids = self.model.generate(**gkwargs)
128
+ maybe_synchronize()
129
+ end = time.time()
130
+ print(f'took {end - start:.2f} seconds')
131
+ new_tokens = output_ids[0, len(input_ids[0]):]
132
+ assistant_response = self.tokenizer.decode(new_tokens,
133
+ skip_special_tokens=True)
134
+ self.history[-1][-1] = assistant_response
135
+
136
+ def __call__(self) -> None:
137
+ print(self.cli_instructions)
138
+ while True:
139
+ print('User:')
140
+ user_inp_lines = []
141
+ while True:
142
+ line = input()
143
+ if line.strip() == '':
144
+ break
145
+ user_inp_lines.append(line)
146
+ user_inp = '\n'.join(user_inp_lines)
147
+ if user_inp.lower() == 'quit':
148
+ break
149
+ elif user_inp.lower() == 'clear':
150
+ self.history = []
151
+ continue
152
+ elif user_inp == 'history':
153
+ print(f'history: {self.history}')
154
+ continue
155
+ elif user_inp == 'history_fmt':
156
+ print(f'history: {self._history_as_formatted_str()}')
157
+ continue
158
+ elif user_inp == 'system':
159
+ print('Enter a new system prompt:')
160
+ new_system = input()
161
+ sys = f'<|im_start|>system\n{new_system.strip()}.<|im_end|>\n'
162
+ self.chat_format.system = sys
163
+ continue
164
+ self.turn(user_inp)
165
+
166
+
167
+ def get_dtype(dtype: str):
168
+ if dtype == 'fp32':
169
+ return torch.float32
170
+ elif dtype == 'fp16':
171
+ return torch.float16
172
+ elif dtype == 'bf16':
173
+ return torch.bfloat16
174
+ else:
175
+ raise NotImplementedError(
176
+ f'dtype {dtype} is not supported. ' +
177
+ 'We only support fp32, fp16, and bf16 currently')
178
+
179
+
180
+ def str2bool(v: Union[str, bool]):
181
+ if isinstance(v, bool):
182
+ return v
183
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
184
+ return True
185
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
186
+ return False
187
+ else:
188
+ raise ArgumentTypeError('Boolean value expected.')
189
+
190
+
191
+ def str_or_bool(v: Union[str, bool]):
192
+ if isinstance(v, bool):
193
+ return v
194
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
195
+ return True
196
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
197
+ return False
198
+ else:
199
+ return v
200
+
201
+
202
+ def parse_args() -> Namespace:
203
+ """Parse commandline arguments."""
204
+ parser = ArgumentParser(
205
+ description='Load a HF CausalLM Model and use it to generate text.')
206
+ parser.add_argument('-n', '--name_or_path', type=str, required=True)
207
+ parser.add_argument('--max_new_tokens', type=int, default=512)
208
+ parser.add_argument('--max_seq_len', type=int, default=None)
209
+ parser.add_argument('--temperature', type=float, default=1.0)
210
+ parser.add_argument('--top_k', type=int, default=50)
211
+ parser.add_argument('--top_p', type=float, default=1.0)
212
+ parser.add_argument('--do_sample',
213
+ type=str2bool,
214
+ nargs='?',
215
+ const=True,
216
+ default=True)
217
+ parser.add_argument('--use_cache',
218
+ type=str2bool,
219
+ nargs='?',
220
+ const=True,
221
+ default=True)
222
+ parser.add_argument('--eos_token_id', type=str, default=None)
223
+ parser.add_argument('--pad_token_id', type=str, default=None)
224
+ parser.add_argument('--model_dtype',
225
+ type=str,
226
+ choices=['fp32', 'fp16', 'bf16'],
227
+ default=None)
228
+ parser.add_argument('--autocast_dtype',
229
+ type=str,
230
+ choices=['fp32', 'fp16', 'bf16'],
231
+ default=None)
232
+ parser.add_argument('--warmup',
233
+ type=str2bool,
234
+ nargs='?',
235
+ const=True,
236
+ default=True)
237
+ parser.add_argument('--trust_remote_code',
238
+ type=str2bool,
239
+ nargs='?',
240
+ const=True,
241
+ default=True)
242
+ parser.add_argument('--use_auth_token',
243
+ type=str_or_bool,
244
+ nargs='?',
245
+ const=True,
246
+ default=None)
247
+ parser.add_argument('--revision', type=str, default=None)
248
+ parser.add_argument('--device', type=str, default=None)
249
+ parser.add_argument('--device_map', type=str, default=None)
250
+ parser.add_argument('--attn_impl', type=str, default=None)
251
+ parser.add_argument('--seed', type=int, default=42)
252
+ parser.add_argument('--system_prompt', type=str, default=None)
253
+ parser.add_argument('--user_msg_fmt', type=str, default=None)
254
+ parser.add_argument('--assistant_msg_fmt', type=str, default=None)
255
+ parser.add_argument(
256
+ '--stop_tokens',
257
+ type=str,
258
+ default='<|endoftext|> <|im_end|>',
259
+ help='A string of tokens to stop generation on; will be split on spaces.'
260
+ )
261
+ return parser.parse_args()
262
+
263
+
264
+ def maybe_synchronize():
265
+ if torch.cuda.is_available():
266
+ torch.cuda.synchronize()
267
+
268
+
269
+ def main(args: Namespace) -> None:
270
+ # Set device or device_map
271
+ if args.device and args.device_map:
272
+ raise ValueError('You can only set one of `device` and `device_map`.')
273
+ if args.device is not None:
274
+ device = args.device
275
+ device_map = None
276
+ else:
277
+ device = None
278
+ device_map = args.device_map or 'auto'
279
+ print(f'Using {device=} and {device_map=}')
280
+
281
+ # Set model_dtype
282
+ if args.model_dtype is not None:
283
+ model_dtype = get_dtype(args.model_dtype)
284
+ else:
285
+ model_dtype = torch.float32
286
+ print(f'Using {model_dtype=}')
287
+
288
+ # Grab config first
289
+ print(f'Loading HF Config...')
290
+ from_pretrained_kwargs = {
291
+ 'use_auth_token': args.use_auth_token,
292
+ 'trust_remote_code': args.trust_remote_code,
293
+ 'revision': args.revision,
294
+ }
295
+ try:
296
+ config = AutoConfig.from_pretrained(args.name_or_path,
297
+ **from_pretrained_kwargs)
298
+ if args.attn_impl is not None and hasattr(config, 'attn_config'):
299
+ config.attn_config['attn_impl'] = args.attn_impl
300
+ if hasattr(config, 'init_device') and device is not None:
301
+ config.init_device = device
302
+ if args.max_seq_len is not None and hasattr(config, 'max_seq_len'):
303
+ config.max_seq_len = args.max_seq_len
304
+
305
+ except Exception as e:
306
+ raise RuntimeError(
307
+ 'If you are having auth problems, try logging in via `huggingface-cli login` '
308
+ +
309
+ 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... '
310
+ +
311
+ 'using your access token from https://huggingface.co/settings/tokens.'
312
+ ) from e
313
+
314
+ # Load HF Model
315
+ print(f'Loading HF model with dtype={model_dtype}...')
316
+ try:
317
+ model = AutoModelForCausalLM.from_pretrained(args.name_or_path,
318
+ config=config,
319
+ torch_dtype=model_dtype,
320
+ device_map=device_map,
321
+ **from_pretrained_kwargs)
322
+ model.eval()
323
+ print(f'n_params={sum(p.numel() for p in model.parameters())}')
324
+ if device is not None:
325
+ print(f'Placing model on {device=}...')
326
+ model.to(device)
327
+ except Exception as e:
328
+ raise RuntimeError(
329
+ 'Unable to load HF model. ' +
330
+ 'If you are having auth problems, try logging in via `huggingface-cli login` '
331
+ +
332
+ 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... '
333
+ +
334
+ 'using your access token from https://huggingface.co/settings/tokens.'
335
+ ) from e
336
+
337
+ print('\nLoading HF tokenizer...')
338
+ tokenizer = AutoTokenizer.from_pretrained(args.name_or_path,
339
+ **from_pretrained_kwargs)
340
+ if tokenizer.pad_token_id is None:
341
+ warnings.warn(
342
+ 'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.'
343
+ )
344
+ tokenizer.pad_token = tokenizer.eos_token
345
+ tokenizer.padding_side = 'left'
346
+
347
+ generate_kwargs = {
348
+ 'max_new_tokens': args.max_new_tokens,
349
+ 'temperature': args.temperature,
350
+ 'top_p': args.top_p,
351
+ 'top_k': args.top_k,
352
+ 'use_cache': args.use_cache,
353
+ 'do_sample': args.do_sample,
354
+ 'eos_token_id': args.eos_token_id or tokenizer.eos_token_id,
355
+ 'pad_token_id': args.pad_token_id or tokenizer.eos_token_id,
356
+ }
357
+ # Autocast
358
+ if args.autocast_dtype is not None:
359
+ autocast_dtype = get_dtype(args.autocast_dtype)
360
+ autocast_context = torch.autocast(model.device.type, autocast_dtype)
361
+ print(f'Using autocast with dtype={autocast_dtype}...')
362
+ else:
363
+ autocast_context = nullcontext()
364
+ print('NOT using autocast...')
365
+
366
+ chat_format = ChatFormatter(system=args.system_prompt,
367
+ user=args.user_msg_fmt,
368
+ assistant=args.assistant_msg_fmt)
369
+
370
+ conversation = Conversation(model=model,
371
+ tokenizer=tokenizer,
372
+ chat_format=chat_format,
373
+ generate_kwargs=generate_kwargs,
374
+ stop_tokens=args.stop_tokens.split())
375
+
376
+ # Warmup
377
+ if args.warmup:
378
+ print('Warming up...')
379
+ with autocast_context:
380
+ conversation.turn('Write a welcome message to the user.')
381
+ conversation.history = []
382
+
383
+ print('Starting conversation...')
384
+ with autocast_context:
385
+ conversation()
386
+
387
+
388
+ if __name__ == '__main__':
389
+ main(parse_args())
Perceptrix/finetune/build/lib/inference/hf_generate.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import itertools
4
+ import os
5
+ import random
6
+ import time
7
+ import warnings
8
+ from argparse import ArgumentParser, ArgumentTypeError, Namespace
9
+ from contextlib import nullcontext
10
+ from typing import Dict, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
15
+
16
+
17
+ def get_dtype(dtype: str):
18
+ if dtype == 'fp32':
19
+ return torch.float32
20
+ elif dtype == 'fp16':
21
+ return torch.float16
22
+ elif dtype == 'bf16':
23
+ return torch.bfloat16
24
+ else:
25
+ raise NotImplementedError(
26
+ f'dtype {dtype} is not supported. ' +\
27
+ f'We only support fp32, fp16, and bf16 currently')
28
+
29
+
30
+ def str2bool(v: Union[str, bool]):
31
+ if isinstance(v, bool):
32
+ return v
33
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
34
+ return True
35
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
36
+ return False
37
+ else:
38
+ raise ArgumentTypeError('Boolean value expected.')
39
+
40
+
41
+ def str_or_bool(v: Union[str, bool]):
42
+ if isinstance(v, bool):
43
+ return v
44
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
45
+ return True
46
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
47
+ return False
48
+ else:
49
+ return v
50
+
51
+
52
+ def parse_args() -> Namespace:
53
+ """Parse commandline arguments."""
54
+ parser = ArgumentParser(
55
+ description='Load a HF CausalLM Model and use it to generate text.')
56
+ parser.add_argument('-n', '--name_or_path', type=str, required=True)
57
+ parser.add_argument(
58
+ '-p',
59
+ '--prompts',
60
+ nargs='+',
61
+ default=[
62
+ 'My name is',
63
+ 'This is an explanation of deep learning to a five year old. Deep learning is',
64
+ ],
65
+ help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\
66
+ 'prompt contained in a txt file.'
67
+ )
68
+ parser.add_argument('--max_seq_len', type=int, default=None)
69
+ parser.add_argument('--max_new_tokens', type=int, default=100)
70
+ parser.add_argument('--max_batch_size', type=int, default=None)
71
+ #####
72
+ # Note: Generation config defaults are set to match Hugging Face defaults
73
+ parser.add_argument('--temperature', type=float, nargs='+', default=[1.0])
74
+ parser.add_argument('--top_k', type=int, nargs='+', default=[50])
75
+ parser.add_argument('--top_p', type=float, nargs='+', default=[1.0])
76
+ parser.add_argument('--repetition_penalty',
77
+ type=float,
78
+ nargs='+',
79
+ default=[1.0])
80
+ parser.add_argument('--no_repeat_ngram_size',
81
+ type=int,
82
+ nargs='+',
83
+ default=[0])
84
+ #####
85
+ parser.add_argument('--seed', type=int, nargs='+', default=[42])
86
+ parser.add_argument('--do_sample',
87
+ type=str2bool,
88
+ nargs='?',
89
+ const=True,
90
+ default=True)
91
+ parser.add_argument('--use_cache',
92
+ type=str2bool,
93
+ nargs='?',
94
+ const=True,
95
+ default=True)
96
+ parser.add_argument('--eos_token_id', type=int, default=None)
97
+ parser.add_argument('--pad_token_id', type=int, default=None)
98
+ parser.add_argument('--model_dtype',
99
+ type=str,
100
+ choices=['fp32', 'fp16', 'bf16'],
101
+ default=None)
102
+ parser.add_argument('--autocast_dtype',
103
+ type=str,
104
+ choices=['fp32', 'fp16', 'bf16'],
105
+ default=None)
106
+ parser.add_argument('--warmup',
107
+ type=str2bool,
108
+ nargs='?',
109
+ const=True,
110
+ default=True)
111
+ parser.add_argument('--trust_remote_code',
112
+ type=str2bool,
113
+ nargs='?',
114
+ const=True,
115
+ default=True)
116
+ parser.add_argument('--use_auth_token',
117
+ type=str_or_bool,
118
+ nargs='?',
119
+ const=True,
120
+ default=None)
121
+ parser.add_argument('--revision', type=str, default=None)
122
+ parser.add_argument('--device', type=str, default=None)
123
+ parser.add_argument('--device_map', type=str, default=None)
124
+ parser.add_argument('--attn_impl', type=str, default=None)
125
+ return parser.parse_args()
126
+
127
+
128
+ def load_prompt_string_from_file(prompt_path_str: str):
129
+ if not prompt_path_str.startswith('file::'):
130
+ raise ValueError('prompt_path_str must start with "file::".')
131
+ _, prompt_file_path = prompt_path_str.split('file::', maxsplit=1)
132
+ prompt_file_path = os.path.expanduser(prompt_file_path)
133
+ if not os.path.isfile(prompt_file_path):
134
+ raise FileNotFoundError(
135
+ f'{prompt_file_path=} does not match any existing files.')
136
+ with open(prompt_file_path, 'r') as f:
137
+ prompt_string = ''.join(f.readlines())
138
+ return prompt_string
139
+
140
+
141
+ def maybe_synchronize():
142
+ if torch.cuda.is_available():
143
+ torch.cuda.synchronize()
144
+
145
+
146
+ def main(args: Namespace) -> None:
147
+ # Set device or device_map
148
+ if args.device and args.device_map:
149
+ raise ValueError('You can only set one of `device` and `device_map`.')
150
+ if args.device is not None:
151
+ device = args.device
152
+ device_map = None
153
+ else:
154
+ device = None
155
+ device_map = args.device_map or 'auto'
156
+ print(f'Using {device=} and {device_map=}')
157
+
158
+ # Set model_dtype
159
+ if args.model_dtype is not None:
160
+ model_dtype = get_dtype(args.model_dtype)
161
+ else:
162
+ model_dtype = torch.float32
163
+ print(f'Using {model_dtype=}')
164
+
165
+ # Load prompts
166
+ prompt_strings = []
167
+ for prompt in args.prompts:
168
+ if prompt.startswith('file::'):
169
+ prompt = load_prompt_string_from_file(prompt)
170
+ prompt_strings.append(prompt)
171
+
172
+ # Grab config first
173
+ print(f'Loading HF Config...')
174
+ from_pretrained_kwargs = {
175
+ 'use_auth_token': args.use_auth_token,
176
+ 'trust_remote_code': args.trust_remote_code,
177
+ 'revision': args.revision,
178
+ }
179
+ try:
180
+ config = AutoConfig.from_pretrained(args.name_or_path,
181
+ **from_pretrained_kwargs)
182
+ if hasattr(config, 'init_device') and device is not None:
183
+ config.init_device = device
184
+ if args.attn_impl is not None and hasattr(config, 'attn_config'):
185
+ config.attn_config['attn_impl'] = args.attn_impl
186
+ if args.max_seq_len is not None and hasattr(config, 'max_seq_len'):
187
+ config.max_seq_len = args.max_seq_len
188
+
189
+ except Exception as e:
190
+ raise RuntimeError(
191
+ 'If you are having auth problems, try logging in via `huggingface-cli login` ' +\
192
+ 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' +\
193
+ 'using your access token from https://huggingface.co/settings/tokens.'
194
+ ) from e
195
+
196
+ # Build tokenizer
197
+ print('\nLoading HF tokenizer...')
198
+ tokenizer = AutoTokenizer.from_pretrained(args.name_or_path,
199
+ **from_pretrained_kwargs)
200
+ if tokenizer.pad_token_id is None:
201
+ warnings.warn(
202
+ 'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.'
203
+ )
204
+ tokenizer.pad_token = tokenizer.eos_token
205
+ tokenizer.padding_side = 'left'
206
+
207
+ # Load HF Model
208
+ print(f'Loading HF model with dtype={model_dtype}...')
209
+ try:
210
+ model = AutoModelForCausalLM.from_pretrained(args.name_or_path,
211
+ config=config,
212
+ torch_dtype=model_dtype,
213
+ device_map=device_map,
214
+ **from_pretrained_kwargs)
215
+ model.eval()
216
+ print(f'n_params={sum(p.numel() for p in model.parameters())}')
217
+ if device is not None:
218
+ print(f'Placing model on {device=}...')
219
+ model.to(device)
220
+ except Exception as e:
221
+ raise RuntimeError(
222
+ 'Unable to load HF model. ' +
223
+ 'If you are having auth problems, try logging in via `huggingface-cli login` '
224
+ +
225
+ 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... '
226
+ +
227
+ 'using your access token from https://huggingface.co/settings/tokens.'
228
+ ) from e
229
+
230
+ # Autocast
231
+ if args.autocast_dtype is not None:
232
+ autocast_dtype = get_dtype(args.autocast_dtype)
233
+ autocast_context = torch.autocast(model.device.type, autocast_dtype)
234
+ print(f'Using autocast with dtype={autocast_dtype}...')
235
+ else:
236
+ autocast_context = nullcontext()
237
+ print('NOT using autocast...')
238
+
239
+ done_warmup = False
240
+
241
+ for temp, topp, topk, repp, nrnz, seed in itertools.product(
242
+ args.temperature, args.top_p, args.top_k, args.repetition_penalty,
243
+ args.no_repeat_ngram_size, args.seed):
244
+
245
+ # Seed randomness
246
+ random.seed(seed)
247
+ torch.manual_seed(seed)
248
+ print(f'\nGenerate seed:\n{seed}')
249
+
250
+ generate_kwargs = {
251
+ 'max_new_tokens': args.max_new_tokens,
252
+ 'temperature': temp,
253
+ 'top_p': topp,
254
+ 'top_k': topk,
255
+ 'repetition_penalty': repp,
256
+ 'no_repeat_ngram_size': nrnz,
257
+ 'use_cache': args.use_cache,
258
+ 'do_sample': False if temp == 0 else args.do_sample,
259
+ 'eos_token_id': args.eos_token_id or tokenizer.eos_token_id,
260
+ 'pad_token_id': args.pad_token_id or tokenizer.pad_token_id,
261
+ }
262
+ print(f'\nGenerate kwargs:\n{generate_kwargs}')
263
+
264
+ # Generate function with correct context managers
265
+ def _generate(encoded_inp: Dict[str, torch.Tensor]):
266
+ with torch.no_grad():
267
+ with autocast_context:
268
+ return model.generate(
269
+ input_ids=encoded_inp['input_ids'],
270
+ attention_mask=encoded_inp['attention_mask'],
271
+ **generate_kwargs,
272
+ )
273
+
274
+ # Split into prompt batches
275
+ batches = []
276
+ if args.max_batch_size:
277
+ bs = args.max_batch_size
278
+ batches = [
279
+ prompt_strings[i:i + bs]
280
+ for i in range(0, len(prompt_strings), bs)
281
+ ]
282
+
283
+ else:
284
+ batches = [prompt_strings]
285
+
286
+ for batch in batches:
287
+ print(f'\nTokenizing prompts...')
288
+ maybe_synchronize()
289
+ encode_start = time.time()
290
+ encoded_inp = tokenizer(batch, return_tensors='pt', padding=True)
291
+ for key, value in encoded_inp.items():
292
+ encoded_inp[key] = value.to(model.device)
293
+ maybe_synchronize()
294
+ encode_end = time.time()
295
+ input_tokens = torch.sum(
296
+ encoded_inp['input_ids'] !=
297
+ tokenizer.pad_token_id, # type: ignore
298
+ axis=1).numpy(force=True)
299
+
300
+ # Warmup
301
+ if args.warmup and (not done_warmup):
302
+ print('Warming up...')
303
+ _ = _generate(encoded_inp)
304
+ done_warmup = True
305
+
306
+ # Run HF generate
307
+ print('Generating responses...')
308
+ maybe_synchronize()
309
+ gen_start = time.time()
310
+ encoded_gen = _generate(encoded_inp)
311
+ maybe_synchronize()
312
+ gen_end = time.time()
313
+
314
+ decode_start = time.time()
315
+ decoded_gen = tokenizer.batch_decode(encoded_gen,
316
+ skip_special_tokens=True)
317
+ maybe_synchronize()
318
+ decode_end = time.time()
319
+ gen_tokens = torch.sum(encoded_gen != tokenizer.pad_token_id,
320
+ axis=1).numpy(force=True) # type: ignore
321
+
322
+ # Print generations
323
+ delimiter = '#' * 100
324
+ # decode the encoded prompt to handle the case when the tokenizer
325
+ # trims extra spaces or does other pre-tokenization things
326
+ effective_prompts = tokenizer.batch_decode(encoded_inp['input_ids'],
327
+ skip_special_tokens=True)
328
+ for idx, (effective_prompt, prompt, gen) in enumerate(
329
+ zip(effective_prompts, batch, decoded_gen)):
330
+ continuation = gen[len(effective_prompt):]
331
+ print(delimiter)
332
+ if len(continuation) > 0:
333
+ print('\033[92m' + prompt + '\033[0m' + continuation)
334
+ else:
335
+ print('Warning. No non-special output tokens generated.')
336
+ print(
337
+ 'This can happen if the generation only contains padding/eos tokens.'
338
+ )
339
+ print('Debug:')
340
+ full_generation = tokenizer.batch_decode(
341
+ encoded_gen, skip_special_tokens=False)[idx]
342
+ print('\033[92m' + 'Prompt:\n' + prompt + '\033[0m')
343
+ print('Full generation:\n' + full_generation)
344
+
345
+ print(delimiter)
346
+
347
+ # Print timing info
348
+ bs = len(batch)
349
+ # ensure that gen_tokens >= 1 in case model only generated padding tokens
350
+ gen_tokens = np.maximum(gen_tokens, np.ones_like(gen_tokens))
351
+ output_tokens = gen_tokens - input_tokens
352
+ total_input_tokens = input_tokens.sum()
353
+ total_output_tokens = output_tokens.sum()
354
+
355
+ encode_latency = 1000 * (encode_end - encode_start)
356
+ gen_latency = 1000 * (gen_end - gen_start)
357
+ decode_latency = 1000 * (decode_end - decode_start)
358
+ total_latency = encode_latency + gen_latency + decode_latency
359
+
360
+ latency_per_output_token = total_latency / total_output_tokens
361
+ output_tok_per_sec = 1000 / latency_per_output_token
362
+ print(f'{bs=}, {input_tokens=}, {output_tokens=}')
363
+ print(f'{total_input_tokens=}, {total_output_tokens=}')
364
+ print(
365
+ f'{encode_latency=:.2f}ms, {gen_latency=:.2f}ms, {decode_latency=:.2f}ms, {total_latency=:.2f}ms'
366
+ )
367
+ print(f'{latency_per_output_token=:.2f}ms/tok')
368
+ print(f'{output_tok_per_sec=:.2f}tok/sec')
369
+
370
+
371
+ if __name__ == '__main__':
372
+ main(parse_args())
Perceptrix/finetune/build/lib/inference/run_mpt_with_ft.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
5
+ # Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """Run MPT model with FT.
20
+
21
+ This script is a modified version of
22
+ https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/multi_gpu_gpt_example.py
23
+ """
24
+
25
+ import argparse
26
+ import configparser
27
+ import os
28
+ import sys
29
+ import timeit
30
+
31
+ import torch
32
+ from torch.nn.utils.rnn import pad_sequence
33
+ from transformers import AutoTokenizer
34
+
35
+ dir_path = os.path.dirname(os.path.realpath(__file__))
36
+ sys.path.append(os.path.join(dir_path, '../../..'))
37
+ from examples.pytorch.gpt.utils import comm, gpt_decoder
38
+ from examples.pytorch.gpt.utils.parallel_gpt import ParallelGPT
39
+
40
+
41
+ @torch.no_grad()
42
+ def main():
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--layer_num',
45
+ type=int,
46
+ default=32,
47
+ help='number of layers')
48
+ parser.add_argument('--input_len',
49
+ type=int,
50
+ default=128,
51
+ help='input sequence length to generate.')
52
+ parser.add_argument('--output_len',
53
+ type=int,
54
+ default=64,
55
+ help='output sequence length to generate.')
56
+ parser.add_argument('--head_num', type=int, default=32, help='head number')
57
+ parser.add_argument('--size_per_head',
58
+ type=int,
59
+ default=128,
60
+ help='size per head')
61
+ parser.add_argument('--vocab_size',
62
+ type=int,
63
+ default=50432,
64
+ help='vocab size')
65
+ parser.add_argument(
66
+ '--beam_width',
67
+ type=int,
68
+ default=1,
69
+ help='beam width for beam search. Using sampling when beam width is 1.')
70
+ parser.add_argument('--top_k',
71
+ type=int,
72
+ default=1,
73
+ help='top k candidate num')
74
+ parser.add_argument('--top_p',
75
+ type=float,
76
+ default=0.95,
77
+ help='top p probability threshold')
78
+ parser.add_argument('--temperature',
79
+ type=float,
80
+ default=0.8,
81
+ help='temperature')
82
+ parser.add_argument('--len_penalty',
83
+ type=float,
84
+ default=0.,
85
+ help='len_penalty')
86
+ parser.add_argument('--beam_search_diversity_rate',
87
+ type=float,
88
+ default=0.,
89
+ help='beam_search_diversity_rate')
90
+ parser.add_argument('--tensor_para_size',
91
+ type=int,
92
+ default=1,
93
+ help='tensor parallel size')
94
+ parser.add_argument('--pipeline_para_size',
95
+ type=int,
96
+ default=1,
97
+ help='pipeline parallel size')
98
+ parser.add_argument('--ckpt_path',
99
+ type=str,
100
+ default='mpt-ft-7b/1-gpu',
101
+ help='path to the FT checkpoint file.')
102
+ parser.add_argument(
103
+ '--tokenizer_name_or_path',
104
+ type=str,
105
+ default='EleutherAI/gpt-neox-20b',
106
+ help=
107
+ 'Name of the tokenizer or the directory where the tokenizer file is located.'
108
+ )
109
+ parser.add_argument(
110
+ '--lib_path',
111
+ type=str,
112
+ help=
113
+ 'path to the libth_transformer dynamic lib file(.e.g., build/lib/libth_transformer.so.'
114
+ )
115
+ parser.add_argument('--start_id',
116
+ type=int,
117
+ default=0,
118
+ help='start token id.')
119
+ parser.add_argument('--end_id', type=int, default=0, help='end token id.')
120
+ parser.add_argument(
121
+ '--max_batch_size',
122
+ type=int,
123
+ default=8,
124
+ help=
125
+ 'Max batch size. If sample_input_file is given, it is truncated to this max_batch_size, otherwise, this value is used as batch size.'
126
+ )
127
+ parser.add_argument('--repetition_penalty',
128
+ type=float,
129
+ default=5.,
130
+ help='repetition penalty')
131
+ parser.add_argument(
132
+ '--presence_penalty',
133
+ type=float,
134
+ default=0.,
135
+ help=
136
+ 'presence penalty. Similar to repetition, but additive rather than multiplicative.'
137
+ )
138
+ parser.add_argument('--min_length',
139
+ type=int,
140
+ default=0,
141
+ help='A minimum number of tokens to generate')
142
+ parser.add_argument(
143
+ '--max_seq_len',
144
+ type=int,
145
+ default=2048,
146
+ help='max sequence length for position embedding table.')
147
+ parser.add_argument('--inference_data_type',
148
+ '--data_type',
149
+ type=str,
150
+ choices=['fp32', 'fp16', 'bf16'],
151
+ default='bf16')
152
+ parser.add_argument('--time',
153
+ action='store_true',
154
+ help='whether or not to measure time elapsed.')
155
+ parser.add_argument(
156
+ '--sample_input_file',
157
+ type=str,
158
+ default=None,
159
+ help=
160
+ 'path to sample input file. If not set, it runs with no context inputs.'
161
+ )
162
+ parser.add_argument('--sample_output_file',
163
+ type=str,
164
+ default=None,
165
+ help='path to sample output file.')
166
+ parser.add_argument(
167
+ '--disable_random_seed',
168
+ dest='random_seed',
169
+ action='store_false',
170
+ help='Disable the use of random seed for sentences in a batch.')
171
+ parser.add_argument('--skip_end_tokens',
172
+ dest='skip_end_tokens',
173
+ action='store_false',
174
+ help='Whether to remove or not end tokens in outputs.')
175
+ parser.add_argument('--no_detokenize',
176
+ dest='detokenize',
177
+ action='store_false',
178
+ help='Skip detokenizing output token ids.')
179
+ parser.add_argument(
180
+ '--int8_mode',
181
+ type=int,
182
+ default=0,
183
+ choices=[0, 1],
184
+ help='The level of quantization to perform.' +
185
+ ' 0: No quantization. All computation in data_type' +
186
+ ' 1: Quantize weights to int8, all compute occurs in fp16/bf16. Not supported when data_type is fp32'
187
+ )
188
+ parser.add_argument(
189
+ '--weights_data_type',
190
+ type=str,
191
+ default='fp32',
192
+ choices=['fp32', 'fp16'],
193
+ help='Data type of FT checkpoint weights',
194
+ )
195
+ parser.add_argument(
196
+ '--return_cum_log_probs',
197
+ type=int,
198
+ default=0,
199
+ choices=[0, 1, 2],
200
+ help='Whether to compute the cumulative log probsbility of sentences.' +
201
+ ' 0: do not return the cumulative log probs' +
202
+ ' 1: return the cumulative log probs of generated sequences' +
203
+ ' 2: return the cumulative log probs of sequences')
204
+ parser.add_argument('--shared_contexts_ratio',
205
+ type=float,
206
+ default=0.0,
207
+ help='Triggers the shared context optimization when ' +
208
+ 'compact_size <= shared_contexts_ratio * batch_size ' +
209
+ 'A value of 0.0 deactivate the optimization')
210
+ parser.add_argument(
211
+ '--use_gpt_decoder_ops',
212
+ action='store_true',
213
+ help='Use separate decoder FT operators instead of end-to-end model op.'
214
+ )
215
+ parser.add_argument(
216
+ '--no-alibi',
217
+ dest='alibi',
218
+ action='store_false',
219
+ help='Do not use ALiBi (aka use_attention_linear_bias).')
220
+ parser.add_argument(
221
+ '--layernorm_eps',
222
+ type=float,
223
+ default=1e-5,
224
+ help='layernorm eps in PyTorch, by default, is 1e-5 and 1e-6 in FT.')
225
+ args = parser.parse_args()
226
+
227
+ ckpt_config = configparser.ConfigParser()
228
+ ckpt_config_path = os.path.join(args.ckpt_path, 'config.ini')
229
+ if os.path.isfile(ckpt_config_path):
230
+ ckpt_config.read(ckpt_config_path)
231
+ if 'gpt' in ckpt_config.keys():
232
+ for args_key, config_key, func in [
233
+ ('layer_num', 'num_layer', ckpt_config.getint),
234
+ ('max_seq_len', 'max_pos_seq_len', ckpt_config.getint),
235
+ ('weights_data_type', 'weight_data_type', ckpt_config.get),
236
+ ('layernorm_eps', 'layernorm_eps', ckpt_config.getfloat),
237
+ ('alibi', 'use_attention_linear_bias', ckpt_config.getboolean),
238
+ ]:
239
+ if config_key in ckpt_config['gpt'].keys():
240
+ prev_val = args.__dict__[args_key]
241
+ args.__dict__[args_key] = func('gpt', config_key)
242
+ print(
243
+ 'Loading {} from config.ini, previous: {}, current: {}'
244
+ .format(args_key, prev_val, args.__dict__[args_key]))
245
+ else:
246
+ print('Not loading {} from config.ini'.format(args_key))
247
+ for key in ['head_num', 'size_per_head', 'tensor_para_size']:
248
+ if key in args.__dict__:
249
+ prev_val = args.__dict__[key]
250
+ args.__dict__[key] = ckpt_config.getint('gpt', key)
251
+ print(
252
+ 'Loading {} from config.ini, previous: {}, current: {}'
253
+ .format(key, prev_val, args.__dict__[key]))
254
+ else:
255
+ print('Not loading {} from config.ini'.format(key))
256
+
257
+ layer_num = args.layer_num
258
+ output_len = args.output_len
259
+ head_num = args.head_num
260
+ size_per_head = args.size_per_head
261
+ vocab_size = args.vocab_size
262
+ beam_width = args.beam_width
263
+ top_k = args.top_k
264
+ top_p = args.top_p
265
+ temperature = args.temperature
266
+ len_penalty = args.len_penalty
267
+ beam_search_diversity_rate = args.beam_search_diversity_rate
268
+ tensor_para_size = args.tensor_para_size
269
+ pipeline_para_size = args.pipeline_para_size
270
+ start_id = args.start_id
271
+ end_id = args.end_id
272
+ max_batch_size = args.max_batch_size
273
+ max_seq_len = args.max_seq_len
274
+ repetition_penalty = args.repetition_penalty
275
+ presence_penalty = args.presence_penalty
276
+ min_length = args.min_length
277
+ weights_data_type = args.weights_data_type
278
+ return_cum_log_probs = args.return_cum_log_probs
279
+ return_output_length = return_cum_log_probs > 0
280
+ shared_contexts_ratio = args.shared_contexts_ratio
281
+ layernorm_eps = args.layernorm_eps
282
+ use_attention_linear_bias = args.alibi
283
+ has_positional_encoding = not args.alibi
284
+
285
+ print('\n=================== Arguments ===================')
286
+ for k, v in vars(args).items():
287
+ print(f'{k.ljust(30, ".")}: {v}')
288
+ print('=================================================\n')
289
+
290
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
291
+ torch.manual_seed(0)
292
+
293
+ comm.initialize_model_parallel(args.tensor_para_size,
294
+ args.pipeline_para_size)
295
+ rank = comm.get_rank()
296
+ device = comm.get_device()
297
+
298
+ # Inputs
299
+ contexts = []
300
+ if args.sample_input_file:
301
+ with open(args.sample_input_file, 'r') as f:
302
+ contexts = f.read().splitlines()
303
+ batch_size = min(len(contexts), max_batch_size)
304
+ contexts = contexts[:batch_size]
305
+ start_ids = [
306
+ torch.tensor(tokenizer.encode(c), dtype=torch.int32, device=device)
307
+ for c in contexts
308
+ ]
309
+ else:
310
+ batch_size = max_batch_size
311
+ contexts = ['<|endoftext|>'] * batch_size
312
+ start_ids = [torch.IntTensor([end_id for _ in range(args.input_len)])
313
+ ] * batch_size
314
+
315
+ start_lengths = [len(ids) for ids in start_ids]
316
+
317
+ start_ids = pad_sequence(start_ids, batch_first=True, padding_value=end_id)
318
+ start_lengths = torch.IntTensor(start_lengths)
319
+
320
+ # Prepare model.
321
+ if not args.use_gpt_decoder_ops:
322
+ gpt = ParallelGPT(head_num,
323
+ size_per_head,
324
+ vocab_size,
325
+ start_id,
326
+ end_id,
327
+ layer_num,
328
+ max_seq_len,
329
+ tensor_para_size,
330
+ pipeline_para_size,
331
+ lib_path=args.lib_path,
332
+ inference_data_type=args.inference_data_type,
333
+ int8_mode=args.int8_mode,
334
+ weights_data_type=weights_data_type,
335
+ layernorm_eps=layernorm_eps,
336
+ use_attention_linear_bias=use_attention_linear_bias,
337
+ has_positional_encoding=has_positional_encoding,
338
+ shared_contexts_ratio=shared_contexts_ratio)
339
+ if not gpt.load(ckpt_path=args.ckpt_path):
340
+ print(
341
+ '[WARNING] Checkpoint file not found. Model loading is skipped.'
342
+ )
343
+ else:
344
+ gpt = gpt_decoder.Gpt(num_heads=head_num,
345
+ size_per_head=size_per_head,
346
+ num_layers=layer_num,
347
+ vocab_size=vocab_size,
348
+ start_id=start_id,
349
+ end_id=end_id,
350
+ tensor_para_size=tensor_para_size,
351
+ pipeline_para_size=pipeline_para_size,
352
+ lib_path=args.lib_path,
353
+ max_seq_len=max_seq_len,
354
+ int8_mode=args.int8_mode,
355
+ weights_data_type=args.weights_data_type)
356
+ gpt.load(args.ckpt_path, args.inference_data_type)
357
+
358
+ if args.random_seed:
359
+ random_seed_tensor = torch.randint(0,
360
+ 10000,
361
+ size=[batch_size],
362
+ dtype=torch.int64)
363
+ else:
364
+ random_seed_tensor = torch.zeros([batch_size], dtype=torch.int64)
365
+
366
+ repetition_penalty_vec = None if repetition_penalty == 1. else repetition_penalty * torch.ones(
367
+ batch_size, dtype=torch.float32)
368
+ presence_penalty_vec = None if presence_penalty == 0. else presence_penalty * torch.ones(
369
+ batch_size, dtype=torch.float32)
370
+
371
+ infer_decode_args = {
372
+ 'beam_width':
373
+ beam_width,
374
+ 'top_k':
375
+ top_k * torch.ones(batch_size, dtype=torch.int32),
376
+ 'top_p':
377
+ top_p * torch.ones(batch_size, dtype=torch.float32),
378
+ 'temperature':
379
+ temperature * torch.ones(batch_size, dtype=torch.float32),
380
+ 'repetition_penalty':
381
+ repetition_penalty_vec,
382
+ 'presence_penalty':
383
+ presence_penalty_vec,
384
+ 'beam_search_diversity_rate':
385
+ beam_search_diversity_rate *
386
+ torch.ones(batch_size, dtype=torch.float32),
387
+ 'len_penalty':
388
+ len_penalty * torch.ones(size=[batch_size], dtype=torch.float32),
389
+ 'bad_words_list':
390
+ None,
391
+ 'min_length':
392
+ min_length * torch.ones(size=[batch_size], dtype=torch.int32),
393
+ 'random_seed':
394
+ random_seed_tensor
395
+ }
396
+
397
+ if not args.use_gpt_decoder_ops:
398
+
399
+ def gpt_generate_fn():
400
+ tokens_batch = gpt(start_ids,
401
+ start_lengths,
402
+ output_len,
403
+ return_output_length=return_output_length,
404
+ return_cum_log_probs=return_cum_log_probs,
405
+ **infer_decode_args)
406
+ return tokens_batch
407
+ else:
408
+
409
+ def gpt_generate_fn():
410
+ output_dict = gpt.generate(
411
+ input_token_ids=start_ids,
412
+ input_lengths=start_lengths,
413
+ gen_length=output_len,
414
+ eos_token_id=end_id,
415
+ return_output_length=return_output_length,
416
+ return_log_probs=return_cum_log_probs,
417
+ **infer_decode_args)
418
+ return output_dict
419
+
420
+ # Generate tokens.
421
+ gen_outputs = gpt_generate_fn()
422
+
423
+ if rank == 0:
424
+ if not args.use_gpt_decoder_ops:
425
+ if return_cum_log_probs > 0:
426
+ tokens_batch, _, cum_log_probs = gen_outputs
427
+ else:
428
+ tokens_batch, cum_log_probs = gen_outputs, None
429
+ else:
430
+ tokens_batch = gen_outputs['output_token_ids']
431
+ cum_log_probs = gen_outputs[
432
+ 'cum_log_probs'] if return_cum_log_probs > 0 else None
433
+ if cum_log_probs is not None:
434
+ print('[INFO] Log probs of sentences:', cum_log_probs)
435
+
436
+ outputs = []
437
+ tokens_batch = tokens_batch.cpu().numpy()
438
+ for i, (context, tokens) in enumerate(zip(contexts, tokens_batch)):
439
+ for beam_id in range(beam_width):
440
+ token = tokens[beam_id][
441
+ start_lengths[i]:] # exclude context input from the output
442
+ if args.skip_end_tokens:
443
+ token = token[token != end_id]
444
+ output = tokenizer.decode(
445
+ token) if args.detokenize else ' '.join(
446
+ str(t) for t in token.tolist())
447
+ outputs.append(output)
448
+ print(
449
+ f'[INFO] batch {i}, beam {beam_id}:\n[Context]\n{context}\n\n[Output]\n{output}\n'
450
+ )
451
+
452
+ if args.sample_output_file:
453
+ with open(args.sample_output_file, 'w+') as f:
454
+ outputs = [o.replace('\n', '\\n') for o in outputs]
455
+ f.writelines('\n'.join(outputs))
456
+
457
+ # Measure inference time.
458
+ if args.time:
459
+ warmup_iterations = 10
460
+ for _ in range(warmup_iterations):
461
+ gpt_generate_fn()
462
+ torch.cuda.synchronize()
463
+ measurement_iterations = 10
464
+ time = timeit.default_timer()
465
+ for _ in range(measurement_iterations):
466
+ gpt_generate_fn()
467
+ torch.cuda.synchronize()
468
+ time_elapsed = timeit.default_timer() - time
469
+ if rank == 0:
470
+ print(f'[INFO] MPT time costs:')
471
+ print(
472
+ 'model_name, gpu_type, gpu_count, batch_size, input_tokens, output_tokens, latency_ms'
473
+ )
474
+ print(
475
+ f'{ckpt_config.get("gpt", "model_name")}, {torch.cuda.get_device_name().replace(" ", "-")}, {torch.cuda.device_count()}, {batch_size}, {args.input_len}, {args.output_len}, {time_elapsed * 1000 / measurement_iterations:.2f}'
476
+ )
477
+
478
+
479
+ if __name__ == '__main__':
480
+ main()
Perceptrix/finetune/build/lib/llmfoundry/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ try:
7
+ from llmfoundry import optim, utils
8
+ from llmfoundry.data import (ConcatTokensDataset,
9
+ MixtureOfDenoisersCollator, NoConcatDataset,
10
+ Seq2SeqFinetuningCollator,
11
+ build_finetuning_dataloader,
12
+ build_text_denoising_dataloader)
13
+ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
14
+ ComposerHFT5)
15
+ from llmfoundry.models.layers.attention import (
16
+ MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
17
+ flash_attn_fn, scaled_multihead_dot_product_attention,
18
+ triton_flash_attn_fn)
19
+ from llmfoundry.models.layers.blocks import MPTBlock
20
+ from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
21
+ build_ffn)
22
+ from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
23
+ from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
24
+ MPTForCausalLM, MPTModel,
25
+ MPTPreTrainedModel)
26
+ from llmfoundry.tokenizers import TiktokenTokenizerWrapper
27
+
28
+ except ImportError as e:
29
+ try:
30
+ is_cuda_available = torch.cuda.is_available()
31
+ except:
32
+ is_cuda_available = False
33
+
34
+ extras = '.[gpu]' if is_cuda_available else '.'
35
+ raise ImportError(
36
+ f'Please make sure to pip install {extras} to get the requirements for the LLM example.'
37
+ ) from e
38
+
39
+ __all__ = [
40
+ 'build_text_denoising_dataloader',
41
+ 'build_finetuning_dataloader',
42
+ 'MixtureOfDenoisersCollator',
43
+ 'Seq2SeqFinetuningCollator',
44
+ 'MPTBlock',
45
+ 'FFN_CLASS_REGISTRY',
46
+ 'MPTMLP',
47
+ 'build_ffn',
48
+ 'MPTConfig',
49
+ 'MPTPreTrainedModel',
50
+ 'MPTModel',
51
+ 'MPTForCausalLM',
52
+ 'ComposerMPTCausalLM',
53
+ 'ComposerHFCausalLM',
54
+ 'ComposerHFPrefixLM',
55
+ 'ComposerHFT5',
56
+ 'COMPOSER_MODEL_REGISTRY',
57
+ 'scaled_multihead_dot_product_attention',
58
+ 'flash_attn_fn',
59
+ 'triton_flash_attn_fn',
60
+ 'MultiheadAttention',
61
+ 'NoConcatDataset',
62
+ 'ConcatTokensDataset',
63
+ 'attn_bias_shape',
64
+ 'build_attn_bias',
65
+ 'build_alibi_bias',
66
+ 'optim',
67
+ 'utils',
68
+ 'TiktokenTokenizerWrapper',
69
+ ]
70
+
71
+ __version__ = '0.3.0'
Perceptrix/finetune/build/lib/llmfoundry/callbacks/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ try:
5
+ from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
6
+ from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
7
+ from llmfoundry.callbacks.generate_callback import Generate
8
+ from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
9
+ from llmfoundry.callbacks.model_gauntlet_callback import ModelGauntlet
10
+ from llmfoundry.callbacks.monolithic_ckpt_callback import \
11
+ MonolithicCheckpointSaver
12
+ from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
13
+ LayerFreezing)
14
+ from llmfoundry.callbacks.scheduled_gc_callback import \
15
+ ScheduledGarbageCollector
16
+ except ImportError as e:
17
+ raise ImportError(
18
+ 'Please make sure to pip install . to get requirements for llm-foundry.'
19
+ ) from e
20
+
21
+ __all__ = [
22
+ 'FDiffMetrics',
23
+ 'Generate',
24
+ 'MonolithicCheckpointSaver',
25
+ 'GlobalLRScaling',
26
+ 'LayerFreezing',
27
+ 'ScheduledGarbageCollector',
28
+ 'EvalGauntlet',
29
+ 'ModelGauntlet',
30
+ 'HuggingFaceCheckpointer',
31
+ ]
Perceptrix/finetune/build/lib/llmfoundry/callbacks/eval_gauntlet_callback.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Aggregate ICL evals into composite scores."""
5
+
6
+ import logging
7
+ import math
8
+ from enum import Enum
9
+ from typing import Dict, Optional
10
+
11
+ from composer.core import Callback, State
12
+ from composer.loggers import Logger
13
+
14
+ __all__ = ['EvalGauntlet']
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class Weighting(Enum):
20
+ EQUAL = 1
21
+ SAMPLE_SZ = 2
22
+ LOG_SAMPLE_SZ = 3
23
+
24
+
25
+ class EvalGauntlet(Callback):
26
+ """The EvalGauntlet aggregates ICL eval results.
27
+
28
+ After `eval_end`, this callback inspects the logger for different ICL metrics and aggregates the scores according to the aggregation
29
+ specification provided in the constructor.
30
+
31
+ Args:
32
+ logger_keys (list): These are the exact keys that the individual benchmark metrics will be
33
+ logged under in the logger after eval
34
+ tasks (dict): This contains the list of categories, as well as the subtasks within them, the
35
+ random baseline accuracy of each subtask, and the number of fewshot examples
36
+ used for the task. See `llmfoundry/scripts/eval/yamls/eval_gauntlet.yaml` to see the structure.
37
+ weighting (Weighting): The weighting scheme used to balance different tasks within each category.
38
+ Either assign them all equal weight, assign them weight proportional
39
+ to the dataset size, or assign them weight proportional to the log2 of the dataset size.
40
+ Options are 'EQUAL', 'SAMPLE_SZ', and 'LOG_SAMPLE_SZ'.
41
+ subtract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy
42
+ from the performance on each individual benchmark before aggregating.
43
+ rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
44
+ by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
45
+ benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting.
46
+ """
47
+
48
+ def __init__(self,
49
+ logger_keys: list,
50
+ categories: dict,
51
+ weighting: str = 'EQUAL',
52
+ subtract_random_baseline: bool = True,
53
+ rescale_accuracy: bool = True,
54
+ benchmark_sizes: Optional[dict] = None):
55
+ if isinstance(logger_keys, dict):
56
+ raise ValueError(
57
+ 'logger_keys now requires a list type as input, not a dict')
58
+ if weighting != Weighting.EQUAL and benchmark_sizes is None:
59
+ raise Exception(
60
+ 'When not using equal weighting, you must provide the benchmark sizes.'
61
+ )
62
+
63
+ if rescale_accuracy and not subtract_random_baseline:
64
+ raise Exception(
65
+ 'Only use accuracy rescaling in conjunction with subtracting random baseline accuracy.'
66
+ )
67
+
68
+ self.categories = categories
69
+ self.weighting = Weighting[weighting]
70
+ self.subtract_random_baseline = subtract_random_baseline
71
+ self.rescale_accuracy = rescale_accuracy
72
+ self.logger_keys = logger_keys
73
+
74
+ for category in self.categories:
75
+
76
+ for benchmark in category['benchmarks']:
77
+ bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
78
+
79
+ if self.weighting != Weighting.EQUAL:
80
+ assert benchmark_sizes is not None
81
+ cumulative_samples = max(
82
+ sum(count for name, count in benchmark_sizes.items()
83
+ if name.startswith(bench_name)), 1)
84
+ else:
85
+ cumulative_samples = -1 # pyright
86
+
87
+ weight = None
88
+ if self.weighting == Weighting.EQUAL:
89
+ weight = 1
90
+ elif self.weighting == Weighting.SAMPLE_SZ:
91
+ weight = cumulative_samples
92
+ elif self.weighting == Weighting.LOG_SAMPLE_SZ:
93
+ weight = max(math.log(cumulative_samples, 2), 1)
94
+
95
+ assert weight is not None
96
+ benchmark['weighting'] = weight
97
+
98
+ def compute_averages(self, state: State) -> Dict[str, float]:
99
+ results = {}
100
+
101
+ for key in self.logger_keys:
102
+
103
+ # starting at index 1 skips the "metric" part of the key which is superfluous
104
+ dl_name, metric_name = key.split('/')[1:-1], key.split('/')[-1]
105
+ if 'Accuracy' not in metric_name:
106
+ continue
107
+
108
+ metric = state.eval_metrics.get('/'.join(dl_name),
109
+ {}).get(metric_name, None)
110
+ if metric is None:
111
+ continue
112
+ val = metric.compute().item()
113
+
114
+ # ending at index 2 allows us to aggregate over dataloaders w/ subcategories
115
+ key = '/'.join(dl_name[0:2])
116
+ if key not in results:
117
+ results[key] = []
118
+
119
+ results[key].append(val)
120
+
121
+ return {k: sum(v) / len(v) for k, v in results.items()}
122
+
123
+ def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
124
+ new_metrics = self.compute_averages(state)
125
+ if len(new_metrics) == 0:
126
+ return {}
127
+ composite_scores = {}
128
+
129
+ for category in self.categories:
130
+ missing_metrics = []
131
+ composite_scores[category['name']] = []
132
+ for benchmark in category['benchmarks']:
133
+ key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
134
+
135
+ if key not in new_metrics:
136
+ log.warning(
137
+ f'Could not find results for benchmark: {benchmark}.')
138
+ missing_metrics.append(key)
139
+ else:
140
+ score = new_metrics[key]
141
+
142
+ if self.subtract_random_baseline:
143
+ score -= benchmark['random_baseline']
144
+
145
+ if self.rescale_accuracy and self.subtract_random_baseline:
146
+ score /= 1.0 - benchmark['random_baseline']
147
+
148
+ composite_scores[category['name']].append({
149
+ 'name': benchmark['name'],
150
+ 'score': score,
151
+ 'weighting': benchmark['weighting']
152
+ })
153
+
154
+ if len(missing_metrics) > 0:
155
+ log.warning(
156
+ f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}"
157
+ )
158
+ del composite_scores[category['name']]
159
+ continue
160
+ total_weight = sum(
161
+ k['weighting'] for k in composite_scores[category['name']])
162
+ composite_scores[category['name']] = sum(
163
+ k['score'] * (k['weighting'] / total_weight)
164
+ for k in composite_scores[category['name']])
165
+
166
+ composite_scores = {
167
+ f'icl/metrics/eval_gauntlet/{k}': v
168
+ for k, v in composite_scores.items()
169
+ }
170
+
171
+ composite_scores['icl/metrics/eval_gauntlet/average'] = sum(
172
+ composite_scores.values()) / len(composite_scores.values()) if len(
173
+ composite_scores.values()) > 0 else 0
174
+ if logger is not None:
175
+ logger.log_metrics(composite_scores)
176
+
177
+ return composite_scores
Perceptrix/finetune/build/lib/llmfoundry/callbacks/fdiff_callback.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Monitor rate of change of loss."""
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ from composer.core import Callback, State
9
+ from composer.loggers import Logger
10
+
11
+
12
+ class FDiffMetrics(Callback):
13
+ """Rate of change of metrics.
14
+
15
+ tracks and plots the rate of change of metrics effectively taking the
16
+ numerical derivative of the metrics
17
+ """
18
+
19
+ def __init__(self,
20
+ diff_train_metrics: bool = False,
21
+ diff_eval_metrics: bool = True):
22
+ self.diff_train_metrics = diff_train_metrics
23
+ self.diff_eval_metrics = diff_eval_metrics
24
+
25
+ self.train_prev_loss = None
26
+ self.train_prev_metric = {}
27
+ self.eval_prev_metric = {}
28
+
29
+ def batch_end(self, state: State, logger: Logger) -> None:
30
+ if self.diff_train_metrics:
31
+ if not isinstance(state.loss, torch.Tensor):
32
+ raise NotImplementedError('Multiple losses not supported yet')
33
+ loss = state.loss.item()
34
+ if self.train_prev_loss:
35
+ logger.log_metrics(
36
+ {'loss/train/total_fdiff': loss - self.train_prev_loss})
37
+ self.train_prev_loss = loss
38
+
39
+ for k in self.train_prev_metric.keys():
40
+ logger.log_metrics({
41
+ f'metrics/train/{k}_fdiff':
42
+ state.train_metric_values[k] - self.train_prev_metric[k]
43
+ })
44
+
45
+ for k in state.train_metric_values.keys():
46
+ value = state.train_metric_values[k]
47
+ self.train_prev_metric[k] = value
48
+
49
+ def eval_end(self, state: State, logger: Logger) -> None:
50
+ if self.diff_eval_metrics:
51
+ evaluator = state.dataloader_label
52
+ assert evaluator is not None, 'dataloader should have been set'
53
+
54
+ metrics = list(state.eval_metrics[evaluator].keys())
55
+
56
+ for k in metrics:
57
+ mkey = '/'.join(['metrics', evaluator, k])
58
+ if mkey in self.eval_prev_metric.keys():
59
+ logger.log_metrics({
60
+ f'{mkey}_fdiff':
61
+ state.eval_metric_values[k] -
62
+ self.eval_prev_metric[mkey]
63
+ })
64
+
65
+ for k in metrics:
66
+ mkey = '/'.join(['metrics', evaluator, k])
67
+ self.eval_prev_metric[mkey] = state.eval_metric_values[k]
Perceptrix/finetune/build/lib/llmfoundry/callbacks/generate_callback.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Deprecated Generate callback.
5
+
6
+ Please use composer.callbacks.Generate instead.
7
+ """
8
+ import warnings
9
+ from typing import Any, List, Union
10
+
11
+ from composer.callbacks import Generate as ComposerGenerate
12
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
13
+
14
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
15
+
16
+
17
+ class Generate(ComposerGenerate):
18
+
19
+ def __init__(self, prompts: List[str], batch_log_interval: int,
20
+ **kwargs: Any):
21
+
22
+ warnings.warn(
23
+ ('Accessing llmfoundry.callbacks.generate_callback.Generate '
24
+ 'is deprecated and will be removed in a future release. '
25
+ 'Please use composer.callbacks.Generate instead.'),
26
+ DeprecationWarning,
27
+ )
28
+
29
+ interval = f'{batch_log_interval}ba'
30
+ super().__init__(prompts=prompts, interval=interval, **kwargs)
Perceptrix/finetune/build/lib/llmfoundry/callbacks/hf_checkpointer.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import contextlib
5
+ import json
6
+ import logging
7
+ import os
8
+ import tempfile
9
+ from pathlib import Path
10
+ from typing import Optional, Union
11
+
12
+ import torch
13
+ from composer.callbacks.utils import create_interval_scheduler
14
+ from composer.core import Callback, Event, State, Time
15
+ from composer.core.state import fsdp_state_dict_type_context
16
+ from composer.loggers import Logger
17
+ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
18
+ from composer.models import HuggingFaceModel
19
+ from composer.utils import dist, format_name_with_dist_and_time, parse_uri
20
+ from transformers import PreTrainedTokenizerBase
21
+
22
+ from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
23
+ from llmfoundry.utils.huggingface_hub_utils import \
24
+ edit_files_for_hf_compatibility
25
+
26
+ log = logging.getLogger(__name__)
27
+
28
+
29
+ class HuggingFaceCheckpointer(Callback):
30
+ """Save a huggingface formatted checkpoint during training.
31
+
32
+ Args:
33
+ save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
34
+ this would be the same as your save_folder.
35
+ save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
36
+ saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
37
+ Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
38
+ :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
39
+ huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
40
+ precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
41
+ overwrite (bool): Whether to overwrite previous checkpoints.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ save_folder: str,
47
+ save_interval: Union[str, int, Time],
48
+ huggingface_folder_name: str = 'ba{batch}',
49
+ precision: str = 'float32',
50
+ overwrite: bool = False,
51
+ ):
52
+ self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
53
+ save_folder)
54
+ self.overwrite = overwrite
55
+ self.precision = precision
56
+ self.dtype = {
57
+ 'float32': torch.float32,
58
+ 'float16': torch.float16,
59
+ 'bfloat16': torch.bfloat16,
60
+ }[precision]
61
+ self.huggingface_folder_name_fstr = os.path.join(
62
+ 'huggingface', huggingface_folder_name)
63
+ self.check_interval = create_interval_scheduler(
64
+ save_interval, include_end_of_training=True)
65
+ self.upload_to_object_store = (self.backend != '')
66
+ if self.upload_to_object_store:
67
+ self.remote_ud = RemoteUploaderDownloader(
68
+ bucket_uri=f'{self.backend}://{self.bucket_name}',
69
+ num_concurrent_uploads=4)
70
+ else:
71
+ self.remote_ud = None
72
+
73
+ self.last_checkpoint_batch: Optional[Time] = None
74
+
75
+ def run_event(self, event: Event, state: State, logger: Logger) -> None:
76
+ # The interval scheduler handles only returning True for the appropriate events
77
+ if state.get_elapsed_duration() is not None and self.check_interval(
78
+ state,
79
+ event) and self.last_checkpoint_batch != state.timestamp.batch:
80
+ self._save_checkpoint(state, logger)
81
+ elif event == Event.INIT:
82
+ if not isinstance(state.model, HuggingFaceModel):
83
+ raise ValueError(
84
+ f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
85
+ + f'Got {type(state.model)} instead.')
86
+ if self.upload_to_object_store and self.remote_ud is not None:
87
+ self.remote_ud.init(state, logger)
88
+ state.callbacks.append(self.remote_ud)
89
+
90
+ def _save_checkpoint(self, state: State, logger: Logger):
91
+ del logger # unused
92
+
93
+ self.last_checkpoint_batch = state.timestamp.batch
94
+
95
+ log.info('Saving HuggingFace formatted checkpoint')
96
+
97
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
98
+ CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
99
+ MPTConfig.register_for_auto_class()
100
+ MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
101
+
102
+ assert isinstance(state.model, HuggingFaceModel)
103
+
104
+ save_dir = format_name_with_dist_and_time(
105
+ str(
106
+ Path(self.save_dir_format_str) /
107
+ self.huggingface_folder_name_fstr), state.run_name,
108
+ state.timestamp)
109
+ dir_context_mgr = tempfile.TemporaryDirectory(
110
+ ) if self.upload_to_object_store else contextlib.nullcontext(
111
+ enter_result=save_dir)
112
+
113
+ with dir_context_mgr as temp_save_dir:
114
+ assert isinstance(temp_save_dir,
115
+ str) # pyright doesn't know about enter_result
116
+
117
+ with fsdp_state_dict_type_context(state.model.model,
118
+ state_dict_type='full'):
119
+ state_dict = state.model.model.state_dict()
120
+
121
+ # convert the state dict to the requested precision
122
+ for k, v in state_dict.items():
123
+ if isinstance(v, torch.Tensor):
124
+ state_dict[k] = v.to(dtype=self.dtype)
125
+
126
+ if dist.get_global_rank() == 0:
127
+ # We raise above if the model is not a HuggingFaceModel, so this assert is safe
128
+ assert hasattr(state.model.model, 'save_pretrained')
129
+ state.model.model.save_pretrained(temp_save_dir,
130
+ state_dict=state_dict)
131
+
132
+ if state.model.tokenizer is not None:
133
+ assert isinstance(state.model.tokenizer,
134
+ PreTrainedTokenizerBase)
135
+ state.model.tokenizer.save_pretrained(temp_save_dir)
136
+
137
+ # Only need to edit files for MPT because it has custom code
138
+ if state.model.model.config.model_type == 'mpt':
139
+ edit_files_for_hf_compatibility(temp_save_dir)
140
+
141
+ with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
142
+ edited_config = json.load(f)
143
+
144
+ if state.model.model.config.model_type == 'mpt':
145
+ edited_config['attn_config']['attn_impl'] = 'torch'
146
+ edited_config['init_device'] = 'cpu'
147
+
148
+ edited_config['torch_dtype'] = self.precision
149
+ with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
150
+ json.dump(edited_config, f, indent=4)
151
+
152
+ if self.upload_to_object_store:
153
+ assert self.remote_ud is not None
154
+ # TODO change to log after other pr
155
+ log.info(
156
+ f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
157
+ )
158
+ for filename in os.listdir(temp_save_dir):
159
+ self.remote_ud.upload_file(
160
+ state=state,
161
+ remote_file_name=os.path.join(save_dir, filename),
162
+ file_path=Path(os.path.join(temp_save_dir,
163
+ filename)),
164
+ overwrite=self.overwrite,
165
+ )
166
+
167
+ dist.barrier()
Perceptrix/finetune/build/lib/llmfoundry/callbacks/model_gauntlet_callback.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from composer.core import Callback
5
+
6
+ __all__ = ['ModelGauntlet']
7
+
8
+
9
+ class ModelGauntlet(Callback):
10
+ """The ModelGauntlet callback has been renamed to EvalGauntlet.
11
+
12
+ We've created this dummy class, in order to alert anyone who may have been
13
+ importing ModelGauntlet.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ *args, # pyright: ignore [reportMissingParameterType]
19
+ **kwargs): # pyright: ignore [reportMissingParameterType]
20
+ raise ImportError(
21
+ 'ModelGauntlet class is deprecated, please use EvalGauntlet')
Perceptrix/finetune/build/lib/llmfoundry/callbacks/monolithic_ckpt_callback.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import contextlib
5
+ import os
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from composer.core import Callback, State
11
+ from composer.core.state import (fsdp_get_optim_state_dict,
12
+ fsdp_state_dict_type_context)
13
+ from composer.loggers import Logger
14
+ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
15
+ from composer.utils import (dist, format_name_with_dist_and_time, parse_uri,
16
+ reproducibility)
17
+
18
+
19
+ class MonolithicCheckpointSaver(Callback):
20
+ """Save a monolithic checkpoint every N batches.
21
+
22
+ Args:
23
+ save_folder (str): Folder to save checkpoints to (can be a URI)
24
+ filename (str): Filename to save checkpoints to.
25
+ batch_interval (int): Number of batches between checkpoints.
26
+ overwrite (bool): Whether to overwrite previous checkpoints.
27
+ keep_optimizer(bool): Whether to save the optimizer state in the monolithic checkpoint.
28
+ """
29
+
30
+ def __init__(self,
31
+ save_folder: str,
32
+ batch_interval: int,
33
+ filename: str = 'ep{epoch}-ba{batch}.pt',
34
+ overwrite: bool = False,
35
+ keep_optimizers: bool = False):
36
+ self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
37
+ save_folder)
38
+ self.filename_format_str = filename
39
+ self.batch_interval = batch_interval
40
+ self.upload_to_object_store = (self.backend != '')
41
+ self.overwrite = overwrite
42
+ self.keep_optimizers = keep_optimizers
43
+ if self.upload_to_object_store:
44
+ self.remote_ud = RemoteUploaderDownloader(
45
+ bucket_uri=f'{self.backend}://{self.bucket_name}')
46
+ else:
47
+ self.remote_ud = None
48
+
49
+ def init(self, state: State, logger: Logger) -> None:
50
+ if self.upload_to_object_store and self.remote_ud is not None:
51
+ self.remote_ud.init(state, logger)
52
+ # updated_logger_destinations = [*logger.destinations, new_remote_ud]
53
+ # logger.destinations = tuple(updated_logger_destinations)
54
+ state.callbacks.append(self.remote_ud)
55
+
56
+ def batch_checkpoint(self, state: State, logger: Logger) -> None:
57
+ if state.timestamp.batch.value % self.batch_interval == 0:
58
+ self._save_checkpoint(state, logger)
59
+
60
+ def fit_end(self, state: State, logger: Logger) -> None:
61
+ if state.timestamp.batch.value % self.batch_interval != 0:
62
+ self._save_checkpoint(state, logger)
63
+
64
+ def _save_checkpoint(self, state: State, logger: Logger) -> None:
65
+ del logger # unused
66
+
67
+ filename = format_name_with_dist_and_time(self.filename_format_str,
68
+ state.run_name,
69
+ state.timestamp)
70
+ save_dir = format_name_with_dist_and_time(self.save_dir_format_str,
71
+ state.run_name,
72
+ state.timestamp)
73
+ dir_context_mgr = tempfile.TemporaryDirectory(
74
+ ) if self.upload_to_object_store else contextlib.nullcontext(
75
+ enter_result=save_dir)
76
+ with dir_context_mgr as temp_save_dir:
77
+ # pyright doesn't know about enter_result
78
+ assert isinstance(temp_save_dir, str)
79
+
80
+ save_path = str(Path(temp_save_dir) / Path(filename))
81
+ dirname = os.path.dirname(save_path)
82
+ if dirname:
83
+ os.makedirs(dirname, exist_ok=True)
84
+ state_dict = {
85
+ 'state': state.state_dict(),
86
+ 'rng': reproducibility.get_rng_state()
87
+ }
88
+ # Remove sharded model and optimizer state dicts
89
+ state_dict['state'].pop('optimizers')
90
+ state_dict['state'].pop('model')
91
+
92
+ # Add in unsharded model params.
93
+ with fsdp_state_dict_type_context(state.model,
94
+ state_dict_type='full'):
95
+ state_dict['state']['model'] = state.model.state_dict()
96
+
97
+ # Add in unsharded optimizer state dict.
98
+ if self.keep_optimizers:
99
+ optimizer = state.optimizers[0]
100
+ state_dict['state']['optimizers'] = {
101
+ type(optimizer).__qualname__:
102
+ fsdp_get_optim_state_dict(state.model,
103
+ optimizer,
104
+ state_dict_type='full')
105
+ }
106
+ if dist.get_global_rank() == 0:
107
+ torch.save(state_dict, save_path)
108
+
109
+ if self.upload_to_object_store and self.remote_ud is not None and dist.get_global_rank(
110
+ ) == 0:
111
+ remote_file_name = str(Path(save_dir) / Path(filename))
112
+ self.remote_ud.upload_file(state=state,
113
+ remote_file_name=remote_file_name,
114
+ file_path=Path(save_path),
115
+ overwrite=self.overwrite)
Perceptrix/finetune/build/lib/llmfoundry/callbacks/resumption_callbacks.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+ from typing import List
6
+
7
+ from composer.core import Callback, State
8
+ from composer.loggers import Logger
9
+
10
+ __all__ = [
11
+ 'GlobalLRScaling',
12
+ 'LayerFreezing',
13
+ ]
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class GlobalLRScaling(Callback):
19
+ """GlobalLRScaling.
20
+
21
+ This callback can be applied upon resuming a model checkpoint. Upon
22
+ fit_start it will multiply the base LR by `lr_scale` and set the WD to be.
23
+
24
+ `wd_pct` * `lr`.
25
+
26
+ Args:
27
+ lr_scale (float): Multiplicative factor to scale LR by
28
+ wd_pct (float): Percentage of LR to set weight decay to.
29
+ """
30
+
31
+ def __init__(self, lr_scale: float, wd_pct: float = 0.0):
32
+ self.lr_scale = lr_scale
33
+ self.wd_pct = wd_pct
34
+
35
+ def fit_start(self, state: State, logger: Logger) -> None:
36
+ del logger # unused
37
+
38
+ if hasattr(state, 'optimizer') and state.optimizers is None:
39
+ raise Exception('No optimizers defined')
40
+ for optimizer in state.optimizers:
41
+ for group in optimizer.param_groups:
42
+ group['lr'] *= self.lr_scale
43
+ group['weight_decay'] = group['lr'] * self.wd_pct
44
+ if 'initial_lr' in group:
45
+ group['initial_lr'] *= self.lr_scale
46
+ log.info(
47
+ f"Set LR and WD to {group['lr']}, {group['weight_decay']}")
48
+
49
+ for scheduler in state.schedulers:
50
+ scheduler.base_lrs = [
51
+ self.lr_scale * lr for lr in scheduler.base_lrs
52
+ ]
53
+
54
+
55
+ class LayerFreezing(Callback):
56
+ """LayerFreezing.
57
+
58
+ This callback can be applied upon resuming a model checkpoint. Upon
59
+ fit_start it freeze the layers specified in `layer_names`. If using
60
+ activation checkpointing, please set the
61
+ `activation_checkpointing_reentrant` flag in `fsdp_config` to false.
62
+
63
+ Args:
64
+ layer_names (float): Names of layers to freeze.
65
+ """
66
+
67
+ def __init__(self, layer_names: List[str]):
68
+ self.layer_names = set(layer_names)
69
+
70
+ def fit_start(self, state: State, logger: Logger) -> None:
71
+ del logger # unused
72
+
73
+ model_layers = set(name for name, _ in state.model.named_parameters())
74
+ for layer in self.layer_names:
75
+ if layer not in model_layers:
76
+ raise Exception(
77
+ f'Attempted to freeze layer not found in model: {layer}\nAvailable layers: {model_layers}'
78
+ )
79
+
80
+ successful_freeze = False
81
+ for name, p in state.model.named_parameters():
82
+ if p.requires_grad and name in self.layer_names:
83
+ p.requires_grad = False
84
+ log.debug(f'Froze layer: {name}\nParam: {p}')
85
+ successful_freeze = True
86
+
87
+ if not successful_freeze:
88
+ raise Exception(
89
+ f"Tried to run LayerFreezing but didn't freeze any layers")
Perceptrix/finetune/build/lib/llmfoundry/callbacks/scheduled_gc_callback.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+
6
+ import torch
7
+ from composer.core import Callback, State
8
+ from composer.loggers import Logger
9
+
10
+
11
+ def gc_cuda():
12
+ """Garbage collect Torch (CUDA) memory."""
13
+ gc.collect()
14
+ if torch.cuda.is_available():
15
+ torch.cuda.empty_cache()
16
+
17
+
18
+ class ScheduledGarbageCollector(Callback):
19
+ """Disable automatic garbage collection and collect garbage at interval.
20
+
21
+ Args:
22
+ batch_interval (int): Number of batches between checkpoints call to gc.collect()
23
+ eval_keep_disabled (bool): keep gc disabled during eval (default: False)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ batch_interval: int,
29
+ eval_keep_disabled: bool = False,
30
+ ):
31
+ self.batch_interval = batch_interval
32
+ self.eval_keep_disabled = eval_keep_disabled
33
+ self.gc_init_state = None
34
+
35
+ def fit_start(self, state: State, logger: Logger) -> None:
36
+ del state, logger # unused
37
+
38
+ # cache if automatic garbage collection is enabled; reset at fit_end
39
+ self.gc_init_state = gc.isenabled()
40
+
41
+ # disable automatic garbage collection
42
+ gc.disable()
43
+ gc_cuda()
44
+
45
+ def fit_end(self, state: State, logger: Logger) -> None:
46
+ del state, logger # unused
47
+
48
+ gc_cuda()
49
+
50
+ # reset automatic garbage collection at fit_end
51
+ if self.gc_init_state:
52
+ gc.enable()
53
+ else:
54
+ gc.disable()
55
+
56
+ def before_dataloader(self, state: State, logger: Logger) -> None:
57
+ del logger # unused
58
+
59
+ if state.timestamp.batch.value % self.batch_interval == 0:
60
+ gc_cuda()
61
+
62
+ def eval_start(self, state: State, logger: Logger) -> None:
63
+ del state, logger # unused
64
+
65
+ gc_cuda()
66
+ if not self.eval_keep_disabled:
67
+ gc.enable()
68
+
69
+ def eval_end(self, state: State, logger: Logger) -> None:
70
+ del state, logger # unused
71
+
72
+ if not self.eval_keep_disabled:
73
+ gc.disable()
74
+
75
+ gc_cuda()
Perceptrix/finetune/build/lib/llmfoundry/data/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
5
+ from llmfoundry.data.denoising import (MixtureOfDenoisersCollator,
6
+ build_text_denoising_dataloader)
7
+ from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator,
8
+ build_finetuning_dataloader)
9
+ from llmfoundry.data.text_data import (StreamingTextDataset,
10
+ build_text_dataloader)
11
+
12
+ __all__ = [
13
+ 'MixtureOfDenoisersCollator',
14
+ 'build_text_denoising_dataloader',
15
+ 'Seq2SeqFinetuningCollator',
16
+ 'build_finetuning_dataloader',
17
+ 'StreamingTextDataset',
18
+ 'build_text_dataloader',
19
+ 'NoConcatDataset',
20
+ 'ConcatTokensDataset',
21
+ ]
Perceptrix/finetune/build/lib/llmfoundry/data/data.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Datasets for converting to MDS Shards."""
5
+ import os
6
+ import warnings
7
+ from typing import Dict, Iterable, Union
8
+
9
+ import datasets as hf_datasets
10
+ import numpy as np
11
+ from torch.utils.data import IterableDataset
12
+ from transformers import PreTrainedTokenizerBase
13
+
14
+
15
+ class NoConcatDataset(IterableDataset):
16
+ """An IterableDataset that returns text samples for MDSWriter.
17
+
18
+ Returns dicts of {'text': bytes}
19
+ """
20
+
21
+ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset,
22
+ hf_datasets.Dataset]):
23
+ self.hf_dataset = hf_dataset
24
+
25
+ def __iter__(self) -> Iterable[Dict[str, bytes]]:
26
+ for sample in self.hf_dataset:
27
+ # convert to bytes to store in MDS binary format
28
+ yield {'text': sample['text'].encode('utf-8')}
29
+
30
+
31
+ class ConcatTokensDataset(IterableDataset):
32
+ """An IterableDataset that returns token samples for MDSWriter.
33
+
34
+ Returns dicts of {'tokens': bytes}
35
+
36
+ To use data created by this class and written to MDS format:
37
+
38
+ ```python
39
+ import torch
40
+ from streaming.base import StreamingDataset
41
+ from transformers import AutoTokenizer
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
44
+ ds = StreamingDataset(local='mds-data-folder', split='val')
45
+
46
+ # note, you need to copy the numpy array because the original is non-writeable
47
+ # and torch does not support non-writeable tensors, so you get a scary warning and
48
+ # if you do try to write to the tensor you get undefined behavior
49
+ tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
50
+ print(tokenizer.decode(tokens))
51
+ ```
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
57
+ tokenizer: PreTrainedTokenizerBase,
58
+ max_length: int,
59
+ bos_text: str,
60
+ eos_text: str,
61
+ no_wrap: bool,
62
+ ):
63
+ self.hf_dataset = hf_dataset
64
+ self.tokenizer = tokenizer
65
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
66
+ self.max_length = max_length
67
+ self.bos_text = bos_text
68
+ self.eos_text = eos_text
69
+ self.should_wrap = not no_wrap
70
+
71
+ self.bos_tokens = self.tokenizer(self.bos_text,
72
+ truncation=False,
73
+ padding=False,
74
+ add_special_tokens=False)['input_ids']
75
+ if len(self.bos_tokens) > 1:
76
+ warnings.warn(
77
+ f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token\
78
+ , instead we got {self.bos_tokens}. Quit if this was in error.')
79
+
80
+ self.eos_tokens = self.tokenizer(self.eos_text,
81
+ truncation=False,
82
+ padding=False,
83
+ add_special_tokens=False)['input_ids']
84
+ if len(self.eos_tokens) > 1:
85
+ warnings.warn(
86
+ f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token\
87
+ , instead we got {self.eos_tokens}. Quit if this was in error.')
88
+
89
+ eos_text_provided = self.eos_text != ''
90
+ bos_text_provided = self.bos_text != ''
91
+ test_text = self.tokenizer('')
92
+ if len(test_text['input_ids']) > 0 and (eos_text_provided or
93
+ bos_text_provided):
94
+ message = 'both eos and bos' if eos_text_provided and bos_text_provided else (
95
+ 'eos_text' if eos_text_provided else 'bos_text')
96
+ warnings.warn(
97
+ f'The provided tokenizer adds special tokens, but you also specified {message}. This may result '
98
+ +
99
+ 'in duplicated special tokens. Please be sure this is what you intend.'
100
+ )
101
+
102
+ def __iter__(self) -> Iterable[Dict[str, bytes]]:
103
+
104
+ buffer = []
105
+ for sample in self.hf_dataset:
106
+ encoded = self.tokenizer(sample['text'],
107
+ truncation=False,
108
+ padding=False)
109
+ iids = encoded['input_ids']
110
+ buffer = buffer + self.bos_tokens + iids + self.eos_tokens
111
+ while len(buffer) >= self.max_length:
112
+ concat_sample = buffer[:self.max_length]
113
+ buffer = buffer[self.max_length:] if self.should_wrap else []
114
+ yield {
115
+ # convert to bytes to store in MDS binary format
116
+ 'tokens': np.asarray(concat_sample).tobytes()
117
+ }
Perceptrix/finetune/build/lib/llmfoundry/data/denoising.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Streaming dataloader for (mixture of) denoising task(s)."""
5
+
6
+ import logging
7
+ import random
8
+ import sys
9
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from omegaconf import DictConfig
14
+ from omegaconf import OmegaConf as om
15
+ from torch.utils.data import DataLoader
16
+ from transformers import PreTrainedTokenizerBase
17
+
18
+ from llmfoundry.data.packing import BinPackWrapper
19
+ from llmfoundry.data.text_data import StreamingTextDataset
20
+ from llmfoundry.models import utils
21
+
22
+ __all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+ # HuggingFace hardcodes the ignore index to -100
27
+ _HF_IGNORE_INDEX = -100
28
+
29
+ # Required signature of any `prefix_function` (see below)
30
+ PREFIX_FUNCTION = Callable[[float, Optional[float], PreTrainedTokenizerBase],
31
+ Sequence[int]]
32
+
33
+
34
+ def ul2_prefix_function(
35
+ mask_ratio: float,
36
+ mean_length: Optional[float],
37
+ tokenizer: PreTrainedTokenizerBase,
38
+ ) -> Sequence[int]:
39
+ """Generates prefixes based on UL2 paper.
40
+
41
+ See: http://arxiv.org/abs/2205.05131
42
+ """
43
+ if mean_length is None:
44
+ # This is the case for "sequence to sequence"
45
+ prefix = '[S2S]' if mask_ratio < 1.0 else '[CLM]'
46
+ elif mean_length >= 12 or mask_ratio >= 0.3:
47
+ # UL2 tags this corruption rate "extreme"
48
+ prefix = '[NLG]'
49
+ else:
50
+ # UL2 tags this corruption rate as "regular"
51
+ prefix = '[NLU]'
52
+ return tokenizer(prefix, add_special_tokens=False).input_ids
53
+
54
+
55
+ class MixtureOfDenoisersCollator:
56
+ """Data collator for mixture of span-corruption denoisers, as in UL2.
57
+
58
+ This collator supports a variety of tasks used to pre-train an
59
+ encoder-decoder model or a (prefix LM) decoder-only model. This is meant
60
+ to be used with a dataset that yields tokenized text sequences. It is not
61
+ required that the token sequences are already padded or truncate, as this
62
+ collator will internally truncate and pad as needed.
63
+
64
+ For the denoising mixture recommended in the original UL2 paper,
65
+ http://arxiv.org/abs/2205.05131, use:
66
+ .. python:
67
+ MixtureOfDenoisersCollator(
68
+ ...,
69
+ span_mean_lengths_and_ratios=[
70
+ [3, .15],
71
+ [8, .15],
72
+ [3, .50],
73
+ [8, .50],
74
+ [64, .15],
75
+ [64, .50],
76
+ ],
77
+ sequence_mask_ratios=0.25
78
+ )
79
+
80
+ Args:
81
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
82
+ prepare the data from raw text. Any missing sentinel tokens will
83
+ be added by the collator.
84
+ max_seq_length (int): The maximum length of sequences produced by this
85
+ collator. Incoming sequences may be truncated to accommodate this
86
+ limit.
87
+ Note that when formatting for decoder-only models, the context
88
+ tokens and target tokens are concatenated, and max_seq_length
89
+ applies to their combined length. For encoder-decoder models, both
90
+ the encoder and decoder will see up to max_seq_length tokens.
91
+ decoder_only_format (bool, optional): Whether to format the batches
92
+ for a decoder-only model (i.e. a prefix LM) or, if ``False``, an
93
+ encoder-decoder model. Default: ``False``.
94
+ span_mean_lengths_and_rations (optional): A length-2 list of a
95
+ ``[mean_length, mask_ratio]`` pair, or a list of such pairs. Each
96
+ pair adds a span corruption denoising task to the task mixture. For
97
+ example, ``[3, 0.15]`` adds the original span corruption task used
98
+ for pre-training a T5 model as in http://arxiv.org/abs/1910.10683,
99
+ which trained with a single span corruption task that used a mean
100
+ span length of 3 and a mask ratio of 15%.
101
+ Default: ``None`` does not add any span corruption tasks.
102
+ sequence_mask_ratios (optional): A float or list of floats, one for each
103
+ sequence corruption denoising task to add to the task mixture. Each
104
+ sequence mask ratio must be greater than 0.0 and at most 1.0.
105
+ This type of task is a special instance of span corruption, with
106
+ exactly one masked span take from the end of the sequence. The
107
+ length of the span is sampled uniformly such that the average
108
+ portion of masked tokens equals sequence_mask_ratio.
109
+ Note: A value of 1.0 essentially yields causal LM examples.
110
+ Default: ``None` does not add any sequence corruption tasks.
111
+ allow_pad_trimming (bool, optional): Whether to allow the collator to
112
+ trim away sequence regions that are entirely padding (i.e. padding
113
+ for each example in the batch). If ``True``, shorter sequences may
114
+ improve throughput but at a potentially higher memory cost
115
+ owing to variable sequence lengths from batch to batch.
116
+ Default: ``False`` yields batches that are always padded to
117
+ max_seq_length.
118
+ prefix_function (callable, optional): A function that maps denoising
119
+ task parameters (e.g. mean_length=3, mask_ratio=0.15) to a prefix
120
+ that will be added to sequences when the associated "noiser" is
121
+ applied.
122
+ To disable these prefixes, use a value of ``None``.
123
+ Default: :func:`ul2_prefix_function` applies the prefix scheme
124
+ suggested in the UL2 paper: http://arxiv.org/abs/2205.05131.
125
+ context_eos (bool, optional): Whether to attach an EOS token to the end
126
+ of the context sequence, marking the transition from context to
127
+ target sequence. Only applicable if decoder_only_format is True.
128
+ Context EOS tokens are always added for encoder-decoder format.
129
+ Default: ``False`` does not attach context EOS.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ tokenizer: PreTrainedTokenizerBase,
135
+ max_seq_length: int,
136
+ decoder_only_format: bool = False,
137
+ span_mean_lengths_and_ratios: Optional[List] = None,
138
+ sequence_mask_ratios: Optional[Union[List[float], float]] = None,
139
+ allow_pad_trimming: bool = False,
140
+ prefix_function: Optional[PREFIX_FUNCTION] = ul2_prefix_function,
141
+ context_eos: Optional[bool] = None,
142
+ ):
143
+ # Prepare the tokenizer for denoising tasks
144
+ utils.adapt_tokenizer_for_denoising(tokenizer)
145
+
146
+ self.tokenizer = tokenizer
147
+ self.max_seq_length = max_seq_length
148
+ self.decoder_only_format = decoder_only_format
149
+ self._sentinel_token_ids = np.array(self.tokenizer.sentinel_token_ids)
150
+
151
+ # Trimming will always be skipped on at least the first __call__
152
+ self._allow_pad_trimming = allow_pad_trimming
153
+ self._seen_first_batch = False
154
+
155
+ self.context_eos = bool(context_eos) if decoder_only_format else True
156
+
157
+ # Process the span_mean_lengths_and_ratios argument
158
+ if span_mean_lengths_and_ratios is None:
159
+ # In this case, there are no span corruption tasks
160
+ self.span_mean_lengths_and_ratios = []
161
+ elif isinstance(span_mean_lengths_and_ratios[0], (int, float)):
162
+ # In this case, there is one span corruption task
163
+ if not len(span_mean_lengths_and_ratios) == 2:
164
+ raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
165
+ 'pair of [mean_length, mask_ratio], a list ' + \
166
+ f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
167
+ self.span_mean_lengths_and_ratios = [span_mean_lengths_and_ratios]
168
+ else:
169
+ # In this case, there are one or more span corruption tasks
170
+ span_mean_lengths_and_ratios = list(span_mean_lengths_and_ratios)
171
+ for spec_pair in span_mean_lengths_and_ratios:
172
+ if len(spec_pair) != 2:
173
+ raise ValueError('`span_mean_lengths_and_ratios` must be a ' + \
174
+ 'pair of [mean_length, mask_ratio], a list ' + \
175
+ f'of such pairs, or None. Got {span_mean_lengths_and_ratios}.')
176
+ self.span_mean_lengths_and_ratios = span_mean_lengths_and_ratios
177
+
178
+ # Process the sequence_mask_ratios argument
179
+ if sequence_mask_ratios is None:
180
+ # In this case, there are no sequence corruption tasks
181
+ self.sequence_mask_ratios = []
182
+ elif isinstance(sequence_mask_ratios, float):
183
+ # In this case, there is one sequence corruption task
184
+ self.sequence_mask_ratios = [sequence_mask_ratios]
185
+ else:
186
+ # In this case, there is one or more sequence corruption tasks
187
+ for ratio in sequence_mask_ratios:
188
+ if not (0 < ratio <= 1.0):
189
+ raise ValueError('`sequence_mask_ratios` must be a float (or list '+\
190
+ 'of floats) that are each >0.0 and <=1.0, or None. '+\
191
+ f'Got {sequence_mask_ratios}.')
192
+ self.sequence_mask_ratios = sequence_mask_ratios
193
+
194
+ # Populate the noisers so we can learn to denoise them!
195
+ self._noisers = []
196
+ self._smallest_max_raw_length = self.max_seq_length * 100
197
+ self._largest_max_raw_length = 0
198
+ self._uses_span_corruption = False
199
+
200
+ # Add "noisers" for any span corruption denoising tasks
201
+ # Each mean_length / mask_ratio combo becomes one of the span
202
+ # corruption denoising tasks
203
+ for span_mean_length, span_mask_ratio in self.span_mean_lengths_and_ratios:
204
+ self._uses_span_corruption = True
205
+ if span_mean_length < 0:
206
+ raise ValueError('All span mean lengths must be positive.')
207
+ if not 0 < span_mask_ratio < 1.0:
208
+ raise ValueError(
209
+ 'All span masking ratios must be between 0.0 and 1.0.')
210
+
211
+ if prefix_function is not None:
212
+ prefix_tokens = prefix_function(span_mask_ratio,
213
+ span_mean_length,
214
+ self.tokenizer)
215
+ else:
216
+ prefix_tokens = None
217
+
218
+ max_raw_length = _get_max_starting_length(
219
+ max_length=self.max_seq_length,
220
+ mask_ratio=span_mask_ratio,
221
+ mean_span_length=span_mean_length,
222
+ n_prefix_tokens=len(prefix_tokens or []),
223
+ decoder_only_format=self.decoder_only_format,
224
+ context_eos=self.context_eos)
225
+ if max_raw_length < self._smallest_max_raw_length:
226
+ self._smallest_max_raw_length = max_raw_length
227
+ if max_raw_length > self._largest_max_raw_length:
228
+ self._largest_max_raw_length = max_raw_length
229
+
230
+ kwargs = {
231
+ 'mean_span_length': span_mean_length,
232
+ 'mask_ratio': span_mask_ratio,
233
+ 'prefix_tokens': prefix_tokens,
234
+ 'max_raw_length': max_raw_length,
235
+ }
236
+ self._noisers.append(kwargs)
237
+
238
+ # Add "noisers" for any sequential denoising tasks
239
+ for sequence_mask_ratio in self.sequence_mask_ratios:
240
+ if prefix_function is not None:
241
+ prefix_tokens = prefix_function(sequence_mask_ratio, None,
242
+ self.tokenizer)
243
+ else:
244
+ prefix_tokens = None
245
+
246
+ max_raw_length = self.max_seq_length - len(prefix_tokens or []) - 1
247
+ if decoder_only_format and self.context_eos:
248
+ max_raw_length = max_raw_length - 1
249
+
250
+ if not self._uses_span_corruption and (
251
+ max_raw_length < self._smallest_max_raw_length):
252
+ # We choose not to count sequence denoising in the smallest
253
+ # unless there is only sequence denoising.
254
+ self._smallest_max_raw_length = max_raw_length
255
+ if max_raw_length > self._largest_max_raw_length:
256
+ self._largest_max_raw_length = max_raw_length
257
+
258
+ kwargs = {
259
+ 'mean_span_length': None,
260
+ 'mask_ratio': sequence_mask_ratio,
261
+ 'prefix_tokens': prefix_tokens,
262
+ 'max_raw_length': max_raw_length,
263
+ }
264
+ self._noisers.append(kwargs)
265
+
266
+ if not self._noisers:
267
+ raise ValueError(
268
+ 'No denoising tasks were included. Make sure to set ' + \
269
+ '`span_mean_lengths_and_ratios` and/or `sequence_mask_ratios`.')
270
+
271
+ @property
272
+ def smallest_max_raw_length(self) -> int:
273
+ return int(self._smallest_max_raw_length)
274
+
275
+ @property
276
+ def largest_max_raw_length(self) -> int:
277
+ return int(self._largest_max_raw_length)
278
+
279
+ def __call__(self, examples: List[Dict[str,
280
+ Any]]) -> Dict[str, torch.Tensor]:
281
+ """Batch examples processed by the span corrupter."""
282
+ processed_examples = []
283
+ for example in examples:
284
+ # Randomly pick a "noiser" to apply to this example
285
+ noiser = random.choice(self._noisers)
286
+ # Apply it
287
+ processed_examples.append(
288
+ noise_token_sequence(
289
+ example,
290
+ mask_ratio=noiser['mask_ratio'],
291
+ mean_span_length=noiser['mean_span_length'],
292
+ prefix_tokens=noiser['prefix_tokens'],
293
+ max_raw_length=noiser['max_raw_length'],
294
+ max_seq_length=self.max_seq_length,
295
+ tokenizer=self.tokenizer,
296
+ sentinel_token_ids=self._sentinel_token_ids,
297
+ decoder_only_format=self.decoder_only_format,
298
+ context_eos=self.context_eos))
299
+ batch = self.tokenizer.pad(processed_examples)
300
+
301
+ # This logic prevents trimming on at least the first batch
302
+ if not (self._allow_pad_trimming and self._seen_first_batch):
303
+ self._seen_first_batch = True
304
+ return batch
305
+ self._seen_first_batch = True
306
+
307
+ # Truncate portions of the inputs that are purely padding
308
+ # (up to a multiple of 8)
309
+ multiple_of = 8
310
+ n_examples_per_length = batch['attention_mask'].sum(0)
311
+ keep_tokens = torch.sum(n_examples_per_length > 0)
312
+ keep_tokens = int(multiple_of * torch.ceil(keep_tokens / multiple_of))
313
+
314
+ # Note: EncDec formatting will always produce a right-padded batch
315
+ if self.tokenizer.padding_side == 'left' and self.decoder_only_format:
316
+ batch['input_ids'] = batch['input_ids'][:, -keep_tokens:]
317
+ batch['attention_mask'] = batch['attention_mask'][:, -keep_tokens:]
318
+ else:
319
+ batch['input_ids'] = batch['input_ids'][:, :keep_tokens]
320
+ batch['attention_mask'] = batch['attention_mask'][:, :keep_tokens]
321
+
322
+ if self.decoder_only_format:
323
+ if self.tokenizer.padding_side == 'left':
324
+ batch['labels'] = batch['labels'][:, -keep_tokens:]
325
+ batch['bidirectional_mask'] = batch[
326
+ 'bidirectional_mask'][:, -keep_tokens:]
327
+ else:
328
+ batch['labels'] = batch['labels'][:, :keep_tokens]
329
+ batch['bidirectional_mask'] = batch[
330
+ 'bidirectional_mask'][:, :keep_tokens]
331
+
332
+ else:
333
+ # Truncate portions of the decoder inputs that are purely padding
334
+ n_examples_per_length = batch['decoder_attention_mask'].sum(0)
335
+ keep_tokens = torch.sum(n_examples_per_length > 0)
336
+ keep_tokens = int(multiple_of *
337
+ torch.ceil(keep_tokens / multiple_of))
338
+
339
+ batch['labels'] = batch['labels'][:, :keep_tokens]
340
+ batch['decoder_attention_mask'] = batch[
341
+ 'decoder_attention_mask'][:, :keep_tokens]
342
+ batch['decoder_input_ids'] = batch[
343
+ 'decoder_input_ids'][:, :keep_tokens]
344
+
345
+ # This slicing can produce non-contiguous tensors, so use .contiguous
346
+ # to prevent related problems
347
+ batch = {k: v.contiguous() for k, v in batch.items()}
348
+
349
+ return batch
350
+
351
+
352
+ def build_text_denoising_dataloader(
353
+ cfg: DictConfig,
354
+ tokenizer: PreTrainedTokenizerBase,
355
+ device_batch_size: int,
356
+ ) -> DataLoader[Dict]:
357
+ """Constructor function for a Mixture of Denoisers dataloader.
358
+
359
+ This function constructs a dataloader that can be used to train an
360
+ encoder-decoder model or a (prefix LM) decoder-only model on a text
361
+ denoising task mixture (e.g. span corruption, or UL2).
362
+
363
+ The underlying dataset is a :class:`StreamingTextDataset`, allowing you to
364
+ stream raw text data or pre-tokenized text data.
365
+
366
+ The dataloader uses a :class:`MixtureOfDenoisersCollator` to prepare the
367
+ tokenized examples into training batches.
368
+
369
+ Args:
370
+ cfg (DictConfig): An omegaconf dictionary used to configure the loader:
371
+ cfg.name (str): The type of dataloader to build. Must = "text_denoising".
372
+ ---
373
+ cfg.dataset.max_seq_len (int): The maximum length of sequences
374
+ in the batch. See :class:`MixtureOfDenoisersCollator` docstring
375
+ for details.
376
+ cfg.dataset.packing_ratio (float, optional): If provided, this invokes
377
+ a collator wrapper that packs device_batch_size*packing_ratio
378
+ raw examples into device_batch_size packed examples. This helps
379
+ minimize padding while preserving sequence integrity.
380
+ This adds `sequence_id` to the batch, which indicates which unique
381
+ sequence each token belongs to.
382
+ Note: Using this feature will not change device_batch_size but it
383
+ will determine the number of raw examples consumed by the dataloader
384
+ per batch. Some examples may be discarded if they do not fit when
385
+ packing.
386
+ Select packing_ratio **carefully** based on the dataset
387
+ statistics, max_seq_len, and tolerance for discarding samples!
388
+ The packing code in `./packing.py` provides a script that can help
389
+ you choose the best packing_ratio.
390
+ See :class:`StreamingTextDataset` for info on other standard config
391
+ options within `cfg.dataset`.
392
+ ---
393
+ cfg.mixture_of_denoisers.decoder_only_format (bool): Whether the
394
+ batches should use the format required for training a decoder-only
395
+ model (if ``True``) or an encoder-decoder model (if ``False``).
396
+ cfg.mixture_of_denoisers.span_mean_lengths_and_ratios (optional): The
397
+ parameters for any span corruption denoising tasks to include in
398
+ the task mixture.
399
+ See :class:`MixtureOfDenoisersCollator` docstring for details.
400
+ cfg.mixture_of_denoisers.sequence_mask_ratios (optional): The
401
+ parameters for any sequence denoising tasks to include in the
402
+ task mixture.
403
+ See :class:`MixtureOfDenoisersCollator` docstring for details.
404
+ cfg.mixture_of_denoisers.allow_pad_trimming (optional): Whether to
405
+ allow the collator to trim padding when possible (if ``True``).
406
+ Defaults to ``False``.
407
+ cfg.mixture_of_denoisers.prefix_function (optional): Set to ``None``
408
+ to disable the UL2-style prefixes that will be automatically
409
+ added by default.
410
+ ---
411
+ See :class:`DataLoader` for standard argument options to the pytorch
412
+ dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
413
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
414
+ prepare the data from raw text. Any missing sentinel tokens will
415
+ be added by the collator.
416
+ device_batch_size (int): The size of the batches (number of examples)
417
+ that the dataloader will produce.
418
+
419
+ Note:
420
+ You can run the script inside `./packing.py` to quickly test the
421
+ padding/waste rates for different `cfg.dataset.packing_ratio` choices,
422
+ given a starting workload YAML.
423
+ """
424
+ assert cfg.name == 'text_denoising', f'Tried to build_denoising text dataloader with cfg.name={cfg.name}'
425
+
426
+ collate_fn = MixtureOfDenoisersCollator(
427
+ tokenizer=tokenizer,
428
+ max_seq_length=cfg.dataset.max_seq_len,
429
+ decoder_only_format=cfg.mixture_of_denoisers.decoder_only_format,
430
+ span_mean_lengths_and_ratios=cfg.mixture_of_denoisers.get(
431
+ 'span_mean_lengths_and_ratios'),
432
+ sequence_mask_ratios=cfg.mixture_of_denoisers.get(
433
+ 'sequence_mask_ratios'),
434
+ allow_pad_trimming=cfg.mixture_of_denoisers.get('allow_pad_trimming',
435
+ False),
436
+ prefix_function=cfg.mixture_of_denoisers.get('prefix_function',
437
+ ul2_prefix_function),
438
+ context_eos=cfg.mixture_of_denoisers.get('context_eos'))
439
+
440
+ truncate_to = cfg.mixture_of_denoisers.get('truncate_raw_tokens_to')
441
+ if truncate_to is None:
442
+ # By default, truncate to the largest max raw length of the denoisers
443
+ truncate_to = collate_fn.largest_max_raw_length
444
+ elif isinstance(truncate_to, str):
445
+ if truncate_to.lower() == 'min':
446
+ # Truncate to the smallest max raw length of the denoisers
447
+ truncate_to = collate_fn.smallest_max_raw_length
448
+ elif truncate_to.lower() == 'max':
449
+ # Truncate to the largest max raw length of the denoisers
450
+ truncate_to = collate_fn.largest_max_raw_length
451
+ else:
452
+ raise ValueError(
453
+ f'truncate_raw_tokens_to(="{truncate_to.lower()}") must be "min", "max", a positive int, or None.'
454
+ )
455
+ else:
456
+ if not isinstance(truncate_to, int):
457
+ ValueError(
458
+ f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
459
+ )
460
+ if truncate_to < 0:
461
+ ValueError(
462
+ f'truncate_raw_tokens_to(={truncate_to}) must be "min", "max", a positive int, or None.'
463
+ )
464
+
465
+ dataset = StreamingTextDataset(
466
+ local=cfg.dataset.local,
467
+ tokenizer=tokenizer,
468
+ max_seq_len=truncate_to,
469
+ remote=cfg.dataset.get('remote'),
470
+ split=cfg.dataset.get('split'),
471
+ shuffle=cfg.dataset.get('shuffle', False),
472
+ predownload=cfg.dataset.get('predownload', 100_000),
473
+ keep_zip=cfg.dataset.get('keep_zip', False),
474
+ download_retry=cfg.dataset.get('download_retry', 2),
475
+ download_timeout=cfg.dataset.get('download_timeout', 60),
476
+ validate_hash=cfg.dataset.get('validate_hash'),
477
+ shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
478
+ num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', 128),
479
+ batch_size=device_batch_size,
480
+ )
481
+
482
+ if dataset.tokenizer.pad_token is None:
483
+ dataset.tokenizer.pad_token = dataset.tokenizer.eos_token
484
+
485
+ if cfg.dataset.get('packing_ratio'):
486
+ n_examples_to_pack = int(device_batch_size * cfg.dataset.packing_ratio)
487
+ if n_examples_to_pack < device_batch_size:
488
+ raise ValueError('packing_ratio must be >= 1, if supplied')
489
+ if not cfg.mixture_of_denoisers.decoder_only_format:
490
+ raise NotImplementedError(
491
+ 'On-the-fly packing is currently only supported for decoder-only formats.'
492
+ )
493
+ collate_fn = BinPackWrapper(
494
+ collator=collate_fn,
495
+ target_batch_size=device_batch_size,
496
+ max_seq_len=cfg.dataset.max_seq_len,
497
+ pad_token_id=dataset.tokenizer.pad_token_id,
498
+ padding_side=dataset.tokenizer.padding_side,
499
+ max_leftover_bins_to_keep=cfg.dataset.get(
500
+ 'max_leftover_bins_to_keep'),
501
+ )
502
+ device_batch_size = n_examples_to_pack
503
+ elif cfg.dataset.get('max_leftover_bins_to_keep') is not None:
504
+ raise ValueError(
505
+ 'cfg.dataset.max_leftover_bins_to_keep has been defined, ' +\
506
+ 'but cfg.dataset.packing_ratio has not been set. Please set ' +\
507
+ 'the latter to turn on packing or remove the former from the config.')
508
+
509
+ return DataLoader(
510
+ dataset,
511
+ collate_fn=collate_fn,
512
+ batch_size=device_batch_size,
513
+ drop_last=cfg.drop_last,
514
+ num_workers=cfg.num_workers,
515
+ pin_memory=cfg.get('pin_memory', True),
516
+ prefetch_factor=cfg.get('prefetch_factor', 2),
517
+ persistent_workers=cfg.get('persistent_workers', False),
518
+ timeout=cfg.get('timeout', 0),
519
+ )
520
+
521
+
522
+ def noise_token_sequence(
523
+ example: Union[torch.Tensor, Mapping[str, Any]],
524
+ mask_ratio: float,
525
+ mean_span_length: Optional[float],
526
+ prefix_tokens: Optional[Sequence[int]],
527
+ max_raw_length: int,
528
+ max_seq_length: int,
529
+ tokenizer: PreTrainedTokenizerBase,
530
+ sentinel_token_ids: np.ndarray,
531
+ decoder_only_format: bool,
532
+ context_eos: bool,
533
+ ) -> Dict[str, torch.Tensor]:
534
+ """Span corruption applicable to all UL2 denoising tasks."""
535
+ # Extract the raw text tokens (trim if we need to)
536
+ if isinstance(example, torch.Tensor):
537
+ # If the example is a tensor, assume is the raw tokens with no padding
538
+ tokens = example
539
+ length = len(tokens)
540
+ else:
541
+ tokens = example['input_ids']
542
+ length = sum(example['attention_mask'])
543
+ if length > max_raw_length:
544
+ length = max_raw_length
545
+ if tokenizer.padding_side == 'left':
546
+ tokens = tokens[-length:]
547
+ else:
548
+ tokens = tokens[:length]
549
+
550
+ prefix_tokens = prefix_tokens or []
551
+
552
+ if length < 1:
553
+ raise ValueError('Example cannot be empty but token length <1.')
554
+
555
+ # mean_span_length==None is a special case for "sequential" denoising
556
+ # (where a single span at the end of the sequence is masked)
557
+ if mean_span_length is None:
558
+ # This ensures that exactly 1 span will be produced and that
559
+ # trimming to max_seq_length won't cut off any <EOS> token.
560
+ # In the decoder-only case, this won't insert new tokens.
561
+ if mask_ratio <= 0.5:
562
+ u = np.random.uniform(low=0.0, high=mask_ratio * 2)
563
+ else:
564
+ u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0)
565
+ mean_span_length = float(np.round(1 + u * (length - 1)))
566
+ mask_ratio = mean_span_length / length
567
+ use_sentinels = False
568
+ else:
569
+ use_sentinels = True
570
+
571
+ # Generate the mask
572
+ # Note: this function can be used for all the UL2 noising functions
573
+ mask = _sample_mask_array(length, mask_ratio, mean_span_length)
574
+ # The sequence should always be unmasked at the beginning
575
+ assert mask[0] == 0
576
+
577
+ # Generate the input/label sequences given the raw tokens and the mask
578
+ tokens_inputs = _apply_mask(tokens,
579
+ mask,
580
+ use_sentinels,
581
+ tokenizer.eos_token_id,
582
+ sentinel_token_ids,
583
+ ensure_eos=context_eos)
584
+ tokens_labels = _apply_mask(tokens,
585
+ 1 - mask,
586
+ use_sentinels,
587
+ tokenizer.eos_token_id,
588
+ sentinel_token_ids,
589
+ ensure_eos=True)
590
+
591
+ # Tag the inputs with any prefix
592
+ if prefix_tokens:
593
+ tokens_inputs = np.concatenate([prefix_tokens, tokens_inputs])
594
+
595
+ # Trim if necessary
596
+ if len(tokens_inputs) > max_seq_length:
597
+ raise ValueError('This should not exceed the max length')
598
+ if len(tokens_labels) > max_seq_length:
599
+ raise ValueError('This should not exceed the max length')
600
+
601
+ tokens_inputs = torch.LongTensor(tokens_inputs)
602
+ tokens_labels = torch.LongTensor(tokens_labels)
603
+
604
+ if decoder_only_format:
605
+ return _format_tokens_for_decoder_only(tokens_inputs, tokens_labels,
606
+ max_seq_length,
607
+ tokenizer.pad_token_id,
608
+ tokenizer.padding_side)
609
+ return _format_tokens_for_encoder_decoder(tokens_inputs, tokens_labels,
610
+ max_seq_length,
611
+ tokenizer.pad_token_id)
612
+
613
+
614
+ def _get_max_starting_length(max_length: int, mask_ratio: float,
615
+ mean_span_length: float, n_prefix_tokens: int,
616
+ decoder_only_format: bool,
617
+ context_eos: bool) -> int:
618
+ """Get max num raw tokens that will fit max_length."""
619
+
620
+ def sequence_stats(length: int):
621
+ length = np.maximum(length, 2)
622
+ num_noise_tokens = int(np.round(mask_ratio * float(length)))
623
+ num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1),
624
+ length - 1)
625
+ num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
626
+ num_noise_spans = np.maximum(num_spans, 1)
627
+ num_nonnoise_tokens = length - num_noise_tokens
628
+ # Prefix, sentinel, and EOS added to input for Enc-Dec
629
+ extra_inp_tokens = n_prefix_tokens + num_noise_spans + int(context_eos)
630
+ # Sentinel and EOS added to target
631
+ extra_targ_tokens = num_noise_spans + 1
632
+ # Sequence totals after corruption
633
+ total_inp_tokens = num_nonnoise_tokens + extra_inp_tokens
634
+ total_targ_tokens = num_noise_tokens + extra_targ_tokens
635
+ return total_inp_tokens, total_targ_tokens
636
+
637
+ def length_fits(length: int) -> bool:
638
+ total_inp_tokens, total_targ_tokens = sequence_stats(length)
639
+ if decoder_only_format:
640
+ return (total_inp_tokens + total_targ_tokens) <= max_length
641
+ return (total_inp_tokens <= max_length) and (total_targ_tokens <=
642
+ max_length)
643
+
644
+ # Start with a definitely too-long sequence and reduce until it fits
645
+ num_raw_tokens = max_length * 2
646
+ while num_raw_tokens > 0:
647
+ if length_fits(num_raw_tokens):
648
+ return num_raw_tokens
649
+ num_raw_tokens -= 1
650
+ raise ValueError(
651
+ 'Unable to find a starting sequence length that can fit given the corruption and max_length parameters.'
652
+ )
653
+
654
+
655
+ def _sample_mask_array(length: int, mask_ratio: float,
656
+ mean_span_length: float) -> np.ndarray:
657
+ """Samples a span corruption mask."""
658
+ if mask_ratio == 0.0:
659
+ return np.zeros(length)
660
+ # This first block computes the number of noise/non-noise spans and the
661
+ # total tokens in each. Extra steps are taken to handle edge cases that
662
+ # cause degeneracy.
663
+ starting_length = length
664
+ length = np.maximum(length, 2)
665
+ num_noise_tokens = int(np.round(mask_ratio * float(length)))
666
+ num_noise_tokens = np.minimum(np.maximum(num_noise_tokens, 1), length - 1)
667
+ num_spans = int(np.round(float(num_noise_tokens) / mean_span_length))
668
+ num_noise_spans = np.maximum(num_spans, 1)
669
+ num_nonnoise_tokens = length - num_noise_tokens
670
+
671
+ # Sample the noise/non-noise span lengths and interleave them to
672
+ # generate the mask array.
673
+ # Note: We always start with a non-noise span.
674
+ def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray:
675
+ """Samples lengths of num_spans segments.
676
+
677
+ Note: the combined length of segments equals total_tokens.
678
+ """
679
+ span_markers = np.less(np.arange(total_tokens - 1), num_spans -
680
+ 1)[np.random.permutation(total_tokens - 1)]
681
+ span_start_indicator = np.concatenate([np.array([0]), span_markers])
682
+ span_id = np.cumsum(span_start_indicator).reshape(-1, 1)
683
+ spans = np.arange(num_spans).reshape(1, -1)
684
+ span_lengths = np.sum(span_id == spans, axis=0)
685
+ return span_lengths
686
+
687
+ noise_span_lengths = _sample_span_lengths(num_noise_tokens, num_noise_spans)
688
+ nonnoise_span_lengths = _sample_span_lengths(num_nonnoise_tokens,
689
+ num_noise_spans)
690
+ interleaved_span_lengths = np.reshape(
691
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
692
+ [num_noise_spans * 2])
693
+
694
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
695
+ span_start_indicator = np.zeros(length)
696
+ span_start_indicator[span_starts] = 1
697
+ span_id = np.cumsum(span_start_indicator)
698
+ is_noise = np.equal(np.mod(span_id, 2), 1)
699
+
700
+ mask = is_noise[:starting_length]
701
+
702
+ return mask
703
+
704
+
705
+ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],
706
+ mask: np.ndarray,
707
+ use_sentinels: bool,
708
+ eos_token_id: int,
709
+ sentinel_token_ids: np.ndarray,
710
+ ensure_eos: bool = True) -> np.ndarray:
711
+ """Remove or replace masked portions from token sequence."""
712
+ if not use_sentinels:
713
+ # The logic is simple if we don't use sentinel tokens
714
+ noised_tokens = np.array(tokens)[np.logical_not(mask)]
715
+
716
+ # Ensure there's an end-of-sentence token at the end
717
+ if ensure_eos and (noised_tokens[-1] != eos_token_id):
718
+ noised_tokens = np.concatenate(
719
+ [noised_tokens, np.array([eos_token_id])])
720
+
721
+ return noised_tokens
722
+
723
+ # Masking at previous token
724
+ prev_token_mask = np.concatenate([np.array([0]), mask[:-1]])
725
+
726
+ # Decompose mask into start-of-span mask and non-start-of-span mask
727
+ start_of_noise_span_token = np.logical_and(mask,
728
+ np.logical_not(prev_token_mask))
729
+ nonstart_noise_span_token = np.logical_and(mask, prev_token_mask)
730
+
731
+ # Replace tokens at the start of each noise span with its corresponding
732
+ # sentinel token
733
+ sentinel_idx = np.minimum(len(sentinel_token_ids),
734
+ np.cumsum(start_of_noise_span_token)) - 1
735
+ tokens = np.where(start_of_noise_span_token,
736
+ sentinel_token_ids[sentinel_idx], tokens)
737
+
738
+ # Remove masked tokens (but preserving the sentinel tokens)
739
+ noised_tokens = tokens[np.logical_not(nonstart_noise_span_token)]
740
+
741
+ # Ensure there's an end-of-sentence token at the end
742
+ if ensure_eos and (noised_tokens[-1] != eos_token_id):
743
+ noised_tokens = np.concatenate(
744
+ [noised_tokens, np.array([eos_token_id])])
745
+ return noised_tokens
746
+
747
+
748
+ def _format_tokens_for_encoder_decoder(
749
+ tokens_inputs: torch.LongTensor,
750
+ tokens_labels: torch.LongTensor,
751
+ max_seq_length: int,
752
+ pad_token_id: int,
753
+ ) -> Dict[str, torch.Tensor]:
754
+ """Package the input/label sequence for an EncDec model."""
755
+ example = {}
756
+ # Re-populate with an empty, padded example
757
+ example['input_ids'] = torch.full((max_seq_length,),
758
+ pad_token_id,
759
+ dtype=torch.int32)
760
+ example['labels'] = torch.full((max_seq_length,),
761
+ _HF_IGNORE_INDEX,
762
+ dtype=torch.int32)
763
+ example['attention_mask'] = torch.zeros_like(example['input_ids'])
764
+ example['decoder_attention_mask'] = torch.zeros_like(example['labels'])
765
+
766
+ # Fill in with processed results (Note: EncDec format is right-padded)
767
+ example['input_ids'][:len(tokens_inputs)] = tokens_inputs
768
+ example['labels'][:len(tokens_labels)] = tokens_labels
769
+ example['attention_mask'][:len(tokens_inputs)] = 1
770
+ example['decoder_attention_mask'][:len(tokens_labels)] = 1
771
+
772
+ # Best practice is to include decoder_input_ids (= right-shifted labels)
773
+ example['decoder_input_ids'] = torch.full_like(example['labels'],
774
+ pad_token_id)
775
+ example['decoder_input_ids'][1:len(tokens_labels)] = tokens_labels[:-1]
776
+ return example
777
+
778
+
779
+ def _format_tokens_for_decoder_only(
780
+ tokens_inputs: torch.LongTensor,
781
+ tokens_labels: torch.LongTensor,
782
+ max_seq_length: int,
783
+ pad_token_id: int,
784
+ padding_side: str,
785
+ ) -> Dict[str, torch.Tensor]:
786
+ """Package the input/label sequence for an decoder-only model."""
787
+ example = {}
788
+ # Re-populate with an empty, padded example
789
+ example['input_ids'] = torch.full((max_seq_length,),
790
+ pad_token_id,
791
+ dtype=torch.int32)
792
+ example['labels'] = torch.full((max_seq_length,),
793
+ _HF_IGNORE_INDEX,
794
+ dtype=torch.int32)
795
+ example['attention_mask'] = torch.full((max_seq_length,),
796
+ 0,
797
+ dtype=torch.bool)
798
+ example['bidirectional_mask'] = torch.full((max_seq_length,),
799
+ 0,
800
+ dtype=torch.bool)
801
+
802
+ n_input = len(tokens_inputs)
803
+ n_label = len(tokens_labels)
804
+ n_concat = n_input + n_label
805
+ assert n_concat <= max_seq_length, f'{n_concat=}, {n_input=}, {n_label=}'
806
+
807
+ tokens_concat = torch.concat([tokens_inputs, tokens_labels], dim=0)
808
+
809
+ # Fill in with the processed results
810
+ if padding_side == 'left':
811
+ example['input_ids'][-n_concat:] = tokens_concat
812
+ # `labels` copies `input_ids` but with -100 at
813
+ # non-loss-generating tokens. `labels` will be shifted in the
814
+ # model code when computing loss.
815
+ example['labels'][-n_concat:] = tokens_concat
816
+ example['labels'][-n_concat:-n_label] = _HF_IGNORE_INDEX
817
+ example['attention_mask'][-n_concat:] = 1
818
+ example['bidirectional_mask'][-n_concat:-n_label] = 1
819
+ else:
820
+ example['input_ids'][:n_concat] = tokens_concat
821
+ # See above comment regarding `labels`
822
+ example['labels'][:n_concat] = tokens_concat
823
+ example['labels'][:n_input] = _HF_IGNORE_INDEX
824
+ example['attention_mask'][:n_concat] = 1
825
+ example['bidirectional_mask'][:n_input] = 1
826
+ return example
827
+
828
+
829
+ # Helpful to test if your dataloader is working locally
830
+ # Run `python denoising.py [local] [remote, optional]` and verify that batches
831
+ # are printed out
832
+ if __name__ == '__main__':
833
+ from llmfoundry.utils.builders import build_tokenizer
834
+
835
+ local = sys.argv[1]
836
+ if len(sys.argv) > 2:
837
+ remote = sys.argv[2]
838
+ else:
839
+ remote = local
840
+ print(f'Reading val split from {remote} -> {local}')
841
+
842
+ decoder_only = True
843
+
844
+ cfg = {
845
+ 'name': 'text_denoising',
846
+ 'dataset': {
847
+ 'local': local,
848
+ 'remote': remote,
849
+ 'split': 'val', # 'val_small',
850
+ 'shuffle': False,
851
+ 'max_seq_len': 2048 if decoder_only else 1024,
852
+ 'packing_ratio': 4.5,
853
+ 'predownload': 1000,
854
+ 'keep_zip': True, # in case we need compressed files after testing
855
+ },
856
+ 'mixture_of_denoisers': {
857
+ 'decoder_only_format': decoder_only,
858
+ 'span_mean_lengths_and_ratios': [[3, .15], [8, .5]],
859
+ 'sequence_mask_ratios': 0.25,
860
+ },
861
+ 'drop_last': False,
862
+ 'num_workers': 0,
863
+ }
864
+ cfg = om.create(cfg)
865
+ device_batch_size = 2
866
+
867
+ tokenizer_name = 'EleutherAI/gpt-neox-20b' if decoder_only else 't5-base'
868
+ tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
869
+ tokenizer = build_tokenizer(tokenizer_name=tokenizer_name,
870
+ tokenizer_kwargs=tokenizer_kwargs)
871
+
872
+ loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
873
+ assert isinstance(loader.dataset, StreamingTextDataset)
874
+
875
+ print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')
876
+
877
+ packing = cfg.dataset.get('packing_ratio') is not None
878
+ if packing:
879
+ tokenizer = loader.collate_fn.base_collator.tokenizer
880
+ else:
881
+ tokenizer = loader.collate_fn.tokenizer
882
+ batch_ix = 0
883
+ for batch in loader:
884
+ if batch_ix >= 50:
885
+ batch_ix += 1
886
+ break
887
+ if batch_ix >= 5:
888
+ if not packing:
889
+ break
890
+ batch_ix += 1
891
+ continue
892
+ print('\n')
893
+ print('#' * 20, f'Batch {batch_ix}', '#' * 20)
894
+ for k, v in batch.items():
895
+ print(k, v.shape, v.dtype)
896
+ for sample_ix, token_sample in enumerate(batch['input_ids']):
897
+ if cfg.mixture_of_denoisers.decoder_only_format:
898
+ labels = batch['labels'][sample_ix]
899
+ attn_inputs = batch['bidirectional_mask'][sample_ix].to(
900
+ torch.bool)
901
+ attn_full = batch['attention_mask'][sample_ix].to(torch.bool)
902
+ attn_labels = torch.logical_xor(attn_inputs, attn_full)
903
+ print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
904
+ if packing:
905
+ for subseq in range(
906
+ int(batch['sequence_id'][sample_ix].max()) + 1):
907
+ is_subseq = batch['sequence_id'][sample_ix] == subseq
908
+ print(
909
+ '\033[93m{}\033[00m\n'.format('Input: '),
910
+ tokenizer.decode(token_sample[torch.logical_and(
911
+ is_subseq, attn_inputs)]))
912
+ print(
913
+ '\033[92m{}\033[00m\n'.format('Target: '),
914
+ tokenizer.decode(labels[torch.logical_and(
915
+ is_subseq, attn_labels)]))
916
+ else:
917
+ print('\033[91m{}\033[00m\n'.format('Full: '),
918
+ tokenizer.decode(token_sample[attn_full]))
919
+ print('\033[93m{}\033[00m\n'.format('Input: '),
920
+ tokenizer.decode(token_sample[attn_inputs]))
921
+ print('\033[92m{}\033[00m\n'.format('Target: '),
922
+ tokenizer.decode(labels[attn_labels]))
923
+ else:
924
+ labels = batch['labels'][sample_ix]
925
+ attn_inputs = batch['attention_mask'][sample_ix].to(torch.bool)
926
+ attn_labels = batch['decoder_attention_mask'][sample_ix].to(
927
+ torch.bool)
928
+ print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
929
+ print('\033[93m{}\033[00m\n'.format('Input: '),
930
+ tokenizer.decode(token_sample[attn_inputs]))
931
+ print('\033[92m{}\033[00m\n'.format('Target: '),
932
+ tokenizer.decode(labels[attn_labels]))
933
+ batch_ix += 1
934
+
935
+ if packing:
936
+ print(f'Padding = {100*(1-loader.collate_fn.efficiency):5.2f}%')
937
+ print(f'Waste = {100*loader.collate_fn.waste:5.2f}%')
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
5
+ from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
6
+
7
+ __all__ = ['Seq2SeqFinetuningCollator', 'build_finetuning_dataloader']
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/collator.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import logging
5
+ import warnings
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import torch
9
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+ # HuggingFace hardcodes the ignore index to -100
14
+ _HF_IGNORE_INDEX = -100
15
+
16
+
17
+ class Seq2SeqFinetuningCollator:
18
+ """A general-purpose collator for sequence-to-sequence training/evaluation.
19
+
20
+ Args:
21
+ tokenizer: A HuggingFace tokenizer. Must have a pad_token set.
22
+ max_seq_len (int): The maximum sequence length of the combined
23
+ context/target sequence (decoder-only format) or of each the
24
+ context sequence and target sequence (encoder-decoder format).
25
+ decoder_only_format (bool): Whether to format the batches for a
26
+ decoder-only model (if True) or an encoder-decoder model (if False).
27
+ allow_pad_trimming (bool, optional): Whether to allow the collator
28
+ to trim padding, which may result in smaller but inconsistent batch
29
+ sizes. Default: ``False`` ensures that all sequences are max_seq_len.
30
+ separator_text (str | bool, optional): If a string is provided, it will
31
+ be used to separate the context and target sequences (appended to end
32
+ of context). If ``True``, will use the tokenizer's sep_token, which must
33
+ be defined. Only applicable for decoder-only formatting.
34
+ format_for_generation (bool, optional): Whether to format the batch such
35
+ that context and target sequences remain separated, which is useful
36
+ when using the context to generate text which should be compared to the
37
+ target (e.g., during evaluation). Default: ``False``.
38
+ batch_metadata (dict, optional): A dictionary of metadata which will be added
39
+ to the batch.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
45
+ max_seq_len: int,
46
+ decoder_only_format: bool,
47
+ allow_pad_trimming: bool = False,
48
+ separator_text: Optional[Union[str, bool]] = None,
49
+ format_for_generation: bool = False,
50
+ batch_metadata: Optional[Dict[str, Any]] = None,
51
+ ):
52
+ self.tokenizer = tokenizer
53
+ self.max_seq_len = max_seq_len
54
+ self.decoder_only_format = decoder_only_format
55
+ self.format_for_generation = format_for_generation
56
+ self.batch_metadata = batch_metadata or {}
57
+
58
+ # Trimming will always be skipped on at least the first __call__
59
+ self._allow_pad_trimming = allow_pad_trimming
60
+ self._seen_first_batch = False
61
+
62
+ illegal_keys = [
63
+ 'input_ids', 'labels', 'attention_mask', 'decoder_input_ids',
64
+ 'decoder_attention_mask', 'generate_output'
65
+ ]
66
+ found_keys = []
67
+ for illegal_key in illegal_keys:
68
+ if illegal_key in self.batch_metadata:
69
+ found_keys.append(illegal_key)
70
+ if found_keys:
71
+ raise ValueError(
72
+ f'The following keys are in batch_metadata but are not allowed: {", ".join(found_keys)}.\n' +\
73
+ f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' +\
74
+ f'{", ".join(illegal_keys)}'
75
+ )
76
+ if self.format_for_generation:
77
+ self.batch_metadata['generate_output'] = True
78
+
79
+ if (max_seq_len % 8) != 0:
80
+ log.warning(
81
+ 'For performance, a max_seq_len as a multiple of 8 is recommended.'
82
+ )
83
+
84
+ if self.tokenizer.pad_token_id is None:
85
+ raise ValueError(
86
+ f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None'
87
+ )
88
+
89
+ self.separator_tokens = []
90
+ if separator_text and decoder_only_format:
91
+ if separator_text == True:
92
+ # Use the tokenizer's sep token or throw an error if undefined
93
+ if self.tokenizer.sep_token_id is None:
94
+ raise ValueError(
95
+ 'Setting separator_text=True requires that the tokenizer has sep_token_id but it has not been set. ' +\
96
+ 'Please pass a string argument for separator_text or set sep_token_id in the tokenizer.'
97
+ )
98
+ self.separator_tokens = [self.tokenizer.sep_token_id]
99
+ else:
100
+ # Convert the string separator_text into token(s)
101
+ self.separator_tokens = tokenizer(
102
+ separator_text, add_special_tokens=False).input_ids
103
+
104
+ self._warned_context = False
105
+ self._warned_target = False
106
+
107
+ def __call__(self, examples: List[Dict[str,
108
+ Any]]) -> Dict[str, torch.Tensor]:
109
+ for check_key in ['input_ids', 'labels', 'attention_mask']:
110
+ if check_key not in examples[0]:
111
+ raise KeyError(
112
+ f'Examples returned by dataset do not include required key: {check_key}'
113
+ )
114
+
115
+ if self.decoder_only_format:
116
+ batch = self._process_and_batch_decoder_only(examples)
117
+ else:
118
+ batch = self._process_and_batch_encoder_decoder(examples)
119
+
120
+ # Add any batch_metadata
121
+ batch_size = batch['input_ids'].shape[0]
122
+ batch.update({
123
+ k: torch.tensor([v] * batch_size)
124
+ for k, v in self.batch_metadata.items()
125
+ })
126
+
127
+ return batch
128
+
129
+ def _process_and_batch_decoder_only(
130
+ self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
131
+ # Steps explained in comments
132
+ processed_examples = []
133
+ for example in examples:
134
+ context = ensure_list(example['input_ids'])
135
+ target = ensure_list(example['labels'])
136
+ # First, get rid of any padding tokens
137
+ context = [t for t in context if t != self.tokenizer.pad_token_id]
138
+ target = [t for t in target if t != self.tokenizer.pad_token_id]
139
+ # Second, append any separator tokens to the context tokens
140
+ if self.separator_tokens:
141
+ context = context + self.separator_tokens
142
+ # Third, ensure that the target text ends with an eos tag
143
+ if target[-1] != self.tokenizer.eos_token_id:
144
+ target = target + [self.tokenizer.eos_token_id]
145
+
146
+ n_context = len(context)
147
+ n_target = len(target)
148
+
149
+ if n_context >= self.max_seq_len:
150
+ if not self._warned_context:
151
+ warnings.warn(
152
+ f'Skipping example because CONTEXT length={n_context} leaves no room ' +\
153
+ f'for TARGET tokens because max_seq_len={self.max_seq_len}. ' +\
154
+ f'If this causes downstream issues because of inconsistent batch sizes, ' +\
155
+ f'consider increasing max_seq_len or using example packing.'
156
+ )
157
+ self._warned_context = True
158
+ continue
159
+
160
+ if self.format_for_generation:
161
+ # When formatting for generation, we need to keep input_ids and
162
+ # labels separate. The input_ids (context) will be fed into the
163
+ # generator and the labels will be used by the eval metric.
164
+ input_ids = context[-self.max_seq_len:]
165
+ n_context = len(input_ids)
166
+ attention_mask = [1] * n_context
167
+ bidirectional_mask = [1] * n_context
168
+ # Annoyingly, we need to pad the everything but input_ids
169
+ # and attention_mask ourselves
170
+ i_pad = [self.tokenizer.pad_token_id
171
+ ] * (self.max_seq_len - n_target)
172
+ z_pad = [0] * (self.max_seq_len - n_context)
173
+ if self.tokenizer.padding_side == 'left':
174
+ labels = i_pad + target
175
+ bidirectional_mask = z_pad + bidirectional_mask
176
+ else:
177
+ labels = target + i_pad
178
+ bidirectional_mask = bidirectional_mask + z_pad
179
+
180
+ else:
181
+ # We need to concatenate the context and target to get the
182
+ # full input sequence, cutting off any excess tokens from the
183
+ # end of the target
184
+ if n_context + n_target > self.max_seq_len:
185
+ old_n_target = int(n_target)
186
+ n_target = self.max_seq_len - n_context
187
+ if not self._warned_target:
188
+ warnings.warn(
189
+ f'Truncating TARGET sequence of length={old_n_target} to length={n_target}, ' +\
190
+ f'so context+target fit max_seq_len={self.max_seq_len}. If truncation is ' +\
191
+ f'a problem, consider increasing max_seq_len.')
192
+ self._warned_target = True
193
+ target = target[-n_target:]
194
+ target[-1] = self.tokenizer.eos_token_id
195
+ n_total = n_context + n_target
196
+
197
+ input_ids = context + target
198
+ labels = ([_HF_IGNORE_INDEX] * n_context) + target
199
+ attention_mask = [1] * n_total
200
+ # bidirectional_mask is used by our prefix lm model variants
201
+ bidirectional_mask = ([1] * n_context) + ([0] * n_target)
202
+
203
+ # Annoyingly, we need to pad the everything but input_ids
204
+ # and attention_mask ourselves
205
+ i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total)
206
+ z_pad = [0] * (self.max_seq_len - n_total)
207
+ if self.tokenizer.padding_side == 'left':
208
+ labels = i_pad + labels
209
+ bidirectional_mask = z_pad + bidirectional_mask
210
+ else:
211
+ labels = labels + i_pad
212
+ bidirectional_mask = bidirectional_mask + z_pad
213
+
214
+ # Update the example
215
+ example['input_ids'] = input_ids
216
+ example['labels'] = labels
217
+ example['attention_mask'] = attention_mask
218
+ example['bidirectional_mask'] = bidirectional_mask
219
+
220
+ processed_examples.append(example)
221
+
222
+ batch = self.tokenizer.pad(
223
+ processed_examples,
224
+ padding='max_length',
225
+ max_length=self.max_seq_len,
226
+ return_tensors='pt',
227
+ )
228
+
229
+ # This logic prevents trimming on at least the first batch
230
+ if not (self._allow_pad_trimming and self._seen_first_batch):
231
+ self._seen_first_batch = True
232
+ return batch
233
+ self._seen_first_batch = True
234
+
235
+ # The batch is ready, but we can trim padding for efficiency
236
+ multiple_of = 8
237
+
238
+ n_non_padding = batch['attention_mask'].sum(dim=1).max()
239
+ keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
240
+ for k, v in batch.items():
241
+ if len(v.shape) < 2:
242
+ continue
243
+ if k == 'labels' and self.format_for_generation:
244
+ continue
245
+ if self.tokenizer.padding_side == 'left':
246
+ batch[k] = v[:, -keep_tokens:].contiguous()
247
+ else:
248
+ batch[k] = v[:, :keep_tokens].contiguous()
249
+
250
+ return batch
251
+
252
+ def _process_and_batch_encoder_decoder(
253
+ self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
254
+ # The encoder-decoder case is has some gotchas.
255
+ # Steps are explained in comments.
256
+ processed_examples = []
257
+ for example in examples:
258
+ context = ensure_list(example['input_ids'])
259
+ target = ensure_list(example['labels'])
260
+ # ... first, get rid of any padding that was already applied
261
+ context = [t for t in context if t != self.tokenizer.pad_token_id]
262
+ target = [t for t in target if t != self.tokenizer.pad_token_id]
263
+ # ... second, ensure that the target text ends with an eos tag
264
+ if target[-1] != self.tokenizer.eos_token_id:
265
+ target = target + [self.tokenizer.eos_token_id]
266
+ # ... third, we need to pad labels ourselves. Because HF.
267
+ if len(target) < self.max_seq_len:
268
+ i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target))
269
+ target = target + i_pad
270
+ else:
271
+ if not self._warned_target:
272
+ warnings.warn(
273
+ f'Truncating TARGET sequence of length={len(target)} ' +\
274
+ f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
275
+ f'a problem, consider increasing max_seq_len.')
276
+ self._warned_target = True
277
+ target = target[:self.max_seq_len -
278
+ 1] + [self.tokenizer.eos_token_id]
279
+
280
+ # We might need to truncate the context. Preserve the beginning.
281
+ if len(context) > self.max_seq_len:
282
+ if not self._warned_context:
283
+ warnings.warn(
284
+ f'Truncating CONTEXT sequence of length={len(context)} ' +\
285
+ f'to max_seq_len={self.max_seq_len}. If truncation is ' +\
286
+ f'a problem, consider increasing max_seq_len.')
287
+ self._warned_context = True
288
+ context = context[:self.max_seq_len -
289
+ 1] + [self.tokenizer.eos_token_id]
290
+
291
+ # Back into the example
292
+ example['input_ids'] = context
293
+ example['attention_mask'] = [1] * len(context)
294
+ example['labels'] = target
295
+
296
+ processed_examples.append(example)
297
+
298
+ # Batch examples into a single dict (this also pads)
299
+ batch = self.tokenizer.pad(
300
+ processed_examples,
301
+ padding='max_length',
302
+ max_length=self.max_seq_len,
303
+ return_tensors='pt',
304
+ )
305
+ # We're still missing decoder_input_ids and decoder_attention_mask
306
+ batch['decoder_input_ids'] = torch.cat([
307
+ torch.full((len(processed_examples), 1),
308
+ self.tokenizer.pad_token_id), batch['labels'][:, :-1]
309
+ ],
310
+ dim=1)
311
+ batch['decoder_input_ids'].masked_fill_(
312
+ batch['decoder_input_ids'] == _HF_IGNORE_INDEX,
313
+ self.tokenizer.pad_token_id)
314
+ batch['decoder_attention_mask'] = torch.not_equal(
315
+ batch['labels'], _HF_IGNORE_INDEX)
316
+
317
+ # This logic prevents trimming on at least the first batch
318
+ if not (self._allow_pad_trimming and self._seen_first_batch):
319
+ self._seen_first_batch = True
320
+ return batch
321
+ self._seen_first_batch = True
322
+
323
+ # The batch is now valid, but we can trim padding for efficiency
324
+ multiple_of = 8
325
+ # (first for the encoder)
326
+ n_non_padding = batch['attention_mask'].sum(dim=1).max()
327
+ keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
328
+ for k in ['input_ids', 'attention_mask']:
329
+ batch[k] = batch[k][:, :keep_tokens].contiguous()
330
+ # (then for the decoder)
331
+ n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max()
332
+ keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of))
333
+ for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']:
334
+ batch[k] = batch[k][:, :keep_tokens].contiguous()
335
+
336
+ return batch
337
+
338
+
339
+ def ensure_list(x: Union[List, torch.Tensor]) -> List:
340
+ if isinstance(x, torch.Tensor):
341
+ x = list(x.flatten())
342
+ assert isinstance(x, list)
343
+ return x
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/dataloader.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import logging
4
+ import os
5
+ from typing import Tuple, Union
6
+
7
+ import datasets as hf_datasets
8
+ import torch
9
+ from composer.utils import dist, get_file, parse_uri
10
+ from omegaconf import DictConfig
11
+ from torch.utils.data import DataLoader
12
+ from transformers import PreTrainedTokenizerBase
13
+
14
+ from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
15
+ from llmfoundry.data.finetuning.tasks import dataset_constructor
16
+ from llmfoundry.data.packing import BinPackWrapper
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+ # HuggingFace hardcodes the ignore index to -100
21
+ _HF_IGNORE_INDEX = -100
22
+
23
+
24
+ def build_finetuning_dataloader(cfg: DictConfig,
25
+ tokenizer: PreTrainedTokenizerBase,
26
+ device_batch_size: int) -> DataLoader:
27
+ """Builds a finetuning dataloader for training or evaluating.
28
+
29
+ The underlying dataset can be built through one of two code paths:
30
+ 1. As a HuggingFace dataset, via `datasets.load_dataset(...)`
31
+ 2. As a streaming dataset
32
+ You will need to set slightly different dataset config fields depending
33
+ on which you intend to use, as explained below.
34
+
35
+ Args:
36
+ cfg (DictConfig): An omegaconf dictionary used to configure the loader:
37
+ cfg.name (str): The type of dataloader to build. Must = "finetuning".
38
+ ---
39
+ *** HuggingFace dataset config fields ***
40
+ cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset
41
+ to use. Can also be a remote http(s) directory or object store bucket
42
+ containing the file {split}.jsonl in the format (prompt, response),
43
+ in which case the builder will create a HuggingFace dataset.
44
+ cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to
45
+ pass to `datasets.load_dataset`, which can be used to load
46
+ a dataset from local files.
47
+ cfg.dataset.preprocessing_fn (str, optional): The name/import path of
48
+ the preprocessing function to use for formatting the data examples.
49
+ If ``None`` (default), the builder will use the preprocessing function
50
+ registered under `hf_name` (see `tasks.py`), if one exists,
51
+ otherwise it will skip preprocessing.
52
+ If `preprocessing_fn` corresponds to a registered preprocessing
53
+ function in `tasks.py`, the builder will use that.
54
+ Otherwise, it will interpret `preprocessing_fn` as a
55
+ "import.path:function_name" import path; e.g., it will call
56
+ `from import.path import function_name` and use the imported
57
+ function as the preprocessing function.
58
+ *** Streaming dataset config fields ***
59
+ cfg.dataset.remote (str, optional): Location of a MDS-formatted
60
+ streaming dataset to use. Setting this will tell the builder
61
+ to create a streaming dataset rather than a HuggingFace dataset.
62
+ cfg.dataset.local (str, optional): Local path where remote data
63
+ will be streamed to. Only valid if `cfg.dataset.remote` has
64
+ also been set.
65
+ *** Shared dataset configs fields ***
66
+ cfg.dataset.max_seq_len (int): The maximum length of sequences
67
+ in the batch. See :class:`Seq2SeqFinetuningCollator` docstring
68
+ for details.
69
+ cfg.dataset.decoder_only_format (bool): Whether to format the
70
+ examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator`
71
+ docstring for details.
72
+ cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow
73
+ the collator to trim padding. See :class:`Seq2SeqFinetuningCollator`
74
+ docstring for details. Default: ``False``.
75
+ cfg.dataset.packing_ratio (float, optional): If provided, this invokes
76
+ a collator wrapper that packs `device_batch_size*packing_ratio`
77
+ raw examples into `device_batch_size` packed examples. This helps
78
+ minimize padding while preserving sequence integrity.
79
+ This adds `sequence_id` to the batch, which indicates which unique
80
+ sequence each token belongs to.
81
+ Note: Using this feature will not change device_batch_size but it
82
+ will determine the number of raw examples consumed by the dataloader
83
+ per batch. Some examples may be discarded if they do not fit when
84
+ packing.
85
+ Select `packing_ratio` **carefully** based on the dataset
86
+ statistics, `max_seq_len`, and tolerance for discarding samples!
87
+ The packing code in `../packing.py` provides a script that can help
88
+ you choose the best `packing_ratio`.
89
+ cfg.dataset.shuffle (bool): Whether to shuffle the dataset.
90
+ ___
91
+ See :class:`StreamingFinetuningDataset` for info on other standard config
92
+ options within `cfg.dataset` that will be passed as kwargs if
93
+ using the streaming codepath.
94
+ ---
95
+ See :class:`DataLoader` for standard argument options to the pytorch
96
+ dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc.
97
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to
98
+ prepare the data from raw text. Any missing sentinel tokens will
99
+ be added by the collator.
100
+ device_batch_size (int): The size of the batches (number of examples)
101
+ that the dataloader will produce.
102
+
103
+ Returns:
104
+ A pytorch dataloader
105
+
106
+ Note:
107
+ You can run the script inside `../packing.py` to quickly test the
108
+ padding/waste rates for different `cfg.dataset.packing_ratio` choices,
109
+ given a starting workload YAML.
110
+ """
111
+ _validate_config(cfg.dataset)
112
+
113
+ # Use EOS as the pad token if none exists
114
+ if tokenizer.pad_token is None:
115
+ tokenizer.pad_token = tokenizer.eos_token
116
+
117
+ dataset = None # for pyright
118
+ if cfg.dataset.get('remote') is not None:
119
+ dataset = dataset_constructor.build_from_streaming(
120
+ tokenizer=tokenizer,
121
+ local=cfg.dataset.local,
122
+ remote=cfg.dataset.get('remote', None),
123
+ split=cfg.dataset.get('split', None),
124
+ download_retry=cfg.dataset.get('download_retry', 2),
125
+ download_timeout=cfg.dataset.get('download_timeout', 60),
126
+ validate_hash=cfg.dataset.get('validate_hash', None),
127
+ keep_zip=cfg.dataset.get('keep_zip', False),
128
+ epoch_size=cfg.dataset.get('epoch_size', None),
129
+ predownload=cfg.dataset.get('predownload', None),
130
+ cache_limit=cfg.dataset.get('cache_limit', None),
131
+ partition_algo=cfg.dataset.get('partition_algo', 'orig'),
132
+ num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
133
+ batch_size=device_batch_size,
134
+ shuffle=cfg.dataset.get('shuffle', False),
135
+ shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'),
136
+ shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
137
+ shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18),
138
+ sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
139
+ sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
140
+ batching_method=cfg.dataset.get('batching_method', 'random'),
141
+ )
142
+
143
+ collate_fn, dataloader_batch_size = _build_collate_fn(
144
+ cfg.dataset, tokenizer, device_batch_size)
145
+
146
+ return DataLoader(
147
+ dataset,
148
+ collate_fn=collate_fn,
149
+ batch_size=dataloader_batch_size,
150
+ drop_last=cfg.drop_last,
151
+ num_workers=cfg.num_workers,
152
+ pin_memory=cfg.get('pin_memory', True),
153
+ prefetch_factor=cfg.get('prefetch_factor', 2),
154
+ persistent_workers=cfg.get('persistent_workers', True),
155
+ timeout=cfg.get('timeout', 0),
156
+ )
157
+
158
+ else:
159
+ backend, _, _ = parse_uri(cfg.dataset.hf_name)
160
+ if backend not in ['', None]:
161
+ if cfg.dataset.get('split') is None:
162
+ raise ValueError(
163
+ 'When using a HuggingFace dataset from a URL, you must set the ' + \
164
+ '`split` key in the dataset config.'
165
+ )
166
+ dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
167
+ else:
168
+ dataset = dataset_constructor.build_from_hf(
169
+ cfg.dataset,
170
+ max_seq_len=cfg.dataset.max_seq_len,
171
+ tokenizer=tokenizer,
172
+ )
173
+
174
+ collate_fn, dataloader_batch_size = _build_collate_fn(
175
+ cfg.dataset, tokenizer, device_batch_size)
176
+
177
+ if cfg.drop_last:
178
+ world_size = dist.get_world_size()
179
+ minimum_dataset_size = world_size * dataloader_batch_size
180
+ if hasattr(dataset, '__len__'):
181
+ full_dataset_size = len(dataset)
182
+ if full_dataset_size < minimum_dataset_size:
183
+ raise ValueError(
184
+ f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
185
+ +
186
+ f'has {full_dataset_size} samples, but your minimum batch size '
187
+ +
188
+ f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
189
+ +
190
+ f'your per device batch size is {dataloader_batch_size}. Please increase the number '
191
+ +
192
+ f'of samples in your dataset to at least {minimum_dataset_size}.'
193
+ )
194
+
195
+ assert dataset is not None
196
+ return DataLoader(
197
+ dataset,
198
+ collate_fn=collate_fn,
199
+ batch_size=dataloader_batch_size,
200
+ drop_last=cfg.drop_last,
201
+ sampler=dist.get_sampler(dataset,
202
+ drop_last=cfg.drop_last,
203
+ shuffle=cfg.dataset.shuffle),
204
+ num_workers=cfg.num_workers,
205
+ pin_memory=cfg.get('pin_memory', True),
206
+ prefetch_factor=cfg.get('prefetch_factor', 2),
207
+ persistent_workers=cfg.get('persistent_workers', True),
208
+ timeout=cfg.get('timeout', 0),
209
+ )
210
+
211
+
212
+ def _validate_config(dataset_cfg: DictConfig) -> None:
213
+ """Validates the dataset configuration.
214
+
215
+ Makes sure that the dataset is properly configured for either
216
+ a HuggingFace dataset or a streaming dataset. Must be valid for one or
217
+ the other.
218
+
219
+ Args:
220
+ dataset_cfg (DictConfig): The dataset configuration to be validated.
221
+
222
+ Raises:
223
+ ValueError: If the dataset configuration does not meet the requirements.
224
+ """
225
+ if dataset_cfg.get('hf_name') is not None:
226
+ # Using the HuggingFace dataset codepath
227
+ illegal_keys = ['local', 'remote']
228
+ discovered_illegal_keys = []
229
+ for key in illegal_keys:
230
+ if dataset_cfg.get(key) is not None:
231
+ discovered_illegal_keys.append('`' + key + '`')
232
+ if discovered_illegal_keys:
233
+ raise ValueError(
234
+ 'The dataset config sets a value for `hf_name` as well as the ' +\
235
+ f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
236
+ 'Those keys are used when building from a streaming dataset, but ' +\
237
+ 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.'
238
+ )
239
+ elif dataset_cfg.get('remote') is not None:
240
+ # Using the streaming dataset codepath
241
+ illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn']
242
+ discovered_illegal_keys = []
243
+ for key in illegal_keys:
244
+ if dataset_cfg.get(key) is not None:
245
+ discovered_illegal_keys.append('`' + key + '`')
246
+ if discovered_illegal_keys:
247
+ raise ValueError(
248
+ 'The dataset config sets a value for `remote` as well as the ' +\
249
+ f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
250
+ 'Those keys are used when building from a HuggingFace dataset, but ' +\
251
+ 'setting `remote` instructs the dataset to build from a streaming dataset.'
252
+ )
253
+ if dataset_cfg.get('local') is None:
254
+ raise ValueError(
255
+ 'Using a streaming dataset requires setting both `remote` and `local`, ' +\
256
+ 'but dataset.local is None.'
257
+ )
258
+ else:
259
+ raise ValueError(
260
+ 'In the dataset config, you must set either `hf_name` to use a ' +\
261
+ 'HuggingFace dataset or set `remote` to use a streaming ' +\
262
+ 'dataset, but both were None.'
263
+ )
264
+
265
+
266
+ def _build_hf_dataset_from_remote(
267
+ cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
268
+ ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
269
+ hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
270
+ """Builds a dataset from a remote object store.
271
+
272
+ This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
273
+ the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
274
+ dataset.
275
+
276
+ The function also ensures synchronicity across multiple processes during the file download. It creates a signal
277
+ file that is used to synchronize the start of the download across different processes. Once the download is
278
+ completed, the function removes the signal file.
279
+
280
+ Args:
281
+ cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
282
+ This includes:
283
+ - dataset.hf_name: The path of the HuggingFace dataset to download.
284
+ - dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
285
+ - dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.
286
+
287
+ tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
288
+
289
+ Returns:
290
+ Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
291
+
292
+ Raises:
293
+ FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
294
+ """
295
+ supported_extensions = ['jsonl', 'csv', 'parquet']
296
+ # HF datasets does not support a split with dashes, so we replace dashes
297
+ # with underscores in the destination split.
298
+ destination_split = cfg.dataset.split.replace('-', '_')
299
+ finetune_dir = os.path.join(
300
+ os.path.dirname(
301
+ os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
302
+ 'downloaded_finetuning',
303
+ destination_split if destination_split != 'data' else 'data_not',
304
+ )
305
+ os.makedirs(finetune_dir, exist_ok=True)
306
+ for extension in supported_extensions:
307
+ name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
308
+ destination = str(
309
+ os.path.abspath(
310
+ os.path.join(
311
+ finetune_dir, 'data',
312
+ f'{destination_split}-00000-of-00001.{extension}')))
313
+
314
+ # Since we don't know exactly what the extension will be, since it is one of a list
315
+ # use a signal file to wait for instead of the desired file
316
+ signal_file_path = os.path.join(
317
+ finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed')
318
+ if dist.get_local_rank() == 0:
319
+ try:
320
+ get_file(path=name, destination=destination, overwrite=True)
321
+ except FileNotFoundError as e:
322
+ if extension == supported_extensions[-1]:
323
+ files_searched = [
324
+ f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}'
325
+ for ext in supported_extensions
326
+ ]
327
+ raise FileNotFoundError(
328
+ f'Could not find a file with any of ' + \
329
+ f'the supported extensions: {supported_extensions}\n' + \
330
+ f'at {files_searched}'
331
+ ) from e
332
+ else:
333
+ log.debug(
334
+ f'Could not find {name}, looking for another extension')
335
+ continue
336
+
337
+ os.makedirs(os.path.dirname(signal_file_path), exist_ok=True)
338
+ with open(signal_file_path, 'wb') as f:
339
+ f.write(b'local_rank0_completed_download')
340
+
341
+ # Avoid the collective call until the local rank zero has finished trying to download the checkpoint
342
+ # so that we don't timeout for large downloads. This syncs all processes on the node
343
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
344
+ # Then, wait to ensure every node has finished downloading the checkpoint
345
+ dist.barrier()
346
+
347
+ # clean up signal file
348
+ if dist.get_local_rank() == 0:
349
+ os.remove(signal_file_path)
350
+ dist.barrier()
351
+
352
+ cfg.dataset.hf_name = finetune_dir
353
+ log.info(cfg.dataset)
354
+ dataset = dataset_constructor.build_from_hf(
355
+ cfg.dataset,
356
+ max_seq_len=cfg.dataset.max_seq_len,
357
+ tokenizer=tokenizer,
358
+ )
359
+ return dataset
360
+
361
+
362
+ def _build_collate_fn(
363
+ dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
364
+ device_batch_size: int
365
+ ) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
366
+ collate_fn = Seq2SeqFinetuningCollator(
367
+ tokenizer=tokenizer,
368
+ max_seq_len=dataset_cfg.max_seq_len,
369
+ decoder_only_format=dataset_cfg.decoder_only_format,
370
+ allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False),
371
+ )
372
+
373
+ packing_ratio = dataset_cfg.get('packing_ratio')
374
+ if packing_ratio is None:
375
+ if dataset_cfg.get('max_leftover_bins_to_keep') is not None:
376
+ raise ValueError(
377
+ 'dataset.max_leftover_bins_to_keep has been defined, ' +\
378
+ 'but dataset.packing_ratio has not been set. Please set ' +\
379
+ 'the latter to turn on packing or remove the former from the config.')
380
+ return collate_fn, device_batch_size
381
+
382
+ if packing_ratio == 1.0:
383
+ return collate_fn, device_batch_size
384
+ elif packing_ratio < 1.0:
385
+ raise ValueError('packing_ratio must be >= 1, if supplied')
386
+
387
+ if not dataset_cfg.decoder_only_format:
388
+ raise NotImplementedError(
389
+ 'On-the-fly packing is currently only supported for decoder-only formats.'
390
+ )
391
+
392
+ collate_fn = BinPackWrapper(
393
+ collator=collate_fn,
394
+ target_batch_size=device_batch_size,
395
+ max_seq_len=dataset_cfg.max_seq_len,
396
+ pad_token_id=tokenizer.pad_token_id,
397
+ padding_side=tokenizer.padding_side,
398
+ max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'),
399
+ )
400
+ n_examples_to_pack = int(device_batch_size * packing_ratio)
401
+ return collate_fn, n_examples_to_pack
402
+
403
+
404
+ if __name__ == '__main__':
405
+ import torch
406
+ from omegaconf import OmegaConf as om
407
+
408
+ from llmfoundry.utils import build_tokenizer
409
+ cfg = om.create({
410
+ 'dataset': {
411
+ 'hf_name':
412
+ 'tatsu-lab/alpaca',
413
+ 'preprocessing_fn':
414
+ 'llmfoundry.data.finetuning.tasks:alpaca_preprocessing_function',
415
+ 'split':
416
+ 'train',
417
+ 'packing_ratio':
418
+ 18.0,
419
+ 'max_seq_len':
420
+ 2048,
421
+ 'decoder_only_format':
422
+ True,
423
+ 'separator_text':
424
+ False,
425
+ 'allow_pad_trimming':
426
+ False,
427
+ 'num_canonical_nodes':
428
+ 472,
429
+ 'shuffle':
430
+ True,
431
+ },
432
+ 'drop_last': False,
433
+ 'num_workers': 0,
434
+ 'pin_memory': False,
435
+ 'prefetch_factor': 2,
436
+ 'persistent_workers': False,
437
+ 'timeout': 0
438
+ })
439
+
440
+ tokenizer_name = 'EleutherAI/gpt-neox-20b'
441
+ tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len}
442
+ tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
443
+
444
+ device_batch_size = 2
445
+ dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
446
+
447
+ packing = cfg.dataset.get('packing_ratio') is not None
448
+
449
+ for i, batch in enumerate(dataloader):
450
+ if i >= 5:
451
+ break
452
+ print(f'-----Batch {i}-----')
453
+ for k, v in batch.items():
454
+ if isinstance(v, torch.Tensor):
455
+ print(k, v.shape)
456
+ else:
457
+ print(k, v)
458
+ for j in range(device_batch_size):
459
+ print(f'--- Sample {j} ---')
460
+ if cfg.dataset.decoder_only_format:
461
+ if packing:
462
+ for subseq in range(int(batch['sequence_id'][j].max()) + 1):
463
+ is_subseq = batch['sequence_id'][j] == subseq
464
+ print(
465
+ '\033[93m{}\033[00m\n'.format('INPUT IDS:'),
466
+ tokenizer.decode(batch['input_ids'][
467
+ j,
468
+ torch.logical_and(
469
+ is_subseq, batch['attention_mask'][j] ==
470
+ 1)],
471
+ skip_special_tokens=False))
472
+ print(
473
+ '\033[92m{}\033[00m\n'.format('CONTEXT: '),
474
+ tokenizer.decode(batch['input_ids'][
475
+ j,
476
+ torch.logical_and(
477
+ is_subseq, batch['bidirectional_mask'][j] ==
478
+ 1)],
479
+ skip_special_tokens=False))
480
+ print(
481
+ '\033[91m{}\033[00m\n'.format('TARGET: '),
482
+ tokenizer.decode(batch['input_ids'][
483
+ j,
484
+ torch.logical_and(
485
+ is_subseq,
486
+ batch['labels'][j] != _HF_IGNORE_INDEX)],
487
+ skip_special_tokens=False))
488
+ else:
489
+ print(
490
+ '\033[93m{}\033[00m\n'.format('INPUT IDS:'),
491
+ tokenizer.decode(
492
+ batch['input_ids'][j,
493
+ batch['attention_mask'][j] == 1],
494
+ skip_special_tokens=False))
495
+ print(
496
+ '\033[92m{}\033[00m\n'.format('CONTEXT: '),
497
+ tokenizer.decode(batch['input_ids'][
498
+ j, batch['bidirectional_mask'][j] == 1],
499
+ skip_special_tokens=False))
500
+ print(
501
+ '\033[91m{}\033[00m\n'.format('TARGET: '),
502
+ tokenizer.decode(batch['input_ids'][
503
+ j, batch['labels'][j] != _HF_IGNORE_INDEX],
504
+ skip_special_tokens=False))
505
+ else:
506
+ print(
507
+ '\033[92m{}\033[00m\n'.format('CONTEXT: '),
508
+ tokenizer.decode(
509
+ batch['input_ids'][j, batch['attention_mask'][j] == 1],
510
+ skip_special_tokens=False))
511
+ print(
512
+ '\033[91m{}\033[00m\n'.format('TARGET: '),
513
+ tokenizer.decode(batch['labels'][
514
+ j, batch['decoder_attention_mask'][j] == 1],
515
+ skip_special_tokens=False))
516
+ print(' ')
Perceptrix/finetune/build/lib/llmfoundry/data/finetuning/tasks.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Includes code for task-specific seq-to-seq data formatting.
5
+
6
+ This file provides some templates/examples of preprocessing functions
7
+ that format examples for use in seq-to-seq finetuning tasks.
8
+ These preprocessing functions take individual examples that contain raw
9
+ text and process them into formatted examples.
10
+
11
+ These functions have this basic structure:
12
+
13
+ def preprocessing_fn(example: Dict) -> Dict[str, str]:
14
+ # code to extract prompt/response from `example`
15
+ ...
16
+ return {
17
+ 'prompt': <prompt>,
18
+ 'response': <response>,
19
+ }
20
+
21
+ where `<prompt>` is a placeholder for the prompt text string that you
22
+ extracted from the input example, and '<response>' is a placeholder for
23
+ the response text string.
24
+
25
+ Just to be clear, "prompt" represents the text you would give the model
26
+ at inference time, and "response" represents the text you are training
27
+ it to produce given the prompt.
28
+
29
+ The key requirement of these functions is that they return a dictionary
30
+ with "prompt" and "response" keys, and that the values associated with
31
+ those keys are strings (i.e. text).
32
+ """
33
+
34
+ import importlib
35
+ import logging
36
+ import os
37
+ import warnings
38
+ from typing import Any, Callable, Dict, List, Optional, Union
39
+
40
+ import datasets as hf_datasets
41
+ from omegaconf import DictConfig
42
+ from streaming import StreamingDataset
43
+ from transformers import PreTrainedTokenizerBase
44
+
45
+ log = logging.getLogger(__name__)
46
+
47
+ __all__ = ['dataset_constructor']
48
+
49
+
50
+ def _tokenize_formatted_example(
51
+ example: Dict[str, Any],
52
+ tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
53
+ if ('prompt' not in example) or ('response' not in example):
54
+ raise KeyError(
55
+ 'Unable to tokenize example because it has not been properly formatted. ' +\
56
+ '"prompt" and "response" are required keys but at least one was missing ' +\
57
+ f'from {example=}.'
58
+ )
59
+ if not isinstance(example['prompt'], str):
60
+ raise TypeError(
61
+ f'Unable to tokenize example because "prompt" was not a string. {example=}'
62
+ )
63
+ if not isinstance(example['response'], str):
64
+ raise TypeError(
65
+ f'Unable to tokenize example because "response" was not a string. {example=}'
66
+ )
67
+ return tokenizer(text=example['prompt'], text_target=example['response'])
68
+
69
+
70
+ class StreamingFinetuningDataset(StreamingDataset):
71
+ """Finetuning dataset with flexible tokenization using StreamingDataset.
72
+
73
+ Args:
74
+ tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
75
+ tokenize samples.
76
+ local (str): Local dataset directory where shards are cached by split.
77
+ remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
78
+ its data must exist locally. StreamingDataset uses either ``streams`` or
79
+ ``remote``/``local``. Defaults to ``None``.
80
+ split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
81
+ the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
82
+ download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
83
+ download_timeout (float): Number of seconds to wait for a shard to download before raising
84
+ an exception. Defaults to ``60``.
85
+ validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
86
+ shards. Defaults to ``None``.
87
+ keep_zip (bool): Whether to keep or delete the compressed form when decompressing
88
+ downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
89
+ `False``.
90
+ epoch_size (int, optional): Number of samples to draw per epoch balanced across all
91
+ streams. If ``None``, takes its value from the total number of underlying samples.
92
+ Provide this field if you are weighting streams relatively to target a larger or
93
+ smaller epoch size. Defaults to ``None``.
94
+ predownload (int, optional): Target number of samples ahead to download the shards of while
95
+ iterating. Defaults to ``100_000``.
96
+ cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
97
+ shard cache. Before downloading a shard, the least recently used resident shard(s) may
98
+ be evicted (deleted from the local cache) in order to stay under the limit. Set to None
99
+ to disable shard eviction. Supports integer bytes as well as string human-readable
100
+ bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
101
+ partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
102
+ num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
103
+ resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
104
+ initial run.
105
+ batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
106
+ partitioned over the workers. Defaults to ``None``.
107
+ shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
108
+ ``False``.
109
+ shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
110
+ shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
111
+ shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
112
+ sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
113
+ Defaults to ``balanced``.
114
+ sampling_granularity (int): When picking samples for a stream's final partial repeat,
115
+ how many samples to pick from the same shard at a time (``1`` for evenly balanced
116
+ across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
117
+ Defaults to ``1``.
118
+ batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
119
+ ``per_stream``. Defaults to ``random``.
120
+ """
121
+
122
+ def __init__(self,
123
+ tokenizer: PreTrainedTokenizerBase,
124
+ local: str,
125
+ remote: Optional[str] = None,
126
+ split: Optional[str] = None,
127
+ download_retry: int = 2,
128
+ download_timeout: float = 60,
129
+ validate_hash: Optional[str] = None,
130
+ keep_zip: bool = False,
131
+ epoch_size: Optional[int] = None,
132
+ predownload: Optional[int] = None,
133
+ cache_limit: Optional[Union[int, str]] = None,
134
+ partition_algo: str = 'orig',
135
+ num_canonical_nodes: Optional[int] = None,
136
+ batch_size: Optional[int] = None,
137
+ shuffle: bool = False,
138
+ shuffle_algo: str = 'py1b',
139
+ shuffle_seed: int = 9176,
140
+ shuffle_block_size: int = 1 << 18,
141
+ sampling_method: str = 'balanced',
142
+ sampling_granularity: int = 1,
143
+ batching_method: str = 'random',
144
+ **kwargs: Any):
145
+
146
+ if len(kwargs) > 0:
147
+ raise ValueError(
148
+ f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}'
149
+ )
150
+
151
+ if remote is None or (local == remote):
152
+ if os.path.isdir(local):
153
+ contents = set(os.listdir(local))
154
+ if split not in contents:
155
+ raise ValueError(
156
+ f'local directory {local} does not contain split {split}'
157
+ )
158
+
159
+ # Build Dataset
160
+ super().__init__(
161
+ local=local,
162
+ remote=remote,
163
+ split=split,
164
+ download_retry=download_retry,
165
+ download_timeout=download_timeout,
166
+ validate_hash=validate_hash,
167
+ keep_zip=keep_zip,
168
+ epoch_size=epoch_size,
169
+ predownload=predownload,
170
+ cache_limit=cache_limit,
171
+ partition_algo=partition_algo,
172
+ num_canonical_nodes=num_canonical_nodes,
173
+ batch_size=batch_size,
174
+ shuffle=shuffle,
175
+ shuffle_algo=shuffle_algo,
176
+ shuffle_seed=shuffle_seed,
177
+ shuffle_block_size=shuffle_block_size,
178
+ sampling_method=sampling_method,
179
+ sampling_granularity=sampling_granularity,
180
+ batching_method=batching_method,
181
+ )
182
+
183
+ self.tokenizer = tokenizer
184
+
185
+ # How to process a sample
186
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
187
+ sample = super().__getitem__(idx)
188
+ return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
189
+
190
+
191
+ class DatasetConstructor:
192
+
193
+ def __init__(self):
194
+ self._task_preprocessing_registry: Dict[str, Callable] = {}
195
+
196
+ def register(self, *names: str) -> Callable[[Callable], Callable]:
197
+ """Decorator for registering preprocessing functions."""
198
+
199
+ def _register_func(name: str, func: Callable) -> None:
200
+ if name in self._task_preprocessing_registry:
201
+ raise ValueError(
202
+ f'A tokenization function has already been registered with {name=}.'
203
+ )
204
+ self._task_preprocessing_registry[name] = func
205
+ return
206
+
207
+ def wrapper(func: Callable) -> Callable:
208
+ for name in names:
209
+ _register_func(name, func)
210
+ return func
211
+
212
+ return wrapper
213
+
214
+ def print_registered_tasks(self) -> None:
215
+ tasks = sorted(self._task_preprocessing_registry.keys())
216
+ print('\n'.join(tasks))
217
+
218
+ def get_preprocessing_fn_from_dict(
219
+ self, mapping: Union[Dict, DictConfig]
220
+ ) -> Callable[[Dict[str, Any]], Dict[str, str]]:
221
+ """Get a preprocessing function from a dictionary.
222
+
223
+ The dictionary maps column names in the dataset to "prompt" and "response".
224
+ For example,
225
+ ```yaml
226
+ preprocessing_fn:
227
+ prompt: text
228
+ response: summary
229
+ ```
230
+ would map the `text` column as to prompt and the `summary` column as the response.
231
+
232
+ Args:
233
+ mapping (dict): A dictionary mapping column names to "prompt" and "response".
234
+
235
+ Returns:
236
+ Callable: The preprocessing function.
237
+
238
+ Raises:
239
+ ValueError: If the mapping does not have keys "prompt" and "response".
240
+ """
241
+
242
+ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:
243
+ if list(mapping.keys()) != ['prompt', 'response']:
244
+ raise ValueError(
245
+ f'Expected {mapping=} to have keys "prompt" and "response".'
246
+ )
247
+ return {
248
+ 'prompt': example[mapping['prompt']],
249
+ 'response': example[mapping['response']]
250
+ }
251
+
252
+ return _preprocessor
253
+
254
+ def get_preprocessing_fn_from_str(
255
+ self,
256
+ preprocessor: Optional[str],
257
+ dataset_name: Optional[str] = None
258
+ ) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
259
+ """Get a preprocessing function from a string.
260
+
261
+ String can be either a registered function or an import path.
262
+
263
+ Args:
264
+ preprocessor (Optional[str]): The name of the preprocessing function, or an import path.
265
+ dataset_name (Optional[str]): The dataset name to look up in the registry.
266
+
267
+ Returns:
268
+ Callable: The preprocessing function or None if not found.
269
+
270
+ Raises:
271
+ ValueError: If the preprocessing function import from the provided string fails.
272
+ """
273
+ if preprocessor is None:
274
+ if dataset_name is None:
275
+ return None
276
+ if dataset_name in self._task_preprocessing_registry:
277
+ log.info(
278
+ f'Re-formatting dataset with "{dataset_name}" preprocessing function.'
279
+ )
280
+ return self._task_preprocessing_registry[dataset_name]
281
+ else:
282
+ log.info('No preprocessor was supplied and no preprocessing function ' +\
283
+ f'is registered for dataset name "{dataset_name}". No additional ' +\
284
+ 'preprocessing will be applied. If the dataset is already formatted ' +\
285
+ 'correctly, you can ignore this message.')
286
+ return None
287
+ if preprocessor in self._task_preprocessing_registry:
288
+ log.info(
289
+ f'Re-formatting dataset with "{preprocessor}" preprocessing function.'
290
+ )
291
+ return self._task_preprocessing_registry[preprocessor]
292
+
293
+ try:
294
+ import_path, function_name = preprocessor.split(':', maxsplit=1)
295
+ module = importlib.import_module(import_path)
296
+ preprocessing_fn = getattr(module, function_name)
297
+ except Exception as e:
298
+ raise ValueError(
299
+ f'Failed to import preprocessing function from string = {preprocessor}.'
300
+ ) from e
301
+
302
+ return preprocessing_fn
303
+
304
+ def build_from_hf(
305
+ self, cfg: DictConfig, max_seq_len: int,
306
+ tokenizer: PreTrainedTokenizerBase
307
+ ) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
308
+ hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
309
+ """Load a HuggingFace Datasets, preprocess, and tokenize.
310
+
311
+ Note: This function will drop examples where the prompt is longer than the max_seq_len
312
+
313
+ Args:
314
+ cfg (DictConfig): The dataset configuration.
315
+ max_seq_len (int): The maximum sequence length. Examples with prompts longer than this will be dropped.
316
+ tokenizer (Tokenizer): The tokenizer to be used for tokenizing the dataset.
317
+
318
+ Returns:
319
+ Dataset: The tokenized dataset.
320
+ """
321
+ dataset_name = cfg.hf_name
322
+ # HF datasets does not support a split with dashes,so we replace split
323
+ # dashes with underscore.
324
+ split = cfg.split.replace('-', '_')
325
+ kwargs = cfg.get('hf_kwargs', {})
326
+ proto_preprocessing_fn = cfg.get('preprocessing_fn')
327
+ if isinstance(proto_preprocessing_fn, dict) or isinstance(
328
+ proto_preprocessing_fn, DictConfig):
329
+ preprocessing_fn = self.get_preprocessing_fn_from_dict(
330
+ proto_preprocessing_fn)
331
+ else:
332
+ preprocessing_fn = self.get_preprocessing_fn_from_str(
333
+ proto_preprocessing_fn, dataset_name)
334
+
335
+ dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)
336
+
337
+ def dataset_mapper(example: Dict):
338
+ if preprocessing_fn is not None:
339
+ example = preprocessing_fn(example)
340
+ return _tokenize_formatted_example(example, tokenizer)
341
+
342
+ columns_to_remove = list(dataset[0].keys())
343
+ tokenized_dataset = dataset.map(
344
+ dataset_mapper,
345
+ batched=False,
346
+ remove_columns=columns_to_remove,
347
+ )
348
+ prompt_length_filtered_dataset = tokenized_dataset.filter(
349
+ lambda example: len(example['input_ids']) < max_seq_len)
350
+
351
+ examples_removed = len(tokenized_dataset) - len(
352
+ prompt_length_filtered_dataset)
353
+ if examples_removed > 0:
354
+ warnings.warn(
355
+ f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
356
+ )
357
+
358
+ empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
359
+ lambda example: len(example['input_ids']) > 0 and len(example[
360
+ 'labels']) > 0 and any(token_id != tokenizer.pad_token_id
361
+ for token_id in example['labels']))
362
+ empty_examples_removed = len(prompt_length_filtered_dataset) - len(
363
+ empty_examples_dropped_dataset)
364
+ if empty_examples_removed > 0:
365
+ warnings.warn(
366
+ f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
367
+ + 'or the response was only padding tokens.')
368
+
369
+ return empty_examples_dropped_dataset
370
+
371
+ def build_from_streaming(self, *args: Any,
372
+ **kwargs: Any) -> StreamingFinetuningDataset:
373
+ return StreamingFinetuningDataset(*args, **kwargs)
374
+
375
+
376
+ dataset_constructor = DatasetConstructor()
377
+
378
+
379
+ @dataset_constructor.register('tatsu-lab/alpaca')
380
+ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
381
+ """Split out prompt/response from text."""
382
+ try:
383
+ prompt, response = inp['text'].split('### Response:')
384
+ prompt += '### Response:'
385
+ except Exception as e:
386
+ raise ValueError(
387
+ f"Unable to extract prompt/response from 'text'={inp['text']}"
388
+ ) from e
389
+ return {'prompt': prompt, 'response': response}
390
+
391
+
392
+ @dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
393
+ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
394
+ """Format the text string."""
395
+ PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
396
+ try:
397
+ if inp['input'] != '':
398
+ instruction = inp['instruction'] + '\n' + inp['input']
399
+ else:
400
+ instruction = inp['instruction']
401
+ prompt = PROMPT_FORMAT.format(instruction=instruction)
402
+ response = inp['output']
403
+ except Exception as e:
404
+ raise ValueError(
405
+ f'Unable to extract prompt/response from {inp=}') from e
406
+ return {'prompt': prompt, 'response': response}
407
+
408
+
409
+ @dataset_constructor.register('bigscience/P3')
410
+ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
411
+ """Format the already-split example."""
412
+ return {
413
+ 'prompt': inp['inputs'] + ':',
414
+ 'response': inp['targets'],
415
+ }
416
+
417
+
418
+ # Muennighoff's P3 and flan datasets share a similar convention
419
+ @dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
420
+ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
421
+ """Format the already-split example."""
422
+ try:
423
+ prompt: str = inp['inputs']
424
+ response: str = inp['targets']
425
+ # Put a space before the response if needed
426
+ transitions = (' ', '\n', '\t')
427
+ if not (prompt.endswith(transitions) or
428
+ response.startswith(transitions)):
429
+ response = ' ' + response
430
+ except Exception as e:
431
+ raise ValueError(
432
+ f'Unable to process prompt/response from {inp=}') from e
433
+ return {'prompt': prompt, 'response': response}
Perceptrix/finetune/build/lib/llmfoundry/data/packing.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ from omegaconf import DictConfig
10
+ from transformers import PreTrainedTokenizerBase
11
+
12
+
13
+ class BinPackWrapper:
14
+ """Utility collator for packing to reduce padding."""
15
+
16
+ def __init__(self,
17
+ collator: Callable,
18
+ target_batch_size: int,
19
+ max_seq_len: int,
20
+ pad_token_id: int,
21
+ padding_side: Literal['left', 'right'],
22
+ max_leftover_bins_to_keep: Optional[int] = None):
23
+ self.base_collator = collator
24
+ self.out_size = int(target_batch_size)
25
+ self.max_seq_len = int(max_seq_len)
26
+ self.pad_token_id = int(pad_token_id)
27
+ self.padding_side = padding_side
28
+
29
+ if self.out_size <= 0:
30
+ raise ValueError(f'{target_batch_size=} must be >0.')
31
+ if self.max_seq_len <= 0:
32
+ raise ValueError(f'{max_seq_len=} must be >0.')
33
+ if self.pad_token_id < 0:
34
+ raise ValueError(f'{pad_token_id=} must be >=0.')
35
+
36
+ if max_leftover_bins_to_keep is None:
37
+ self.max_leftover_bins_to_keep = int(10 * self.out_size)
38
+ elif max_leftover_bins_to_keep < 0:
39
+ raise ValueError(
40
+ f'{max_leftover_bins_to_keep=} must be >=0 or None.')
41
+ else:
42
+ self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep)
43
+
44
+ self.n_packed_tokens = 0
45
+ self.n_total_tokens = 0
46
+ self.n_packed_examples = 0
47
+
48
+ self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []
49
+
50
+ @property
51
+ def waste(self) -> float:
52
+ return 1 - (self.n_packed_tokens / self.n_total_tokens)
53
+
54
+ @property
55
+ def efficiency(self) -> float:
56
+ return self.n_packed_tokens / (self.max_seq_len *
57
+ self.n_packed_examples)
58
+
59
+ def __call__(
60
+ self,
61
+ examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
62
+ batch = self.base_collator(examples)
63
+
64
+ assert 'attention_mask' in batch
65
+ assert 'input_ids' in batch
66
+
67
+ for key in batch.keys():
68
+ assert key in [
69
+ 'input_ids',
70
+ 'labels',
71
+ 'attention_mask',
72
+ 'bidirectional_mask',
73
+ ]
74
+
75
+ # Cut everything down to size
76
+ sizes, trimmed_examples = [], []
77
+ for idx in range(batch['attention_mask'].shape[0]):
78
+ size, trimmed_example = extract_trim_batch_idx(batch, idx)
79
+ sizes.append(size)
80
+ trimmed_examples.append(trimmed_example)
81
+
82
+ # Apply our CS 101 bin packing algorithm.
83
+ packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing(
84
+ sizes=sizes,
85
+ examples=trimmed_examples,
86
+ num_bins=self.out_size,
87
+ max_bin_size=self.max_seq_len,
88
+ existing_bins=self._leftover_bins,
89
+ )
90
+ self.n_packed_tokens += n_packed_tokens
91
+ self.n_total_tokens += n_total_tokens
92
+ self.n_packed_examples += self.out_size
93
+ self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
94
+
95
+ # Re-pad to max_seq_len and batch
96
+ batch = repad(packed_examples,
97
+ max_seq_len=self.max_seq_len,
98
+ pad_token_id=self.pad_token_id,
99
+ padding_side=self.padding_side)
100
+ return batch
101
+
102
+
103
+ def extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
104
+ idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
105
+ example = {k: v[idx] for k, v in batch.items()}
106
+
107
+ keep = example['attention_mask'] == 1
108
+ size = int(keep.sum())
109
+ trim_example = {k: v[keep] for k, v in example.items()}
110
+ trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])
111
+
112
+ return size, trim_example
113
+
114
+
115
+ def combine_in_place(
116
+ example: Dict[str, torch.Tensor],
117
+ add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
118
+ if 'labels' in add_on:
119
+ # Prevents the last token in example from being trained to
120
+ # predict the first token in add_on, which would make no sense.
121
+ add_on['labels'][0] = -100
122
+
123
+ for k in example.keys():
124
+ if k == 'sequence_id':
125
+ example[k] = torch.cat(
126
+ [example[k], add_on[k] + 1 + torch.max(example[k])])
127
+ else:
128
+ example[k] = torch.cat([example[k], add_on[k]])
129
+ return example
130
+
131
+
132
+ def first_fit_bin_packing(
133
+ sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int,
134
+ max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]
135
+ ) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[
136
+ str, torch.Tensor]]]]:
137
+
138
+ # Will contain tuples (bin_size_size, packed_example)
139
+ bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins
140
+
141
+ starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])
142
+
143
+ sizes_and_examples = [
144
+ (size, example) for size, example in zip(sizes, examples)
145
+ ]
146
+ sorted_sizes_and_examples = sorted(sizes_and_examples,
147
+ key=lambda x: x[0],
148
+ reverse=True)
149
+
150
+ required_num_examples = max(0, num_bins - len(bins))
151
+ num_examples = len(sizes)
152
+ if num_examples < required_num_examples:
153
+ for size, example in sorted_sizes_and_examples:
154
+ # Can't keep packing. All remaining items get their own bin.
155
+ bins.append((size, example))
156
+
157
+ total_bin_sizes = sum([bin_size for bin_size, _ in bins])
158
+ total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
159
+ total_example_sizes = sum(sizes)
160
+ if total_new_bin_sizes != total_example_sizes:
161
+ raise AssertionError(
162
+ f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
163
+ )
164
+
165
+ sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
166
+ bin_sizes, packed_examples = [], []
167
+ for bin_size, packed_example in sorted_bins:
168
+ bin_sizes.append(bin_size)
169
+ packed_examples.append(packed_example)
170
+
171
+ # Return:
172
+ # - the num_bins largest packed examples
173
+ # - the total tokens in those examples
174
+ # - the total size of all new examples
175
+ # - leftover bins
176
+ return packed_examples[:num_bins], sum(
177
+ bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]
178
+
179
+ # Go through each item from longest to shortest.
180
+ # Note: all items will either go into an existing or new bin.
181
+ for i, (size, example) in enumerate(sorted_sizes_and_examples):
182
+ # If we can't keep packing, all remaining items get their own bin.
183
+ required_num_examples = max(0, num_bins - len(bins))
184
+ n_remaining = num_examples - i
185
+ assert n_remaining >= required_num_examples
186
+ if n_remaining == required_num_examples:
187
+ # Can't keep packing. All remaining items get their own bin.
188
+ bins.append((size, example))
189
+ continue
190
+
191
+ # Add it to the first bin it fits in
192
+ added = False
193
+ for bidx in range(len(bins)):
194
+ if bins[bidx][0] + size <= max_bin_size:
195
+ bin_size, packed_example = bins.pop(bidx)
196
+ bin_size = bin_size + size
197
+ packed_example = combine_in_place(packed_example, example)
198
+ bins.append((bin_size, packed_example))
199
+ added = True
200
+ break
201
+ # If it didn't fit anywhere, open a new bin
202
+ if not added:
203
+ bins.append((size, example))
204
+
205
+ total_bin_sizes = sum([bin_size for bin_size, _ in bins])
206
+ total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
207
+ total_example_sizes = sum(sizes)
208
+ if total_new_bin_sizes != total_example_sizes:
209
+ raise AssertionError(
210
+ f'Error in packing. {total_example_sizes=} does not equal {total_new_bin_sizes=}.'
211
+ )
212
+
213
+ sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
214
+ bin_sizes, packed_examples = [], []
215
+ for bin_size, packed_example in sorted_bins:
216
+ bin_sizes.append(bin_size)
217
+ packed_examples.append(packed_example)
218
+
219
+ # Return:
220
+ # - the num_bins largest packed examples
221
+ # - the total tokens in those examples
222
+ # - the total size of all new examples
223
+ # - leftover bins
224
+ return packed_examples[:num_bins], sum(
225
+ bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]
226
+
227
+
228
+ def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
229
+ pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
230
+
231
+ def pad_tensor(tensor: torch.Tensor, pad_value: int):
232
+ if len(tensor) == max_seq_len:
233
+ return tensor
234
+ t = torch.full((max_seq_len,),
235
+ pad_value,
236
+ dtype=tensor.dtype,
237
+ device=tensor.device)
238
+ if padding_side == 'left':
239
+ t[-len(tensor):] = tensor
240
+ elif padding_side == 'right':
241
+ t[:len(tensor)] = tensor
242
+ else:
243
+ raise ValueError(f'Unknown {padding_side=}')
244
+ return t
245
+
246
+ pad_vals = {
247
+ 'input_ids': pad_token_id,
248
+ 'labels': -100,
249
+ 'attention_mask': 0,
250
+ 'bidirectional_mask': 0,
251
+ 'sequence_id': -1,
252
+ }
253
+ keys = packed_examples[0].keys()
254
+ batch = {}
255
+ for key in keys:
256
+ batch[key] = torch.stack([
257
+ pad_tensor(example[key], pad_vals[key])
258
+ for example in packed_examples
259
+ ])
260
+ return batch
261
+
262
+
263
+ if __name__ == '__main__':
264
+ from argparse import ArgumentParser, Namespace
265
+
266
+ from omegaconf import OmegaConf as om
267
+
268
+ from llmfoundry import (build_finetuning_dataloader,
269
+ build_text_denoising_dataloader)
270
+ from llmfoundry.data import build_text_dataloader
271
+ from llmfoundry.utils import build_tokenizer
272
+
273
+ def parse_args() -> Namespace:
274
+ """Parse commandline arguments."""
275
+ parser = ArgumentParser(
276
+ description=
277
+ 'Profile packing_ratio choices for a particular workload.')
278
+ parser.add_argument(
279
+ '--yaml-path',
280
+ type=str,
281
+ required=True,
282
+ help='Path to the YAML that defines the workload to profile.')
283
+ parser.add_argument('--num-devices',
284
+ type=int,
285
+ default=None,
286
+ help='How many devices your run will use.')
287
+ parser.add_argument('--min',
288
+ type=float,
289
+ required=True,
290
+ help='Smallest packing_ratio to test. Must be >=1.')
291
+ parser.add_argument(
292
+ '--max',
293
+ type=float,
294
+ required=True,
295
+ help='Largest packing_ratio to test. Must be larger than `min`.')
296
+ parser.add_argument(
297
+ '--num-packing-ratios',
298
+ type=int,
299
+ default=10,
300
+ help=
301
+ 'Number of packing_ratio values (spaced between `min` and `max) to try.'
302
+ )
303
+
304
+ args = parser.parse_args()
305
+
306
+ if not os.path.isfile(args.yaml_path):
307
+ raise FileNotFoundError(
308
+ '`yaml_path` does not correspond to any existing file.')
309
+ if args.num_devices < 1:
310
+ raise ValueError('`num_devices` must be a positive integer.')
311
+ if args.min < 1.0:
312
+ raise ValueError('`min` must be >=1.0.')
313
+ if args.max < args.min:
314
+ raise ValueError('`max` cannot be less than `min`.')
315
+ if args.num_packing_ratios < 1:
316
+ raise ValueError('`num_packing_ratios` must be a positive integer.')
317
+ return args
318
+
319
+ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
320
+ device_batch_size: int):
321
+ if cfg.name == 'text':
322
+ return build_text_dataloader(cfg, tokenizer, device_batch_size)
323
+ elif cfg.name == 'text_denoising':
324
+ return build_text_denoising_dataloader(cfg, tokenizer,
325
+ device_batch_size)
326
+ elif cfg.name == 'finetuning':
327
+ return build_finetuning_dataloader(cfg, tokenizer,
328
+ device_batch_size)
329
+ else:
330
+ raise ValueError(
331
+ f'Not sure how to build dataloader with config: {cfg}')
332
+
333
+ args = parse_args()
334
+
335
+ with open(args.yaml_path) as f:
336
+ cfg = om.load(f)
337
+ if 'parameters' in cfg:
338
+ cfg = om.to_container(cfg.parameters)
339
+ cfg = om.create(cfg)
340
+ device_batch_size = cfg.global_train_batch_size // args.num_devices
341
+
342
+ # Determine the packing_ratio values we'll try
343
+ packing_ratios, raw_batch_sizes = [], []
344
+ for packing_ratio in np.linspace(args.min,
345
+ args.max,
346
+ args.num_packing_ratios,
347
+ endpoint=True):
348
+ packing_ratio = np.round(10 * packing_ratio) / 10
349
+ raw_batch_size = int(packing_ratio * device_batch_size)
350
+ if raw_batch_size not in raw_batch_sizes:
351
+ packing_ratios.append(packing_ratio)
352
+ raw_batch_sizes.append(raw_batch_size)
353
+
354
+ # Fetch a bunch of raw examples once, which we'll re-use
355
+ if 'train_loader' not in cfg:
356
+ raise ValueError('config must define train_loader')
357
+ dataloader_cfg = cfg.train_loader
358
+
359
+ max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep',
360
+ None)
361
+
362
+ # build tokenizer
363
+ if 'tokenizer' not in cfg:
364
+ raise ValueError('config must define tokenizer')
365
+
366
+ resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True)
367
+ if not isinstance(resolved_tokenizer_cfg, Dict):
368
+ raise ValueError(
369
+ 'tokenizer config needs to be resolved by omegaconf into a Dict.')
370
+ tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg
371
+
372
+ tokenizer_name = tokenizer_cfg['name']
373
+ tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
374
+ tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
375
+
376
+ # Turn off packing for the dataloader (we want raw, pre-packed examples)
377
+ dataloader_cfg.dataset.packing_ratio = None
378
+ dataloader_cfg.dataset.max_leftovers_to_keep = None
379
+ train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
380
+ max(raw_batch_sizes) * 100)
381
+
382
+ # Get a bunch of raw examples
383
+ big_batch = next(iter(train_dataloader))
384
+
385
+ def split_big_batch(raw_batch_size: int) -> List:
386
+ input_ids = big_batch['input_ids'].split(raw_batch_size)
387
+ batches = [{'input_ids': x} for x in input_ids]
388
+
389
+ for key in big_batch.keys():
390
+ if key == 'input_ids':
391
+ continue
392
+ for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
393
+ batches[idx].update({key: split})
394
+ return batches
395
+
396
+ def profile_packing(raw_batch_size: int) -> Tuple[float, float]:
397
+ packer = BinPackWrapper(
398
+ collator=lambda x: x,
399
+ target_batch_size=device_batch_size,
400
+ max_seq_len=dataloader_cfg.dataset.max_seq_len,
401
+ pad_token_id=0, # <-- Doesn't need to be correct for profiling
402
+ padding_side='left', # <-- Doesn't need to be correct for profiling
403
+ max_leftover_bins_to_keep=max_leftovers_to_keep)
404
+
405
+ # Simulate feeding the packing collator a bunch of data
406
+ for batch in split_big_batch(raw_batch_size):
407
+ if batch['input_ids'].shape[0] < device_batch_size:
408
+ continue
409
+ _ = packer(batch)
410
+
411
+ # Return the padding / waste stats over that bunch of data
412
+ padding_percent = 100 * (1 - packer.efficiency)
413
+ waste_percent = 100 * packer.waste
414
+ return padding_percent, waste_percent
415
+
416
+ header = '\n\n\n packing_ratio | % PADDING | % WASTE'
417
+ fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%'
418
+
419
+ print(header)
420
+ print('-' * len(header))
421
+ for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
422
+ padding, waste = profile_packing(raw_batch_size)
423
+ print(fstr.format(packing_ratio, padding, waste))
Perceptrix/finetune/build/lib/llmfoundry/data/text_data.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Build a StreamingTextDataset dataset and dataloader for training."""
5
+
6
+ import os
7
+ from itertools import islice
8
+ from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence,
9
+ Union, cast)
10
+
11
+ import numpy as np
12
+ import torch
13
+ import transformers
14
+ from omegaconf import DictConfig
15
+ from omegaconf import OmegaConf as om
16
+ from streaming import Stream, StreamingDataset
17
+ from torch.utils.data import DataLoader
18
+ from transformers import PreTrainedTokenizerBase
19
+
20
+
21
+ class StreamingTextDataset(StreamingDataset):
22
+ """Generic text dataset using MosaicML's StreamingDataset.
23
+
24
+ Args:
25
+ tokenizer (Tokenizer): HuggingFace tokenizer to
26
+ tokenize samples.
27
+ max_seq_len (int): The max sequence length of each sample.
28
+ streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
29
+ which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
30
+ ``remote``/``local``. Defaults to ``None``.
31
+ remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
32
+ its data must exist locally. StreamingDataset uses either ``streams`` or
33
+ ``remote``/``local``. Defaults to ``None``.
34
+ local (str, optional): Local working directory to download shards to. This is where shards
35
+ are cached while they are being used. Uses a temp directory if not set.
36
+ StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
37
+ split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
38
+ the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
39
+ download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
40
+ download_timeout (float): Number of seconds to wait for a shard to download before raising
41
+ an exception. Defaults to ``60``.
42
+ validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
43
+ shards. Defaults to ``None``.
44
+ keep_zip (bool): Whether to keep or delete the compressed form when decompressing
45
+ downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
46
+ `False``.
47
+ epoch_size (int, optional): Number of samples to draw per epoch balanced across all
48
+ streams. If ``None``, takes its value from the total number of underlying samples.
49
+ Provide this field if you are weighting streams relatively to target a larger or
50
+ smaller epoch size. Defaults to ``None``.
51
+ predownload (int, optional): Target number of samples ahead to download the shards of while
52
+ iterating. Defaults to ``100_000``.
53
+ cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's
54
+ shard cache. Before downloading a shard, the least recently used resident shard(s) may
55
+ be evicted (deleted from the local cache) in order to stay under the limit. Set to None
56
+ to disable shard eviction. Supports integer bytes as well as string human-readable
57
+ bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.
58
+ partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``.
59
+ num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
60
+ resumption. Defaults to ``None``, which is interpreted as the number of nodes of the
61
+ initial run.
62
+ batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
63
+ partitioned over the workers. Defaults to ``None``.
64
+ shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
65
+ ``False``.
66
+ shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``.
67
+ shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
68
+ shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
69
+ sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
70
+ Defaults to ``balanced``.
71
+ sampling_granularity (int): When picking samples for a stream's final partial repeat,
72
+ how many samples to pick from the same shard at a time (``1`` for evenly balanced
73
+ across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
74
+ Defaults to ``1``.
75
+ batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
76
+ ``per_stream``. Defaults to ``random``.
77
+ """
78
+
79
+ def __init__(self,
80
+ tokenizer: PreTrainedTokenizerBase,
81
+ max_seq_len: int,
82
+ streams: Optional[Sequence[Stream]] = None,
83
+ remote: Optional[str] = None,
84
+ local: Optional[str] = None,
85
+ split: Optional[str] = None,
86
+ download_retry: int = 2,
87
+ download_timeout: float = 60,
88
+ validate_hash: Optional[str] = None,
89
+ keep_zip: bool = False,
90
+ epoch_size: Optional[int] = None,
91
+ predownload: int = 100_000,
92
+ cache_limit: Optional[Union[int, str]] = None,
93
+ partition_algo: str = 'orig',
94
+ num_canonical_nodes: Optional[int] = None,
95
+ batch_size: Optional[int] = None,
96
+ shuffle: bool = False,
97
+ shuffle_algo: str = 'py1b',
98
+ shuffle_seed: int = 9176,
99
+ shuffle_block_size: int = 1 << 18,
100
+ sampling_method: str = 'balanced',
101
+ sampling_granularity: int = 1,
102
+ batching_method: str = 'random',
103
+ **kwargs: Any):
104
+
105
+ group_method = kwargs.pop('group_method', None)
106
+ if group_method is not None:
107
+ raise NotImplementedError(
108
+ 'group_method is deprecated and has been removed.\nTo ' +
109
+ 'concatenate, use the --concat_tokens ' +
110
+ 'argument when creating your MDS dataset with concat_c4.py')
111
+
112
+ if len(kwargs) > 0:
113
+ raise ValueError(
114
+ f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}'
115
+ )
116
+
117
+ if local is not None and (remote is None or (local == remote)):
118
+ if os.path.isdir(local):
119
+ contents = set(os.listdir(local))
120
+ if split not in contents:
121
+ raise ValueError(
122
+ f'local directory {local} does not contain split {split}'
123
+ )
124
+
125
+ # TODO: discover where yamls are being converted incorrect, but temporary workaround
126
+ if isinstance(shuffle_block_size, float):
127
+ shuffle_block_size = int(shuffle_block_size)
128
+
129
+ # Build Dataset
130
+ super().__init__(
131
+ streams=streams,
132
+ remote=remote,
133
+ local=local,
134
+ split=split,
135
+ download_retry=download_retry,
136
+ download_timeout=download_timeout,
137
+ validate_hash=validate_hash,
138
+ keep_zip=keep_zip,
139
+ epoch_size=epoch_size,
140
+ predownload=predownload,
141
+ cache_limit=cache_limit,
142
+ partition_algo=partition_algo,
143
+ num_canonical_nodes=num_canonical_nodes,
144
+ batch_size=batch_size,
145
+ shuffle=shuffle,
146
+ shuffle_algo=shuffle_algo,
147
+ shuffle_seed=shuffle_seed,
148
+ shuffle_block_size=shuffle_block_size,
149
+ sampling_method=sampling_method,
150
+ sampling_granularity=sampling_granularity,
151
+ batching_method=batching_method,
152
+ )
153
+ self.tokenizer = tokenizer
154
+ self.max_seq_len = max_seq_len
155
+
156
+ # How to tokenize a text sample to a token sample
157
+ def _tokenize(self, text_sample: Mapping) -> Dict[str, List[int]]:
158
+ if self.tokenizer._pad_token is None:
159
+ # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
160
+ raise RuntimeError(
161
+ 'If tokenizing on-the-fly, tokenizer must have a pad_token_id')
162
+
163
+ return self.tokenizer(text_sample['text'],
164
+ truncation=True,
165
+ padding='max_length',
166
+ max_length=self.max_seq_len)
167
+
168
+ def _read_binary_tokenized_sample(self, sample: Dict[str,
169
+ Any]) -> torch.Tensor:
170
+ return torch.from_numpy(
171
+ np.frombuffer(sample['tokens'],
172
+ dtype=np.int64)[:self.max_seq_len].copy())
173
+
174
+ # How to process a sample
175
+ def __getitem__(self,
176
+ idx: int) -> Union[Dict[str, List[int]], torch.Tensor]:
177
+ sample = super().__getitem__(idx)
178
+ if 'text' in sample:
179
+ token_sample = self._tokenize(sample)
180
+ elif 'tokens' in sample:
181
+ token_sample = self._read_binary_tokenized_sample(sample)
182
+ else:
183
+ raise RuntimeError(
184
+ 'StreamingTextDataset needs samples to have a `text` or `tokens` column'
185
+ )
186
+ return token_sample
187
+
188
+
189
+ class ConcatenatedSequenceCollatorWrapper:
190
+ """Collator wrapper to add sequence_id to batch."""
191
+
192
+ def __init__(
193
+ self,
194
+ base_collator: Callable,
195
+ eos_token_id: Optional[int] = None,
196
+ bos_token_id: Optional[int] = None,
197
+ ):
198
+ self.base_collator = base_collator
199
+ if (eos_token_id is None) and (bos_token_id is None):
200
+ raise ValueError(
201
+ 'Must supply a value for either eos_token_id or bos_token_id, but got None for both.'
202
+ )
203
+ if (eos_token_id is not None) and (bos_token_id is not None):
204
+ raise ValueError(
205
+ 'Cannot use *both* EOS and BOS tokens for detecting sequence boundaries. ' +\
206
+ 'Please supply `eos_token_id` if sequences end with an EOS token, or use ' +\
207
+ '`bos_token_id` if sequences start with a BOS token.'
208
+ )
209
+
210
+ if eos_token_id is None:
211
+ self.split_token_id = cast(int, bos_token_id)
212
+ self.bos_mode = True
213
+ else:
214
+ self.split_token_id = eos_token_id
215
+ self.bos_mode = False
216
+
217
+ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
218
+ batch = self.base_collator(examples)
219
+ batch['sequence_id'] = self.get_sequence_id_from_batch(batch)
220
+ return batch
221
+
222
+ def get_sequence_id_from_batch(
223
+ self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
224
+ is_separator = torch.eq(batch['input_ids'], self.split_token_id)
225
+ cumulative_sep = torch.cumsum(is_separator,
226
+ dim=1).to(batch['input_ids'].dtype)
227
+ # If separator token is bos, we're already done
228
+ if self.bos_mode:
229
+ return cumulative_sep
230
+
231
+ # If separator token is eos, right shift 1 space
232
+ left_zeros = cumulative_sep.new_zeros((cumulative_sep.shape[0], 1))
233
+ return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)
234
+
235
+
236
+ def build_text_dataloader(
237
+ cfg: DictConfig,
238
+ tokenizer: PreTrainedTokenizerBase,
239
+ device_batch_size: int,
240
+ ) -> DataLoader:
241
+ assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
242
+ if cfg.dataset.get('group_method', None) is not None:
243
+ raise NotImplementedError(
244
+ 'group_method is deprecated and has been removed.\nTo ' +
245
+ 'concatenate, use the --concat_tokens ' +
246
+ 'argument when creating your MDS dataset with convert_dataset_hf.py'
247
+ )
248
+
249
+ # get kwargs
250
+ streams_dict = cfg.dataset.pop('streams', None)
251
+ mlm_probability = cfg.dataset.pop('mlm_probability', None)
252
+ eos_token_id = cfg.dataset.pop('eos_token_id', None)
253
+ bos_token_id = cfg.dataset.pop('bos_token_id', None)
254
+
255
+ # build streams
256
+ streams = None
257
+ if streams_dict is not None:
258
+ streams = []
259
+ for _, stream in streams_dict.items():
260
+ # stream is the streams kwargs
261
+ # fwd all kwargs with **stream allows streaming to check args
262
+ streams.append(Stream(**stream))
263
+
264
+ # build dataset potentially with streams
265
+ dataset = StreamingTextDataset(
266
+ tokenizer=tokenizer,
267
+ streams=streams,
268
+ batch_size=device_batch_size,
269
+ **cfg.dataset,
270
+ )
271
+
272
+ collate_fn = transformers.DataCollatorForLanguageModeling(
273
+ tokenizer=dataset.tokenizer,
274
+ mlm=mlm_probability is not None,
275
+ mlm_probability=mlm_probability)
276
+
277
+ if (eos_token_id is not None) or (bos_token_id is not None):
278
+ # Note: Will raise an error if both are non-None
279
+ collate_fn = ConcatenatedSequenceCollatorWrapper(
280
+ base_collator=collate_fn,
281
+ eos_token_id=eos_token_id,
282
+ bos_token_id=bos_token_id)
283
+
284
+ return DataLoader(
285
+ dataset,
286
+ collate_fn=collate_fn,
287
+ batch_size=device_batch_size,
288
+ drop_last=cfg.drop_last,
289
+ num_workers=cfg.num_workers,
290
+ pin_memory=cfg.get('pin_memory', True),
291
+ prefetch_factor=cfg.get('prefetch_factor', 2),
292
+ persistent_workers=cfg.get('persistent_workers', True),
293
+ timeout=cfg.get('timeout', 0),
294
+ )
295
+
296
+
297
+ # Helpful to test if your dataloader is working locally
298
+ # Run `python data.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out
299
+ if __name__ == '__main__':
300
+ import argparse
301
+
302
+ from llmfoundry.utils.builders import build_tokenizer
303
+
304
+ parser = argparse.ArgumentParser()
305
+ parser.add_argument('--tokenizer',
306
+ type=str,
307
+ default='EleutherAI/gpt-neox-20b',
308
+ help='the name of the tokenizer to use')
309
+ parser.add_argument('--local_path',
310
+ type=str,
311
+ required=True,
312
+ help='the path to the local copy of the dataset')
313
+ parser.add_argument(
314
+ '--remote_path',
315
+ type=str,
316
+ default=None,
317
+ help='the path to the remote copy to stream from (optional)')
318
+ parser.add_argument('--split',
319
+ type=str,
320
+ default='val',
321
+ help='which split of the dataset to use')
322
+ parser.add_argument('--max_seq_len',
323
+ type=int,
324
+ default=32,
325
+ help='max sequence length to test')
326
+
327
+ args = parser.parse_args()
328
+
329
+ if args.remote_path is not None:
330
+ print(
331
+ f'Reading {args.split} split from {args.local_path} <- streamed from <- {args.remote_path}'
332
+ )
333
+ else:
334
+ print(f'Reading {args.split} split from {args.local_path}')
335
+
336
+ cfg = {
337
+ 'name': 'text',
338
+ 'dataset': {
339
+ 'local': args.local_path,
340
+ 'remote': args.remote_path,
341
+ 'split': args.split,
342
+ 'shuffle': False,
343
+ 'max_seq_len': args.max_seq_len,
344
+ 'keep_zip': True, # in case we need compressed files after testing
345
+ },
346
+ 'drop_last': False,
347
+ 'num_workers': 4,
348
+ }
349
+ cfg = om.create(cfg)
350
+ device_batch_size = 2
351
+
352
+ tokenizer_name = args.tokenizer
353
+ tokenizer_kwargs = {'model_max_length': args.max_seq_len}
354
+ tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
355
+
356
+ loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
357
+ assert isinstance(loader.dataset, StreamingTextDataset)
358
+ tokenizer = loader.dataset.tokenizer
359
+
360
+ for batch_ix, batch in enumerate(islice(loader, 5)):
361
+ print('\n')
362
+ print('#' * 20, f'Batch {batch_ix}', '#' * 20)
363
+ for k, v in batch.items():
364
+ print(k, v.shape, v.dtype)
365
+ for sample_ix, token_sample in enumerate(batch['input_ids']):
366
+ print('-' * 20, f' Sample {sample_ix} ', '-' * 20)
367
+ print(tokenizer.decode(token_sample))
Perceptrix/finetune/build/lib/llmfoundry/models/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from llmfoundry.models.hf import (ComposerHFCausalLM, ComposerHFPrefixLM,
5
+ ComposerHFT5)
6
+ from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
7
+ MPTForCausalLM, MPTModel, MPTPreTrainedModel)
8
+
9
+ __all__ = [
10
+ 'ComposerHFCausalLM',
11
+ 'ComposerHFPrefixLM',
12
+ 'ComposerHFT5',
13
+ 'MPTConfig',
14
+ 'MPTPreTrainedModel',
15
+ 'MPTModel',
16
+ 'MPTForCausalLM',
17
+ 'ComposerMPTCausalLM',
18
+ ]
Perceptrix/finetune/build/lib/llmfoundry/models/hf/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM
5
+ from llmfoundry.models.hf.hf_fsdp import (prepare_hf_causal_lm_model_for_fsdp,
6
+ prepare_hf_enc_dec_model_for_fsdp,
7
+ prepare_hf_model_for_fsdp)
8
+ from llmfoundry.models.hf.hf_prefix_lm import ComposerHFPrefixLM
9
+ from llmfoundry.models.hf.hf_t5 import ComposerHFT5
10
+
11
+ __all__ = [
12
+ 'ComposerHFCausalLM',
13
+ 'ComposerHFPrefixLM',
14
+ 'ComposerHFT5',
15
+ 'prepare_hf_causal_lm_model_for_fsdp',
16
+ 'prepare_hf_enc_dec_model_for_fsdp',
17
+ 'prepare_hf_model_for_fsdp',
18
+ ]
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_causal_lm.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""
5
+
6
+ import logging
7
+ import os
8
+ from typing import Mapping, Union
9
+
10
+ # required for loading a python model into composer
11
+ import transformers
12
+ from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
13
+ InContextLearningLMAccuracy,
14
+ InContextLearningLMExpectedCalibrationError,
15
+ InContextLearningMCExpectedCalibrationError,
16
+ InContextLearningMultipleChoiceAccuracy,
17
+ InContextLearningQAAccuracy,
18
+ LanguageCrossEntropy, LanguagePerplexity)
19
+ from composer.utils import dist
20
+ from omegaconf import DictConfig
21
+ from torch import nn
22
+ from transformers import (AutoConfig, AutoModelForCausalLM,
23
+ PreTrainedTokenizerBase)
24
+
25
+ from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
26
+ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
27
+ from llmfoundry.models.layers.llama_attention_monkeypatch import \
28
+ get_llama_attention_patch_fn
29
+ from llmfoundry.models.utils import init_empty_weights
30
+
31
+ try:
32
+ from peft.peft_model import PeftModel
33
+ model_types = PeftModel, transformers.PreTrainedModel
34
+
35
+ except ImportError:
36
+ model_types = transformers.PreTrainedModel
37
+
38
+ __all__ = ['ComposerHFCausalLM']
39
+
40
+ log = logging.getLogger(__name__)
41
+
42
+
43
+ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
44
+ """Configures a :class:`.HuggingFaceModel` around a Causal LM.
45
+
46
+ Args:
47
+ om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library.
48
+ if DictConfig, the following keys are required:
49
+ cfg.pretrained_model_name_or_path (str): The name of or local path to
50
+ the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel).
51
+ cfg.config_overrides (dict, optional): An optional dictionary of keyword
52
+ arguments that override the default configuration associated with
53
+ cfg.pretrained_model_name_or_path.
54
+ cfg.pretrained (bool): Whether to instantiate the model with pre-trained
55
+ weights coming from cfg.pretrained_model_name_or_path. If ``True``,
56
+ cfg.config_overrides must be compatible with the pre-trained weights.
57
+ cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
58
+ initialize the model on. Currently, `meta` is only supported when
59
+ cfg.pretrained is ``False``. Default: ``'cpu'``.
60
+ tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
61
+ """
62
+
63
+ def __init__(self, om_model_config: Union[DictConfig,
64
+ transformers.PreTrainedModel,
65
+ nn.Module],
66
+ tokenizer: PreTrainedTokenizerBase):
67
+ # set up training and eval metrics
68
+ train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
69
+ eval_metrics = [
70
+ LanguageCrossEntropy(),
71
+ LanguagePerplexity(),
72
+ InContextLearningLMAccuracy(),
73
+ InContextLearningMultipleChoiceAccuracy(),
74
+ InContextLearningQAAccuracy(),
75
+ InContextLearningCodeEvalAccuracy(),
76
+ InContextLearningLMExpectedCalibrationError(),
77
+ InContextLearningMCExpectedCalibrationError()
78
+ ]
79
+
80
+ # if we are passed a DictConfig, we need to instantiate the model
81
+ if isinstance(om_model_config, DictConfig):
82
+ if not om_model_config.get('trust_remote_code',
83
+ True) and om_model_config.get(
84
+ 'pretrained_model_name_or_path',
85
+ None).startswith('mosaicml/mpt'):
86
+ raise ValueError(
87
+ 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
88
+ +
89
+ 'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.'
90
+ )
91
+
92
+ if not om_model_config.get('use_train_metrics', True):
93
+ train_metrics = []
94
+
95
+ # load the model config
96
+ trust_remote_code = om_model_config.get('trust_remote_code', True)
97
+ use_auth_token = om_model_config.get('use_auth_token', False)
98
+ config = AutoConfig.from_pretrained(
99
+ om_model_config.pretrained_model_name_or_path,
100
+ trust_remote_code=trust_remote_code,
101
+ use_auth_token=use_auth_token,
102
+ )
103
+
104
+ # set config overrides
105
+ for k, v in om_model_config.get('config_overrides', {}).items():
106
+ if not hasattr(config, k):
107
+ raise ValueError(
108
+ f'config does not have attribute "{k}" to override ({k}: {v}).'
109
+ )
110
+
111
+ attr = getattr(config, k)
112
+ # attempt to disallow typos in nested configs
113
+ if isinstance(attr, Mapping):
114
+ extra_keys = [
115
+ _k for _k in v.keys() if _k not in attr.keys()
116
+ ]
117
+ if extra_keys:
118
+ raise ValueError(
119
+ f'Config dict override got unknown keys. ' +
120
+ f'Extra keys: {extra_keys}. ' +
121
+ f'Expected (a subset of) keys: {list(attr.keys())}.'
122
+ )
123
+ getattr(config, k).update(v)
124
+ # necessary case to allow for rope_scaling to be overriden in llama config
125
+ elif attr is None and isinstance(v, Mapping):
126
+ setattr(config, k, {})
127
+ getattr(config, k).update(v)
128
+ else:
129
+ setattr(config, k, v)
130
+
131
+ load_in_8bit = om_model_config.get('load_in_8bit', False)
132
+
133
+ # below we set up the device to initialize the model on
134
+ init_device = om_model_config.get('init_device', 'cpu')
135
+
136
+ # Get the device we want to initialize, and use the
137
+ # reolved version to initialize the HF model
138
+ resolved_init_device = hf_get_init_device(init_device)
139
+
140
+ # We need to have all non-zero local ranks be not-pretrained
141
+ # Rank 0 will still be pretrained, and distribute the weights appropriately
142
+ if dist.get_local_rank() != 0 and init_device == 'mixed':
143
+ om_model_config.pretrained = False
144
+
145
+ # initialize the model on the correct device
146
+ if resolved_init_device == 'cpu':
147
+ if om_model_config.pretrained:
148
+ model = AutoModelForCausalLM.from_pretrained(
149
+ om_model_config.pretrained_model_name_or_path,
150
+ trust_remote_code=trust_remote_code,
151
+ use_auth_token=use_auth_token,
152
+ load_in_8bit=load_in_8bit,
153
+ config=config)
154
+ else:
155
+ model = AutoModelForCausalLM.from_config(
156
+ config,
157
+ trust_remote_code=trust_remote_code,
158
+ )
159
+ elif resolved_init_device == 'meta':
160
+ if om_model_config.pretrained:
161
+ raise ValueError(
162
+ 'Setting cfg.pretrained=True is not supported when init_device="meta".'
163
+ )
164
+ with init_empty_weights(include_buffers=False):
165
+ model = AutoModelForCausalLM.from_config(
166
+ config,
167
+ trust_remote_code=trust_remote_code,
168
+ )
169
+ else:
170
+ raise ValueError(
171
+ f'init_device="{init_device}" must be either "cpu" or "meta".'
172
+ )
173
+
174
+ signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
175
+ if dist.get_local_rank() == 0:
176
+ with open(signal_file_path, 'wb') as f:
177
+ f.write(b'local_rank0_completed_download')
178
+
179
+ # Avoid the collective call until the local rank zero has finished trying to download the checkpoint
180
+ # so that we don't timeout for large downloads. This syncs all processes on the node
181
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
182
+ # Then, wait to ensure every node has finished downloading the checkpoint
183
+ dist.barrier()
184
+
185
+ if dist.get_local_rank() == 0:
186
+ os.remove(signal_file_path)
187
+
188
+ z_loss = om_model_config.get('z_loss', 0.0)
189
+
190
+ attention_patch_type = om_model_config.get('attention_patch_type',
191
+ None)
192
+ if attention_patch_type is not None:
193
+ if model.config.model_type != 'llama':
194
+ raise ValueError(
195
+ f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
196
+ )
197
+
198
+ log.debug(
199
+ f'Patching llama attention with {attention_patch_type} attention'
200
+ )
201
+ from transformers.models.llama.modeling_llama import \
202
+ LlamaAttention
203
+ LlamaAttention.forward = get_llama_attention_patch_fn(
204
+ attention_patch_type)
205
+ model.config.use_cache = False
206
+
207
+ # elif the model is either a PeftModel or a PreTrainedModel
208
+ elif isinstance(om_model_config, model_types):
209
+ model = om_model_config
210
+ init_device = 'cpu'
211
+ z_loss = 0.0
212
+
213
+ # else, unsupported type
214
+ else:
215
+ raise ValueError(
216
+ f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
217
+ )
218
+
219
+ composer_model = super().__init__(model=model,
220
+ shift_labels=True,
221
+ tokenizer=tokenizer,
222
+ metrics=train_metrics,
223
+ eval_metrics=eval_metrics,
224
+ z_loss=z_loss,
225
+ init_device=init_device)
226
+
227
+ return composer_model
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_fsdp.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # helper functions from https://github.com/CarperAI/trlx/blob/main/trlx/utils/modeling.py
5
+ # which is MIT licensed
6
+
7
+ import functools
8
+ from typing import Any, Iterable, List, Optional
9
+
10
+ import torch
11
+ from transformers import PreTrainedModel
12
+ from transformers.models.opt.modeling_opt import OPTDecoder
13
+
14
+
15
+ # helper functions
16
+ def rhasattr(obj: Any, attr: str) -> bool:
17
+ """A chain-able attribute version of hasattr.
18
+
19
+ For example, to check if
20
+ `obj` has the attribute `foo.bar.baz`, you can use:
21
+ `rhasattr(obj, "foo.bar.baz")`
22
+ Reference: https://stackoverflow.com/a/67303315
23
+ """
24
+ _nested_attrs = attr.split('.')
25
+ _curr_obj = obj
26
+ for _a in _nested_attrs[:-1]:
27
+ if hasattr(_curr_obj, _a):
28
+ _curr_obj = getattr(_curr_obj, _a)
29
+ else:
30
+ return False
31
+ return hasattr(_curr_obj, _nested_attrs[-1])
32
+
33
+
34
+ def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
35
+ """A chain-able attribute version of getattr.
36
+
37
+ For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
38
+ `rgetattr(obj, "foo.bar.baz")`
39
+ Reference: https://stackoverflow.com/a/31174427
40
+ """
41
+
42
+ def _getattr(obj: Any, attr: str):
43
+ return getattr(obj, attr, *args)
44
+
45
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
46
+
47
+
48
+ def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
49
+ for attr in attrs:
50
+ if rhasattr(obj, attr):
51
+ return rgetattr(obj, attr)
52
+ return None
53
+
54
+
55
+ def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
56
+ """Returns the causal decoder backbone of the specified HuggingFace model.
57
+
58
+ Newer HF models have a `self.get_decoder()` method. Older models do not.
59
+
60
+ NOTE: Different model configurations have different causal decoder attribute
61
+ names.
62
+ - transformer: (GPT2LMHeadModel, GPTJConfig)
63
+ - model.decoder: (OPTConfig, BloomConfig)
64
+ - gpt_neox: (GPTNeoXConfig)
65
+ """
66
+ if hasattr(model, 'get_decoder'):
67
+ return model.get_decoder()
68
+
69
+ decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox')
70
+ causal_base_model = findattr(model, decoder_attrs)
71
+ if causal_base_model is None:
72
+ raise ValueError(
73
+ f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.'
74
+ )
75
+ return causal_base_model
76
+
77
+
78
+ def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
79
+ """Returns the hidden layers of the specified model.
80
+
81
+ NOTE: Different model configurations have different hidden layer attribute names.
82
+ - transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
83
+ - model.decoder.layers: (OPTForCausalLM)
84
+ - gpt_neox.layers: (GPTNeoXForCausalLM)
85
+ - model.layers: (LlaMaForCausalLM)
86
+ - transformer.blocks: (MPTForCausalLM)
87
+ """
88
+ hidden_layers_attrs = (
89
+ 'transformer.h', # BLOOM, GPT2, GPTJ
90
+ 'model.decoder.layers', # OPT
91
+ 'gpt_neox.layers', # GPTNeoX
92
+ 'block', # T5, BART, Pegasus (from encoder)
93
+ 'layers', # ProphetNet, Marian (from encoder)
94
+ 'model.layers', # LLaMa
95
+ 'transformer.blocks', # MPT
96
+ )
97
+ layers = findattr(model, hidden_layers_attrs)
98
+ if layers is None:
99
+ raise ValueError(
100
+ f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}'
101
+ )
102
+ return layers
103
+
104
+
105
+ def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
106
+ """Returns the appropriate device to initialize models."""
107
+ from composer.utils import dist
108
+ if init_device == 'mixed':
109
+ if dist.get_local_rank() == 0:
110
+ return 'cpu'
111
+ return 'meta'
112
+ return init_device
113
+
114
+
115
+ # /end helper functions
116
+
117
+
118
+ def prepare_hf_model_for_fsdp(model: PreTrainedModel,
119
+ init_device: Optional[str]) -> None:
120
+ """FSDP wrap a HuggingFace model.
121
+
122
+ Call specific functions
123
+ """
124
+ if model.config.is_encoder_decoder:
125
+ prepare_hf_enc_dec_model_for_fsdp(model, init_device)
126
+ else:
127
+ # many common decoder-only model do not set the flag
128
+ # model.config.is_decoder, so we can't trust it
129
+ prepare_hf_causal_lm_model_for_fsdp(model, init_device)
130
+
131
+
132
+ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
133
+ init_device: Optional[str]) -> None:
134
+ """FSDP wrap a HuggingFace decoder.
135
+
136
+ Wrap any model for FSDP which follows one of the 3 existing conventions from
137
+ HuggingFace for decoder-only LLMs.
138
+ """
139
+ causal_base_model = hf_get_causal_base_model(model)
140
+
141
+ # OPT has an extra layer of wrapping, so special case here
142
+ if isinstance(causal_base_model, OPTDecoder):
143
+ model.model._fsdp_wrap = False
144
+ model_block = hf_get_hidden_layers(model)
145
+ lm_head = model.get_output_embeddings()
146
+ # some models (OPT) implement .get_input_embeddings for the causal subclass
147
+ # but all of them implement it for the base model
148
+ tied_embeddings = causal_base_model.get_input_embeddings()
149
+ modules = {
150
+ 'base_model': causal_base_model,
151
+ 'model_block': model_block,
152
+ 'lm_head': lm_head,
153
+ 'tied_embeddings': tied_embeddings
154
+ }
155
+
156
+ for mod_name, module in modules.items():
157
+ if module is None:
158
+ raise ValueError(
159
+ f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
160
+ 'follow common layer/weight naming conventions.')
161
+ block_type = type(model_block[0])
162
+ if init_device == 'mixed':
163
+ # For FSDP with models with different device initializations, `mixed`, which
164
+ # initializes the model on rank 0 on `cpu` and on all other ranks on `meta,``
165
+ # we need to tag all child modules that are torch.nn.Modules with `_fsdp_wrap`.
166
+ for child in model.children():
167
+ if isinstance(child, type(causal_base_model)):
168
+ continue
169
+ if isinstance(child, torch.nn.Module):
170
+ child._fsdp_wrap = True
171
+
172
+ for child in causal_base_model.children():
173
+ if isinstance(child, torch.nn.ModuleList):
174
+ continue
175
+ if isinstance(child, torch.nn.Module):
176
+ child._fsdp_wrap = True
177
+
178
+ if model.config.tie_word_embeddings and not model.config.model_type == 'mpt':
179
+ raise ValueError(
180
+ 'The passed in HuggingFaceModel has tied word embeddings ' +
181
+ 'and the passed in initialization device is `mixed.` ' +
182
+ 'In order to support this initialization scheme, we would need to break '
183
+ +
184
+ 'the weight tying. As a result, either use a different initialization scheme '
185
+ + 'or in the model config set `tie_word_embeddings=False.`')
186
+ else:
187
+ # When using the HF LM models,
188
+ # the weights of the self.lm_head and self.transformer.wte are tied.
189
+ # This tying occurs inside the `self.post_init()` function.
190
+ # This is a hurdle for FSDP because they need to be in the same FSDP block
191
+ # These lines ensures that both modules stay together in the top-most block when
192
+ # the model has this tying enabled (almost all do; this property defaults to True)
193
+ if model.config.tie_word_embeddings:
194
+ causal_base_model._fsdp_wrap = False
195
+ tied_embeddings._fsdp_wrap = False
196
+ lm_head._fsdp_wrap = False
197
+
198
+ # FSDP Wrap and Activation Checkpoint every model block
199
+ model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
200
+ model.activation_checkpointing_fn = lambda module: isinstance(
201
+ module, block_type)
202
+
203
+
204
+ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel,
205
+ init_device: Optional[str]) -> None:
206
+ """Wrap an encoder/decoder HF model.
207
+
208
+ This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
209
+ You have model.shared, model.encoder, model.decoder and model.lm_head, where
210
+ model.shared are the embeddings which are tied to model.lm_head, and
211
+ model.shared == model.encoder.embed_tokens and model.shared ==
212
+ model.decoder.embed_tokens
213
+ """
214
+ tied_embeddings = model.get_input_embeddings()
215
+ encoder = model.get_encoder()
216
+ decoder = model.get_decoder()
217
+ lm_head = model.get_output_embeddings()
218
+ # some encoder/decoders have different layers for encoder vs decoder
219
+ encoder_block = hf_get_hidden_layers(encoder)
220
+ decoder_block = hf_get_hidden_layers(decoder)
221
+
222
+ modules = {
223
+ 'encoder': encoder,
224
+ 'decoder': decoder,
225
+ 'encoder_block': encoder_block,
226
+ 'decoder_block': decoder_block,
227
+ 'lm_head': lm_head,
228
+ 'tied_embeddings': tied_embeddings
229
+ }
230
+
231
+ for mod_name, module in modules.items():
232
+ if module is None:
233
+ raise ValueError(
234
+ f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
235
+ 'follow common layer/weight naming conventions.')
236
+ decoder_block_type = type(decoder_block[0])
237
+ encoder_block_type = type(encoder_block[0])
238
+
239
+ if model.config.tie_word_embeddings:
240
+ # it is possible to train an enc/dec without tied embeddings, hence the check
241
+ tied_embeddings._fsdp_wrap = False
242
+ encoder._fsdp_wrap = False
243
+ decoder._fsdp_wrap = False
244
+ lm_head._fsdp_wrap = False
245
+
246
+ # FSDP Wrap and Activation Checkpoint every decoder block
247
+ model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
248
+ model.activation_checkpointing_fn = lambda module: isinstance(
249
+ module, decoder_block_type)
250
+
251
+ if encoder_block_type == decoder_block_type:
252
+ return
253
+
254
+ # need to wrap encoder blocks separately for ProhpetNet and Marian
255
+ model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
256
+ model.activation_checkpointing_fn = lambda module: isinstance(
257
+ module, encoder_block_type)
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_prefix_lm.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Implements a Hugging Prefix LM wrapped inside a :class:`.ComposerModel`."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Mapping, MutableMapping
9
+
10
+ from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
11
+ from composer.utils import dist
12
+ from omegaconf import DictConfig
13
+ from transformers import (AutoConfig, AutoModelForCausalLM,
14
+ PreTrainedTokenizerBase)
15
+
16
+ from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
17
+ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
18
+ from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
19
+ add_bidirectional_mask_if_missing,
20
+ convert_hf_causal_lm_to_prefix_lm,
21
+ init_empty_weights)
22
+
23
+ __all__ = ['ComposerHFPrefixLM']
24
+
25
+ # HuggingFace hardcodes the ignore index to -100
26
+ _HF_IGNORE_INDEX = -100
27
+
28
+
29
+ class ComposerHFPrefixLM(HuggingFaceModelWithZLoss):
30
+ """Configures a :class:`.HuggingFaceModel` around a Prefix LM.
31
+
32
+ Note: HuggingFace does not natively support Prefix LM-style models. This function uses
33
+ `transformers.AutoModelForCausalLM` to instantiate a Causal LM, then uses a conversion utility
34
+ to turn the model into a Prefix LM. Currently, that conversion utility only supports the
35
+ following HuggingFace Causal LM types:
36
+ - `GPT2LMHeadModel`
37
+ - `GPTNeoForCausalLM`
38
+ - `GPTNeoXForCausalLM`
39
+ - `GPTJForCausalLM`
40
+ - `BloomForCausalLM`
41
+ - `OPTForCausalLM`
42
+
43
+ Args:
44
+ cfg (DictConfig): An omegaconf dictionary used to configure the model:
45
+ cfg.pretrained_model_name_or_path (str): The name of or local path to
46
+ the HF model (e.g., `gpt2` to instantiate a GPT2LMHeadModel). The model
47
+ will be converted to a Prefix LM during initialization.
48
+ cfg.config_overrides (dict, optional): An optional dictionary of keyword
49
+ arguments that override the default configuration associated with
50
+ cfg.pretrained_model_name_or_path. Default: ``{}``.
51
+ cfg.pretrained (bool): Whether to instantiate the model with pre-trained
52
+ weights coming from cfg.pretrained_model_name_or_path. If ``True``,
53
+ cfg.config_overrides must be compatible with the pre-trained weights.
54
+ cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
55
+ initialize the model on. Currently, `meta` is only supported when
56
+ cfg.pretrained is ``False``. Default: ``'cpu'``.
57
+ cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
58
+ the z-loss will be multiplied by this value before being added to the
59
+ standard loss term. Default: ``0.0``.
60
+ cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
61
+ of the model/tokenizer to include sentinel tokens that are used in denoising
62
+ tasks like Span Corruption. If you intend to load from an existing Composer
63
+ checkpoint that was trained on such a task, set this to ``True`` to ensure
64
+ that the model vocab size matches your checkpoint's vocab size when loading
65
+ the weights. Default: ``False``.
66
+ tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
67
+ """
68
+
69
+ def __init__(self, om_model_config: DictConfig,
70
+ tokenizer: PreTrainedTokenizerBase):
71
+ config = AutoConfig.from_pretrained(
72
+ om_model_config.pretrained_model_name_or_path,
73
+ trust_remote_code=om_model_config.get('trust_remote_code', True),
74
+ use_auth_token=om_model_config.get('use_auth_token', False),
75
+ )
76
+
77
+ # set config overrides
78
+ for k, v in om_model_config.get('config_overrides', {}).items():
79
+ if not hasattr(config, k):
80
+ raise ValueError(
81
+ f'config does not have attribute "{k}" to override ({k}: {v}).'
82
+ )
83
+
84
+ attr = getattr(config, k)
85
+ if isinstance(attr, Mapping):
86
+ extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
87
+ if extra_keys:
88
+ raise ValueError(
89
+ f'Config dict override got unknown keys. ' +
90
+ f'Extra keys: {extra_keys}. ' +
91
+ f'Expected (a subset of) keys: {list(attr.keys())}.')
92
+ getattr(config, k).update(v)
93
+ else:
94
+ setattr(config, k, v)
95
+
96
+ # Set up the tokenizer (add tokens for denoising sentinels if needed)
97
+ if om_model_config.get('adapt_vocab_for_denoising', False):
98
+ adapt_tokenizer_for_denoising(tokenizer)
99
+
100
+ init_device = om_model_config.get('init_device', 'cpu')
101
+
102
+ # Get the device we want to initialize, and use the
103
+ # resolved version to initialize the HF model
104
+ resolved_init_device = hf_get_init_device(init_device)
105
+
106
+ # We need to have all non-zero local ranks be not-pretrained
107
+ # Rank 0 will still be pretrained, and distribute the weights appropriately
108
+ if dist.get_local_rank() != 0 and init_device == 'mixed':
109
+ om_model_config.pretrained = False
110
+
111
+ if resolved_init_device == 'cpu':
112
+ if om_model_config.pretrained:
113
+ model = AutoModelForCausalLM.from_pretrained(
114
+ om_model_config.pretrained_model_name_or_path,
115
+ config=config)
116
+ else:
117
+ model = AutoModelForCausalLM.from_config(config)
118
+ elif resolved_init_device == 'meta':
119
+ if om_model_config.pretrained:
120
+ raise ValueError(
121
+ 'Setting cfg.pretrained=True is not supported when init_device="meta".'
122
+ )
123
+ with init_empty_weights(include_buffers=False):
124
+ model = AutoModelForCausalLM.from_config(config)
125
+ else:
126
+ raise ValueError(
127
+ f'init_device="{init_device}" must be either "cpu" or "meta".')
128
+
129
+ # Convert the Causal LM into a Prefix LM via our custom wrapper
130
+ model = convert_hf_causal_lm_to_prefix_lm(model)
131
+
132
+ metrics = [
133
+ LanguageCrossEntropy(ignore_index=_HF_IGNORE_INDEX),
134
+ MaskedAccuracy(ignore_index=_HF_IGNORE_INDEX)
135
+ ]
136
+
137
+ composer_model = super().__init__(model=model,
138
+ shift_labels=True,
139
+ tokenizer=tokenizer,
140
+ metrics=metrics,
141
+ z_loss=om_model_config.get(
142
+ 'z_loss', 0.0),
143
+ init_device=init_device)
144
+
145
+ return composer_model
146
+
147
+ def forward(self, batch: MutableMapping):
148
+ # Add bidirectional_mask if it is missing and can be constructed
149
+ add_bidirectional_mask_if_missing(batch)
150
+ return super().forward(batch)
Perceptrix/finetune/build/lib/llmfoundry/models/hf/hf_t5.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Implements a Hugging Face T5 wrapped inside a :class:`.ComposerModel`."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Mapping
9
+
10
+ from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
11
+ from composer.utils import dist
12
+ from omegaconf import DictConfig
13
+ from transformers import (AutoConfig, PreTrainedTokenizerBase,
14
+ T5ForConditionalGeneration)
15
+
16
+ from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
17
+ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
18
+ from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
19
+ init_empty_weights)
20
+
21
+ __all__ = ['ComposerHFT5']
22
+
23
+ # HuggingFace hardcodes the ignore index to -100
24
+ _HF_IGNORE_INDEX = -100
25
+
26
+
27
+ class ComposerHFT5(HuggingFaceModelWithZLoss):
28
+ """Configures a :class:`.HuggingFaceModel` around a T5.
29
+
30
+ Note: This function uses `transformers.T5ForConditionalGeneration`. Future releases
31
+ will expand support to more general classes of HF Encoder-Decoder models.
32
+
33
+ Args:
34
+ cfg (DictConfig): An omegaconf dictionary used to configure the model:
35
+ cfg.pretrained_model_name_or_path (str): The name of or local path to
36
+ the HF model (e.g., `t5-base` to instantiate a T5 using the base config).
37
+ cfg.config_overrides (dict, optional): An optional dictionary of keyword
38
+ arguments that override the default configuration associated with
39
+ cfg.pretrained_model_name_or_path. Default: ``{}``.
40
+ cfg.pretrained (bool): Whether to instantiate the model with pre-trained
41
+ weights coming from cfg.pretrained_model_name_or_path. If ``True``,
42
+ cfg.config_overrides must be compatible with the pre-trained weights.
43
+ cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
44
+ initialize the model on. Currently, `meta` is only supported when
45
+ cfg.pretrained is ``False``. Default: ``'cpu'``.
46
+ cfg.z_loss (float, optional): The coefficient of the z-loss. If >0.0, this
47
+ the z-loss will be multiplied by this value before being added to the
48
+ standard loss term. Default: ``0.0``.
49
+ cfg.adapt_vocab_for_denoising (bool, optional): Whether to adapt the vocab
50
+ of the model/tokenizer to include sentinel tokens that are used in denoising
51
+ tasks like Span Corruption. If you intend to load from an existing Composer
52
+ checkpoint that was trained on such a task, set this to ``True`` to ensure
53
+ that the model vocab size matches your checkpoint's vocab size when loading
54
+ the weights. Default: ``False``.
55
+ tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
56
+ """
57
+
58
+ def __init__(self, om_model_config: DictConfig,
59
+ tokenizer: PreTrainedTokenizerBase):
60
+ config = AutoConfig.from_pretrained(
61
+ om_model_config.pretrained_model_name_or_path,
62
+ trust_remote_code=om_model_config.get('trust_remote_code', True),
63
+ use_auth_token=om_model_config.get('use_auth_token', False),
64
+ )
65
+
66
+ # set config overrides
67
+ for k, v in om_model_config.get('config_overrides', {}).items():
68
+ if not hasattr(config, k):
69
+ raise ValueError(
70
+ f'config does not have attribute "{k}" to override ({k}: {v}).'
71
+ )
72
+
73
+ attr = getattr(config, k)
74
+ if isinstance(attr, Mapping):
75
+ extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
76
+ if extra_keys:
77
+ raise ValueError(
78
+ f'Config dict override got unknown keys. ' +
79
+ f'Extra keys: {extra_keys}. ' +
80
+ f'Expected (a subset of) keys: {list(attr.keys())}.')
81
+ getattr(config, k).update(v)
82
+ else:
83
+ setattr(config, k, v)
84
+
85
+ if not config.is_encoder_decoder:
86
+ raise ValueError(f'Model type "hf_t5" currently only supports T5 models ' +\
87
+ f'using configs where `is_encoder_decoder` is ``True``.')
88
+
89
+ # Set up the tokenizer (add tokens for denoising sentinels if needed)
90
+ if om_model_config.get('adapt_vocab_for_denoising', False):
91
+ adapt_tokenizer_for_denoising(tokenizer)
92
+
93
+ init_device = om_model_config.get('init_device', 'cpu')
94
+
95
+ # Get the device we want to initialize, and use the
96
+ # resolved version to initialize the HF model
97
+ resolved_init_device = hf_get_init_device(init_device)
98
+
99
+ # We need to have all non-zero local ranks be not-pretrained
100
+ # Rank 0 will still be pretrained, and distribute the weights appropriately
101
+ if dist.get_local_rank() != 0 and init_device == 'mixed':
102
+ om_model_config.pretrained = False
103
+
104
+ if resolved_init_device == 'cpu':
105
+ if om_model_config.pretrained:
106
+ model = T5ForConditionalGeneration.from_pretrained(
107
+ om_model_config.pretrained_model_name_or_path,
108
+ config=config)
109
+ else:
110
+ model = T5ForConditionalGeneration(config)
111
+ elif resolved_init_device == 'meta':
112
+ if om_model_config.pretrained:
113
+ raise ValueError(
114
+ 'Setting cfg.pretrained=True is not supported when init_device="meta".'
115
+ )
116
+ with init_empty_weights(include_buffers=False):
117
+ model = T5ForConditionalGeneration(config)
118
+ else:
119
+ raise ValueError(
120
+ f'init_device="{init_device}" must be either "cpu" or "meta".')
121
+
122
+ metrics = [
123
+ LanguageCrossEntropy(ignore_index=_HF_IGNORE_INDEX),
124
+ MaskedAccuracy(ignore_index=_HF_IGNORE_INDEX)
125
+ ]
126
+
127
+ composer_model = super().__init__(model=model,
128
+ tokenizer=tokenizer,
129
+ metrics=metrics,
130
+ z_loss=om_model_config.get(
131
+ 'z_loss', 0.0),
132
+ init_device=init_device)
133
+
134
+ return composer_model
Perceptrix/finetune/build/lib/llmfoundry/models/hf/model_wrapper.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Re-usable :class:`.ComposerModel` for LLM HF Models."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import inspect
9
+ from collections import UserDict
10
+ from typing import List, Mapping, Optional
11
+
12
+ import torch
13
+ import transformers
14
+ from composer.models.huggingface import HuggingFaceModel
15
+ from torchmetrics import Metric
16
+ from transformers import PreTrainedTokenizerBase
17
+ from transformers.utils.generic import ModelOutput
18
+
19
+ from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp
20
+
21
+ # HuggingFace hardcodes the ignore index to -100
22
+ _HF_IGNORE_INDEX = -100
23
+
24
+
25
+ class HuggingFaceModelWithZLoss(HuggingFaceModel):
26
+ """Wrapper around HuggingFaceModel.
27
+
28
+ This adds z-loss, which is used in some training contexts,
29
+ and is a convenient way to patch features that are generically
30
+ useful for HF models.
31
+ See use of z_loss in PaLM: https://arxiv.org/abs/2204.02311v3, Section 5.
32
+ Also, from https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666:
33
+ Two uses of z_loss are:
34
+ - To keep the logits from drifting too far from zero, which can cause
35
+ unacceptable roundoff errors in bfloat16.
36
+ - To encourage the logits to be normalized log-probabilities.
37
+
38
+ Handles preparation for FSDP wrapping.
39
+ """
40
+
41
+ def __init__(self,
42
+ model: transformers.PreTrainedModel,
43
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
44
+ metrics: Optional[List[Metric]] = None,
45
+ eval_metrics: Optional[List[Metric]] = None,
46
+ z_loss: float = 0.0,
47
+ shift_labels: bool = False,
48
+ init_device: Optional[str] = None):
49
+ super().__init__(model,
50
+ tokenizer,
51
+ use_logits=True,
52
+ metrics=metrics,
53
+ eval_metrics=eval_metrics,
54
+ shift_labels=shift_labels)
55
+ self.z_loss = float(z_loss)
56
+ if self.z_loss < 0.0:
57
+ raise ValueError(f'z_loss(={z_loss}) cannot be negative.')
58
+
59
+ self.model_forward_args = inspect.getfullargspec(
60
+ self.model.forward).args
61
+ # inspect.getfullargspec HuggingFace quantized model could not return args correctly
62
+ if not self.model_forward_args:
63
+ self.model_forward_args = inspect.signature(
64
+ self.model.forward).parameters.keys()
65
+
66
+ # Note: We need to add the FSDP related attributes to the model AFTER the super init,
67
+ # so that the (possible) embedding resizing doesn't destroy them
68
+ prepare_hf_model_for_fsdp(self.model, init_device)
69
+
70
+ # This provides support for meta initialization when using FSDP
71
+ self.model.param_init_fn = lambda module: self.model._init_weights(
72
+ module)
73
+
74
+ def forward(self, batch: Mapping):
75
+ if isinstance(batch, dict) or isinstance(batch, UserDict):
76
+ # Further input validation is left to the huggingface forward call
77
+ batch = {
78
+ k: v for k, v in batch.items() if k in self.model_forward_args
79
+ }
80
+ output = self.model(**batch) # type: ignore (thirdparty)
81
+ else:
82
+ raise ValueError(
83
+ 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model'
84
+ )
85
+ return output
86
+
87
+ def loss(self, outputs: ModelOutput, batch: Mapping):
88
+ if self.config.use_return_dict:
89
+ loss, logits = outputs['loss'], outputs['logits']
90
+ else:
91
+ # loss is at index 0 in the output tuple, logits are at index 1
92
+ loss, logits = outputs[:2]
93
+ if self.z_loss == 0.0:
94
+ return loss
95
+
96
+ # Add a z_loss to the standard loss
97
+ logits_flat = logits.view(-1, logits.size(-1))
98
+ labels_flat = batch['labels'].view(-1)
99
+ log_z = torch.logsumexp(logits_flat[labels_flat != _HF_IGNORE_INDEX],
100
+ dim=1)
101
+ log_z2 = log_z**2
102
+ z_loss = log_z2.mean() * self.z_loss
103
+ if self.config.use_return_dict:
104
+ outputs['loss'] += z_loss
105
+ return outputs['loss']
106
+ else:
107
+ outputs[0] += z_loss
108
+ return outputs[0]
Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from llmfoundry.models.inference_api_wrapper.interface import \
5
+ InferenceAPIEvalWrapper
6
+ from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
7
+ OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper)
8
+
9
+ __all__ = [
10
+ 'OpenAICausalLMEvalWrapper',
11
+ 'OpenAIChatAPIEvalWrapper',
12
+ 'InferenceAPIEvalWrapper',
13
+ ]
Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/interface.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Dict, Optional
5
+
6
+ import torch
7
+ from composer.core.types import Batch
8
+ from composer.metrics import InContextLearningMetric
9
+ from composer.metrics.nlp import (InContextLearningLMAccuracy,
10
+ InContextLearningLMExpectedCalibrationError,
11
+ InContextLearningMCExpectedCalibrationError,
12
+ InContextLearningMultipleChoiceAccuracy,
13
+ InContextLearningQAAccuracy,
14
+ LanguageCrossEntropy, LanguagePerplexity)
15
+ from composer.models import ComposerModel
16
+ from torchmetrics import Metric
17
+ from transformers import AutoTokenizer
18
+
19
+
20
+ class InferenceAPIEvalWrapper(ComposerModel):
21
+
22
+ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
23
+ self.tokenizer = tokenizer
24
+ self.labels = None
25
+ # set up training and eval metrics
26
+ eval_metrics = [
27
+ LanguageCrossEntropy(),
28
+ LanguagePerplexity(),
29
+ InContextLearningLMAccuracy(),
30
+ InContextLearningMultipleChoiceAccuracy(),
31
+ InContextLearningQAAccuracy(),
32
+ InContextLearningLMExpectedCalibrationError(),
33
+ InContextLearningMCExpectedCalibrationError()
34
+ ]
35
+ self.eval_metrics = {
36
+ metric.__class__.__name__: metric for metric in eval_metrics
37
+ }
38
+ super().__init__()
39
+
40
+ def get_metrics(self, is_train: bool = False):
41
+ if is_train:
42
+ raise NotImplementedError(
43
+ 'You cannot use inference wrappers for training')
44
+ else:
45
+ metrics = self.eval_metrics
46
+
47
+ return metrics if metrics else {}
48
+
49
+ def get_next_token_logit_tensor(self,
50
+ prompt: str) -> Optional[torch.Tensor]:
51
+ raise NotImplementedError
52
+
53
+ def rebatch(self, batch: Batch):
54
+ # default is a no-op, but Chat API modifies these
55
+ return batch
56
+
57
+ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
58
+ # If the batch mode is generate, we will generate a requested number of tokens using the underlying
59
+ # model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
60
+ # be returned from eval_forward
61
+ output_logits_batch = []
62
+ for tokens, cont_idxs in zip(batch['input_ids'],
63
+ batch['continuation_indices']):
64
+
65
+ seqlen = tokens.shape[0]
66
+ tokens = tokens.tolist()
67
+ cont_idxs = cont_idxs.tolist()
68
+ expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
69
+ output_logits = torch.nn.functional.one_hot(
70
+ torch.tensor(tokens[1:cont_idxs[0]]),
71
+ num_classes=self.tokenizer.vocab_size)
72
+ for i in range(len(expected_cont_tokens)):
73
+ # decode one token at a time
74
+ prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] +
75
+ expected_cont_tokens[0:i])
76
+ next_logit_tensor = self.get_next_token_logit_tensor(prompt)
77
+ if next_logit_tensor is None:
78
+ continue
79
+ output_logits = torch.cat(
80
+ [output_logits,
81
+ next_logit_tensor.reshape(1, -1)])
82
+ padding = torch.nn.functional.one_hot(
83
+ torch.full((seqlen - output_logits.shape[0],),
84
+ self.tokenizer.pad_token_id),
85
+ num_classes=self.tokenizer.vocab_size)
86
+ output_logits = torch.cat([output_logits, padding])
87
+ output_logits_batch.append(output_logits)
88
+
89
+ return torch.stack(output_logits_batch).to(batch['input_ids'].device)
90
+
91
+ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
92
+ batch = self.rebatch(batch)
93
+ self.labels = batch.pop('labels')
94
+ self.labels[:, :-1] = self.labels[:, 1:].clone()
95
+ self.labels[:, -1] = -100
96
+ if isinstance(metric, InContextLearningMetric) and batch.get(
97
+ 'mode', None) == 'icl_task':
98
+ assert self.labels is not None
99
+ metric.update(batch, outputs, self.labels)
100
+ else:
101
+ raise NotImplementedError(
102
+ 'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task'
103
+ )
104
+
105
+ def forward(self):
106
+ raise NotImplementedError(
107
+ "Inference API wrapper doesn't support forward")
108
+
109
+ def loss(self):
110
+ raise NotImplementedError("Inference API wrapper doesn't support loss")
Perceptrix/finetune/build/lib/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Implements a OpenAI chat and causal LM inference API wrappers."""
5
+
6
+ import logging
7
+ import os
8
+ from time import sleep
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
+ import torch
12
+ from composer.core.types import Batch
13
+ from composer.utils.import_helpers import MissingConditionalImportError
14
+ from transformers import AutoTokenizer
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+ from llmfoundry.models.inference_api_wrapper.interface import \
19
+ InferenceAPIEvalWrapper
20
+
21
+ __all__ = [
22
+ 'OpenAICausalLMEvalWrapper',
23
+ 'OpenAIChatAPIEvalWrapper',
24
+ ]
25
+
26
+ MAX_RETRIES = 10
27
+
28
+
29
+ class OpenAIEvalInterface(InferenceAPIEvalWrapper):
30
+
31
+ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
32
+ super().__init__(model_cfg, tokenizer)
33
+ try:
34
+ import openai
35
+ except ImportError as e:
36
+ raise MissingConditionalImportError(
37
+ extra_deps_group='openai',
38
+ conda_package='openai',
39
+ conda_channel='conda-forge') from e
40
+ openai.api_key = os.getenv('OPENAI_API_KEY')
41
+ self.model_name = model_cfg['version']
42
+
43
+ def generate_completion(self, prompt: str, num_tokens: int):
44
+ raise NotImplementedError()
45
+
46
+ def process_result(self, completion: Optional[dict]):
47
+ raise NotImplementedError()
48
+
49
+ def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):
50
+ completion = self.try_generate_completion(prompt, num_tokens)
51
+ return self.process_result(completion)
52
+
53
+ def try_generate_completion(self, prompt: str, num_tokens: int):
54
+ try:
55
+ from openai.error import RateLimitError
56
+ except ImportError as e:
57
+ raise MissingConditionalImportError(
58
+ extra_deps_group='openai',
59
+ conda_package='openai',
60
+ conda_channel='conda-forge') from e
61
+ tries = 0
62
+ completion = None
63
+ while tries < MAX_RETRIES:
64
+ tries += 1
65
+ try:
66
+
67
+ completion = self.generate_completion(prompt, num_tokens)
68
+ break
69
+ except RateLimitError as e:
70
+ if 'You exceeded your current quota' in str(e._message):
71
+ raise e
72
+ sleep(60)
73
+ continue
74
+ except Exception:
75
+ continue
76
+ return completion
77
+
78
+
79
+ class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):
80
+
81
+ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
82
+ super().__init__(model_cfg, tokenizer)
83
+ try:
84
+ import openai
85
+ except ImportError as e:
86
+ raise MissingConditionalImportError(
87
+ extra_deps_group='openai',
88
+ conda_package='openai',
89
+ conda_channel='conda-forge') from e
90
+
91
+ self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create(
92
+ self.model_name,
93
+ messages=[{
94
+ 'role': 'user',
95
+ 'content': prompt
96
+ }],
97
+ max_tokens=num_tokens,
98
+ temperature=0.0)
99
+
100
+ def retokenize(self, tokens: List[int], cont_idxs: List[int]):
101
+ """Chat API will never respond with a word-initial space.
102
+
103
+ If the continuation tokens begin with a word initial space, we need to
104
+ re-tokenize with the space removed.
105
+ """
106
+ original_len = len(tokens)
107
+ retokenized_continuation = self.tokenizer(
108
+ self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] +
109
+ 1]).strip())['input_ids']
110
+
111
+ # replace the original continuation with the retokenized continuation + padding
112
+ padding = [tokens[-1]] * (
113
+ len(tokens) - len(tokens[:cont_idxs[0]] + retokenized_continuation))
114
+ tokens = tokens[:cont_idxs[0]] + retokenized_continuation + padding
115
+
116
+ if len(tokens) > original_len:
117
+ # this only happens if we were already at max seq len and the continuation got LARGER
118
+ tokens = tokens[-original_len:]
119
+ cont_idxs = list(
120
+ range(original_len - len(retokenized_continuation),
121
+ original_len))
122
+ else:
123
+ cont_idxs = list(
124
+ range(cont_idxs[0],
125
+ cont_idxs[0] + len(retokenized_continuation)))
126
+ return torch.tensor(tokens), torch.tensor(cont_idxs)
127
+
128
+ def rebatch(self, batch: Batch):
129
+ """Chat API tokenization has different behavior than GPT3.
130
+
131
+ Model responses will never begin with spaces even if the continuation is
132
+ expected to, so we need to retokenize the input to account for that.
133
+ """
134
+ new_batch: Dict[str, Union[List[torch.Tensor], torch.Tensor]] = {
135
+ 'input_ids': [],
136
+ 'continuation_indices': [],
137
+ 'labels': []
138
+ }
139
+ for tokens, cont_idxs in zip(batch['input_ids'],
140
+ batch['continuation_indices']):
141
+ tokens, cont_idxs = self.retokenize(tokens.tolist(),
142
+ cont_idxs.tolist())
143
+
144
+ assert isinstance(new_batch['input_ids'], list)
145
+ new_batch['input_ids'].append(tokens)
146
+ assert isinstance(new_batch['labels'], list)
147
+ new_batch['labels'].append(tokens)
148
+ assert isinstance(new_batch['continuation_indices'], list)
149
+ new_batch['continuation_indices'].append(cont_idxs)
150
+
151
+ new_batch.update({
152
+ k: torch.stack(new_batch[k]) # pyright: ignore
153
+ for k in ['input_ids', 'labels']
154
+ })
155
+
156
+ new_batch.update({k: v for k, v in batch.items() if k not in new_batch})
157
+
158
+ return new_batch
159
+
160
+ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
161
+ # Override the base class because Chat's API always strips spacing from model outputs resulting in different tokens
162
+ # than what the continuation would expect.
163
+ # Get around this issue by retokenizing the batch to remove spacing from the continuation as well as
164
+ # decoding the whole continuation at once.
165
+ output_logits_batch = []
166
+ batch = self.rebatch(batch)
167
+ for tokens, cont_idxs in zip(batch['input_ids'],
168
+ batch['continuation_indices']):
169
+
170
+ seqlen = tokens.shape[0]
171
+ tokens = tokens.tolist()
172
+ cont_idxs = cont_idxs.tolist()
173
+ expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
174
+ output_logits = torch.nn.functional.one_hot(
175
+ torch.tensor(tokens[1:cont_idxs[0]]),
176
+ num_classes=self.tokenizer.vocab_size)
177
+
178
+ prompt = self.tokenizer.decode(tokens[:cont_idxs[0]])
179
+ next_logit_tensor = self.get_next_token_logit_tensor(
180
+ prompt, num_tokens=len(expected_cont_tokens))
181
+
182
+ if next_logit_tensor is not None:
183
+ output_logits = torch.cat([output_logits, next_logit_tensor])
184
+ padding = torch.nn.functional.one_hot(
185
+ torch.full((seqlen - output_logits.shape[0],),
186
+ self.tokenizer.pad_token_id),
187
+ num_classes=self.tokenizer.vocab_size)
188
+ output_logits = torch.cat([output_logits, padding])
189
+ output_logits_batch.append(output_logits)
190
+
191
+ return torch.stack(output_logits_batch).to(batch['input_ids'].device)
192
+
193
+ def process_result(self, completion: Optional[dict]):
194
+ assert isinstance(completion, dict)
195
+ if len(completion['choices']) > 0:
196
+ tensors = []
197
+ for t in self.tokenizer(completion['choices'][0]['message']
198
+ ['content'])['input_ids']:
199
+ tensors.append(
200
+ self.tokenizer.construct_logit_tensor(
201
+ {self.tokenizer.decode([t]): 0.0}))
202
+
203
+ if len(tensors) == 0:
204
+ return None
205
+ return torch.stack(tensors)
206
+ else:
207
+ # the model sometimes stops early even though we are still requesting tokens!
208
+ # not sure if there's a fix
209
+ return None
210
+
211
+
212
+ class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):
213
+
214
+ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
215
+ super().__init__(model_cfg, tokenizer)
216
+ try:
217
+ import openai
218
+ except ImportError as e:
219
+ raise MissingConditionalImportError(
220
+ extra_deps_group='openai',
221
+ conda_package='openai',
222
+ conda_channel='conda-forge') from e
223
+
224
+ self.generate_completion = lambda prompt, num_tokens: openai.Completion.create(
225
+ engine=self.model_name,
226
+ prompt=prompt,
227
+ max_tokens=1,
228
+ logprobs=5,
229
+ temperature=0.0)
230
+
231
+ def process_result(self, completion: Optional[dict]):
232
+ if completion is None:
233
+ raise ValueError("Couldn't generate model output")
234
+
235
+ assert isinstance(completion, dict)
236
+ if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0:
237
+ tensor = self.tokenizer.construct_logit_tensor(
238
+ dict(completion['choices'][0]['logprobs']['top_logprobs'][0]))
239
+ return tensor
240
+ else:
241
+ # the model sometimes stops early even though we are still requesting tokens!
242
+ # not sure if there's a fix
243
+ return None
Perceptrix/finetune/build/lib/llmfoundry/models/layers/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from llmfoundry.models.layers.attention import (
5
+ ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
6
+ attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
7
+ scaled_multihead_dot_product_attention, triton_flash_attn_fn)
8
+ from llmfoundry.models.layers.blocks import MPTBlock
9
+ from llmfoundry.models.layers.custom_embedding import SharedEmbedding
10
+ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
11
+ from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
12
+ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
13
+
14
+ __all__ = [
15
+ 'scaled_multihead_dot_product_attention',
16
+ 'flash_attn_fn',
17
+ 'triton_flash_attn_fn',
18
+ 'MultiheadAttention',
19
+ 'MultiQueryAttention',
20
+ 'attn_bias_shape',
21
+ 'build_attn_bias',
22
+ 'build_alibi_bias',
23
+ 'ATTN_CLASS_REGISTRY',
24
+ 'MPTMLP',
25
+ 'MPTBlock',
26
+ 'NORM_CLASS_REGISTRY',
27
+ 'LPLayerNorm',
28
+ 'FC_CLASS_REGISTRY',
29
+ 'SharedEmbedding',
30
+ 'FFN_CLASS_REGISTRY',
31
+ 'build_ffn',
32
+ ]
Perceptrix/finetune/build/lib/llmfoundry/models/layers/attention.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Attention layers."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import Any, List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+ from packaging import version
14
+ from torch import nn
15
+
16
+ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
17
+ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
18
+
19
+
20
+ def is_flash_v2_installed():
21
+ try:
22
+ import flash_attn as flash_attn
23
+ except:
24
+ return False
25
+ return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
26
+
27
+
28
+ def is_flash_v1_installed():
29
+ try:
30
+ import flash_attn as flash_attn
31
+ except:
32
+ return False
33
+ return version.parse(flash_attn.__version__) < version.parse('2.0.0')
34
+
35
+
36
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
37
+ original_is_causal: bool) -> bool:
38
+ # disable causal when it is not needed
39
+ # necessary for flash & triton for generation with kv_cache
40
+ if original_is_causal and num_query_tokens != num_key_tokens:
41
+ if num_query_tokens != 1:
42
+ raise NotImplementedError(
43
+ 'MPT does not support query and key with different number of tokens, unless number of query tokens is 1.'
44
+ )
45
+ else:
46
+ return False
47
+ return original_is_causal
48
+
49
+
50
+ def repeat_kv_for_gqa(hidden: torch.Tensor, n_rep: int) -> torch.Tensor:
51
+ """Perform repeat of kv heads along a particular dimension.
52
+
53
+ hidden.shape expected to be: (batch size, seq len, kv_n_heads, head_dim)
54
+ n_rep: amount of repetitions of kv_n_heads
55
+ Unlike torch.repeat_interleave, this function avoids allocating new memory.
56
+ """
57
+ if n_rep == 1:
58
+ return hidden
59
+
60
+ b, s, kv_n_heads, d = hidden.shape
61
+
62
+ hidden = hidden[:, :, :, None, :].expand(b, s, kv_n_heads, n_rep, d)
63
+
64
+ return hidden.reshape(b, s, kv_n_heads * n_rep, d)
65
+
66
+
67
+ def scaled_multihead_dot_product_attention(
68
+ query: torch.Tensor,
69
+ key: torch.Tensor,
70
+ value: torch.Tensor,
71
+ n_heads: int,
72
+ kv_n_heads: Optional[int] = None,
73
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
74
+ softmax_scale: Optional[float] = None,
75
+ attn_bias: Optional[torch.Tensor] = None,
76
+ key_padding_mask: Optional[torch.Tensor] = None,
77
+ is_causal: bool = False,
78
+ dropout_p: float = 0.0,
79
+ training: bool = False,
80
+ needs_weights: bool = False,
81
+ multiquery: bool = False,
82
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
83
+ torch.Tensor]]]:
84
+
85
+ if multiquery:
86
+ warnings.warn(
87
+ DeprecationWarning(
88
+ 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
89
+ ))
90
+ kv_n_heads = 1
91
+ elif kv_n_heads is None:
92
+ warnings.warn(
93
+ DeprecationWarning(
94
+ 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
95
+ ))
96
+ kv_n_heads = n_heads
97
+
98
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
99
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
100
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
101
+
102
+ if past_key_value is not None:
103
+ # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
104
+ # kv_cache is therefore stored using that shape.
105
+ # attn_impl: torch stores the kv_cache in the ordering which is most advantageous
106
+ # for its attn computation ie
107
+ # keys are stored as tensors with shape [b, h, d_head, s] and
108
+ # values are stored as tensors with shape [b, h, s, d_head]
109
+ if len(past_key_value) != 0:
110
+ k = torch.cat([past_key_value[0], k], dim=3)
111
+ v = torch.cat([past_key_value[1], v], dim=2)
112
+
113
+ past_key_value = (k, v)
114
+
115
+ b, _, s_q, d = q.shape
116
+ s_k = k.size(-1)
117
+
118
+ # grouped query case
119
+ if kv_n_heads > 1 and kv_n_heads < n_heads:
120
+ # necessary to do a transpose to swap (b h s d) -> (b s h d) for repeat_kv_for_gqa function
121
+ k = repeat_kv_for_gqa(k.transpose(1, 2),
122
+ n_heads // kv_n_heads).transpose(1, 2)
123
+ v = repeat_kv_for_gqa(v.transpose(1, 2),
124
+ n_heads // kv_n_heads).transpose(1, 2)
125
+
126
+ if softmax_scale is None:
127
+ softmax_scale = 1 / math.sqrt(d)
128
+
129
+ attn_weight = q.matmul(k) * softmax_scale
130
+
131
+ if attn_bias is not None:
132
+ # clamp to 0 necessary for torch 2.0 compile()
133
+ _s_q = max(0, attn_bias.size(2) - s_q)
134
+ _s_k = max(0, attn_bias.size(3) - s_k)
135
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
136
+
137
+ if (attn_bias.size(-1) != 1 and
138
+ attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
139
+ attn_bias.size(-2) != s_q):
140
+ raise RuntimeError(
141
+ f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
142
+ )
143
+ attn_weight = attn_weight + attn_bias
144
+
145
+ min_val = torch.finfo(q.dtype).min
146
+
147
+ if key_padding_mask is not None:
148
+ if attn_bias is not None:
149
+ warnings.warn(
150
+ 'Propagating key_padding_mask to the attention module ' +\
151
+ 'and applying it within the attention module can cause ' +\
152
+ 'unnecessary computation/memory usage. Consider integrating ' +\
153
+ 'into attn_bias once and passing that to each attention ' +\
154
+ 'module instead.'
155
+ )
156
+ attn_weight = attn_weight.masked_fill(
157
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
158
+
159
+ if is_causal and (not q.size(2) == 1):
160
+ s = max(s_q, s_k)
161
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
162
+ causal_mask = causal_mask.tril()
163
+ causal_mask = causal_mask.to(torch.bool)
164
+ causal_mask = ~causal_mask
165
+ causal_mask = causal_mask[-s_q:, -s_k:]
166
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
167
+ min_val)
168
+
169
+ attn_weight = torch.softmax(attn_weight, dim=-1)
170
+
171
+ if dropout_p:
172
+ attn_weight = torch.nn.functional.dropout(attn_weight,
173
+ p=dropout_p,
174
+ training=training,
175
+ inplace=True)
176
+
177
+ out = attn_weight.to(v.dtype).matmul(v)
178
+ out = rearrange(out, 'b h s d -> b s (h d)')
179
+
180
+ if needs_weights:
181
+ return out, attn_weight, past_key_value
182
+ return out, None, past_key_value
183
+
184
+
185
+ def check_valid_inputs(*tensors: torch.Tensor,
186
+ valid_dtypes: Optional[List[torch.dtype]] = None):
187
+ if valid_dtypes is None:
188
+ valid_dtypes = [torch.float16, torch.bfloat16]
189
+ for tensor in tensors:
190
+ if tensor.dtype not in valid_dtypes:
191
+ raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
192
+ if not tensor.is_cuda:
193
+ raise TypeError(f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
194
+
195
+
196
+ def flash_attn_fn(
197
+ query: torch.Tensor,
198
+ key: torch.Tensor,
199
+ value: torch.Tensor,
200
+ n_heads: int,
201
+ kv_n_heads: Optional[int] = None,
202
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
203
+ softmax_scale: Optional[float] = None,
204
+ attn_bias: Optional[torch.Tensor] = None,
205
+ key_padding_mask: Optional[torch.Tensor] = None,
206
+ is_causal: bool = False,
207
+ dropout_p: float = 0.0,
208
+ training: bool = False,
209
+ needs_weights: bool = False,
210
+ multiquery: bool = False,
211
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
212
+ torch.Tensor]]]:
213
+ try:
214
+ from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
215
+ except:
216
+ raise RuntimeError(
217
+ 'Please install flash-attn==1.0.9 or flash-attn==2.3.2')
218
+
219
+ check_valid_inputs(query, key, value)
220
+
221
+ if multiquery:
222
+ warnings.warn(
223
+ DeprecationWarning(
224
+ 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
225
+ ))
226
+ kv_n_heads = 1
227
+ elif kv_n_heads is None:
228
+ warnings.warn(
229
+ DeprecationWarning(
230
+ 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
231
+ ))
232
+ kv_n_heads = n_heads
233
+
234
+ if past_key_value is not None:
235
+ if len(past_key_value) != 0:
236
+ key = torch.cat([past_key_value[0], key], dim=1)
237
+ value = torch.cat([past_key_value[1], value], dim=1)
238
+
239
+ past_key_value = (key, value)
240
+
241
+ if attn_bias is not None:
242
+ # clamp to 0 necessary for torch 2.0 compile()
243
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
244
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
245
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
246
+
247
+ if attn_bias is not None:
248
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
249
+
250
+ batch_size, seqlen = query.shape[:2]
251
+
252
+ if key_padding_mask is None:
253
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
254
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
255
+
256
+ query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
257
+ query, query_padding_mask)
258
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
259
+
260
+ key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
261
+ key, key_padding_mask)
262
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
263
+
264
+ value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
265
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
266
+
267
+ # multi-query case
268
+ if kv_n_heads == 1:
269
+ # Expanding a tensor does not allocate new memory, but only creates a new
270
+ # view on the existing tensor where a dimension of size one is expanded
271
+ # to a larger size by setting the stride to 0.
272
+ # - pytorch docs
273
+ #
274
+ # hopefully the kernels can utilize this and we're jot just wasting BW here
275
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
276
+ key_unpad.size(-1))
277
+ value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
278
+ value_unpad.size(-1))
279
+ # grouped query case
280
+ elif kv_n_heads < n_heads:
281
+ # Each query belong to a group of kv heads of group size n_heads // kv_n_heads
282
+ # We repeat each kv head by the group size number to use the underlying MHA kernels
283
+
284
+ # since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d)
285
+ # we use .view to modify {key, value}_unpad appropriately
286
+
287
+ key_unpad = repeat_kv_for_gqa(
288
+ key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
289
+ n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
290
+ value_unpad = repeat_kv_for_gqa(
291
+ value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
292
+ n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
293
+
294
+ dropout_p = dropout_p if training else 0.0
295
+
296
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
297
+
298
+ if is_flash_v1_installed():
299
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
300
+ q=query_unpad,
301
+ k=key_unpad,
302
+ v=value_unpad,
303
+ cu_seqlens_q=cu_seqlens_q,
304
+ cu_seqlens_k=cu_seqlens_k,
305
+ max_seqlen_q=max_seqlen_q,
306
+ max_seqlen_k=max_seqlen_k,
307
+ dropout_p=dropout_p,
308
+ softmax_scale=softmax_scale,
309
+ causal=reset_is_causal,
310
+ return_attn_probs=needs_weights)
311
+ elif is_flash_v2_installed():
312
+ output_unpad = flash_attn_interface.flash_attn_varlen_func(
313
+ q=query_unpad,
314
+ k=key_unpad,
315
+ v=value_unpad,
316
+ cu_seqlens_q=cu_seqlens_q,
317
+ cu_seqlens_k=cu_seqlens_k,
318
+ max_seqlen_q=max_seqlen_q,
319
+ max_seqlen_k=max_seqlen_k,
320
+ dropout_p=dropout_p,
321
+ softmax_scale=softmax_scale,
322
+ causal=reset_is_causal,
323
+ return_attn_probs=needs_weights)
324
+ else:
325
+ raise RuntimeError(
326
+ 'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
327
+
328
+ output = bert_padding.pad_input(
329
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
330
+ seqlen)
331
+ return output, None, past_key_value
332
+
333
+
334
+ def triton_flash_attn_fn(
335
+ query: torch.Tensor,
336
+ key: torch.Tensor,
337
+ value: torch.Tensor,
338
+ n_heads: int,
339
+ kv_n_heads: Optional[int] = None,
340
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
341
+ softmax_scale: Optional[float] = None,
342
+ attn_bias: Optional[torch.Tensor] = None,
343
+ key_padding_mask: Optional[torch.Tensor] = None,
344
+ is_causal: bool = False,
345
+ dropout_p: float = 0.0,
346
+ training: bool = False,
347
+ needs_weights: bool = False,
348
+ multiquery: bool = False,
349
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
350
+ torch.Tensor]]]:
351
+ try:
352
+ from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
353
+ except:
354
+ _installed = False
355
+ if version.parse(torch.__version__) < version.parse('2.0.0'):
356
+ _installed = True
357
+ # if torch1.13.1 revert to using triton flash attn from HazyResearch
358
+ # with flash-attn==1.0.9 and triton==2.0.0.dev20221202
359
+ try:
360
+ from flash_attn.flash_attn_triton import flash_attn_func
361
+ except:
362
+ _installed = False
363
+ if not _installed:
364
+ # installing triton-pre-mlir works for both torch1.13.1 and torch2.0+
365
+ # default recommendation is to install this variant
366
+ raise RuntimeError(
367
+ 'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
368
+ +
369
+ 'and `pip install .[gpu]` if installing from llm-foundry source or '
370
+ +
371
+ '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
372
+ +
373
+ 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
374
+ +
375
+ 'Note: (1) requires you have CMake and PyTorch already installed.'
376
+ )
377
+
378
+ check_valid_inputs(query, key, value)
379
+
380
+ if multiquery:
381
+ warnings.warn(
382
+ DeprecationWarning(
383
+ 'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
384
+ ))
385
+ kv_n_heads = 1
386
+ elif kv_n_heads is None:
387
+ warnings.warn(
388
+ DeprecationWarning(
389
+ 'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
390
+ ))
391
+ kv_n_heads = n_heads
392
+
393
+ if past_key_value is not None:
394
+ if len(past_key_value) != 0:
395
+ key = torch.cat([past_key_value[0], key], dim=1)
396
+ value = torch.cat([past_key_value[1], value], dim=1)
397
+
398
+ past_key_value = (key, value)
399
+
400
+ if attn_bias is not None:
401
+ # clamp to 0 necessary for torch 2.0 compile()
402
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
403
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
404
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
405
+
406
+ if dropout_p:
407
+ raise NotImplementedError(
408
+ f'Dropout not implemented for attn_impl: triton.')
409
+ dropout_p = dropout_p if training else 0.0
410
+
411
+ if needs_weights:
412
+ raise NotImplementedError(
413
+ f'attn_impl: triton cannot return attn weights.')
414
+
415
+ if key_padding_mask is not None:
416
+ warnings.warn(
417
+ 'Propagating key_padding_mask to the attention module ' +\
418
+ 'and applying it within the attention module can cause ' +\
419
+ 'unnecessary computation/memory usage. Consider integrating ' +\
420
+ 'into attn_bias once and passing that to each attention ' +\
421
+ 'module instead.'
422
+ )
423
+ b_size, s_k = key_padding_mask.shape[:2]
424
+
425
+ if attn_bias is None:
426
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
427
+
428
+ attn_bias = attn_bias.masked_fill(
429
+ ~key_padding_mask.view((b_size, 1, 1, s_k)),
430
+ torch.finfo(query.dtype).min)
431
+
432
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
433
+ key = rearrange(key, 'b s (h d) -> b s h d', h=kv_n_heads)
434
+ value = rearrange(value, 'b s (h d) -> b s h d', h=kv_n_heads)
435
+
436
+ # multi-query case
437
+ if kv_n_heads == 1:
438
+ # necessary to repeat instead of expand tensor because
439
+ # output contains NaN in edge cases such as with head dimension = 8
440
+ key = key.repeat(1, 1, n_heads, 1)
441
+ value = value.repeat(1, 1, n_heads, 1)
442
+ # grouped query case
443
+ elif kv_n_heads < n_heads:
444
+ # Each query belong to a group of kv heads of group size n_heads // kv_n_heads
445
+ # We repeat each kv head by the group size number to use the underlying MHA kernels
446
+ key = repeat_kv_for_gqa(key, n_heads // kv_n_heads)
447
+ value = repeat_kv_for_gqa(value, n_heads // kv_n_heads)
448
+
449
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
450
+ attn_output = flash_attn_func( # type: ignore
451
+ query, key, value, attn_bias, reset_is_causal, softmax_scale)
452
+
453
+ output = attn_output.view(*attn_output.shape[:2], -1) # type: ignore
454
+
455
+ return output, None, past_key_value
456
+
457
+
458
+ class GroupedQueryAttention(nn.Module):
459
+ """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
460
+
461
+ and Multi-query attention (MQA).
462
+
463
+ This allows the user to set a variable of number of kv_n_heads, rather than
464
+ just n_heads or 1, as in MHA and MQA. Using torch or triton attention
465
+ implementation enables user to also use additive bias.
466
+ """
467
+
468
+ def __init__(
469
+ self,
470
+ d_model: int,
471
+ n_heads: int,
472
+ kv_n_heads: int,
473
+ attn_impl: str = 'triton',
474
+ clip_qkv: Optional[float] = None,
475
+ qk_ln: bool = False,
476
+ softmax_scale: Optional[float] = None,
477
+ attn_pdrop: float = 0.0,
478
+ norm_type: str = 'low_precision_layernorm',
479
+ fc_type: str = 'torch',
480
+ device: Optional[str] = None,
481
+ bias: bool = True,
482
+ ):
483
+ super().__init__()
484
+
485
+ self.attn_impl = attn_impl
486
+ self.clip_qkv = clip_qkv
487
+ self.qk_ln = qk_ln
488
+
489
+ self.d_model = d_model
490
+ self.n_heads = n_heads
491
+ self.kv_n_heads = kv_n_heads
492
+
493
+ self.head_dim = d_model // n_heads
494
+
495
+ if self.kv_n_heads <= 0:
496
+ raise ValueError('kv_n_heads should be greater than zero.')
497
+
498
+ if self.kv_n_heads > self.n_heads:
499
+ raise ValueError(
500
+ 'The number of KV heads should be less than or equal to Q heads.'
501
+ )
502
+
503
+ if self.n_heads % self.kv_n_heads != 0:
504
+ raise ValueError(
505
+ 'Each Q head should get the same number of KV heads, so n_heads must be divisible by kv_n_heads.'
506
+ )
507
+
508
+ self.softmax_scale = softmax_scale
509
+ if self.softmax_scale is None:
510
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
511
+ self.attn_dropout_p = attn_pdrop
512
+
513
+ fc_kwargs: dict[str, Any] = {
514
+ 'bias': bias,
515
+ }
516
+ if fc_type != 'te':
517
+ fc_kwargs['device'] = device
518
+ self.Wqkv = FC_CLASS_REGISTRY[fc_type](
519
+ self.d_model,
520
+ self.d_model + 2 * self.kv_n_heads * self.head_dim,
521
+ **fc_kwargs,
522
+ )
523
+ # for param init fn; enables shape based init of fused layers
524
+ fuse_splits = [
525
+ i * self.head_dim
526
+ for i in range(1, self.n_heads + 2 * self.kv_n_heads)
527
+ ]
528
+ self.Wqkv._fused = (0, fuse_splits)
529
+
530
+ if self.qk_ln:
531
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
532
+ self.q_ln = norm_class(self.d_model, device=device)
533
+ self.k_ln = norm_class(self.kv_n_heads * self.head_dim,
534
+ device=device)
535
+
536
+ if self.attn_impl == 'flash':
537
+ self.attn_fn = flash_attn_fn
538
+ elif self.attn_impl == 'triton':
539
+ self.attn_fn = triton_flash_attn_fn
540
+ elif self.attn_impl == 'torch':
541
+ self.attn_fn = scaled_multihead_dot_product_attention
542
+ else:
543
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
544
+
545
+ self.out_proj = FC_CLASS_REGISTRY[fc_type](
546
+ self.d_model,
547
+ self.d_model,
548
+ **fc_kwargs,
549
+ )
550
+ self.out_proj._is_residual = True
551
+
552
+ def forward(
553
+ self,
554
+ x: torch.Tensor,
555
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
556
+ attn_bias: Optional[torch.Tensor] = None,
557
+ attention_mask: Optional[torch.Tensor] = None,
558
+ is_causal: bool = True,
559
+ needs_weights: bool = False,
560
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
561
+ torch.Tensor, torch.Tensor]]]:
562
+ qkv = self.Wqkv(x)
563
+
564
+ if self.clip_qkv:
565
+ qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
566
+
567
+ query, key, value = qkv.split(
568
+ [
569
+ self.d_model,
570
+ self.kv_n_heads * self.head_dim,
571
+ self.kv_n_heads * self.head_dim,
572
+ ],
573
+ dim=2,
574
+ )
575
+
576
+ key_padding_mask = attention_mask
577
+
578
+ if self.qk_ln:
579
+ # Applying layernorm to qk
580
+ dtype = query.dtype
581
+ query = self.q_ln(query).to(dtype)
582
+ key = self.k_ln(key).to(dtype)
583
+
584
+ context, attn_weights, past_key_value = self.attn_fn(
585
+ query,
586
+ key,
587
+ value,
588
+ self.n_heads,
589
+ self.kv_n_heads,
590
+ past_key_value=past_key_value,
591
+ softmax_scale=self.softmax_scale,
592
+ attn_bias=attn_bias,
593
+ key_padding_mask=key_padding_mask,
594
+ is_causal=is_causal,
595
+ dropout_p=self.attn_dropout_p,
596
+ training=self.training,
597
+ needs_weights=needs_weights,
598
+ )
599
+
600
+ return self.out_proj(context), attn_weights, past_key_value
601
+
602
+
603
+ class MultiheadAttention(GroupedQueryAttention):
604
+ """Multi-head self attention.
605
+
606
+ Using torch or triton attention implementation enables user to also use
607
+ additive bias.
608
+ """
609
+
610
+ def __init__(
611
+ self,
612
+ d_model: int,
613
+ n_heads: int,
614
+ attn_impl: str = 'triton',
615
+ clip_qkv: Optional[float] = None,
616
+ qk_ln: bool = False,
617
+ softmax_scale: Optional[float] = None,
618
+ attn_pdrop: float = 0.0,
619
+ norm_type: str = 'low_precision_layernorm',
620
+ fc_type: str = 'torch',
621
+ device: Optional[str] = None,
622
+ bias: bool = True,
623
+ ):
624
+ super().__init__(
625
+ d_model=d_model,
626
+ n_heads=n_heads,
627
+ kv_n_heads=n_heads, # for MHA, same # heads as kv groups
628
+ attn_impl=attn_impl,
629
+ clip_qkv=clip_qkv,
630
+ qk_ln=qk_ln,
631
+ softmax_scale=softmax_scale,
632
+ attn_pdrop=attn_pdrop,
633
+ norm_type=norm_type,
634
+ fc_type=fc_type,
635
+ device=device,
636
+ bias=bias,
637
+ )
638
+
639
+
640
+ class MultiQueryAttention(GroupedQueryAttention):
641
+ """Multi-Query self attention.
642
+
643
+ Using torch or triton attention implementation enables user to also use
644
+ additive bias.
645
+ """
646
+
647
+ def __init__(
648
+ self,
649
+ d_model: int,
650
+ n_heads: int,
651
+ attn_impl: str = 'triton',
652
+ clip_qkv: Optional[float] = None,
653
+ qk_ln: bool = False,
654
+ softmax_scale: Optional[float] = None,
655
+ attn_pdrop: float = 0.0,
656
+ norm_type: str = 'low_precision_layernorm',
657
+ fc_type: str = 'torch',
658
+ device: Optional[str] = None,
659
+ bias: bool = True,
660
+ ):
661
+ super().__init__(
662
+ d_model=d_model,
663
+ n_heads=n_heads,
664
+ kv_n_heads=1, # for MQA, 1 head
665
+ attn_impl=attn_impl,
666
+ clip_qkv=clip_qkv,
667
+ qk_ln=qk_ln,
668
+ softmax_scale=softmax_scale,
669
+ attn_pdrop=attn_pdrop,
670
+ norm_type=norm_type,
671
+ fc_type=fc_type,
672
+ device=device,
673
+ bias=bias,
674
+ )
675
+
676
+
677
+ def attn_bias_shape(
678
+ attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
679
+ prefix_lm: bool, causal: bool,
680
+ use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
681
+ if attn_impl == 'flash':
682
+ return None
683
+ elif attn_impl in ['torch', 'triton']:
684
+ if alibi:
685
+ if (prefix_lm or not causal) or use_sequence_id:
686
+ return (1, n_heads, seq_len, seq_len)
687
+ return (1, n_heads, 1, seq_len)
688
+ elif prefix_lm or use_sequence_id:
689
+ return (1, 1, seq_len, seq_len)
690
+ return None
691
+ else:
692
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
693
+
694
+
695
+ def build_attn_bias(
696
+ attn_impl: str,
697
+ attn_bias: torch.Tensor,
698
+ n_heads: int,
699
+ seq_len: int,
700
+ causal: bool = False,
701
+ alibi: bool = False,
702
+ alibi_bias_max: int = 8,
703
+ ) -> Optional[torch.Tensor]:
704
+ if attn_impl == 'flash':
705
+ return None
706
+ elif attn_impl in ['torch', 'triton']:
707
+ if alibi:
708
+ # in place add alibi to attn bias
709
+ device, dtype = attn_bias.device, attn_bias.dtype
710
+ attn_bias = attn_bias.add(
711
+ build_alibi_bias(
712
+ n_heads,
713
+ seq_len,
714
+ full=not causal,
715
+ alibi_bias_max=alibi_bias_max,
716
+ device=device,
717
+ dtype=dtype,
718
+ ))
719
+ return attn_bias
720
+ else:
721
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
722
+
723
+
724
+ def gen_slopes(n_heads: int,
725
+ alibi_bias_max: int = 8,
726
+ device: Optional[torch.device] = None) -> torch.Tensor:
727
+ _n_heads = 2**math.ceil(math.log2(n_heads))
728
+ m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
729
+ m = m.mul(alibi_bias_max / _n_heads)
730
+ slopes = (1. / torch.pow(2, m))
731
+
732
+ if _n_heads != n_heads:
733
+ # if n_heads is not a power of two,
734
+ # Huggingface and FasterTransformer calculate slopes normally,
735
+ # then return this strided concatenation of slopes
736
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
737
+
738
+ return slopes.view(1, n_heads, 1, 1)
739
+
740
+
741
+ def build_alibi_bias(
742
+ n_heads: int,
743
+ seq_len: int,
744
+ full: bool = False,
745
+ alibi_bias_max: int = 8,
746
+ device: Optional[torch.device] = None,
747
+ dtype: Optional[torch.dtype] = None,
748
+ ) -> torch.Tensor:
749
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32,
750
+ device=device).view(1, 1, 1, seq_len)
751
+ if full:
752
+ # generate 1 x Heads x SeqLen x SeqLen alibi bias mask
753
+ # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
754
+ alibi_bias = alibi_bias - torch.arange(
755
+ 1 - seq_len, 1, dtype=torch.int32, device=device).view(
756
+ 1, 1, seq_len, 1)
757
+ alibi_bias = alibi_bias.abs().mul(-1)
758
+
759
+ slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
760
+ alibi_bias = alibi_bias * slopes
761
+ return alibi_bias.to(dtype=dtype)
762
+
763
+
764
+ ATTN_CLASS_REGISTRY = {
765
+ 'multihead_attention': MultiheadAttention,
766
+ 'multiquery_attention': MultiQueryAttention,
767
+ 'grouped_query_attention': GroupedQueryAttention
768
+ }
Perceptrix/finetune/build/lib/llmfoundry/models/layers/blocks.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """GPT Blocks used for the GPT Model."""
5
+
6
+ from typing import Any, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
12
+ from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
13
+ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
14
+
15
+
16
+ class MPTBlock(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ d_model: int,
21
+ n_heads: int,
22
+ expansion_ratio: int,
23
+ attn_config: Optional[Dict] = None,
24
+ ffn_config: Optional[Dict] = None,
25
+ resid_pdrop: float = 0.0,
26
+ norm_type: str = 'low_precision_layernorm',
27
+ fc_type: str = 'torch',
28
+ device: Optional[str] = None,
29
+ no_bias: bool = False,
30
+ **kwargs: Any,
31
+ ):
32
+ if attn_config is None:
33
+ attn_config = {
34
+ 'attn_type': 'multihead_attention',
35
+ 'attn_pdrop': 0.0,
36
+ 'attn_impl': 'triton',
37
+ 'qk_ln': False,
38
+ 'clip_qkv': None,
39
+ 'softmax_scale': None,
40
+ 'prefix_lm': False,
41
+ 'attn_uses_sequence_id': False,
42
+ 'alibi': False,
43
+ 'alibi_bias_max': 8,
44
+ }
45
+
46
+ if ffn_config is None:
47
+ ffn_config = {
48
+ 'ffn_type': 'mptmlp',
49
+ }
50
+
51
+ del kwargs # unused, just to capture any extra args from the config
52
+ super().__init__()
53
+
54
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
55
+ assert isinstance(attn_config['attn_type'], str)
56
+ attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
57
+
58
+ # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
59
+ args_to_exclude_in_attn_class = {
60
+ 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
61
+ 'alibi_bias_max'
62
+ }
63
+ attn_config_subset_for_attn_class = {
64
+ k: v
65
+ for k, v in attn_config.items()
66
+ if k not in args_to_exclude_in_attn_class
67
+ }
68
+
69
+ self.norm_1 = norm_class(d_model, device=device)
70
+ self.attn = attn_class(
71
+ d_model=d_model,
72
+ n_heads=n_heads,
73
+ fc_type=fc_type,
74
+ device=device,
75
+ **attn_config_subset_for_attn_class,
76
+ bias=not no_bias,
77
+ )
78
+ self.norm_2 = None
79
+ if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
80
+ False):
81
+ self.norm_2 = norm_class(d_model, device=device)
82
+ self.ffn = build_ffn(
83
+ d_model=d_model,
84
+ expansion_ratio=expansion_ratio,
85
+ device=device,
86
+ bias=not no_bias,
87
+ **ffn_config,
88
+ )
89
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
90
+ self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
96
+ attn_bias: Optional[torch.Tensor] = None,
97
+ attention_mask: Optional[torch.ByteTensor] = None,
98
+ is_causal: bool = True,
99
+ output_attentions: bool = False,
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
101
+ torch.Tensor, torch.Tensor]]]:
102
+ a = self.norm_1(x)
103
+ b, attn_weights, past_key_value = self.attn(
104
+ a,
105
+ past_key_value=past_key_value,
106
+ attn_bias=attn_bias,
107
+ attention_mask=attention_mask,
108
+ is_causal=is_causal,
109
+ needs_weights=output_attentions,
110
+ )
111
+ x = x + self.resid_attn_dropout(b)
112
+ m = x
113
+ if self.norm_2 is not None:
114
+ m = self.norm_2(x)
115
+ n = self.ffn(m)
116
+ x = x + self.resid_ffn_dropout(n)
117
+ return x, attn_weights, past_key_value
Perceptrix/finetune/build/lib/llmfoundry/models/layers/custom_embedding.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+
9
+ class SharedEmbedding(nn.Embedding):
10
+
11
+ def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
12
+ if unembed:
13
+ return F.linear(input, self.weight)
14
+ return super().forward(input)