hudsongouge commited on
Commit
adf0368
·
0 Parent(s):

Update space

Browse files
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DAT Byte
3
+ emoji: 💬
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.34.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: A chat interface for the DAT Byte LLM.
12
+ ---
13
+
14
+ An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import onnxruntime as ort
4
+ from inference.onnx_inference import generate_text, sequence_breaker_strings
5
+ from inference.model import ByteTokenizer
6
+
7
+ # --- Globals ---
8
+ MODEL_OPTIONS = [
9
+ ("DAT-Byte Small (200M)", "small", True),
10
+ ("DAT-Byte Medium", "medium", False),
11
+ ("DAT-Byte Large", "large", False),
12
+ ]
13
+
14
+ ONNX_PATH = "models/small.onnx" # Assumes model.onnx is in the root directory
15
+
16
+ # Cache for the ONNX session
17
+ SESSION_CACHE = {}
18
+ TOKENIZER = ByteTokenizer()
19
+
20
+ # Prepare sequence breakers
21
+ SEQUENCE_BREAKER_IDS = {TOKENIZER.im_start_id, TOKENIZER.im_end_id}
22
+ for s in sequence_breaker_strings:
23
+ # These are single-byte tokens, so encode will return a list with one ID
24
+ try:
25
+ SEQUENCE_BREAKER_IDS.add(TOKENIZER.encode(s.encode("utf-8"))[0])
26
+ except IndexError:
27
+ print(f"Warning: Could not encode sequence breaker string: {s}")
28
+
29
+
30
+ # --- Model Loading ---
31
+ def get_session(model_key):
32
+ if model_key != "small":
33
+ raise ValueError("Only DAT-Byte Small is available.")
34
+ if model_key not in SESSION_CACHE:
35
+ if not os.path.exists(ONNX_PATH):
36
+ raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}")
37
+ # Using CPUExecutionProvider as per the project's goal
38
+ SESSION_CACHE[model_key] = ort.InferenceSession(
39
+ ONNX_PATH, providers=["CPUExecutionProvider"]
40
+ )
41
+ return SESSION_CACHE[model_key]
42
+
43
+
44
+ # --- Gradio Callbacks ---
45
+ def chat_respond(
46
+ message,
47
+ history,
48
+ model_name,
49
+ max_tokens,
50
+ temperature,
51
+ top_k,
52
+ dry_range,
53
+ dry_allowed_length,
54
+ dry_base,
55
+ dry_multiplier,
56
+ user_role="user",
57
+ assistant_role="assistant",
58
+ ):
59
+ model_key = next(
60
+ (key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
61
+ None,
62
+ )
63
+ if not model_key:
64
+ history.append({"role": "user", "content": message})
65
+ history.append(
66
+ {"role": "assistant", "content": f"Model '{model_name}' is not available."}
67
+ )
68
+ return history
69
+
70
+ history = history or []
71
+ try:
72
+ session = get_session(model_key)
73
+ except Exception as e:
74
+ history.append({"role": "user", "content": message})
75
+ history.append(
76
+ {"role": "assistant", "content": f"[Model loading error: {str(e)}]"}
77
+ )
78
+ return history
79
+
80
+ prompt = ""
81
+ for turn in history:
82
+ prompt += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
83
+ prompt += (
84
+ f"<|im_start|>{user_role}\n{message}<|im_end|>\n<|im_start|>{assistant_role}\n"
85
+ )
86
+
87
+ generated_text, _ = generate_text(
88
+ session=session,
89
+ tokenizer=TOKENIZER,
90
+ prompt=prompt,
91
+ max_new_tokens=max_tokens,
92
+ temperature=temperature,
93
+ top_k=top_k,
94
+ stop_sequences=["<|im_end|>".encode("utf-8")],
95
+ dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
96
+ dry_range=dry_range,
97
+ dry_allowed_length=dry_allowed_length,
98
+ dry_base=dry_base,
99
+ dry_multiplier=dry_multiplier,
100
+ )
101
+ generated_text = generated_text.decode("utf-8", "ignore")
102
+
103
+ history.append({"role": "user", "content": message})
104
+ history.append({"role": "assistant", "content": generated_text})
105
+ return history
106
+
107
+
108
+ def completion_respond(
109
+ prompt,
110
+ model_name,
111
+ max_tokens,
112
+ temperature,
113
+ top_k,
114
+ dry_range,
115
+ dry_allowed_length,
116
+ dry_base,
117
+ dry_multiplier,
118
+ ):
119
+ model_key = next(
120
+ (key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled),
121
+ None,
122
+ )
123
+ if not model_key:
124
+ return f"[Model '{model_name}' is not available or unknown.]"
125
+
126
+ try:
127
+ session = get_session(model_key)
128
+ except Exception as e:
129
+ return f"[Model loading error: {str(e)}]"
130
+
131
+ generated_text, _ = generate_text(
132
+ session=session,
133
+ tokenizer=TOKENIZER,
134
+ prompt=prompt,
135
+ max_new_tokens=max_tokens,
136
+ temperature=temperature,
137
+ top_k=top_k,
138
+ dry_sequence_breakers=SEQUENCE_BREAKER_IDS,
139
+ dry_range=dry_range,
140
+ dry_allowed_length=dry_allowed_length,
141
+ dry_base=dry_base,
142
+ dry_multiplier=dry_multiplier,
143
+ )
144
+ return generated_text
145
+
146
+
147
+ # --- Gradio UI ---
148
+ with gr.Blocks() as demo:
149
+ gr.Markdown("# DAT-Byte Playground (ONNX Accelerated)")
150
+ with gr.Row():
151
+ with gr.Column(scale=1):
152
+ model_selector = gr.Radio(
153
+ [opt[0] for opt in MODEL_OPTIONS],
154
+ value=MODEL_OPTIONS[0][0],
155
+ label="Model",
156
+ interactive=True,
157
+ )
158
+ gr.Markdown("**Note:** Only DAT-Byte Small is currently available.")
159
+ mode_selector = gr.Radio(
160
+ ["Chat", "Raw Completion"], value="Chat", label="Mode"
161
+ )
162
+ max_tokens = gr.Slider(
163
+ minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
164
+ )
165
+ temperature = gr.Slider(
166
+ minimum=0.05, maximum=2.0, value=0.5, step=0.05, label="Temperature"
167
+ )
168
+ top_k = gr.Slider(minimum=0, maximum=256, value=15, step=1, label="Top-k")
169
+ with gr.Accordion("DRY Sampling (Don't Repeat Yourself)", open=False):
170
+ dry_range = gr.Slider(
171
+ minimum=0, maximum=2048, value=1024, step=32, label="Range"
172
+ )
173
+ dry_allowed_length = gr.Slider(
174
+ minimum=1, maximum=64, value=20, step=1, label="Allowed Length"
175
+ )
176
+ dry_base = gr.Slider(
177
+ minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Base"
178
+ )
179
+ dry_multiplier = gr.Slider(
180
+ minimum=0.0, maximum=2.0, value=0.0, step=0.05, label="Multiplier"
181
+ )
182
+ user_role_box = gr.Textbox("user", label="User Role", visible=True)
183
+ assistant_role_box = gr.Textbox(
184
+ "assistant", label="Assistant Role", visible=True
185
+ )
186
+
187
+ with gr.Column(scale=3):
188
+ chatbot = gr.Chatbot(label="Chat", type="messages", height=600)
189
+ with gr.Row():
190
+ chat_input = gr.Textbox(
191
+ label="Message", placeholder="Type a message...", scale=4
192
+ )
193
+ send_button = gr.Button("Send", scale=1)
194
+ completion_input = gr.Textbox(label="Prompt", visible=False)
195
+ completion_output = gr.Textbox(label="Completion", visible=False)
196
+
197
+ # UI Logic
198
+ def update_mode(mode):
199
+ is_chat = mode == "Chat"
200
+ return (
201
+ gr.update(visible=is_chat), # chatbot
202
+ gr.update(visible=is_chat), # chat_input row
203
+ gr.update(visible=not is_chat), # completion_input
204
+ gr.update(visible=not is_chat), # completion_output
205
+ gr.update(visible=is_chat), # user_role_box
206
+ gr.update(visible=is_chat), # assistant_role_box
207
+ )
208
+
209
+ mode_selector.change(
210
+ update_mode,
211
+ [mode_selector],
212
+ [
213
+ chatbot,
214
+ chat_input.parent,
215
+ completion_input,
216
+ completion_output,
217
+ user_role_box,
218
+ assistant_role_box,
219
+ ],
220
+ )
221
+
222
+ # Event Handlers
223
+ chat_inputs = [
224
+ chat_input,
225
+ chatbot,
226
+ model_selector,
227
+ max_tokens,
228
+ temperature,
229
+ top_k,
230
+ dry_range,
231
+ dry_allowed_length,
232
+ dry_base,
233
+ dry_multiplier,
234
+ user_role_box,
235
+ assistant_role_box,
236
+ ]
237
+ chat_args = {"fn": chat_respond, "inputs": chat_inputs, "outputs": [chatbot]}
238
+
239
+ def clear_input():
240
+ return ""
241
+
242
+ clear_args = {"fn": clear_input, "inputs": [], "outputs": [chat_input]}
243
+
244
+ send_button.click(**chat_args).then(**clear_args)
245
+ chat_input.submit(**chat_args).then(**clear_args)
246
+
247
+ completion_inputs = [
248
+ completion_input,
249
+ model_selector,
250
+ max_tokens,
251
+ temperature,
252
+ top_k,
253
+ dry_range,
254
+ dry_allowed_length,
255
+ dry_base,
256
+ dry_multiplier,
257
+ ]
258
+ completion_input.submit(
259
+ completion_respond,
260
+ completion_inputs,
261
+ [completion_output],
262
+ )
263
+
264
+ if __name__ == "__main__":
265
+ demo.launch()
commit.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ git commit -am 'Update space' && git push
export_onnx.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from inference.model import DiffTransformerLLM
3
+ from inference.inference import load_model
4
+ import argparse
5
+ import os
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description="Export DiffTransformerLLM to ONNX")
10
+ parser.add_argument(
11
+ "--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt)"
12
+ )
13
+ parser.add_argument(
14
+ "--onnx_path", type=str, default="model.onnx", help="Output ONNX file path"
15
+ )
16
+ parser.add_argument(
17
+ "--seq_len", type=int, default=32, help="Dummy input sequence length"
18
+ )
19
+ args = parser.parse_args()
20
+
21
+ device = torch.device("cpu")
22
+ print(f"Loading model from {args.checkpoint}")
23
+ model = load_model(args.checkpoint, device=device, fp16=False, quantize=False)
24
+ model.eval()
25
+
26
+ # Prepare dummy input
27
+ batch_size = 1
28
+ seq_len = args.seq_len
29
+ input_ids = torch.randint(0, 259, (batch_size, seq_len), dtype=torch.long)
30
+
31
+ # Create a dummy causal mask. This will be a dynamic input to the ONNX model.
32
+ causal_mask = torch.triu(
33
+ torch.ones(1, seq_len, seq_len, dtype=torch.bool), diagonal=1
34
+ )
35
+ attn_mask = torch.zeros(1, seq_len, seq_len, dtype=torch.float32)
36
+ attn_mask.masked_fill_(causal_mask, float("-inf"))
37
+
38
+ # Export to ONNX
39
+ print(f"Exporting to ONNX: {args.onnx_path}")
40
+ torch.onnx.export(
41
+ model,
42
+ (input_ids, attn_mask),
43
+ args.onnx_path,
44
+ input_names=["input_ids", "attn_mask"],
45
+ output_names=["logits"],
46
+ dynamic_axes={
47
+ "input_ids": {0: "batch_size", 1: "seq_len"},
48
+ "attn_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"},
49
+ "logits": {0: "batch_size", 1: "seq_len"},
50
+ },
51
+ opset_version=17,
52
+ do_constant_folding=True,
53
+ )
54
+ print(f"ONNX export complete: {args.onnx_path}")
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
inference/__init__.py ADDED
File without changes
inference/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (155 Bytes). View file
 
