|
|
|
from tvm.script import ir as I |
|
from tvm.script import tir as T |
|
|
|
|
|
|
|
|
|
|
|
@I.ir_module |
|
class Module: |
|
@T.prim_func |
|
def decode1(A: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): |
|
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) |
|
|
|
decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") |
|
for i, j in T.grid(T.int64(4096), T.int64(4096)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) |
|
T.writes(decode[v_i, v_j]) |
|
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] |
|
for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): |
|
with T.block("T_transpose"): |
|
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) |
|
T.reads(decode[v_ax1, v_ax0]) |
|
T.writes(T_transpose[v_ax0, v_ax1]) |
|
T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] |
|
|
|
@T.prim_func |
|
def decode2(A: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), B: T.Buffer((T.int64(103), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): |
|
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) |
|
|
|
decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") |
|
for i, j in T.grid(T.int64(4096), T.int64(11008)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) |
|
T.writes(decode[v_i, v_j]) |
|
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] |
|
for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): |
|
with T.block("T_transpose"): |
|
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) |
|
T.reads(decode[v_ax1, v_ax0]) |
|
T.writes(T_transpose[v_ax0, v_ax1]) |
|
T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] |
|
|
|
@T.prim_func |
|
def decode3(A: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(276), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): |
|
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) |
|
|
|
decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") |
|
for i, j in T.grid(T.int64(11008), T.int64(4096)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j]) |
|
T.writes(decode[v_i, v_j]) |
|
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j] |
|
for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): |
|
with T.block("T_transpose"): |
|
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) |
|
T.reads(decode[v_ax1, v_ax0]) |
|
T.writes(T_transpose[v_ax0, v_ax1]) |
|
T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] |
|
|
|
@T.prim_func |
|
def divide1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): |
|
T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)}) |
|
|
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): |
|
with T.block("T_divide"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(A[v_ax0, v_ax1, v_ax2], B[()]) |
|
T.writes(T_divide[v_ax0, v_ax1, v_ax2]) |
|
T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()] |
|
|
|
@T.prim_func |
|
def fused_decode4_fused_matmul1_cast2(lv1576: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv1577: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16") |
|
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16") |
|
for i, j in T.grid(T.int64(4096), T.int64(32000)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(lv1576[v_i // T.int64(5), v_j], lv1577[v_i // T.int64(40), v_j]) |
|
T.writes(var_decode_intermediate[v_i, v_j]) |
|
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1576[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1577[v_i // T.int64(40), v_j] |
|
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)): |
|
with T.block("matmul"): |
|
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
|
T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) |
|
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
with T.init(): |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] |
|
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): |
|
with T.block("compute"): |
|
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) |
|
T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) |
|
p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
|
|
@T.prim_func |
|
def fused_decode5_fused_matmul6_add1(lv1605: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv197: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") |
|
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") |
|
for i, j in T.grid(T.int64(4096), T.int64(4096)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(lv1605[v_i // T.int64(5), v_j], lv1606[v_i // T.int64(40), v_j]) |
|
T.writes(var_decode_intermediate[v_i, v_j]) |
|
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1605[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j] |
|
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): |
|
with T.block("matmul"): |
|
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
|
T.reads(lv197[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) |
|
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
with T.init(): |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv197[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] |
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
|
with T.block("T_add"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(lv1581[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) |
|
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) |
|
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1581[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] |
|
|
|
@T.prim_func |
|
def fused_decode5_matmul6(lv1587: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") |
|
for i, j in T.grid(T.int64(4096), T.int64(4096)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(lv1587[v_i // T.int64(5), v_j], lv1588[v_i // T.int64(40), v_j]) |
|
T.writes(var_decode_intermediate[v_i, v_j]) |
|
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1587[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j] |
|
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): |
|
with T.block("matmul"): |
|
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
|
T.reads(lv1583[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) |
|
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
with T.init(): |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1583[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] |
|
|
|
@T.prim_func |
|
def fused_decode6_fused_matmul8_multiply1(lv1617: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") |
|
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") |
|
for i, j in T.grid(T.int64(4096), T.int64(11008)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(lv1617[v_i // T.int64(5), v_j], lv1618[v_i // T.int64(40), v_j]) |
|
T.writes(var_decode_intermediate[v_i, v_j]) |
|
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1617[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j] |
|
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): |
|
with T.block("matmul"): |
|
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
|
T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) |
|
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
with T.init(): |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] |
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): |
|
with T.block("T_multiply"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) |
|
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) |
|
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2] |
|
|
|
@T.prim_func |
|
def fused_decode6_fused_matmul8_silu1(lv1611: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") |
|
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") |
|
compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16") |
|
for i, j in T.grid(T.int64(4096), T.int64(11008)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(lv1611[v_i // T.int64(5), v_j], lv1612[v_i // T.int64(40), v_j]) |
|
T.writes(var_decode_intermediate[v_i, v_j]) |
|
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1611[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j] |
|
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)): |
|
with T.block("matmul"): |
|
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
|
T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) |
|
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
with T.init(): |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] |
|
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): |
|
with T.block("compute"): |
|
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) |
|
T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
T.writes(compute[v_i0, v_i1, v_i2]) |
|
compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)): |
|
with T.block("T_multiply"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) |
|
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) |
|
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] |
|
|
|
@T.prim_func |
|
def fused_decode7_fused_matmul9_add1(lv1623: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv200: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv198: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") |
|
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") |
|
for i, j in T.grid(T.int64(11008), T.int64(4096)): |
|
with T.block("decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(lv1623[v_i // T.int64(5), v_j], lv1624[v_i // T.int64(40), v_j]) |
|
T.writes(var_decode_intermediate[v_i, v_j]) |
|
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1623[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j] |
|
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): |
|
with T.block("matmul"): |
|
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) |
|
T.reads(lv200[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) |
|
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) |
|
with T.init(): |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) |
|
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv200[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] |
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
|
with T.block("T_add"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(lv198[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) |
|
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) |
|
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv198[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] |
|
|
|
@T.prim_func |
|
def fused_reshape7_squeeze1(lv1591: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16") |
|
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): |
|
with T.block("T_reshape"): |
|
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
|
T.reads(lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]) |
|
T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) |
|
var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)] |
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)): |
|
with T.block("T_squeeze"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]) |
|
T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2]) |
|
var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2] |
|
|
|
@T.prim_func |
|
def fused_transpose6_reshape8(lv1616: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"tir.noalias": T.bool(True)}) |
|
|
|
var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16") |
|
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): |
|
with T.block("T_transpose"): |
|
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
|
T.reads(lv1616[v_ax0, v_ax2, v_ax1, v_ax3]) |
|
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) |
|
var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1616[v_ax0, v_ax2, v_ax1, v_ax3] |
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
|
with T.block("T_reshape"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]) |
|
T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2]) |
|
var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)] |
|
|
|
@T.prim_func |
|
def reshape5(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")): |
|
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)}) |
|
|
|
for ax0 in range(T.int64(1)): |
|
with T.block("T_reshape"): |
|
v_ax0 = T.axis.spatial(T.int64(1), ax0) |
|
T.reads(A[T.int64(0), T.int64(0)]) |
|
T.writes(T_reshape[v_ax0]) |
|
T_reshape[v_ax0] = A[T.int64(0), T.int64(0)] |
|
|
|
@T.prim_func |
|
def reshape6(A: T.Buffer((T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) |
|
|
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
|
with T.block("T_reshape"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(A[T.int64(0), v_ax2 % T.int64(4096)]) |
|
T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) |
|
T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(4096)] |
|
|
|
@T.prim_func |
|
def reshape7(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")): |
|
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) |
|
|
|
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): |
|
with T.block("T_reshape"): |
|
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
|
T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]) |
|
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) |
|
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)] |
|
|
|
@T.prim_func |
|
def rms_norm1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), B: T.Buffer((T.int64(4096),), "float16"), rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) |
|
|
|
Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1))) |
|
for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
|
with T.block("Ared_temp"): |
|
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) |
|
T.reads(A[v_bsz, v_i, v_k]) |
|
T.writes(Ared_temp[v_bsz, v_i]) |
|
with T.init(): |
|
Ared_temp[v_bsz, v_i] = T.float32(0) |
|
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k]) |
|
for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
|
with T.block("rms_norm"): |
|
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) |
|
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) |
|
T.writes(rms_norm[v_bsz, v_i, v_k]) |
|
rms_norm[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) |
|
|
|
@T.prim_func |
|
def rotary_embedding1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), n: T.int64): |
|
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) |
|
|
|
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): |
|
with T.block("rotary"): |
|
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
|
T.reads(B[n + v_i1 - T.int64(1), v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[n + v_i1 - T.int64(1), v_i3]) |
|
T.writes(rotary[v_i0, v_i1, v_i2, v_i3]) |
|
rotary[v_i0, v_i1, v_i2, v_i3] = B[n + v_i1 - T.int64(1), v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[n + v_i1 - T.int64(1), v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1)) |
|
|
|
@T.prim_func |
|
def slice1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), slice: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)}) |
|
|
|
for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): |
|
with T.block("slice"): |
|
v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k]) |
|
T.reads(A[v_i, T.int64(0), v_k]) |
|
T.writes(slice[v_i, v_j, v_k]) |
|
slice[v_i, v_j, v_k] = A[v_i, T.int64(0), v_k] |
|
|
|
@T.prim_func |
|
def softmax1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): |
|
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) |
|
|
|
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1))) |
|
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000))) |
|
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1))) |
|
for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)): |
|
with T.block("T_softmax_maxelem"): |
|
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) |
|
T.reads(A[v_i0, v_i1, v_k]) |
|
T.writes(T_softmax_maxelem[v_i0, v_i1]) |
|
with T.init(): |
|
T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38) |
|
T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k]) |
|
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): |
|
with T.block("T_softmax_exp"): |
|
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) |
|
T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) |
|
T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) |
|
T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) |
|
for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)): |
|
with T.block("T_softmax_expsum"): |
|
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) |
|
T.reads(T_softmax_exp[v_i0, v_i1, v_k]) |
|
T.writes(T_softmax_expsum[v_i0, v_i1]) |
|
with T.init(): |
|
T_softmax_expsum[v_i0, v_i1] = T.float32(0) |
|
T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] |
|
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)): |
|
with T.block("T_softmax_norm"): |
|
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) |
|
T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) |
|
T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) |
|
T.block_attr({"axis": 2}) |
|
T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] |
|
|
|
@T.prim_func |
|
def squeeze1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")): |
|
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)}) |
|
|
|
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)): |
|
with T.block("T_squeeze"): |
|
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) |
|
T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) |
|
T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) |
|
T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] |
|
|
|
@T.prim_func |
|
def take_decode1(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), C: T.Buffer((T.int64(1),), "int32"), take_decode: T.Buffer((T.int64(1), T.int64(4096)), "float16")): |
|
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) |
|
|
|
for i, j in T.grid(T.int64(1), T.int64(4096)): |
|
with T.block("take_decode"): |
|
v_i, v_j = T.axis.remap("SS", [i, j]) |
|
T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)]) |
|
T.writes(take_decode[v_i, v_j]) |
|
take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)] |
|
|
|
@T.prim_func |
|
def transpose5(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): |
|
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) |
|
|
|
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128)): |
|
with T.block("T_transpose"): |
|
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
|
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) |
|
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) |
|
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] |
|
|
|
|