File size: 38,596 Bytes
f7bf4f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 |
from tvm.script import ir as I
from tvm.script import tir as T
# fmt: off
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def NT_matmul(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
NT_matmul_1 = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
T.writes(NT_matmul_1[v_i0, v_i1, v_i2])
with T.init():
NT_matmul_1[v_i0, v_i1, v_i2] = T.float16(0)
NT_matmul_1[v_i0, v_i1, v_i2] = NT_matmul_1[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k]
@T.prim_func
def extend_te(var_A: T.handle, var_concat_te: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16")
m = T.int64()
concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16")
# with T.block("root"):
for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m):
with T.block("concat_te"):
v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j])
T.reads(A[v_b, v__, v_i, v_j + n - m])
T.writes(concat_te[v_b, v__, v_i, v_j])
concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m])
@T.prim_func
def full(var_T_full: T.handle):
T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
n = T.int64()
T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n):
with T.block("T_full"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads()
T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504)
@T.prim_func
def fused_NT_matmul1_divide_maximum_minimum_cast(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
m = T.int64()
lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16")
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m))
# with T.block("root"):
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("T_divide"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("T_maximum"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("T_minimum"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
@T.prim_func
def fused_NT_matmul2_multiply(p_lv45: T.handle, lv45: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_lv50: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv45_1 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
lv50 = T.match_buffer(p_lv50, (T.int64(1), n, T.int64(11008)), "float16")
var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
# with T.block("root"):
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv45_1[v_i0, v_i1, v_k], lv45[v_i2, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45_1[v_i0, v_i1, v_k] * lv45[v_i2, v_k]
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(lv50[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv50[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
@T.prim_func
def fused_NT_matmul2_silu(p_lv45: T.handle, lv38: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
# with T.block("root"):
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv45[v_i0, v_i1, v_k], lv38[v_i2, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv38[v_i2, v_k]
for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)):
with T.block("compute"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
T.writes(compute[v_i0, v_i1, v_i2])
compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
@T.prim_func
def fused_NT_matmul3_add(p_lv51: T.handle, lv52: T.Buffer((T.int64(4096), T.int64(11008)), "float16"), p_lv44: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(11008)), "float16")
lv44 = T.match_buffer(p_lv44, (T.int64(1), n, T.int64(4096)), "float16")
var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv51[v_i0, v_i1, v_k], lv52[v_i2, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv52[v_i2, v_k]
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(lv44[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv44[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
@T.prim_func
def fused_NT_matmul4_divide2_maximum1_minimum1_cast3(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n))
# with T.block("root"):
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_divide"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_maximum"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_minimum"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
@T.prim_func
def fused_NT_matmul_add(p_lv41: T.handle, lv31: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_lv2: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16")
lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16")
var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
with T.block("NT_matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv41[v_i0, v_i1, v_k], lv31[v_i2, v_k])
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv41[v_i0, v_i1, v_k] * lv31[v_i2, v_k]
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
@T.prim_func
def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64):
T.func_attr({"tir.noalias": T.bool(True)})
var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16")
# with T.block("root"):
var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16")
for i, j in T.grid(n, n):
with T.block("make_diag_mask_te"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads()
T.writes(var_make_diag_mask_te_intermediate[v_i, v_j])
var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n):
with T.block("T_broadcast_to"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3])
T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3]
@T.prim_func
def fused_softmax2_cast4(p_lv1613: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n = T.int64()
lv1613 = T.match_buffer(p_lv1613, (T.int64(1), T.int64(32), T.int64(1), n))
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
# with T.block("root"):
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_softmax_maxelem"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv1613[v_i0, v_i1, v_i2, v_k])
T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
with T.init():
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1613[v_i0, v_i1, v_i2, v_k])
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_softmax_exp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(lv1613[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1613[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_softmax_expsum"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
with T.init():
T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("T_softmax_norm"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
T.block_attr({"axis": 3})
var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
@T.prim_func
def fused_softmax_cast1(p_lv36: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n, m = T.int64(), T.int64()
lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, m))
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16")
# with T.block("root"):
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n))
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n))
var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("T_softmax_maxelem"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(lv36[v_i0, v_i1, v_i2, v_k])
T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
with T.init():
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv36[v_i0, v_i1, v_i2, v_k])
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("T_softmax_exp"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(lv36[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv36[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("T_softmax_expsum"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
with T.init():
T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("T_softmax_norm"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
T.block_attr({"axis": 3})
var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
@T.prim_func
def matmul(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
n, m = T.int64(), T.int64()
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16")
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
matmul_1 = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
# with T.block("root"):
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
T.writes(matmul_1[v_i0, v_i1, v_i2, v_i3])
with T.init():
matmul_1[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
matmul_1[v_i0, v_i1, v_i2, v_i3] = matmul_1[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
@T.prim_func
def matmul7(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
# with T.block("root"):
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
with T.init():
matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
@T.prim_func
def reshape(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n), "int32")
T_reshape = T.match_buffer(var_T_reshape, (n,), "int32")
# with T.block("root"):
for ax0 in range(n):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(n, ax0)
T.reads(A[T.int64(0), v_ax0 % n])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n]
@T.prim_func
def reshape1(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (n, T.int64(4096)), "float16")
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]
@T.prim_func
def reshape2(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
@T.prim_func
def reshape3(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
m = T.int64()
A = T.match_buffer(var_A, (m, T.int64(32), T.int64(128)), "float16")
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), m, T.int64(32), T.int64(128)), "float16")
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), m, T.int64(32), T.int64(128)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)]
@T.prim_func
def reshape4(var_A: T.handle, var_T_reshape: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
@T.prim_func
def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
Ared_temp = T.alloc_buffer((T.int64(1), n))
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("Ared_temp"):
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
T.reads(A[v_bsz, v_i, v_k])
T.writes(Ared_temp[v_bsz, v_i])
with T.init():
Ared_temp[v_bsz, v_i] = T.float32(0)
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
with T.block("rms_norm"):
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
T.writes(rms_norm_1[v_bsz, v_i, v_k])
rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
@T.prim_func
def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), var_rotary: T.handle, m: T.int64):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
# with T.block("root"):
for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
with T.block("rotary"):
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(B[m + v_i1 - n, v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[m + v_i1 - n, v_i3])
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[m + v_i1 - n, v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
@T.prim_func
def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
# with T.block("root"):
for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
with T.block("slice"):
v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
T.reads(A[v_i, n - T.int64(1), v_k])
T.writes(slice_1[v_i, v_j, v_k])
slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k]
@T.prim_func
def squeeze(var_A: T.handle, var_T_squeeze: T.handle):
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(128)), "float16")
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(128)):
with T.block("T_squeeze"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
@T.prim_func
def take_decode(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), var_C: T.handle, var_take_decode: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
n = T.int64()
C = T.match_buffer(var_C, (n,), "int32")
take_decode_1 = T.match_buffer(var_take_decode, (n, T.int64(4096)), "float16")
# with T.block("root"):
for i, j in T.grid(n, T.int64(4096)):
with T.block("take_decode"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
T.writes(take_decode_1[v_i, v_j])
take_decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
@T.prim_func
def transpose(var_A: T.handle, var_T_transpose: T.handle):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
@T.prim_func
def transpose1(var_A: T.handle, var_T_transpose: T.handle):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
n = T.int64()
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
# fmt: on
|