inference/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (155 Bytes). View file
 
inference/__pycache__/inference.cpython-312.pyc ADDED
Binary file (11.4 kB). View file
 
inference/__pycache__/inference.cpython-313.pyc ADDED
Binary file (11.8 kB). View file
 
inference/__pycache__/model.cpython-312.pyc ADDED
Binary file (9.27 kB). View file
 
inference/__pycache__/model.cpython-313.pyc ADDED
Binary file (9.51 kB). View file
 
inference/__pycache__/onnx_inference.cpython-312.pyc ADDED
Binary file (5.24 kB). View file
 
inference/__pycache__/onnx_inference.cpython-313.pyc ADDED
Binary file (7.44 kB). View file
 
inference/__pycache__/optimized_diffattn.cpython-312.pyc ADDED
Binary file (7.76 kB). View file
 
inference/__pycache__/optimized_diffattn.cpython-313.pyc ADDED
Binary file (7.79 kB). View file
 
inference/__pycache__/rotary.cpython-312.pyc ADDED
Binary file (2.52 kB). View file
 
inference/__pycache__/rotary.cpython-313.pyc ADDED
Binary file (2.41 kB). View file
 
inference/inference.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import os
4
+ import torch.quantization
5
+ from .model import (
6
+ DiffTransformerLLM,
7
+ ByteTokenizer,
8
+ IM_START_TOKEN,
9
+ IM_END_TOKEN,
10
+ PAD_TOKEN,
11
+ )
12
+
13
+ force_CPU = True
14
+
15
+
16
+ def list_checkpoints(checkpoint_dir="checkpoints"):
17
+ """List all available checkpoints in the directory."""
18
+ if not os.path.exists(checkpoint_dir):
19
+ print(f"Checkpoint directory {checkpoint_dir} not found.")
20
+ return []
21
+
22
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
23
+ return sorted(checkpoints)
24
+
25
+
26
+ def load_model(checkpoint_path, device=None, fp16=True):
27
+ """Load a trained model from a checkpoint, applying optimizations as needed."""
28
+ import torch
29
+
30
+ if device is None:
31
+ device = torch.device(
32
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
33
+ )
34
+
35
+ print(f"Loading checkpoint from {checkpoint_path}")
36
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
37
+
38
+ # Hyperparams
39
+ vocab_size = 259 # 256 bytes + 3 special tokens
40
+ embed_dim = 768
41
+ num_layers = 28
42
+ num_heads = 12
43
+ ffn_hidden_dim = embed_dim * 4
44
+ max_seq_len = 512
45
+ dropout = 0.1 # For inference you can set dropout=0
46
+
47
+ # Model
48
+ model = DiffTransformerLLM(
49
+ vocab_size=vocab_size,
50
+ embed_dim=embed_dim,
51
+ num_layers=num_layers,
52
+ num_heads=num_heads,
53
+ ffn_hidden_dim=ffn_hidden_dim,
54
+ max_seq_len=max_seq_len,
55
+ dropout=dropout,
56
+ )
57
+
58
+ # The checkpoint is the state dict itself
59
+ state_dict = checkpoint
60
+
61
+ # Load the state dict into the float32 model first
62
+ model.load_state_dict(state_dict)
63
+ model.eval()
64
+
65
+ # Apply device-specific optimizations
66
+ if device.type == "cpu":
67
+ print("Optimizing for CPU with dynamic quantization (int8).")
68
+ # Set the quantization engine
69
+ torch.backends.quantized.engine = "qnnpack"
70
+ # Quantize the linear layers to int8 for performance
71
+ model = torch.quantization.quantize_dynamic(
72
+ model, {torch.nn.Linear}, dtype=torch.qint8
73
+ )
74
+ elif device.type == "cuda" and fp16:
75
+ print("Casting model to fp16 for CUDA.")
76
+ model = model.half()
77
+
78
+ model = model.to(device)
79
+
80
+ print("Model loaded successfully.")
81
+ return model
82
+
83
+
84
+ def generate_text(
85
+ model,
86
+ tokenizer,
87
+ prompt,
88
+ max_new_tokens=100,
89
+ temperature=1.0,
90
+ top_k=0,
91
+ top_p=0.9,
92
+ repetition_penalty=1.0,
93
+ device=None,
94
+ stop_sequences=[],
95
+ ):
96
+ """
97
+ Generate text from a prompt using the trained model.
98
+
99
+ Args:
100
+ model: The trained DiffTransformerLLM model
101
+ tokenizer: ByteTokenizer instance
102
+ prompt: Text prompt to start generation (as a string)
103
+ max_new_tokens: Maximum number of new tokens to generate
104
+ temperature: Controls randomness. Lower is more deterministic.
105
+ top_k: If > 0, only sample from the top k most likely tokens
106
+ top_p: If > 0, sample from the smallest set of tokens whose cumulative probability exceeds p
107
+ repetition_penalty: Penalize repetition. 1.0 means no penalty.
108
+ device: Device to run inference on
109
+
110
+ Returns:
111
+ The generated text as a string
112
+ """
113
+ if device is None:
114
+ device = torch.device(
115
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
116
+ )
117
+
118
+ # Convert prompt to bytes and tokenize - process as-is without adding special tokens
119
+ prompt_bytes = prompt.encode("utf-8", errors="replace")
120
+ input_ids = (
121
+ torch.tensor(
122
+ tokenizer.encode(prompt_bytes, add_special_tokens=False), dtype=torch.long
123
+ )
124
+ .unsqueeze(0)
125
+ .to(device)
126
+ )
127
+ stop_sequences = [
128
+ tokenizer.encode(
129
+ seq.encode("utf-8", errors="replace"), add_special_tokens=False
130
+ )
131
+ for seq in stop_sequences
132
+ ]
133
+
134
+ # Track generated token IDs
135
+ generated_ids = input_ids.clone()
136
+ generated_bytes = b""
137
+
138
+ # Set the model to evaluation mode
139
+ model.eval()
140
+
141
+ with torch.no_grad():
142
+ for _ in range(max_new_tokens):
143
+ # Only use the last max_seq_len tokens if we exceed the model's context length
144
+ if generated_ids.size(1) > model.max_seq_len:
145
+ input_ids = generated_ids[:, -model.max_seq_len :]
146
+ else:
147
+ input_ids = generated_ids
148
+
149
+ # Forward pass to get logits for the next token
150
+ logits = model(input_ids)
151
+
152
+ # Get logits for the next token (last position)
153
+ next_token_logits = logits[:, -1, :].squeeze(0)
154
+
155
+ # Apply temperature
156
+ if temperature > 0:
157
+ next_token_logits = next_token_logits / temperature
158
+
159
+ # Apply repetition penalty
160
+ if repetition_penalty > 1.0:
161
+ for token_id in set(generated_ids[0].tolist()):
162
+ next_token_logits[token_id] /= repetition_penalty
163
+
164
+ # Apply top-k filtering
165
+ if top_k > 0:
166
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
167
+ next_token_logits = torch.full_like(next_token_logits, float("-inf"))
168
+ next_token_logits.scatter_(0, top_k_indices, top_k_logits)
169
+
170
+ # Apply top-p (nucleus) filtering
171
+ if 0 < top_p < 1.0:
172
+ sorted_logits, sorted_indices = torch.sort(
173
+ next_token_logits, descending=True
174
+ )
175
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=0), dim=0)
176
+
177
+ # Remove tokens with cumulative probability above the threshold
178
+ sorted_indices_to_remove = cumulative_probs > top_p
179
+ # Shift the indices to the right to keep the first token above the threshold
180
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
181
+ ..., :-1
182
+ ].clone()
183
+ sorted_indices_to_remove[..., 0] = 0
184
+
185
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
186
+ next_token_logits[indices_to_remove] = float("-inf")
187
+
188
+ # Sample from the filtered distribution
189
+ probs = F.softmax(next_token_logits, dim=0)
190
+ next_token = torch.multinomial(probs, 1)
191
+
192
+ # Append the generated token to the sequence
193
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
194
+ # Check if IM_END_TOKEN has been generated
195
+ token_bytes = tokenizer.decode([next_token.item()])
196
+ generated_bytes += token_bytes
197
+ try:
198
+ print(token_bytes.decode("utf-8", errors="replace"), end="", flush=True)
199
+ except Exception as e:
200
+ print(f"<Error decoding token: {e}>", end="", flush=True)
201
+ stop_generated = False
202
+ stop_seq = None
203
+ for stop_seq in stop_sequences:
204
+ if generated_ids.tolist()[0][-len(stop_seq) :] == stop_seq:
205
+ stop_generated = True
206
+ break
207
+ if stop_generated:
208
+ # Remove the stop sequence from the generated IDs
209
+ generated_ids = generated_ids[:, : -len(stop_seq)]
210
+ generated_bytes = generated_bytes[: -len(stop_seq)]
211
+ break
212
+
213
+ # Decode to bytes and then to string
214
+ try:
215
+ generated_text = generated_bytes.decode("utf-8", errors="replace")
216
+ except Exception as e:
217
+ print(f"\nError decoding generated text: {e}")
218
+ generated_text = "<decoding error>"
219
+
220
+ return generated_text, prompt + generated_text
221
+
222
+
223
+ def main():
224
+ parser = argparse.ArgumentParser(
225
+ description="Text generation with DiffAttention LLM"
226
+ )
227
+ parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
228
+ parser.add_argument(
229
+ "--prompt",
230
+ type=str,
231
+ default="""\nHow many 'b's are in "barber"? \n""",
232
+ )
233
+ parser.add_argument(
234
+ "--max_tokens",
235
+ type=int,
236
+ default=500,
237
+ help="Maximum number of tokens to generate",
238
+ )
239
+ parser.add_argument(
240
+ "--temperature", type=float, default=0.7, help="Sampling temperature"
241
+ )
242
+ parser.add_argument(
243
+ "--top_k", type=int, default=10, help="Top-k sampling parameter (0 to disable)"
244
+ )
245
+ parser.add_argument(
246
+ "--top_p",
247
+ type=float,
248
+ default=0.9,
249
+ help="Top-p (nucleus) sampling parameter (0 to disable)",
250
+ )
251
+ parser.add_argument(
252
+ "--repetition_penalty",
253
+ type=float,
254
+ default=1.2,
255
+ help="Repetition penalty (1.0 for no penalty)",
256
+ )
257
+ parser.add_argument(
258
+ "--list_checkpoints",
259
+ action="store_true",
260
+ help="List available checkpoints and exit",
261
+ )
262
+ args = parser.parse_args()
263
+
264
+ # List checkpoints if requested
265
+ if args.list_checkpoints:
266
+ print("Available checkpoints:")
267
+ checkpoints = list_checkpoints()
268
+ for i, ckpt in enumerate(checkpoints):
269
+ print(f"{i+1}. {ckpt}")
270
+ return
271
+
272
+ # If no checkpoint specified, use the latest one
273
+ if not args.checkpoint:
274
+ checkpoints = list_checkpoints()
275
+ if not checkpoints:
276
+ print("No checkpoints found. Please train the model first.")
277
+ return
278
+
279
+ # Find the latest epoch_end checkpoint
280
+ end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
281
+ if end_checkpoints:
282
+ latest_checkpoint = max(end_checkpoints)
283
+ else:
284
+ latest_checkpoint = max(checkpoints)
285
+
286
+ checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
287
+ else:
288
+ checkpoint_path = args.checkpoint
289
+
290
+ # Set device
291
+ device = torch.device(
292
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
293
+ )
294
+ print(f"Using device: {device}")
295
+
296
+ # Initialize tokenizer
297
+ tokenizer = ByteTokenizer()
298
+
299
+ # Load model
300
+ model = load_model(checkpoint_path, device)
301
+
302
+ # Generate text
303
+ print(f"\nGenerating text with prompt: '{args.prompt}'")
304
+ print(
305
+ f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
306
+ )
307
+ print("\nGenerating...")
308
+
309
+ generated_text, full_text = generate_text(
310
+ model=model,
311
+ tokenizer=tokenizer,
312
+ prompt=args.prompt,
313
+ max_new_tokens=args.max_tokens,
314
+ temperature=args.temperature,
315
+ top_k=args.top_k,
316
+ top_p=args.top_p,
317
+ repetition_penalty=args.repetition_penalty,
318
+ device=device,
319
+ )
320
+
321
+ print("\n\nGenerated completion only:")
322
+ print("-" * 40)
323
+ print(generated_text)
324
+ print("-" * 40)
325
+
326
+ print("\nFull generated text (prompt + completion):")
327
+ print("-" * 40)
328
+ print(full_text)
329
+ print("-" * 40)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ import argparse
334
+
335
+ main()
inference/model.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from .optimized_diffattn import MultiheadDiffAttn
6
+
7
+ # --- Tokenizer Definition ---
8
+ # Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad>
9
+ IM_START_TOKEN = "<|im_start|>"
10
+ IM_END_TOKEN = "<|im_end|>"
11
+ PAD_TOKEN = "<pad>"
12
+
13
+ SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN]
14
+ VOCAB_SIZE = 256 + len(SPECIAL_TOKENS)
15
+
16
+ # Create token to id mapping
17
+ token_to_id = {}
18
+ id_to_token = {}
19
+
20
+ for i in range(256):
21
+ token_to_id[bytes([i])] = i
22
+ id_to_token[i] = bytes([i])
23
+
24
+ for i, token_str in enumerate(SPECIAL_TOKENS):
25
+ token_id = 256 + i
26
+ token_to_id[token_str] = token_id
27
+ id_to_token[token_id] = token_str
28
+
29
+ PAD_ID = token_to_id[PAD_TOKEN]
30
+ IM_START_ID = token_to_id[IM_START_TOKEN]
31
+ IM_END_ID = token_to_id[IM_END_TOKEN]
32
+
33
+
34
+ class ByteTokenizer:
35
+ def __init__(self):
36
+ self.token_to_id = token_to_id
37
+ self.id_to_token = id_to_token
38
+ self.vocab_size = VOCAB_SIZE
39
+ self.pad_id = PAD_ID
40
+ self.im_start_id = IM_START_ID
41
+ self.im_end_id = IM_END_ID
42
+
43
+ def encode(self, text_bytes: bytes, add_special_tokens=True):
44
+ ids = [self.token_to_id[bytes([b])] for b in text_bytes]
45
+ if add_special_tokens:
46
+ return [self.im_start_id] + ids + [self.im_end_id]
47
+ return ids
48
+
49
+ def decode(self, ids: list[int]):
50
+ tokens = []
51
+ for i in ids:
52
+ token = self.id_to_token.get(i)
53
+ if token is None:
54
+ # Handle unknown token ID if necessary, or raise error
55
+ tokens.append(b"?") # Placeholder for unknown
56
+ elif isinstance(token, bytes):
57
+ tokens.append(token)
58
+ # Ignore special tokens for decoding to raw text, or handle as needed
59
+ return b"".join(tokens)
60
+
61
+
62
+ # --- RoPE Embeddings --- (Reused from previous script)
63
+ def get_rotary_embeddings(seq_len, dim_model, theta=10000.0):
64
+ if dim_model % 2 != 0:
65
+ raise ValueError(f"dim_model must be even, got {dim_model}")
66
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
67
+ div_term = torch.exp(
68
+ torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model)
69
+ )
70
+ angles = position * div_term
71
+ cos_emb = torch.cos(angles)
72
+ sin_emb = torch.sin(angles)
73
+ return cos_emb, sin_emb
74
+
75
+
76
+ # --- Model Definition ---
77
+ class FeedForward(nn.Module):
78
+ def __init__(self, embed_dim, hidden_dim, dropout=0.1):
79
+ super().__init__()
80
+ self.fc1 = nn.Linear(embed_dim, hidden_dim)
81
+ self.fc2 = nn.Linear(hidden_dim, embed_dim)
82
+ self.dropout = nn.Dropout(dropout)
83
+ self.act = nn.GELU()
84
+
85
+ def forward(self, x):
86
+ return self.fc2(self.dropout(self.act(self.fc1(x))))
87
+
88
+
89
+ class DiffTransformerBlock(nn.Module):
90
+ def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1):
91
+ super().__init__()
92
+ self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout)
93
+ self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout)
94
+ self.norm1 = nn.LayerNorm(embed_dim)
95
+ self.norm2 = nn.LayerNorm(embed_dim)
96
+ self.dropout = nn.Dropout(dropout)
97
+
98
+ def forward(self, x, rel_pos, attn_mask=None):
99
+ # Pre-norm
100
+ attn_out = self.attn(self.norm1(x), rel_pos, attn_mask)
101
+ x = x + self.dropout(attn_out)
102
+ ffn_out = self.ffn(self.norm2(x))
103
+ x = x + self.dropout(ffn_out)
104
+ return x
105
+
106
+
107
+ class DiffTransformerLLM(nn.Module):
108
+ def __init__(
109
+ self,
110
+ vocab_size,
111
+ embed_dim,
112
+ num_layers,
113
+ num_heads,
114
+ ffn_hidden_dim,
115
+ max_seq_len,
116
+ dropout=0.1,
117
+ ):
118
+ super().__init__()
119
+ self.embed_dim = embed_dim
120
+ self.max_seq_len = max_seq_len
121
+
122
+ self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
123
+ # Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions
124
+ self.dropout = nn.Dropout(dropout)
125
+
126
+ self.layers = nn.ModuleList(
127
+ [
128
+ DiffTransformerBlock(
129
+ embed_dim, num_heads, depth, ffn_hidden_dim, dropout
130
+ )
131
+ for depth in range(num_layers)
132
+ ]
133
+ )
134
+ self.norm_out = nn.LayerNorm(embed_dim)
135
+ self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
136
+
137
+ # Tie weights
138
+ self.token_embeddings.weight = self.lm_head.weight
139
+
140
+ # RoPE precomputation
141
+ # The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2
142
+ self.rope_head_dim = embed_dim // num_heads // 2
143
+ cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim)
144
+ self.register_buffer("cos_emb", cos_emb, persistent=False)
145
+ self.register_buffer("sin_emb", sin_emb, persistent=False)
146
+
147
+ def forward(self, input_ids, attn_mask=None):
148
+ batch_size, seq_len = input_ids.shape
149
+
150
+ x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim)
151
+ x = self.dropout(x)
152
+
153
+ # Ensure RoPE embeddings are on the same device *and* dtype as activations
154
+ rel_pos = (
155
+ self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype),
156
+ self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype),
157
+ )
158
+
159
+ # Create causal attention mask if not provided
160
+ if attn_mask is None:
161
+ # Standard causal mask for autoregressive decoding
162
+ # MultiheadDiffAttn expects a mask where -inf indicates masked positions
163
+ causal_mask = torch.triu(
164
+ torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
165
+ diagonal=1,
166
+ )
167
+ else:
168
+ # If a custom mask is provided (e.g., for padding), ensure it's correctly formatted
169
+ # For MultiheadDiffAttn, 0 means attend, -inf means mask.
170
+ # Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face)
171
+ # We need to convert it: (1 - attn_mask) * -inf
172
+ # However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding.
173
+ # For simplicity, let's assume the provided attn_mask is already in the correct format if not None.
174
+ # If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it.
175
+ # Let's stick to causal mask for now, padding handled by loss_fn ignore_index.
176
+ causal_mask = torch.triu(
177
+ torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
178
+ diagonal=1,
179
+ )
180
+
181
+ for layer in self.layers:
182
+ x = layer(x, rel_pos, attn_mask=causal_mask)
183
+
184
+ x = self.norm_out(x)
185
+ logits = self.lm_head(x)
186
+ return logits
187
+
188
+ def count_parameters(self):
189
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
inference/onnx_inference.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ import torch
4
+ import time
5
+ import argparse
6
+ from typing import Set, Optional
7
+ from .model import ByteTokenizer
8
+
9
+ sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]
10
+
11
+
12
+ class DRYLogitsProcessor:
13
+ """
14
+ Don't Repeat Yourself (DRY) Logits Processor that penalizes repetitive sequences.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ multiplier: float = 0.5,
20
+ base: float = 2.0,
21
+ allowed_length: int = 1,
22
+ sequence_breakers: Optional[Set[int]] = None,
23
+ range: int = 512,
24
+ ):
25
+ """
26
+ Args:
27
+ multiplier: Base penalty multiplier
28
+ base: Exponential base for penalty calculation
29
+ allowed_length: Length of sequence that's allowed to repeat without penalty
30
+ sequence_breakers: Set of token IDs that should break sequence matching
31
+ range: Number of previous tokens to consider for repetition checking
32
+ """
33
+ self.multiplier = multiplier
34
+ self.base = base
35
+ self.allowed_length = allowed_length
36
+ self.sequence_breakers = sequence_breakers or set()
37
+ self.range = range
38
+
39
+ def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
40
+ """
41
+ Apply DRY penalty to logits.
42
+
43
+ Args:
44
+ input_ids: Array of shape (batch_size, seq_len)
45
+ scores: Array of shape (vocab_size,) with logits
46
+
47
+ Returns:
48
+ Modified scores with penalties applied
49
+ """
50
+ if self.range > 0:
51
+ input_ids = input_ids[:, -self.range :]
52
+
53
+ # Convert to torch tensors for easier manipulation
54
+ input_tensor = torch.from_numpy(input_ids)
55
+ scores_tensor = torch.from_numpy(scores)
56
+
57
+ for input_ids_row in input_tensor:
58
+ # Raw integer must be extracted here to check for set membership
59
+ last_token = input_ids_row[-1].item()
60
+
61
+ if last_token in self.sequence_breakers:
62
+ continue
63
+
64
+ # Exclude the last token as it always matches
65
+ match_indices = (input_ids_row[:-1] == last_token).nonzero(as_tuple=False)
66
+
67
+ # Stores the maximum matching sequence length for each next token
68
+ match_lengths = {}
69
+
70
+ for i in match_indices.squeeze(1):
71
+ i = i.item()
72
+ if i + 1 >= len(input_ids_row):
73
+ continue
74
+
75
+ next_token = input_ids_row[i + 1].item()
76
+
77
+ if next_token in self.sequence_breakers:
78
+ continue
79
+
80
+ # We have already found that `last_token` matches at this index,
81
+ # so the match is at least of length 1.
82
+ match_length = 1
83
+
84
+ # Extend the match backwards as far as possible
85
+ while True:
86
+ j = i - match_length
87
+ if j < 0:
88
+ break # Start of input reached
89
+
90
+ if match_length + 1 > len(input_ids_row):
91
+ break # End of input reached
92
+
93
+ previous_token = input_ids_row[-(match_length + 1)].item()
94
+ if input_ids_row[j] != previous_token:
95
+ break # Start of match reached
96
+
97
+ if previous_token in self.sequence_breakers:
98
+ break # Sequence-breaking token reached
99
+
100
+ match_length += 1
101
+
102
+ # Update the maximum match length for this next token
103
+ if match_length >= match_lengths.get(next_token, 0):
104
+ match_lengths[next_token] = match_length
105
+
106
+ # Apply penalties
107
+ for token, match_length in match_lengths.items():
108
+ if match_length >= self.allowed_length:
109
+ penalty = self.multiplier * (
110
+ self.base ** (match_length - self.allowed_length)
111
+ )
112
+ scores_tensor[token] -= penalty
113
+
114
+ return scores_tensor.numpy()
115
+
116
+
117
+ def generate_text(
118
+ session,
119
+ tokenizer,
120
+ prompt,
121
+ max_new_tokens=100,
122
+ temperature=0.8,
123
+ top_k=25, # There are only 256 bytes total
124
+ stop_sequences=None,
125
+ dry_multiplier: float = 0.0, # Set to 0 to disable DRY by default
126
+ dry_base: float = 2.0,
127
+ dry_allowed_length: int = 20, # 20 since this is byte level.
128
+ dry_sequence_breakers: Optional[Set[int]] = None,
129
+ dry_range: int = 512,
130
+ ):
131
+ """Generate text using an ONNX model with DRY sampling and stop sequences."""
132
+ input_ids_list = tokenizer.encode(prompt.encode("utf-8"), add_special_tokens=False)
133
+ input_ids = np.array([input_ids_list], dtype=np.int64)
134
+
135
+ generated_token_ids = []
136
+ start_time = time.time()
137
+
138
+ for _ in range(max_new_tokens):
139
+ seq_len = input_ids.shape[1]
140
+
141
+ # Create a causal mask for the current sequence length.
142
+ causal_mask = np.triu(np.ones((1, seq_len, seq_len), dtype=np.bool_), k=1)
143
+ attn_mask = np.zeros((1, seq_len, seq_len), dtype=np.float32)
144
+ attn_mask[causal_mask] = -np.inf
145
+
146
+ ort_inputs = {"input_ids": input_ids, "attn_mask": attn_mask}
147
+
148
+ try:
149
+ ort_outs = session.run(None, ort_inputs)
150
+ except Exception as e:
151
+ print(f"ONNX Runtime Error: {e}")
152
+ # Potentially return or handle the error gracefully
153
+ return "[ONNX Error]", 0
154
+
155
+ logits = ort_outs[0][0, -1, :]
156
+
157
+ # Apply DRY penalty if enabled
158
+ if dry_multiplier > 0:
159
+ dry_processor = DRYLogitsProcessor(
160
+ multiplier=dry_multiplier,
161
+ base=dry_base,
162
+ allowed_length=dry_allowed_length,
163
+ sequence_breakers=dry_sequence_breakers,
164
+ range=dry_range,
165
+ )
166
+ logits = dry_processor(input_ids, logits)
167
+
168
+ # Apply temperature scaling
169
+ logits = logits / temperature
170
+
171
+ # Apply top-k filtering
172
+ if top_k > 0:
173
+ top_k = min(top_k, logits.shape[-1])
174
+ indices_to_remove = logits.argsort()[:-top_k]
175
+ logits[indices_to_remove] = -float("inf")
176
+
177
+ # Sample from the distribution
178
+ probs = torch.softmax(torch.from_numpy(logits), dim=-1).numpy()
179
+ next_token_id = np.random.choice(len(probs), p=probs)
180
+
181
+ if next_token_id == tokenizer.im_end_id:
182
+ break
183
+
184
+ input_ids = np.append(input_ids, [[next_token_id]], axis=1)
185
+ generated_token_ids.append(next_token_id)
186
+
187
+ if stop_sequences:
188
+ current_output = tokenizer.decode(np.array(generated_token_ids))
189
+ stop_generation = False
190
+ for seq in stop_sequences:
191
+ if current_output.endswith(seq):
192
+ stop_generation = True
193
+ # Remove the stop sequence from the generated text
194
+ generated_token_ids = generated_token_ids[: -len(seq)]
195
+ current_output = tokenizer.decode(np.array(generated_token_ids))
196
+ break
197
+ if stop_generation:
198
+ break
199
+
200
+ final_text = tokenizer.decode(np.array(generated_token_ids))
201
+ tps = len(generated_token_ids) / (time.time() - start_time)
202
+ return final_text, tps
inference/optimized_diffattn.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ # Re-use rotary embedding helper from the original codebase
9
+ from .rotary import apply_rotary_emb
10
+
11
+ # -----------------------------------------------------------------------------
12
+ # Utility helpers (copied from the original implementation)
13
+ # -----------------------------------------------------------------------------
14
+
15
+
16
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
17
+ """Efficiently repeat keys / values for GQA without allocating new memory."""
18
+ bs, n_kv_heads, slen, head_dim = x.shape
19
+ if n_rep == 1:
20
+ return x
21
+ return (
22
+ x[:, :, None, :, :]
23
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
24
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
25
+ )
26
+
27
+
28
+ def lambda_init_fn(depth: int) -> float:
29
+ """Init schedule described in the DiffAttention paper."""
30
+ return 0.8 - 0.6 * math.exp(-0.3 * depth)
31
+
32
+
33
+ # -----------------------------------------------------------------------------
34
+ # Optimised Multi-head DiffAttention implementation
35
+ # -----------------------------------------------------------------------------
36
+
37
+
38
+ class MultiheadDiffAttn(nn.Module):
39
+ """Optimised DiffAttention block.
40
+
41
+ Differences from the original implementation:
42
+ 1. Removes the dependency on Apex / FusedRMSNorm; uses native LayerNorm.
43
+ 2. Keeps all tensors on-device and works well with autocast fp16/bf16.
44
+ 3. Minimises Python-side tensor reshapes and kernel launches.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim: int,
50
+ depth: int,
51
+ num_heads: int,
52
+ num_kv_heads: Optional[int] = None,
53
+ dropout: float = 0.1,
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ self.embed_dim = embed_dim
58
+ self.num_heads = num_heads # query heads (will be doubled internally)
59
+ self.num_kv_heads = num_kv_heads or num_heads
60
+ self.n_rep = (
61
+ self.num_heads // self.num_kv_heads
62
+ ) # replication factor for keys / values (GQA)
63
+ self.attn_dropout = dropout # Store dropout rate for attention
64
+
65
+ # One half of a traditional head – DiffAttention uses pairs of heads
66
+ self.head_dim = embed_dim // self.num_heads // 2
67
+ assert (
68
+ self.head_dim * self.num_heads * 2 == embed_dim
69
+ ), "embed_dim must be divisible by num_heads * 2"
70
+ self.scaling = self.head_dim**-0.5
71
+
72
+ # Projections. We keep them separated because K/V are smaller (GQA)
73
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
74
+ self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
75
+ self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
76
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
77
+
78
+ # Add dropout for regularization
79
+ self.dropout = nn.Dropout(dropout)
80
+
81
+ # DiffAttention lambda parameters (learnable)
82
+ self.lambda_init = lambda_init_fn(depth)
83
+ self.lambda_q1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
84
+ self.lambda_k1 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
85
+ self.lambda_q2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
86
+ self.lambda_k2 = nn.Parameter(torch.randn(self.head_dim) * 0.1)
87
+
88
+ # Use standard LayerNorm which has a highly-optimised CUDA kernel
89
+ self.subln = nn.LayerNorm(2 * self.head_dim, eps=1e-5)
90
+
91
+ # ---------------------------------------------------------------------
92
+ # Forward
93
+ # ---------------------------------------------------------------------
94
+ def forward(
95
+ self,
96
+ x: torch.Tensor, # [bsz, seq_len, embed_dim]
97
+ rel_pos: tuple[torch.Tensor, torch.Tensor],
98
+ attn_mask: Optional[torch.Tensor] = None,
99
+ ) -> torch.Tensor:
100
+ bsz, seq_len, _ = x.size()
101
+
102
+ # ---- Projections --------------------------------------------------
103
+ # Projections (run inside the outer autocast context so they stay in
104
+ # the low-precision dtype and use tensor cores)
105
+ q = self.q_proj(x)
106
+ k = self.k_proj(x)
107
+ v = self.v_proj(x)
108
+
109
+ # Reshape into paired heads (2 × heads)
110
+ q = q.view(bsz, seq_len, 2 * self.num_heads, self.head_dim)
111
+ k = k.view(bsz, seq_len, 2 * self.num_kv_heads, self.head_dim)
112
+ v = v.view(bsz, seq_len, self.num_kv_heads, 2 * self.head_dim)
113
+
114
+ # Rotary position encodings (ensure dtype matches q)
115
+ cos, sin = rel_pos
116
+ cos = cos.to(dtype=q.dtype)
117
+ sin = sin.to(dtype=q.dtype)
118
+ q = apply_rotary_emb(q, cos, sin, interleaved=True)
119
+ k = apply_rotary_emb(k, cos, sin, interleaved=True)
120
+
121
+ # ---- Prepare tensors for matmul ----------------------------------
122
+ # Shape conventions follow PyTorch’s `scaled_dot_product_attention`:
123
+ # (bsz, heads, seq, head_dim)
124
+ q = q.transpose(1, 2) # [bsz, 2*heads, seq, head_dim]
125
+ k = k.transpose(1, 2) # [bsz, 2*kv_heads, seq, head_dim]
126
+ v = v.transpose(1, 2) # [bsz, kv_heads, seq, 2*head_dim]
127
+
128
+ # Replicate k/v heads when using GQA
129
+ k = repeat_kv(k, self.n_rep) # [bsz, 2*heads, seq, head_dim]
130
+ v = repeat_kv(v, self.n_rep) # [bsz, heads, seq, 2*head_dim]
131
+
132
+ # ---- Fused scaled dot-product attention (Flash / SDPA) -----------
133
+ #
134
+ # We avoid instantiating the full (seq×seq) score matrix. Instead we
135
+ # run the fused attention kernel twice (positive/negative queries) and
136
+ # combine the resulting context tensors with the λ weighting. This
137
+ # keeps everything in fp16/bf16 and leverages Blackwell’s Flash/SDPA
138
+ # path, giving ~30-80× speed-up vs. the naive implementation.
139
+ # ------------------------------------------------------------------
140
+
141
+ # Re-arrange the paired heads: [bsz, 2*H, S, D] → [bsz, H, 2, S, D]
142
+ q_pairs = q.view(bsz, 2, self.num_heads, seq_len, self.head_dim).permute(
143
+ 0, 2, 1, 3, 4
144
+ )
145
+ k_pairs = k.view(bsz, 2, self.num_heads, seq_len, self.head_dim).permute(
146
+ 0, 2, 1, 3, 4
147
+ )
148
+
149
+ q_pos, q_neg = q_pairs[:, :, 0], q_pairs[:, :, 1] # [bsz, H, S, D]
150
+ k_pos, k_neg = k_pairs[:, :, 0], k_pairs[:, :, 1]
151
+
152
+ # λ scalar (identical across heads / sequence)
153
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1)).type_as(q_pos)
154
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2)).type_as(q_pos)
155
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init # scalar tensor
156
+
157
+ # --- Fused attention (only TWO SDPA calls) -------------------------
158
+ ctx_pos = F.scaled_dot_product_attention(
159
+ q_pos, k_pos, v, dropout_p=self.attn_dropout, is_causal=True
160
+ ) # [bsz, H, S, 2*D]
161
+ ctx_neg = F.scaled_dot_product_attention(
162
+ q_neg, k_neg, v, dropout_p=self.attn_dropout, is_causal=True
163
+ ) # [bsz, H, S, 2*D]
164
+
165
+ # DiffAttention combination
166
+ attn_out = ctx_pos - lambda_full * ctx_neg # [bsz, H, S, 2*D]
167
+
168
+ # LayerNorm & residual scaling
169
+ attn_out = self.subln(attn_out) * (1.0 - self.lambda_init)
170
+
171
+ # Collapse heads and project out
172
+ attn_out = attn_out.transpose(1, 2).reshape( # [bsz, seq, heads, 2*head_dim]
173
+ bsz, seq_len, self.embed_dim
174
+ )
175
+ # Apply output projection and dropout
176
+ out = self.out_proj(attn_out)
177
+ return self.dropout(out)
inference/rotary.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+
7
+
8
+ def apply_rotary_emb_torch(
9
+ x,
10
+ cos,
11
+ sin,
12
+ interleaved=False,
13
+ inplace=False,
14
+ seqlen_offsets=0,
15
+ cu_seqlens=None,
16
+ max_seqlen=None,
17
+ ):
18
+ # Only supports the basic (not interleaved, not variable-length) case.
19
+ rotary_dim = cos.shape[1] * 2
20
+ x1 = x[..., :rotary_dim]
21
+ x2 = x[..., rotary_dim:]
22
+
23
+ # Split [even, odd] pairs
24
+ x1_1, x1_2 = x1[..., ::2], x1[..., 1::2] # (..., rotary_dim/2)
25
+
26
+ # Reshape cos/sin for broadcasting
27
+ # x: [batch, seqlen, nheads, rotary_dim]
28
+ # cos/sin: [seqlen, rotary_dim/2]
29
+ # reshape to [1, seqlen, 1, rotary_dim/2] to broadcast
30
+ cos = cos.unsqueeze(0).unsqueeze(2)
31
+ sin = sin.unsqueeze(0).unsqueeze(2)
32
+
33
+ rot_x1 = x1_1 * cos - x1_2 * sin
34
+ rot_x2 = x1_1 * sin + x1_2 * cos
35
+ # Interleave last dimension: (..., rotary_dim/2, 2) -> (..., rotary_dim)
36
+ rot_x = torch.stack([rot_x1, rot_x2], dim=-1).reshape_as(x1)
37
+ out = torch.cat([rot_x, x2], dim=-1)
38
+ return out
39
+
40
+
41
+ def apply_rotary_emb(
42
+ x,
43
+ cos,
44
+ sin,
45
+ interleaved=False,
46
+ inplace=False,
47
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
48
+ cu_seqlens: Optional[torch.Tensor] = None,
49
+ max_seqlen: Optional[int] = None,
50
+ ):
51
+ """
52
+ Arguments:
53
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
54
+ else (total_seqlen, nheads, headdim)
55
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
56
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
57
+ of 1st half and 2nd half (GPT-NeoX style).
58
+ inplace: if True, apply rotary embedding in-place.
59
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
60
+ Most commonly used in inference when we have KV cache.
61
+ cu_seqlens: (batch + 1,) or None
62
+ max_seqlen: int
63
+ Return:
64
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
65
+ else (total_seqlen, nheads, headdim)
66
+ rotary_dim must be <= headdim
67
+ Apply rotary embedding to the first rotary_dim of x.
68
+ """
69
+ # We are forcing the use of the pure PyTorch implementation (`apply_rotary_emb_torch`)
70
+ # for all devices. The custom Triton kernel (`ApplyRotaryEmb`) was causing a graph
71
+ # break in `torch.compile`, pushing expensive operations to the CPU.
72
+ # By using the pure PyTorch version, `torch.compile` can create a single, fully-optimized
73
+ # graph, which should resolve the CPU bottleneck and improve GPU utilization.
74
+ return apply_rotary_emb_torch(
75
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
76
+ )
models/small.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06daa397631d28d8b2c1eee51f0f992c4e69927cc770a20d8ed5e2c40f95cc33
3
+ size 796014268
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch>=2.2.0
test-trad.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference.inference import generate_text, list_checkpoints, load_model
2
+ import argparse
3
+ import torch
4
+ from inference.model import ByteTokenizer
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(
9
+ description="Text generation with DiffAttention LLM"
10
+ )
11
+ parser.add_argument("--checkpoint", type=str, help="Path to the checkpoint file")
12
+ parser.add_argument(
13
+ "--prompt",
14
+ type=str,
15
+ default="""<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n""",
16
+ )
17
+ parser.add_argument(
18
+ "--max_tokens",
19
+ type=int,
20
+ default=500,
21
+ help="Maximum number of tokens to generate",
22
+ )
23
+ parser.add_argument(
24
+ "--temperature", type=float, default=0.7, help="Sampling temperature"
25
+ )
26
+ parser.add_argument(
27
+ "--top_k", type=int, default=1, help="Top-k sampling parameter (0 to disable)"
28
+ )
29
+ parser.add_argument(
30
+ "--top_p",
31
+ type=float,
32
+ default=0.9,
33
+ help="Top-p (nucleus) sampling parameter (0 to disable)",
34
+ )
35
+ parser.add_argument(
36
+ "--repetition_penalty",
37
+ type=float,
38
+ default=1.0,
39
+ help="Repetition penalty (1.0 for no penalty)",
40
+ )
41
+ parser.add_argument(
42
+ "--list_checkpoints",
43
+ action="store_true",
44
+ help="List available checkpoints and exit",
45
+ )
46
+ args = parser.parse_args()
47
+
48
+ # List checkpoints if requested
49
+ if args.list_checkpoints:
50
+ print("Available checkpoints:")
51
+ checkpoints = list_checkpoints()
52
+ for i, ckpt in enumerate(checkpoints):
53
+ print(f"{i+1}. {ckpt}")
54
+ return
55
+
56
+ # If no checkpoint specified, use the latest one
57
+ if not args.checkpoint:
58
+ checkpoints = list_checkpoints()
59
+ if not checkpoints:
60
+ print("No checkpoints found. Please train the model first.")
61
+ return
62
+
63
+ # Find the latest epoch_end checkpoint
64
+ end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt]
65
+ if end_checkpoints:
66
+ latest_checkpoint = max(end_checkpoints)
67
+ else:
68
+ latest_checkpoint = max(checkpoints)
69
+
70
+ checkpoint_path = os.path.join("checkpoints", latest_checkpoint)
71
+ else:
72
+ checkpoint_path = args.checkpoint
73
+
74
+ # Set device
75
+ device = torch.device(
76
+ "cuda" if torch.cuda.is_available() and not force_CPU else "cpu"
77
+ )
78
+ print(f"Using device: {device}")
79
+
80
+ # Initialize tokenizer
81
+ tokenizer = ByteTokenizer()
82
+
83
+ # Load model
84
+ model = load_model(checkpoint_path, device)
85
+
86
+ # Generate text
87
+ print(f"\nGenerating text with prompt: '{args.prompt}'")
88
+ print(
89
+ f"Parameters: temperature={args.temperature}, top_k={args.top_k}, top_p={args.top_p}, repetition_penalty={args.repetition_penalty}"
90
+ )
91
+ print("\nGenerating...")
92
+
93
+ generated_text, full_text = generate_text(
94
+ model=model,
95
+ tokenizer=tokenizer,
96
+ prompt=args.prompt,
97
+ max_new_tokens=args.max_tokens,
98
+ temperature=args.temperature,
99
+ top_k=args.top_k,
100
+ top_p=args.top_p,
101
+ repetition_penalty=args.repetition_penalty,
102
+ device=device,
103
+ )
104
+
105
+ print("\n\nGenerated completion only:")
106
+ print("-" * 40)
107
+ print(generated_text)
108
+ print("-" * 40)
109
+
110
+ print("\nFull generated text (prompt + completion):")
111
+ print("-" * 40)
112
+ print(full_text)
113
+ print("-" * 40)
114
+
115
+
116
+ if __name__ == "__main__":
117
+ import argparse
118
+
119
+ main()
test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference.onnx_inference import generate_text
2
+ import argparse
3
+ import onnxruntime as ort
4
+ from inference.model import ByteTokenizer
5
+
6
+ sequence_breaker_strings = ["\n", ":", '"', "*", "<", ">", "|"]
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(
11
+ description="Inference with ONNX DiffTransformerLLM"
12
+ )
13
+ parser.add_argument(
14
+ "--onnx_path", type=str, default="models/small.onnx", help="Path to ONNX model"
15
+ )
16
+ parser.add_argument(
17
+ "--prompt",
18
+ type=str,
19
+ default="<|im_start|>system\nYou are a helpful chatbot<|im_end|>\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\n",
20
+ help="Prompt for the model",
21
+ )
22
+ parser.add_argument("--max_tokens", type=int, default=100, help="Max new tokens")
23
+ parser.add_argument(
24
+ "--temperature", type=float, default=0.7, help="Temperature for sampling"
25
+ )
26
+ parser.add_argument("--top_k", type=int, default=1, help="Top-k for sampling")
27
+ parser.add_argument(
28
+ "--stop_sequence", type=str, action="append", help="Stop sequence(s)"
29
+ )
30
+ # DRY sampling args
31
+ parser.add_argument(
32
+ "--dry_range", type=int, default=1024, help="Range for DRY sampling"
33
+ )
34
+ parser.add_argument(
35
+ "--dry_allowed_length",
36
+ type=int,
37
+ default=17,
38
+ help="Allowed repeat length for DRY sampling",
39
+ )
40
+ parser.add_argument(
41
+ "--dry_base", type=float, default=1.1, help="Base for DRY penalty"
42
+ )
43
+ parser.add_argument(
44
+ "--dry_multiplier", type=float, default=0.0, help="Multiplier for DRY penalty"
45
+ )
46
+
47
+ args = parser.parse_args()
48
+
49
+ print(f"Loading ONNX model from {args.onnx_path}")
50
+ session = ort.InferenceSession(args.onnx_path, providers=["CPUExecutionProvider"])
51
+ tokenizer = ByteTokenizer()
52
+
53
+ sequence_breaker_ids = {tokenizer.im_start_id, tokenizer.im_end_id}
54
+ for s in sequence_breaker_strings:
55
+ # These are single-byte tokens, so encode will return a list with one ID
56
+ sequence_breaker_ids.add(tokenizer.encode(s.encode("utf-8"))[0])
57
+
58
+ print(f"Prompt: {args.prompt}")
59
+ print("--- Output ---")
60
+ generated_text, tps = generate_text(
61
+ session,
62
+ tokenizer,
63
+ args.prompt,
64
+ max_new_tokens=args.max_tokens,
65
+ temperature=args.temperature,
66
+ top_k=args.top_k,
67
+ stop_sequences=["<|im_end|>".encode("utf-8")],
68
+ dry_sequence_breakers=sequence_breaker_ids,
69
+ dry_range=args.dry_range,
70
+ dry_allowed_length=args.dry_allowed_length,
71
+ dry_base=args.dry_base,
72
+ dry_multiplier=args.dry_multiplier,
73
+ )
74
+ print(generated_text)
75
+ print(generated_text.decode("utf-8", "ignore"))
76
+ print("--------------")
77
+ print(f"\nPerformance: {tps:.2f} tokens/second")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()