Spaces:
Sleeping
Sleeping
Commit
·
adf0368
0
Parent(s):
Update space
Browse files- README.md +14 -0
- app.py +265 -0
- commit.sh +1 -0
- export_onnx.py +58 -0
- inference/__init__.py +0 -0
- inference/__pycache__/__init__.cpython-312.pyc +0 -0
- inference/__pycache__/__init__.cpython-313.pyc +0 -0
- inference/__pycache__/inference.cpython-312.pyc +0 -0
- inference/__pycache__/inference.cpython-313.pyc +0 -0
- inference/__pycache__/model.cpython-312.pyc +0 -0
- inference/__pycache__/model.cpython-313.pyc +0 -0
- inference/__pycache__/onnx_inference.cpython-312.pyc +0 -0
- inference/__pycache__/onnx_inference.cpython-313.pyc +0 -0
- inference/__pycache__/optimized_diffattn.cpython-312.pyc +0 -0
- inference/__pycache__/optimized_diffattn.cpython-313.pyc +0 -0
- inference/__pycache__/rotary.cpython-312.pyc +0 -0
- inference/__pycache__/rotary.cpython-313.pyc +0 -0
- inference/inference.py +335 -0
- inference/model.py +189 -0
- inference/onnx_inference.py +202 -0
- inference/optimized_diffattn.py +177 -0
- inference/rotary.py +76 -0
- models/small.onnx +3 -0
- requirements.txt +1 -0
- test-trad.py +119 -0
- test.py +81 -0
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DAT Byte
|
3 |
+
emoji: 💬
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.34.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: A chat interface for the DAT Byte LLM.
|
12 |
+
---
|
13 |
+
|
14 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
app.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import onnxruntime as ort
|
4 |
+
from inference.onnx_inference import generate_text, sequence_breaker_strings
|
5 |
+
from inference.model import ByteTokenizer
|
6 |
+
|
7 |
+
# --- Globals ---
|
8 |
+
MODEL_OPTIONS = [
|
9 |
+
("DAT-Byte Small (200M)", "small", True),
|
10 |
+
("DAT-Byte Medium", "medium", False),
|
11 |
+
("DAT-Byte Large", "large", False),
|
12 |
+
]
|
13 |
+
|
14 |
+
ONNX_PATH = "models/small.onnx" # Assumes model.onnx is in the root directory
|
15 |
+
|
16 |
+
# Cache for the ONNX session
|
17 |
+
SESSION_CACHE = {}
|
18 |
+
TOKENIZER = ByteTokenizer()
|
19 |
+
|
20 |
+
# Prepare sequence breakers
|
21 |
+
SEQUENCE_BREAKER_IDS = {TOKENIZER.im_start_id, TOKENIZER.im_end_id}
|
22 |
+
for s in sequence_breaker_strings:
|
23 |
+
# These are single-byte tokens, so encode will return a list with one ID
|
24 |
+
try:
|
25 |
+
SEQUENCE_BREAKER_IDS.add(TOKENIZER.encode(s.encode("utf-8"))[0])
|
26 |
+
except IndexError:
|
27 |
+
print(f"Warning: Could not encode sequence breaker string: {s}")
|
28 |
+
|
29 |
+
|
30 |
+
# --- Model Loading ---
|
31 |
+
def get_session(model_key):
|
32 |
+
if model_key != "small":
|
33 |
+
raise ValueError("Only DAT-Byte Small is available.")
|
34 |
+
if model_key not in SESSION_CACHE:
|
35 |
+
if not os.path.exists(ONNX_PATH):
|
36 |
+
raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
|
37 |
+
# Using CPUExecutionProvider as per the project's goal
|
38 |
+
SESSION_CACHE[model_key] = ort.InferenceSession(
|
39 |
+
ONNX_PATH, providers=["CPUExecutionProvider"]
|
40 |
+
)
|
41 |
+
return SESSION_CACHE[model_key]
|
42 |
+
|
43 |
+
|
44 |
+
# --- Gradio Callbacks ---
|
45 |
+
def chat_respond(
|
46 |
+
message,
|
47 |
+
history,
|
48 |
+
model_name,
|
49 |
+
max_tokens,
|
50 |
+
temperature,
|
51 |
+
top_k,
|
52 |
+
dry_range,
|
53 |
+
dry_allowed_length,
|
54 |
+
dry_base,
|
55 |
+
dry_multiplier,
|
56 |
+
user_role="user",
|
57 |
+
assistant_role="assistant",
|
58 |
+
):
|
59 |
+
model_key = next(
|
60 |
+
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
|
61 |
+
None,
|
62 |
+
)
|
63 |
+
if not model_key:
|
64 |
+
history.append({"role": "user", "content": message})
|
65 |
+
history.append(
|
66 |
+
{"role": "assistant", "content": f"Model '{model_name}' is not available."}
|
67 |
+
)
|
68 |
+
return history
|
69 |
+
|
70 |
+
history = history or []
|
71 |
+
try:
|
72 |
+
session = get_session(model_key)
|
73 |
+
except Exception as e:
|
74 |
+
history.append({"role": "user", "content": message})
|
75 |
+
history.append(
|
76 |
+
{"role": "assistant", "content": f"[Model loading error: {str(e)}]"}
|
77 |
+
)
|
78 |
+
return history
|
79 |
+
|
80 |
+
prompt = ""
|
81 |
+
for turn in history:
|
82 |
+
prompt += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
|
83 |
+
prompt += (
|
84 |
+
f"<|im_start|>{user_role}\n{message}<|im_end|>\n<|im_start|>{assistant_role}\n"
|
85 |
+
)
|
86 |
+
|
87 |
+
generated_text, _ = generate_text(
|
88 |
+
session=session,
|
89 |
+
tokenizer=TOKENIZER,
|
90 |
+
prompt=prompt,
|
91 |
+
max_new_tokens=max_tokens,
|
92 |
+
temperature=temperature,
|
93 |
+
top_k=top_k,
|
94 |
+
stop_sequences=["<|im_end|>".encode("utf-8")],
|
95 |
+
dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
|
96 |
+
dry_range=dry_range,
|
97 |
+
dry_allowed_length=dry_allowed_length,
|
98 |
+
dry_base=dry_base,
|
99 |
+
dry_multiplier=dry_multiplier,
|
100 |
+
)
|
101 |
+
generated_text = generated_text.decode("utf-8", "ignore")
|
102 |
+
|
103 |
+
history.append({"role": "user", "content": message})
|
104 |
+
history.append({"role": "assistant", "content": generated_text})
|
105 |
+
return history
|
106 |
+
|
107 |
+
|
108 |
+
def completion_respond(
|
109 |
+
prompt,
|
110 |
+
model_name,
|
111 |
+
max_tokens,
|
112 |
+
temperature,
|
113 |
+
top_k,
|
114 |
+
dry_range,
|
115 |
+
dry_allowed_length,
|
116 |
+
dry_base,
|
117 |
+
dry_multiplier,
|
118 |
+
):
|
119 |
+
model_key = next(
|
120 |
+
(key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
|
121 |
+
None,
|
122 |
+
)
|
123 |
+
if not model_key:
|
124 |
+
return f"[Model '{model_name}' is not available or unknown.]"
|
125 |
+
|
126 |
+
try:
|
127 |
+
session = get_session(model_key)
|
128 |
+
except Exception as e:
|
129 |
+
return f"[Model loading error: {str(e)}]"
|
130 |
+
|
131 |
+
generated_text, _ = generate_text(
|
132 |
+
session=session,
|
133 |
+
tokenizer=TOKENIZER,
|
134 |
+
prompt=prompt,
|
135 |
+
max_new_tokens=max_tokens,
|
136 |
+
temperature=temperature,
|
137 |
+
top_k=top_k,
|
138 |
+
dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
|
139 |
+
dry_range=dry_range,
|
140 |
+
dry_allowed_length=dry_allowed_length,
|
141 |
+
dry_base=dry_base,
|
142 |
+
dry_multiplier=dry_multiplier,
|
143 |
+
)
|
144 |
+
return generated_text
|
145 |
+
|
146 |
+
|
147 |
+
# --- Gradio UI ---
|
148 |
+
with gr.Blocks() as demo:
|
149 |
+
gr.Markdown("# DAT-Byte Playground (ONNX Accelerated)")
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Column(scale=1):
|
152 |
+
model_selector = gr.Radio(
|
153 |
+
[opt[0] for opt in MODEL_OPTIONS],
|
154 |
+
value=MODEL_OPTIONS[0][0],
|
155 |
+
label="Model",
|
156 |
+
interactive=True,
|
157 |
+
)
|
158 |
+
gr.Markdown("**Note:** Only DAT-Byte Small is currently available.")
|
159 |
+
mode_selector = gr.Radio(
|
160 |
+
["Chat", "Raw Completion"], value="Chat", label="Mode"
|
161 |
+
)
|
162 |
+
max_tokens = gr.Slider(
|
163 |
+
minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
|
164 |
+
)
|
165 |
+
temperature = gr.Slider(
|
166 |
+
minimum=0.05, maximum=2.0, value=0.5, step=0.05, label="Temperature"
|
167 |
+
)
|
168 |
+
top_k = gr.Slider(minimum=0, maximum=256, value=15, step=1, label="Top-k")
|
169 |
+
with gr.Accordion("DRY Sampling (Don't Repeat Yourself)", open=False):
|
170 |
+
dry_range = gr.Slider(
|
171 |
+
minimum=0, maximum=2048, value=1024, step=32, label="Range"
|
172 |
+
)
|
173 |
+
dry_allowed_length = gr.Slider(
|
174 |
+
minimum=1, maximum=64, value=20, step=1, label="Allowed Length"
|
175 |
+
)
|
176 |
+
dry_base = gr.Slider(
|
177 |
+
minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Base"
|
178 |
+
)
|
179 |
+
dry_multiplier = gr.Slider(
|
180 |
+
minimum=0.0, maximum=2.0, value=0.0, step=0.05, label="Multiplier"
|
181 |
+
)
|
182 |
+
user_role_box = gr.Textbox("user", label="User Role", visible=True)
|
183 |
+
assistant_role_box = gr.Textbox(
|
184 |
+
"assistant", label="Assistant Role", visible=True
|
185 |
+
)
|
186 |
+
|
187 |
+
with gr.Column(scale=3):
|
188 |
+
chatbot = gr.Chatbot(label="Chat", type="messages", height=600)
|
189 |
+
with gr.Row():
|
190 |
+
chat_input = gr.Textbox(
|
191 |
+
label="Message", placeholder="Type a message...", scale=4
|
192 |
+
)
|
193 |
+
send_button = gr.Button("Send", scale=1)
|
194 |
+
completion_input = gr.Textbox(label="Prompt", visible=False)
|
195 |
+
completion_output = gr.Textbox(label="Completion", visible=False)
|
196 |
+
|
197 |
+
# UI Logic
|
198 |
+
def update_mode(mode):
|
199 |
+
is_chat = mode == "Chat"
|
200 |
+
return (
|
201 |
+
gr.update(visible=is_chat), # chatbot
|
202 |
+
gr.update(visible=is_chat), # chat_input row
|
203 |
+
gr.update(visible=not is_chat), # completion_input
|
204 |
+
gr.update(visible=not is_chat), # completion_output
|
205 |
+
gr.update(visible=is_chat), # user_role_box
|
206 |
+
gr.update(visible=is_chat), # assistant_role_box
|
207 |
+
)
|
208 |
+
|
209 |
+
mode_selector.change(
|
210 |
+
update_mode,
|
211 |
+
[mode_selector],
|
212 |
+
[
|
213 |
+
chatbot,
|
214 |
+
chat_input.parent,
|
215 |
+
completion_input,
|
216 |
+
completion_output,
|
217 |
+
user_role_box,
|
218 |
+
assistant_role_box,
|
219 |
+
],
|
220 |
+
)
|
221 |
+
|
222 |
+
# Event Handlers
|
223 |
+
chat_inputs = [
|
224 |
+
chat_input,
|
225 |
+
chatbot,
|
226 |
+
model_selector,
|
227 |
+
max_tokens,
|
228 |
+
temperature,
|
229 |
+
top_k,
|
230 |
+
dry_range,
|
231 |
+
dry_allowed_length,
|
232 |
+
dry_base,
|
233 |
+
dry_multiplier,
|
234 |
+
user_role_box,
|
235 |
+
assistant_role_box,
|
236 |
+
]
|
237 |
+
chat_args = {"fn": chat_respond, "inputs": chat_inputs, "outputs": [chatbot]}
|
238 |
+
|
239 |
+
def clear_input():
|
240 |
+
return ""
|
241 |
+
|
242 |
+
clear_args = {"fn": clear_input, "inputs": [], "outputs": [chat_input]}
|
243 |
+
|
244 |
+
send_button.click(**chat_args).then(**clear_args)
|
245 |
+
chat_input.submit(**chat_args).then(**clear_args)
|
246 |
+
|
247 |
+
completion_inputs = [
|
248 |
+
completion_input,
|
249 |
+
model_selector,
|
250 |
+
max_tokens,
|
251 |
+
temperature,
|
252 |
+
top_k,
|
253 |
+
dry_range,
|
254 |
+
dry_allowed_length,
|
255 |
+
dry_base,
|
256 |
+
dry_multiplier,
|
257 |
+
]
|
258 |
+
completion_input.submit(
|
259 |
+
completion_respond,
|
260 |
+
completion_inputs,
|
261 |
+
[completion_output],
|
262 |
+
)
|
263 |
+
|
264 |
+
if __name__ == "__main__":
|
265 |
+
demo.launch()
|
commit.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
git commit -am 'Update space' && git push
|
export_onnx.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from inference.model import DiffTransformerLLM
|
3 |
+
from inference.inference import load_model
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = argparse.ArgumentParser(description="Export DiffTransformerLLM to ONNX")
|
10 |
+
parser.add_argument(
|
11 |
+
"--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt)"
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
"--onnx_path", type=str, default="model.onnx", help="Output ONNX file path"
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
"--seq_len", type=int, default=32, help="Dummy input sequence length"
|
18 |
+
)
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
device = torch.device("cpu")
|
22 |
+
print(f"Loading model from {args.checkpoint}")
|
23 |
+
model = load_model(args.checkpoint, device=device, fp16=False, quantize=False)
|
24 |
+
model.eval()
|
25 |
+
|
26 |
+
# Prepare dummy input
|
27 |
+
batch_size = 1
|
28 |
+
seq_len = args.seq_len
|
29 |
+
input_ids = torch.randint(0, 259, (batch_size, seq_len), dtype=torch.long)
|
30 |
+
|
31 |
+
# Create a dummy causal mask. This will be a dynamic input to the ONNX model.
|
32 |
+
causal_mask = torch.triu(
|
33 |
+
torch.ones(1, seq_len, seq_len, dtype=torch.bool), diagonal=1
|
34 |
+
)
|
35 |
+
attn_mask = torch.zeros(1, seq_len, seq_len, dtype=torch.float32)
|
36 |
+
attn_mask.masked_fill_(causal_mask, float("-inf"))
|
37 |
+
|
38 |
+
# Export to ONNX
|
39 |
+
print(f"Exporting to ONNX: {args.onnx_path}")
|
40 |
+
torch.onnx.export(
|
41 |
+
model,
|
42 |
+
(input_ids, attn_mask),
|
43 |
+
args.onnx_path,
|
44 |
+
input_names=["input_ids", "attn_mask"],
|
45 |
+
output_names=["logits"],
|
46 |
+
dynamic_axes={
|
47 |
+
"input_ids": {0: "batch_size", 1: "seq_len"},
|
48 |
+
"attn_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"},
|
49 |
+
"logits": {0: "batch_size", 1: "seq_len"},
|
50 |
+
},
|
51 |
+
opset_version=17,
|
52 |
+
do_constant_folding=True,
|
53 |
+
)
|
54 |
+
print(f"ONNX export complete: {args.onnx_path}")
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
main()
|
inference/__init__.py
ADDED
File without changes
|
inference/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (155 Bytes). View file
|
|
inference/__pycache__/__init__.cpython-313.pyc
ADDED
Binary file (155 Bytes). View file
|
|
inference/__pycache__/inference.cpython-312.pyc
ADDED
Binary file (11.4 kB). View file
|
|
inference/__pycache__/inference.cpython-313.pyc
ADDED
Binary file (11.8 kB). View file
|
|
inference/__pycache__/model.cpython-312.pyc
ADDED
Binary file (9.27 kB). View file
|
|
inference/__pycache__/model.cpython-313.pyc
ADDED
Binary file (9.51 kB). View file
|
|
inference/__pycache__/onnx_inference.cpython-312.pyc
ADDED
Binary file (5.24 kB). View file
|
|
inference/__pycache__/onnx_inference.cpython-313.pyc
ADDED
Binary file (7.44 kB). View file
|
|
inference/__pycache__/optimized_diffattn.cpython-312.pyc
ADDED
Binary file (7.76 kB). View file
|
|
inference/__pycache__/optimized_diffattn.cpython-313.pyc
ADDED
Binary file (7.79 kB). View file
|
|
inference/__pycache__/rotary.cpython-312.pyc
ADDED
Binary file (2.52 kB). View file
|
|
inference/__pycache__/rotary.cpython-313.pyc
ADDED
Binary file (2.41 kB). View file
|
|
inference/inference.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import os
|
4 |
+
import torch.quantization
|
5 |
+
from .model import (
|
6 |
+
DiffTransformerLLM,
|
7 |
+
ByteTokenizer,
|
8 |
+
IM_START_TOKEN,
|
9 |
+
IM_END_TOKEN,
|
10 |
+
PAD_TOKEN,
|
11 |
+
)
|
12 |
+
|
13 |
+
force_CPU = True
|
14 |
+
|
15 |
+
|
16 |
+
def list_checkpoints(checkpoint_dir="checkpoints"):
|
17 |
+
"""List all available checkpoints in the directory."""
|
18 |
+
if not os.path.exists(checkpoint_dir):
|
19 |
+
print(f"Checkpoint directory {checkpoint_dir} not found.")
|
20 |
+
return []
|
21 |
+
|
22 |
+
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
|
23 |
+
return sorted(checkpoints)
|
24 |
+
|
25 |
+
|
26 |
+
def load_model(checkpoint_path, device=None, fp16=True):
|
27 |
+
"""Load a trained model from a checkpoint, applying optimizations as needed."""
|
28 |
+
import torch
|
29 |
+
|
30 |
+
if device is None:
|
31 |
+
device = torch.device(
|
32 |
+
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
|
33 |
+
)
|
34 |
+
|
35 |
+
print(f"Loading checkpoint from {checkpoint_path}")
|
36 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
37 |
+
|
38 |
+
# Hyperparams
|
39 |
+
vocab_size = 259 # 256 bytes + 3 special tokens
|
40 |
+
embed_dim = 768
|
41 |
+
num_layers = 28
|
42 |
+
num_heads = 12
|
43 |
+
ffn_hidden_dim = embed_dim * 4
|
44 |
+
max_seq_len = 512
|
45 |
+
dropout = 0.1 # For inference you can set dropout=0
|
46 |
+
|
47 |
+
# Model
|
48 |
+
model = DiffTransformerLLM(
|
49 |
+
vocab_size=vocab_size,
|
50 |
+
embed_dim=embed_dim,
|
51 |
+
num_layers=num_layers,
|
52 |
+
num_heads=num_heads,
|
53 |
+
ffn_hidden_dim=ffn_hidden_dim,
|
54 |
+
max_seq_len=max_seq_len,
|
55 |
+
dropout=dropout,
|
56 |
+
)
|
57 |
+
|
58 |
+
# The checkpoint is the state dict itself
|
59 |
+
state_dict = checkpoint
|
60 |
+
|
61 |
+
# Load the state dict into the float32 model first
|
62 |
+
model.load_state_dict(state_dict)
|
63 |
+
model.eval()
|
64 |
+
|
65 |
+
# Apply device-specific optimizations
|
66 |
+
if device.type == "cpu":
|
67 |
+
print("Optimizing for CPU with dynamic quantization (int8).")
|
68 |
+
# Set the quantization engine
|
69 |
+
torch.backends.quantized.engine = "qnnpack"
|
70 |
+
# Quantize the linear layers to int8 for performance
|
71 |
+
model = torch.quantization.quantize_dynamic(
|
72 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
73 |
+
)
|
74 |
+
elif device.type == "cuda" and fp16:
|
75 |
+
print("Casting model to fp16 for CUDA.")
|
76 |
+
model = model.half()
|
77 |
+
|
78 |
+
model = model.to(device)
|
79 |
+
|
80 |
+
print("Model loaded successfully.")
|
81 |
+
return model
|
82 |
+
|
83 |
+
|
84 |
+
def generate_text(
|
85 |
+
model,
|
86 |
+
tokenizer,
|
87 |
+
prompt,
|
88 |
+
max_new_tokens=100,
|
89 |
+
temperature=1.0,
|
90 |
+
top_k=0,
|
91 |
+
top_p=0.9,
|
92 |
+
repetition_penalty=1.0,
|
93 |
+
device=None,
|
94 |
+
stop_sequences=[],
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Generate text from a prompt using the trained model.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
model: The trained DiffTransformerLLM model
|
101 |
+
tokenizer: ByteTokenizer instance
|
102 |
+
prompt: Text prompt to start generation (as a string)
|
103 |
+
max_new_tokens: Maximum number of new tokens to generate
|
104 |
+
temperature: Controls randomness. Lower is more deterministic.
|
105 |
+
top_k: If > 0, only sample from the top k most likely tokens
|
106 |
+
top_p: If > 0, sample from the smallest set of tokens whose cumulative probability exceeds p
|
107 |
+
repetition_penalty: Penalize repetition. 1.0 means no penalty.
|
108 |
+
device: Device to run inference on
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
The generated text as a string
|
112 |
+
"""
|
113 |
+
if device is None:
|
114 |
+
device = torch.device(
|
115 |
+
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
|
116 |
+
)
|
117 |
+
|
118 |
+
# Convert prompt to bytes and tokenize - process as-is without adding special tokens
|
119 |
+
prompt_bytes = prompt.encode("utf-8", errors="replace")
|
120 |
+
input_ids = (
|
121 |
+
torch.tensor(
|
122 |
+
tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long
|
123 |
+
)
|
124 |
+
.unsqueeze(0)
|
125 |
+
.to(device)
|
126 |
+
)
|
127 |
+
stop_sequences = [
|
128 |
+
tokenizer.encode(
|
129 |
+
seq.encode("utf-8", errors="replace"), add_special_tokens=False
|
130 |
+
)
|
131 |
+
for seq in stop_sequences
|
132 |
+
]
|
133 |
+
|
134 |
+
# Track generated token IDs
|
135 |
+
generated_ids = input_ids.clone()
|
136 |
+
generated_bytes = b""
|
137 |
+
|
138 |
+
# Set the model to evaluation mode
|
139 |
+
model.eval()
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
for _ in range(max_new_tokens):
|
143 |
+
# Only use the last max_seq_len tokens if we exceed the model's context length
|
144 |
+
if generated_ids.size(1) > model.max_seq_len:
|
145 |
+
input_ids = generated_ids[:, -model.max_seq_len :]
|
146 |
+
else:
|
147 |
+
input_ids = generated_ids
|
148 |
+
|
149 |
+
# Forward pass to get logits for the next token
|
150 |
+
logits = model(input_ids)
|
151 |
+
|
152 |
+
# Get logits for the next token (last position)
|
153 |
+
next_token_logits = logits[:, -1, :].squeeze(0)
|
154 |
+
|
155 |
+
# Apply temperature
|
156 |
+
if temperature > 0:
|
157 |
+
next_token_logits = next_token_logits / temperature
|
158 |
+
|
159 |
+
# Apply repetition penalty
|
160 |
+
if repetition_penalty > 1.0:
|
161 |
+
for token_id in set(generated_ids[0].tolist()):
|
162 |
+
next_token_logits[token_id] /= repetition_penalty
|
163 |
+
|
164 |
+
# Apply top-k filtering
|
165 |
+
if top_k > 0:
|
166 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
|
167 |
+
next_token_logits = torch.full_like(next_token_logits, float("-inf"))
|
168 |
+
next_token_logits.scatter_(0, top_k_indices, top_k_logits)
|
169 |
+
|
170 |
+
# Apply top-p (nucleus) filtering
|
171 |
+
if 0 < top_p < 1.0:
|
172 |
+
sorted_logits, sorted_indices = torch.sort(
|
173 |
+
next_token_logits, descending=True
|
174 |
+
)
|
175 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=0), dim=0)
|
176 |
+
|
177 |
+
# Remove tokens with cumulative probability above the threshold
|
178 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
179 |
+
# Shift the indices to the right to keep the first token above the threshold
|
180 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
181 |
+
..., :-1
|
182 |
+
].clone()
|
183 |
+
sorted_indices_to_remove[..., 0] = 0
|
184 |
+
|
185 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
186 |
+
next_token_logits[indices_to_remove] = float("-inf")
|
187 |
+
|
188 |
+
# Sample from the filtered distribution
|
189 |
+
probs = F.softmax(next_token_logits, dim=0)
|
190 |
+
next_token = torch.multinomial(probs, 1)
|
191 |
+
|
192 |
+
# Append the generated token to the sequence
|
193 |
+
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
|
194 |
+
# Check if IM_END_TOKEN has been generated
|
195 |
+
token_bytes = tokenizer.decode([next_token.item()])
|
196 |
+
generated_bytes += token_bytes
|
197 |
+
try:
|
198 |
+
print(token_bytes.decode("utf-8", errors="replace"), end="", flush=True)
|
199 |
+
except Exception as e:
|
200 |
+
print(f"<Error decoding token: {e}>", end="", flush=True)
|
201 |
+
stop_generated = False
|
202 |
+
stop_seq = None
|
203 |
+
for stop_seq in stop_sequences:
|
204 |
+
if generated_ids.tolist()[0][-len(stop_seq) :] == stop_seq:
|
205 |
+
stop_generated = True
|
206 |
+
break
|
207 |
+
if stop_generated:
|
208 |
+
# Remove the stop sequence from the generated IDs
|
209 |
+
generated_ids = generated_ids[:, : -len(stop_seq)]
|
210 |
+
generated_bytes = generated_bytes[: -len(stop_seq)]
|
211 |
+
break
|
212 |
+
|
213 |
+
# Decode to bytes and then to string
|
214 |
+
try:
|
215 |
+
generated_text = generated_bytes.decode("utf-8", errors="replace")
|
216 |
+
except Exception as e:
|
217 |
+
print(f"\nError decoding generated text: {e}")
|
218 |
+
generated_text = "<decoding error>"
|
219 |
+
|
220 |
+
return generated_text, prompt + generated_text
|
221 |
+
|
222 |
+
|
223 |
+
def main():
|
224 |
+
parser = argparse.ArgumentParser(
|
225 |
+
description="Text generation with DiffAttention LLM"
|
226 |
+
)
|
227 |
+
parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
|
228 |
+
parser.add_argument(
|
229 |
+
"--prompt",
|
230 |
+
type=str,
|
231 |
+
default="""\nHow many 'b's are in "barber"? \n""",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--max_tokens",
|
235 |
+
type=int,
|
236 |
+
default=500,
|
237 |
+
help="Maximum number of tokens to generate",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--temperature", type=float, default=0.7, help="Sampling temperature"
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)"
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--top_p",
|
247 |
+
type=float,
|
248 |
+
default=0.9,
|
249 |
+
help="Top-p (nucleus) sampling parameter (0 to disable)",
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--repetition_penalty",
|
253 |
+
type=float,
|
254 |
+
default=1.2,
|
255 |
+
help="Repetition penalty (1.0 for no penalty)",
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--list_checkpoints",
|
259 |
+
action="store_true",
|
260 |
+
help="List available checkpoints and exit",
|
261 |
+
)
|
262 |
+
args = parser.parse_args()
|
263 |
+
|
264 |
+
# List checkpoints if requested
|
265 |
+
if args.list_checkpoints:
|
266 |
+
print("Available checkpoints:")
|
267 |
+
checkpoints = list_checkpoints()
|
268 |
+
for i, ckpt in enumerate(checkpoints):
|
269 |
+
print(f"{i+1}. {ckpt}")
|
270 |
+
return
|
271 |
+
|
272 |
+
# If no checkpoint specified, use the latest one
|
273 |
+
if not args.checkpoint:
|
274 |
+
checkpoints = list_checkpoints()
|
275 |
+
if not checkpoints:
|
276 |
+
print("No checkpoints found. Please train the model first.")
|
277 |
+
return
|
278 |
+
|
279 |
+
# Find the latest epoch_end checkpoint
|
280 |
+
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
|
281 |
+
if end_checkpoints:
|
282 |
+
latest_checkpoint = max(end_checkpoints)
|
283 |
+
else:
|
284 |
+
latest_checkpoint = max(checkpoints)
|
285 |
+
|
286 |
+
checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
|
287 |
+
else:
|
288 |
+
checkpoint_path = args.checkpoint
|
289 |
+
|
290 |
+
# Set device
|
291 |
+
device = torch.device(
|
292 |
+
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
|
293 |
+
)
|
294 |
+
print(f"Using device: {device}")
|
295 |
+
|
296 |
+
# Initialize tokenizer
|
297 |
+
tokenizer = ByteTokenizer()
|
298 |
+
|
299 |
+
# Load model
|
300 |
+
model = load_model(checkpoint_path, device)
|
301 |
+
|
302 |
+
# Generate text
|
303 |
+
print(f"\nGenerating text with prompt: '{args.prompt}'")
|
304 |
+
print(
|
305 |
+
f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
|
306 |
+
)
|
307 |
+
print("\nGenerating...")
|
308 |
+
|
309 |
+
generated_text, full_text = generate_text(
|
310 |
+
model=model,
|
311 |
+
tokenizer=tokenizer,
|
312 |
+
prompt=args.prompt,
|
313 |
+
max_new_tokens=args.max_tokens,
|
314 |
+
temperature=args.temperature,
|
315 |
+
top_k=args.top_k,
|
316 |
+
top_p=args.top_p,
|
317 |
+
repetition_penalty=args.repetition_penalty,
|
318 |
+
device=device,
|
319 |
+
)
|
320 |
+
|
321 |
+
print("\n\nGenerated completion only:")
|
322 |
+
print("-" * 40)
|
323 |
+
print(generated_text)
|
324 |
+
print("-" * 40)
|
325 |
+
|
326 |
+
print("\nFull generated text (prompt + completion):")
|
327 |
+
print("-" * 40)
|
328 |
+
print(full_text)
|
329 |
+
print("-" * 40)
|
330 |
+
|
331 |
+
|
332 |
+
if __name__ == "__main__":
|
333 |
+
import argparse
|
334 |
+
|
335 |
+
main()
|
inference/model.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
from .optimized_diffattn import MultiheadDiffAttn
|
6 |
+
|
7 |
+
# --- Tokenizer Definition ---
|
8 |
+
# Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad>
|
9 |
+
IM_START_TOKEN = "<|im_start|>"
|
10 |
+
IM_END_TOKEN = "<|im_end|>"
|
11 |
+
PAD_TOKEN = "<pad>"
|
12 |
+
|
13 |
+
SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN]
|
14 |
+
VOCAB_SIZE = 256 + len(SPECIAL_TOKENS)
|
15 |
+
|
16 |
+
# Create token to id mapping
|
17 |
+
token_to_id = {}
|
18 |
+
id_to_token = {}
|
19 |
+
|
20 |
+
for i in range(256):
|
21 |
+
token_to_id[bytes([i])] = i
|
22 |
+
id_to_token[i] = bytes([i])
|
23 |
+
|
24 |
+
for i, token_str in enumerate(SPECIAL_TOKENS):
|
25 |
+
token_id = 256 + i
|
26 |
+
token_to_id[token_str] = token_id
|
27 |
+
id_to_token[token_id] = token_str
|
28 |
+
|
29 |
+
PAD_ID = token_to_id[PAD_TOKEN]
|
30 |
+
IM_START_ID = token_to_id[IM_START_TOKEN]
|
31 |
+
IM_END_ID = token_to_id[IM_END_TOKEN]
|
32 |
+
|
33 |
+
|
34 |
+
class ByteTokenizer:
|
35 |
+
def __init__(self):
|
36 |
+
self.token_to_id = token_to_id
|
37 |
+
self.id_to_token = id_to_token
|
38 |
+
self.vocab_size = VOCAB_SIZE
|
39 |
+
self.pad_id = PAD_ID
|
40 |
+
self.im_start_id = IM_START_ID
|
41 |
+
self.im_end_id = IM_END_ID
|
42 |
+
|
43 |
+
def encode(self, text_bytes: bytes, add_special_tokens=True):
|
44 |
+
ids = [self.token_to_id[bytes([b])] for b in text_bytes]
|
45 |
+
if add_special_tokens:
|
46 |
+
return [self.im_start_id] + ids + [self.im_end_id]
|
47 |
+
return ids
|
48 |
+
|
49 |
+
def decode(self, ids: list[int]):
|
50 |
+
tokens = []
|
51 |
+
for i in ids:
|
52 |
+
token = self.id_to_token.get(i)
|
53 |
+
if token is None:
|
54 |
+
# Handle unknown token ID if necessary, or raise error
|
55 |
+
tokens.append(b"?") # Placeholder for unknown
|
56 |
+
elif isinstance(token, bytes):
|
57 |
+
tokens.append(token)
|
58 |
+
# Ignore special tokens for decoding to raw text, or handle as needed
|
59 |
+
return b"".join(tokens)
|
60 |
+
|
61 |
+
|
62 |
+
# --- RoPE Embeddings --- (Reused from previous script)
|
63 |
+
def get_rotary_embeddings(seq_len, dim_model, theta=10000.0):
|
64 |
+
if dim_model % 2 != 0:
|
65 |
+
raise ValueError(f"dim_model must be even, got {dim_model}")
|
66 |
+
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
|
67 |
+
div_term = torch.exp(
|
68 |
+
torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model)
|
69 |
+
)
|
70 |
+
angles = position * div_term
|
71 |
+
cos_emb = torch.cos(angles)
|
72 |
+
sin_emb = torch.sin(angles)
|
73 |
+
return cos_emb, sin_emb
|
74 |
+
|
75 |
+
|
76 |
+
# --- Model Definition ---
|
77 |
+
class FeedForward(nn.Module):
|
78 |
+
def __init__(self, embed_dim, hidden_dim, dropout=0.1):
|
79 |
+
super().__init__()
|
80 |
+
self.fc1 = nn.Linear(embed_dim, hidden_dim)
|
81 |
+
self.fc2 = nn.Linear(hidden_dim, embed_dim)
|
82 |
+
self.dropout = nn.Dropout(dropout)
|
83 |
+
self.act = nn.GELU()
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return self.fc2(self.dropout(self.act(self.fc1(x))))
|
87 |
+
|
88 |
+
|
89 |
+
class DiffTransformerBlock(nn.Module):
|
90 |
+
def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1):
|
91 |
+
super().__init__()
|
92 |
+
self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout)
|
93 |
+
self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout)
|
94 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
95 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
96 |
+
self.dropout = nn.Dropout(dropout)
|
97 |
+
|
98 |
+
def forward(self, x, rel_pos, attn_mask=None):
|
99 |
+
# Pre-norm
|
100 |
+
attn_out = self.attn(self.norm1(x), rel_pos, attn_mask)
|
101 |
+
x = x + self.dropout(attn_out)
|
102 |
+
ffn_out = self.ffn(self.norm2(x))
|
103 |
+
x = x + self.dropout(ffn_out)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class DiffTransformerLLM(nn.Module):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
vocab_size,
|
111 |
+
embed_dim,
|
112 |
+
num_layers,
|
113 |
+
num_heads,
|
114 |
+
ffn_hidden_dim,
|
115 |
+
max_seq_len,
|
116 |
+
dropout=0.1,
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.embed_dim = embed_dim
|
120 |
+
self.max_seq_len = max_seq_len
|
121 |
+
|
122 |
+
self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
|
123 |
+
# Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions
|
124 |
+
self.dropout = nn.Dropout(dropout)
|
125 |
+
|
126 |
+
self.layers = nn.ModuleList(
|
127 |
+
[
|
128 |
+
DiffTransformerBlock(
|
129 |
+
embed_dim, num_heads, depth, ffn_hidden_dim, dropout
|
130 |
+
)
|
131 |
+
for depth in range(num_layers)
|
132 |
+
]
|
133 |
+
)
|
134 |
+
self.norm_out = nn.LayerNorm(embed_dim)
|
135 |
+
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
|
136 |
+
|
137 |
+
# Tie weights
|
138 |
+
self.token_embeddings.weight = self.lm_head.weight
|
139 |
+
|
140 |
+
# RoPE precomputation
|
141 |
+
# The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2
|
142 |
+
self.rope_head_dim = embed_dim // num_heads // 2
|
143 |
+
cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim)
|
144 |
+
self.register_buffer("cos_emb", cos_emb, persistent=False)
|
145 |
+
self.register_buffer("sin_emb", sin_emb, persistent=False)
|
146 |
+
|
147 |
+
def forward(self, input_ids, attn_mask=None):
|
148 |
+
batch_size, seq_len = input_ids.shape
|
149 |
+
|
150 |
+
x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim)
|
151 |
+
x = self.dropout(x)
|
152 |
+
|
153 |
+
# Ensure RoPE embeddings are on the same device *and* dtype as activations
|
154 |
+
rel_pos = (
|
155 |
+
self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype),
|
156 |
+
self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype),
|
157 |
+
)
|
158 |
+
|
159 |
+
# Create causal attention mask if not provided
|
160 |
+
if attn_mask is None:
|
161 |
+
# Standard causal mask for autoregressive decoding
|
162 |
+
# MultiheadDiffAttn expects a mask where -inf indicates masked positions
|
163 |
+
causal_mask = torch.triu(
|
164 |
+
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
|
165 |
+
diagonal=1,
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
# If a custom mask is provided (e.g., for padding), ensure it's correctly formatted
|
169 |
+
# For MultiheadDiffAttn, 0 means attend, -inf means mask.
|
170 |
+
# Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face)
|
171 |
+
# We need to convert it: (1 - attn_mask) * -inf
|
172 |
+
# However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding.
|
173 |
+
# For simplicity, let's assume the provided attn_mask is already in the correct format if not None.
|
174 |
+
# If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it.
|
175 |
+
# Let's stick to causal mask for now, padding handled by loss_fn ignore_index.
|
176 |
+
causal_mask = torch.triu(
|
177 |
+
torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
|
178 |
+
diagonal=1,
|
179 |
+
)
|
180 |
+
|
181 |
+
for layer in self.layers:
|
182 |
+
x = layer(x, rel_pos, attn_mask=causal_mask)
|
183 |
+
|
184 |
+
x = self.norm_out(x)
|
185 |
+
logits = self.lm_head(x)
|
186 |
+
return logits
|
187 |
+
|
188 |
+
def count_parameters(self):
|
189 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
inference/onnx_inference.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnxruntime as ort
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
import argparse
|
6 |
+
from typing import Set, Optional
|
7 |
+
from .model import ByteTokenizer
|
8 |
+
|
9 |
+
sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]
|
10 |
+
|
11 |
+
|
12 |
+
class DRYLogitsProcessor:
|
13 |
+
"""
|
14 |
+
Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
multiplier: float = 0.5,
|
20 |
+
base: float = 2.0,
|
21 |
+
allowed_length: int = 1,
|
22 |
+
sequence_breakers: Optional[Set[int]] = None,
|
23 |
+
range: int = 512,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Args:
|
27 |
+
multiplier: Base penalty multiplier
|
28 |
+
base: Exponential base for penalty calculation
|
29 |
+
allowed_length: Length of sequence that's allowed to repeat without penalty
|
30 |
+
sequence_breakers: Set of token IDs that should break sequence matching
|
31 |
+
range: Number of previous tokens to consider for repetition checking
|
32 |
+
"""
|
33 |
+
self.multiplier = multiplier
|
34 |
+
self.base = base
|
35 |
+
self.allowed_length = allowed_length
|
36 |
+
self.sequence_breakers = sequence_breakers or set()
|
37 |
+
self.range = range
|
38 |
+
|
39 |
+
def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
|
40 |
+
"""
|
41 |
+
Apply DRY penalty to logits.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
input_ids: Array of shape (batch_size, seq_len)
|
45 |
+
scores: Array of shape (vocab_size,) with logits
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Modified scores with penalties applied
|
49 |
+
"""
|
50 |
+
if self.range > 0:
|
51 |
+
input_ids = input_ids[:, -self.range :]
|
52 |
+
|
53 |
+
# Convert to torch tensors for easier manipulation
|
54 |
+
input_tensor = torch.from_numpy(input_ids)
|
55 |
+
scores_tensor = torch.from_numpy(scores)
|
56 |
+
|
57 |
+
for input_ids_row in input_tensor:
|
58 |
+
# Raw integer must be extracted here to check for set membership
|
59 |
+
last_token = input_ids_row[-1].item()
|
60 |
+
|
61 |
+
if last_token in self.sequence_breakers:
|
62 |
+
continue
|
63 |
+
|
64 |
+
# Exclude the last token as it always matches
|
65 |
+
match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False)
|
66 |
+
|
67 |
+
# Stores the maximum matching sequence length for each next token
|
68 |
+
match_lengths = {}
|
69 |
+
|
70 |
+
for i in match_indices.squeeze(1):
|
71 |
+
i = i.item()
|
72 |
+
if i + 1 >= len(input_ids_row):
|
73 |
+
continue
|
74 |
+
|
75 |
+
next_token = input_ids_row[i + 1].item()
|
76 |
+
|
77 |
+
if next_token in self.sequence_breakers:
|
78 |
+
continue
|
79 |
+
|
80 |
+
# We have already found that `last_token` matches at this index,
|
81 |
+
# so the match is at least of length 1.
|
82 |
+
match_length = 1
|
83 |
+
|
84 |
+
# Extend the match backwards as far as possible
|
85 |
+
while True:
|
86 |
+
j = i - match_length
|
87 |
+
if j < 0:
|
88 |
+
break # Start of input reached
|
89 |
+
|
90 |
+
if match_length + 1 > len(input_ids_row):
|
91 |
+
break # End of input reached
|
92 |
+
|
93 |
+
previous_token = input_ids_row[-(match_length + 1)].item()
|
94 |
+
if input_ids_row[j] != previous_token:
|
95 |
+
break # Start of match reached
|
96 |
+
|
97 |
+
if previous_token in self.sequence_breakers:
|
98 |
+
break # Sequence-breaking token reached
|
99 |
+
|
100 |
+
match_length += 1
|
101 |
+
|
102 |
+
# Update the maximum match length for this next token
|
103 |
+
if match_length >= match_lengths.get(next_token, 0):
|
104 |
+
match_lengths[next_token] = match_length
|
105 |
+
|
106 |
+
# Apply penalties
|
107 |
+
for token, match_length in match_lengths.items():
|
108 |
+
if match_length >= self.allowed_length:
|
109 |
+
penalty = self.multiplier * (
|
110 |
+
self.base ** (match_length - self.allowed_length)
|
111 |
+
)
|
112 |
+
scores_tensor[token] -= penalty
|
113 |
+
|
114 |
+
return scores_tensor.numpy()
|
115 |
+
|
116 |
+
|
117 |
+
def generate_text(
|
118 |
+
session,
|
119 |
+
tokenizer,
|
120 |
+
prompt,
|
121 |
+
max_new_tokens=100,
|
122 |
+
temperature=0.8,
|
123 |
+
top_k=25, # There are only 256 bytes total
|
124 |
+
stop_sequences=None,
|
125 |
+
dry_multiplier: float = 0.0, # Set to 0 to disable DRY by default
|
126 |
+
dry_base: float = 2.0,
|
127 |
+
dry_allowed_length: int = 20, # 20 since this is byte level.
|
128 |
+
dry_sequence_breakers: Optional[Set[int]] = None,
|
129 |
+
dry_range: int = 512,
|
130 |
+
):
|
131 |
+
"""Generate text using an ONNX model with DRY sampling and stop sequences."""
|
132 |
+
input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False)
|
133 |
+
input_ids = np.array([input_ids_list], dtype=np.int64)
|
134 |
+
|
135 |
+
generated_token_ids = []
|
136 |
+
start_time = time.time()
|
137 |
+
|
138 |
+
for _ in range(max_new_tokens):
|
139 |
+
seq_len = input_ids.shape[1]
|
140 |
+
|
141 |
+
# Create a causal mask for the current sequence length.
|
142 |
+
causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1)
|
143 |
+
attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32)
|
144 |
+
attn_mask[causal_mask] = -np.inf
|
145 |
+
|
146 |
+
ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask}
|
147 |
+
|
148 |
+
try:
|
149 |
+
ort_outs = session.run(None, ort_inputs)
|
150 |
+
except Exception as e:
|
151 |
+
print(f"ONNX Runtime Error: {e}")
|
152 |
+
# Potentially return or handle the error gracefully
|
153 |
+
return "[ONNX Error]", 0
|
154 |
+
|
155 |
+
logits = ort_outs[0][0, -1, :]
|
156 |
+
|
157 |
+
# Apply DRY penalty if enabled
|
158 |
+
if dry_multiplier > 0:
|
159 |
+
dry_processor = DRYLogitsProcessor(
|
160 |
+
multiplier=dry_multiplier,
|
161 |
+
base=dry_base,
|
162 |
+
allowed_length=dry_allowed_length,
|
163 |
+
sequence_breakers=dry_sequence_breakers,
|
164 |
+
range=dry_range,
|
165 |
+
)
|
166 |
+
logits = dry_processor(input_ids, logits)
|
167 |
+
|
168 |
+
# Apply temperature scaling
|
169 |
+
logits = logits / temperature
|
170 |
+
|
171 |
+
# Apply top-k filtering
|
172 |
+
if top_k > 0:
|
173 |
+
top_k = min(top_k, logits.shape[-1])
|
174 |
+
indices_to_remove = logits.argsort()[:-top_k]
|
175 |
+
logits[indices_to_remove] = -float("inf")
|
176 |
+
|
177 |
+
# Sample from the distribution
|
178 |
+
probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
|
179 |
+
next_token_id = np.random.choice(len(probs), p=probs)
|
180 |
+
|
181 |
+
if next_token_id == tokenizer.im_end_id:
|
182 |
+
break
|
183 |
+
|
184 |
+
input_ids = np.append(input_ids, [[next_token_id]], axis=1)
|
185 |
+
generated_token_ids.append(next_token_id)
|
186 |
+
|
187 |
+
if stop_sequences:
|
188 |
+
current_output = tokenizer.decode(np.array(generated_token_ids))
|
189 |
+
stop_generation = False
|
190 |
+
for seq in stop_sequences:
|
191 |
+
if current_output.endswith(seq):
|
192 |
+
stop_generation = True
|
193 |
+
# Remove the stop sequence from the generated text
|
194 |
+
generated_token_ids = generated_token_ids[: -len(seq)]
|
195 |
+
current_output = tokenizer.decode(np.array(generated_token_ids))
|
196 |
+
break
|
197 |
+
if stop_generation:
|
198 |
+
break
|
199 |
+
|
200 |
+
final_text = tokenizer.decode(np.array(generated_token_ids))
|
201 |
+
tps = len(generated_token_ids) / (time.time() - start_time)
|
202 |
+
return final_text, tps
|
inference/optimized_diffattn.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
# Re-use rotary embedding helper from the original codebase
|
9 |
+
from .rotary import apply_rotary_emb
|
10 |
+
|
11 |
+
# -----------------------------------------------------------------------------
|
12 |
+
# Utility helpers (copied from the original implementation)
|
13 |
+
# -----------------------------------------------------------------------------
|
14 |
+
|
15 |
+
|
16 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
17 |
+
"""Efficiently repeat keys / values for GQA without allocating new memory."""
|
18 |
+
bs, n_kv_heads, slen, head_dim = x.shape
|
19 |
+
if n_rep == 1:
|
20 |
+
return x
|
21 |
+
return (
|
22 |
+
x[:, :, None, :, :]
|
23 |
+
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
24 |
+
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def lambda_init_fn(depth: int) -> float:
|
29 |
+
"""Init schedule described in the DiffAttention paper."""
|
30 |
+
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
31 |
+
|
32 |
+
|
33 |
+
# -----------------------------------------------------------------------------
|
34 |
+
# Optimised Multi-head DiffAttention implementation
|
35 |
+
# -----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
|
38 |
+
class MultiheadDiffAttn(nn.Module):
|
39 |
+
"""Optimised DiffAttention block.
|
40 |
+
|
41 |
+
Differences from the original implementation:
|
42 |
+
1. Removes the dependency on Apex / FusedRMSNorm; uses native LayerNorm.
|
43 |
+
2. Keeps all tensors on-device and works well with autocast fp16/bf16.
|
44 |
+
3. Minimises Python-side tensor reshapes and kernel launches.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
embed_dim: int,
|
50 |
+
depth: int,
|
51 |
+
num_heads: int,
|
52 |
+
num_kv_heads: Optional[int] = None,
|
53 |
+
dropout: float = 0.1,
|
54 |
+
) -> None:
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.embed_dim = embed_dim
|
58 |
+
self.num_heads = num_heads # query heads (will be doubled internally)
|
59 |
+
self.num_kv_heads = num_kv_heads or num_heads
|
60 |
+
self.n_rep = (
|
61 |
+
self.num_heads // self.num_kv_heads
|
62 |
+
) # replication factor for keys / values (GQA)
|
63 |
+
self.attn_dropout = dropout # Store dropout rate for attention
|
64 |
+
|
65 |
+
# One half of a traditional head – DiffAttention uses pairs of heads
|
66 |
+
self.head_dim = embed_dim // self.num_heads // 2
|
67 |
+
assert (
|
68 |
+
self.head_dim * self.num_heads * 2 == embed_dim
|
69 |
+
), "embed_dim must be divisible by num_heads * 2"
|
70 |
+
self.scaling = self.head_dim**-0.5
|
71 |
+
|
72 |
+
# Projections. We keep them separated because K/V are smaller (GQA)
|
73 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
74 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
|
75 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
|
76 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
77 |
+
|
78 |
+
# Add dropout for regularization
|
79 |
+
self.dropout = nn.Dropout(dropout)
|
80 |
+
|
81 |
+
# DiffAttention lambda parameters (learnable)
|
82 |
+
self.lambda_init = lambda_init_fn(depth)
|
83 |
+
self.lambda_q1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
|
84 |
+
self.lambda_k1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
|
85 |
+
self.lambda_q2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
|
86 |
+
self.lambda_k2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
|
87 |
+
|
88 |
+
# Use standard LayerNorm which has a highly-optimised CUDA kernel
|
89 |
+
self.subln = nn.LayerNorm(2 * self.head_dim, eps=1e-5)
|
90 |
+
|
91 |
+
# ---------------------------------------------------------------------
|
92 |
+
# Forward
|
93 |
+
# ---------------------------------------------------------------------
|
94 |
+
def forward(
|
95 |
+
self,
|
96 |
+
x: torch.Tensor, # [bsz, seq_len, embed_dim]
|
97 |
+
rel_pos: tuple[torch.Tensor, torch.Tensor],
|
98 |
+
attn_mask: Optional[torch.Tensor] = None,
|
99 |
+
) -> torch.Tensor:
|
100 |
+
bsz, seq_len, _ = x.size()
|
101 |
+
|
102 |
+
# ---- Projections --------------------------------------------------
|
103 |
+
# Projections (run inside the outer autocast context so they stay in
|
104 |
+
# the low-precision dtype and use tensor cores)
|
105 |
+
q = self.q_proj(x)
|
106 |
+
k = self.k_proj(x)
|
107 |
+
v = self.v_proj(x)
|
108 |
+
|
109 |
+
# Reshape into paired heads (2 × heads)
|
110 |
+
q = q.view(bsz, seq_len, 2 * self.num_heads, self.head_dim)
|
111 |
+
k = k.view(bsz, seq_len, 2 * self.num_kv_heads, self.head_dim)
|
112 |
+
v = v.view(bsz, seq_len, self.num_kv_heads, 2 * self.head_dim)
|
113 |
+
|
114 |
+
# Rotary position encodings (ensure dtype matches q)
|
115 |
+
cos, sin = rel_pos
|
116 |
+
cos = cos.to(dtype=q.dtype)
|
117 |
+
sin = sin.to(dtype=q.dtype)
|
118 |
+
q = apply_rotary_emb(q, cos, sin, interleaved=True)
|
119 |
+
k = apply_rotary_emb(k, cos, sin, interleaved=True)
|
120 |
+
|
121 |
+
# ---- Prepare tensors for matmul ----------------------------------
|
122 |
+
# Shape conventions follow PyTorch’s `scaled_dot_product_attention`:
|
123 |
+
# (bsz, heads, seq, head_dim)
|
124 |
+
q = q.transpose(1, 2) # [bsz, 2*heads, seq, head_dim]
|
125 |
+
k = k.transpose(1, 2) # [bsz, 2*kv_heads, seq, head_dim]
|
126 |
+
v = v.transpose(1, 2) # [bsz, kv_heads, seq, 2*head_dim]
|
127 |
+
|
128 |
+
# Replicate k/v heads when using GQA
|
129 |
+
k = repeat_kv(k, self.n_rep) # [bsz, 2*heads, seq, head_dim]
|
130 |
+
v = repeat_kv(v, self.n_rep) # [bsz, heads, seq, 2*head_dim]
|
131 |
+
|
132 |
+
# ---- Fused scaled dot-product attention (Flash / SDPA) -----------
|
133 |
+
#
|
134 |
+
# We avoid instantiating the full (seq×seq) score matrix. Instead we
|
135 |
+
# run the fused attention kernel twice (positive/negative queries) and
|
136 |
+
# combine the resulting context tensors with the λ weighting. This
|
137 |
+
# keeps everything in fp16/bf16 and leverages Blackwell’s Flash/SDPA
|
138 |
+
# path, giving ~30-80× speed-up vs. the naive implementation.
|
139 |
+
# ------------------------------------------------------------------
|
140 |
+
|
141 |
+
# Re-arrange the paired heads: [bsz, 2*H, S, D] → [bsz, H, 2, S, D]
|
142 |
+
q_pairs = q.view(bsz, 2, self.num_heads, seq_len, self.head_dim).permute(
|
143 |
+
0, 2, 1, 3, 4
|
144 |
+
)
|
145 |
+
k_pairs = k.view(bsz, 2, self.num_heads, seq_len, self.head_dim).permute(
|
146 |
+
0, 2, 1, 3, 4
|
147 |
+
)
|
148 |
+
|
149 |
+
q_pos, q_neg = q_pairs[:, :, 0], q_pairs[:, :, 1] # [bsz, H, S, D]
|
150 |
+
k_pos, k_neg = k_pairs[:, :, 0], k_pairs[:, :, 1]
|
151 |
+
|
152 |
+
# λ scalar (identical across heads / sequence)
|
153 |
+
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)).type_as(q_pos)
|
154 |
+
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)).type_as(q_pos)
|
155 |
+
lambda_full = lambda_1 - lambda_2 + self.lambda_init # scalar tensor
|
156 |
+
|
157 |
+
# --- Fused attention (only TWO SDPA calls) -------------------------
|
158 |
+
ctx_pos = F.scaled_dot_product_attention(
|
159 |
+
q_pos, k_pos, v, dropout_p=self.attn_dropout, is_causal=True
|
160 |
+
) # [bsz, H, S, 2*D]
|
161 |
+
ctx_neg = F.scaled_dot_product_attention(
|
162 |
+
q_neg, k_neg, v, dropout_p=self.attn_dropout, is_causal=True
|
163 |
+
) # [bsz, H, S, 2*D]
|
164 |
+
|
165 |
+
# DiffAttention combination
|
166 |
+
attn_out = ctx_pos - lambda_full * ctx_neg # [bsz, H, S, 2*D]
|
167 |
+
|
168 |
+
# LayerNorm & residual scaling
|
169 |
+
attn_out = self.subln(attn_out) * (1.0 - self.lambda_init)
|
170 |
+
|
171 |
+
# Collapse heads and project out
|
172 |
+
attn_out = attn_out.transpose(1, 2).reshape( # [bsz, seq, heads, 2*head_dim]
|
173 |
+
bsz, seq_len, self.embed_dim
|
174 |
+
)
|
175 |
+
# Apply output projection and dropout
|
176 |
+
out = self.out_proj(attn_out)
|
177 |
+
return self.dropout(out)
|
inference/rotary.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def apply_rotary_emb_torch(
|
9 |
+
x,
|
10 |
+
cos,
|
11 |
+
sin,
|
12 |
+
interleaved=False,
|
13 |
+
inplace=False,
|
14 |
+
seqlen_offsets=0,
|
15 |
+
cu_seqlens=None,
|
16 |
+
max_seqlen=None,
|
17 |
+
):
|
18 |
+
# Only supports the basic (not interleaved, not variable-length) case.
|
19 |
+
rotary_dim = cos.shape[1] * 2
|
20 |
+
x1 = x[..., :rotary_dim]
|
21 |
+
x2 = x[..., rotary_dim:]
|
22 |
+
|
23 |
+
# Split [even, odd] pairs
|
24 |
+
x1_1, x1_2 = x1[..., ::2], x1[..., 1::2] # (..., rotary_dim/2)
|
25 |
+
|
26 |
+
# Reshape cos/sin for broadcasting
|
27 |
+
# x: [batch, seqlen, nheads, rotary_dim]
|
28 |
+
# cos/sin: [seqlen, rotary_dim/2]
|
29 |
+
# reshape to [1, seqlen, 1, rotary_dim/2] to broadcast
|
30 |
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
31 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
32 |
+
|
33 |
+
rot_x1 = x1_1 * cos - x1_2 * sin
|
34 |
+
rot_x2 = x1_1 * sin + x1_2 * cos
|
35 |
+
# Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim)
|
36 |
+
rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1)
|
37 |
+
out = torch.cat([rot_x, x2], dim=-1)
|
38 |
+
return out
|
39 |
+
|
40 |
+
|
41 |
+
def apply_rotary_emb(
|
42 |
+
x,
|
43 |
+
cos,
|
44 |
+
sin,
|
45 |
+
interleaved=False,
|
46 |
+
inplace=False,
|
47 |
+
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
48 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
49 |
+
max_seqlen: Optional[int] = None,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Arguments:
|
53 |
+
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
54 |
+
else (total_seqlen, nheads, headdim)
|
55 |
+
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
56 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
57 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
58 |
+
inplace: if True, apply rotary embedding in-place.
|
59 |
+
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
60 |
+
Most commonly used in inference when we have KV cache.
|
61 |
+
cu_seqlens: (batch + 1,) or None
|
62 |
+
max_seqlen: int
|
63 |
+
Return:
|
64 |
+
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
65 |
+
else (total_seqlen, nheads, headdim)
|
66 |
+
rotary_dim must be <= headdim
|
67 |
+
Apply rotary embedding to the first rotary_dim of x.
|
68 |
+
"""
|
69 |
+
# We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`)
|
70 |
+
# for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph
|
71 |
+
# break in `torch.compile`, pushing expensive operations to the CPU.
|
72 |
+
# By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized
|
73 |
+
# graph, which should resolve the CPU bottleneck and improve GPU utilization.
|
74 |
+
return apply_rotary_emb_torch(
|
75 |
+
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
76 |
+
)
|
models/small.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:06daa397631d28d8b2c1eee51f0f992c4e69927cc770a20d8ed5e2c40f95cc33
|
3 |
+
size 796014268
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
torch>=2.2.0
|
test-trad.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inference.inference import generate_text, list_checkpoints, load_model
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from inference.model import ByteTokenizer
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser(
|
9 |
+
description="Text generation with DiffAttention LLM"
|
10 |
+
)
|
11 |
+
parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
|
12 |
+
parser.add_argument(
|
13 |
+
"--prompt",
|
14 |
+
type=str,
|
15 |
+
default="""<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n""",
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"--max_tokens",
|
19 |
+
type=int,
|
20 |
+
default=500,
|
21 |
+
help="Maximum number of tokens to generate",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--temperature", type=float, default=0.7, help="Sampling temperature"
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--top_k", type=int, default=1, help="Top-k sampling parameter (0 to disable)"
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--top_p",
|
31 |
+
type=float,
|
32 |
+
default=0.9,
|
33 |
+
help="Top-p (nucleus) sampling parameter (0 to disable)",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--repetition_penalty",
|
37 |
+
type=float,
|
38 |
+
default=1.0,
|
39 |
+
help="Repetition penalty (1.0 for no penalty)",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--list_checkpoints",
|
43 |
+
action="store_true",
|
44 |
+
help="List available checkpoints and exit",
|
45 |
+
)
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
# List checkpoints if requested
|
49 |
+
if args.list_checkpoints:
|
50 |
+
print("Available checkpoints:")
|
51 |
+
checkpoints = list_checkpoints()
|
52 |
+
for i, ckpt in enumerate(checkpoints):
|
53 |
+
print(f"{i+1}. {ckpt}")
|
54 |
+
return
|
55 |
+
|
56 |
+
# If no checkpoint specified, use the latest one
|
57 |
+
if not args.checkpoint:
|
58 |
+
checkpoints = list_checkpoints()
|
59 |
+
if not checkpoints:
|
60 |
+
print("No checkpoints found. Please train the model first.")
|
61 |
+
return
|
62 |
+
|
63 |
+
# Find the latest epoch_end checkpoint
|
64 |
+
end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
|
65 |
+
if end_checkpoints:
|
66 |
+
latest_checkpoint = max(end_checkpoints)
|
67 |
+
else:
|
68 |
+
latest_checkpoint = max(checkpoints)
|
69 |
+
|
70 |
+
checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
|
71 |
+
else:
|
72 |
+
checkpoint_path = args.checkpoint
|
73 |
+
|
74 |
+
# Set device
|
75 |
+
device = torch.device(
|
76 |
+
"cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
|
77 |
+
)
|
78 |
+
print(f"Using device: {device}")
|
79 |
+
|
80 |
+
# Initialize tokenizer
|
81 |
+
tokenizer = ByteTokenizer()
|
82 |
+
|
83 |
+
# Load model
|
84 |
+
model = load_model(checkpoint_path, device)
|
85 |
+
|
86 |
+
# Generate text
|
87 |
+
print(f"\nGenerating text with prompt: '{args.prompt}'")
|
88 |
+
print(
|
89 |
+
f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
|
90 |
+
)
|
91 |
+
print("\nGenerating...")
|
92 |
+
|
93 |
+
generated_text, full_text = generate_text(
|
94 |
+
model=model,
|
95 |
+
tokenizer=tokenizer,
|
96 |
+
prompt=args.prompt,
|
97 |
+
max_new_tokens=args.max_tokens,
|
98 |
+
temperature=args.temperature,
|
99 |
+
top_k=args.top_k,
|
100 |
+
top_p=args.top_p,
|
101 |
+
repetition_penalty=args.repetition_penalty,
|
102 |
+
device=device,
|
103 |
+
)
|
104 |
+
|
105 |
+
print("\n\nGenerated completion only:")
|
106 |
+
print("-" * 40)
|
107 |
+
print(generated_text)
|
108 |
+
print("-" * 40)
|
109 |
+
|
110 |
+
print("\nFull generated text (prompt + completion):")
|
111 |
+
print("-" * 40)
|
112 |
+
print(full_text)
|
113 |
+
print("-" * 40)
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
import argparse
|
118 |
+
|
119 |
+
main()
|
test.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inference.onnx_inference import generate_text
|
2 |
+
import argparse
|
3 |
+
import onnxruntime as ort
|
4 |
+
from inference.model import ByteTokenizer
|
5 |
+
|
6 |
+
sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
parser = argparse.ArgumentParser(
|
11 |
+
description="Inference with ONNX DiffTransformerLLM"
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
"--onnx_path", type=str, default="models/small.onnx", help="Path to ONNX model"
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
"--prompt",
|
18 |
+
type=str,
|
19 |
+
default="<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
|
20 |
+
help="Prompt for the model",
|
21 |
+
)
|
22 |
+
parser.add_argument("--max_tokens", type=int, default=100, help="Max new tokens")
|
23 |
+
parser.add_argument(
|
24 |
+
"--temperature", type=float, default=0.7, help="Temperature for sampling"
|
25 |
+
)
|
26 |
+
parser.add_argument("--top_k", type=int, default=1, help="Top-k for sampling")
|
27 |
+
parser.add_argument(
|
28 |
+
"--stop_sequence", type=str, action="append", help="Stop sequence(s)"
|
29 |
+
)
|
30 |
+
# DRY sampling args
|
31 |
+
parser.add_argument(
|
32 |
+
"--dry_range", type=int, default=1024, help="Range for DRY sampling"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--dry_allowed_length",
|
36 |
+
type=int,
|
37 |
+
default=17,
|
38 |
+
help="Allowed repeat length for DRY sampling",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--dry_base", type=float, default=1.1, help="Base for DRY penalty"
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--dry_multiplier", type=float, default=0.0, help="Multiplier for DRY penalty"
|
45 |
+
)
|
46 |
+
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
print(f"Loading ONNX model from {args.onnx_path}")
|
50 |
+
session = ort.InferenceSession(args.onnx_path, providers=["CPUExecutionProvider"])
|
51 |
+
tokenizer = ByteTokenizer()
|
52 |
+
|
53 |
+
sequence_breaker_ids = {tokenizer.im_start_id, tokenizer.im_end_id}
|
54 |
+
for s in sequence_breaker_strings:
|
55 |
+
# These are single-byte tokens, so encode will return a list with one ID
|
56 |
+
sequence_breaker_ids.add(tokenizer.encode(s.encode("utf-8"))[0])
|
57 |
+
|
58 |
+
print(f"Prompt: {args.prompt}")
|
59 |
+
print("--- Output ---")
|
60 |
+
generated_text, tps = generate_text(
|
61 |
+
session,
|
62 |
+
tokenizer,
|
63 |
+
args.prompt,
|
64 |
+
max_new_tokens=args.max_tokens,
|
65 |
+
temperature=args.temperature,
|
66 |
+
top_k=args.top_k,
|
67 |
+
stop_sequences=["<|im_end|>".encode("utf-8")],
|
68 |
+
dry_sequence_breakers=sequence_breaker_ids,
|
69 |
+
dry_range=args.dry_range,
|
70 |
+
dry_allowed_length=args.dry_allowed_length,
|
71 |
+
dry_base=args.dry_base,
|
72 |
+
dry_multiplier=args.dry_multiplier,
|
73 |
+
)
|
74 |
+
print(generated_text)
|
75 |
+
print(generated_text.decode("utf-8", "ignore"))
|
76 |
+
print("--------------")
|
77 |
+
print(f"\nPerformance: {tps:.2f} tokens/second")
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
main()
|