File size: 19,912 Bytes
72268ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 |
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "cpu_func/rep_penalty.h"
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "cuda_func/q4_matrix.cuh"
#include "cuda_func/q4_matmul.cuh"
#include "cuda_func/column_remap.cuh"
#include "cuda_func/rms_norm.cuh"
#include "cuda_func/rope.cuh"
#include "cuda_func/half_matmul.cuh"
#include "cuda_func/q4_attn.cuh"
#include "cuda_func/q4_mlp.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
int fused_mlp_thd,
int sdp_thd,
bool matmul_fused_remap,
bool rmsnorm_no_half2,
bool rope_no_half2,
bool matmul_no_half2,
bool silu_no_half2,
bool concurrent_streams
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.fused_mlp_thd = fused_mlp_thd;
tuningParams.sdp_thd = sdp_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.rmsnorm_no_half2 = rmsnorm_no_half2;
tuningParams.rope_no_half2 = rope_no_half2;
tuningParams.matmul_no_half2 = matmul_no_half2;
tuningParams.silu_no_half2 = silu_no_half2;
tuningParams.concurrent_streams = concurrent_streams;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_mlp,
torch::Tensor temp_zeros_float,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
int max_zeros_float = temp_zeros_float.size(-1);
prepare_buffers_cuda
(
device_index,
(half*) temp_state.data_ptr(),
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_mlp.data_ptr(),
(float*) temp_zeros_float.data_ptr(),
(half*) temp_dq.data_ptr(),
max_zeros_float
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Matmul half @ quant + half @ half @ half -> half
// Same as q4_matmul, but adds (x @ lora_A) @ lora_B to the result
void q4_matmul_lora
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out,
torch::Tensor lora_A,
torch::Tensor lora_B,
torch::Tensor lora_temp // empty tensor, shape of (x @ lora_A)
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK_SHAPES(x, 0, lora_temp, 0, 1);
TORCH_CHECK_SHAPES(x, 1, lora_A, 0, 1);
TORCH_CHECK_SHAPES(lora_A, 1, lora_B, 0, 1);
TORCH_CHECK_SHAPES(lora_B, 1, out, 1, 1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// lora_temp = x @ lora_A
half_matmul_cublas_cuda
(
&tuningParams,
(half*) x.data_ptr(),
(half*) lora_A.data_ptr(),
(half*) lora_temp.data_ptr(),
x.size(0),
x.size(1),
lora_A.size(1),
handle
);
// out = lora_temp @ lora_B
half_matmul_cublas_cuda
(
&tuningParams,
(half*) lora_temp.data_ptr(),
(half*) lora_B.data_ptr(),
(half*) out.data_ptr(),
lora_temp.size(0),
lora_temp.size(1),
lora_B.size(1),
handle
);
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
true
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
handle,
true
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
// Matmul half @ half -> half, custom kernel
void half_matmul
(
torch::Tensor x,
torch::Tensor w,
torch::Tensor out
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(w, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 1, w, 0, 1);
int height = x.size(0);
int dim = x.size(1);
int width = w.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
half_matmul_cuda
(
(half*) x.data_ptr(),
(half*) w.data_ptr(),
(half*) out.data_ptr(),
height,
dim,
width
);
}
// Matmul half @ half -> half using cuBLAS
void half_matmul_cublas
(
torch::Tensor x,
torch::Tensor w,
torch::Tensor out
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(w, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 1, w, 0, 1);
int height = x.size(0);
int dim = x.size(1);
int width = w.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
half_matmul_cublas_cuda
(
&tuningParams,
(half*) x.data_ptr(),
(half*) w.data_ptr(),
(half*) out.data_ptr(),
height,
dim,
width,
handle
);
}
// Llama self attention (WIP)
void q4_attn
(
torch::Tensor x, // shape == (bsz, q_len, dim)
torch::Tensor rms_norm_weight, // shape == (x.shape[1],) == (dim,)
float epsilon,
torch::Tensor query_states, // shape == (bsz, q_len, dim)
torch::Tensor key_states, // shape == (bsz, q_len, dim)
torch::Tensor value_states, // shape == (bsz, q_len, dim)
uintptr_t q_proj,
uintptr_t k_proj,
uintptr_t v_proj,
torch::Tensor sin,
torch::Tensor cos,
int q_len,
int past_len,
int num_heads,
int num_kv_heads,
int head_dim,
torch::Tensor key_cache,
torch::Tensor value_cache,
int max_seq_len,
torch::Tensor q_a,
torch::Tensor q_b,
torch::Tensor k_a,
torch::Tensor k_b,
torch::Tensor v_a,
torch::Tensor v_b,
torch::Tensor lora_temp
)
{
TORCH_CHECK_DTYPE(query_states, kHalf);
TORCH_CHECK_DTYPE(key_states, kHalf);
int bsz = query_states.size(0);
int dim = query_states.size(2);
torch::Device device = x.device();
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
cudaStream_t current_stream = at::cuda::getCurrentCUDAStream().stream();
int q_rank = q_a.device().is_meta() ? 0 : q_a.size(1);
int k_rank = k_a.device().is_meta() ? 0 : k_a.size(1);
int v_rank = v_a.device().is_meta() ? 0 : v_a.size(1);
q4_attn_cuda
(
&tuningParams,
current_stream,
at::cuda::getCurrentCUDABlasHandle(),
(half*) x.data_ptr(),
(half*) rms_norm_weight.data_ptr(),
epsilon,
(half*) query_states.data_ptr(),
(half*) key_states.data_ptr(),
(half*) value_states.data_ptr(),
reinterpret_cast<Q4Matrix*>(q_proj),
reinterpret_cast<Q4Matrix*>(k_proj),
reinterpret_cast<Q4Matrix*>(v_proj),
(half*) sin.data_ptr(),
(half*) cos.data_ptr(),
bsz,
q_len,
dim,
head_dim,
num_heads,
num_kv_heads,
past_len,
(half*) key_cache.data_ptr(),
(half*) value_cache.data_ptr(),
q_rank ? (half*) q_a.data_ptr() : NULL,
q_rank ? (half*) q_b.data_ptr() : NULL,
q_rank,
k_rank ? (half*) k_a.data_ptr() : NULL,
k_rank ? (half*) k_b.data_ptr() : NULL,
k_rank,
v_rank ? (half*) v_a.data_ptr() : NULL,
v_rank ? (half*) v_b.data_ptr() : NULL,
v_rank,
lora_temp.device().is_meta() ? NULL : (half*) lora_temp.data_ptr(),
max_seq_len,
device_index
);
}
void q4_attn_2
(
torch::Tensor x,
torch::Tensor attn_output,
uintptr_t o_proj,
torch::Tensor o_a,
torch::Tensor o_b,
torch::Tensor lora_temp
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(attn_output, kHalf);
const at::cuda::OptionalCUDAGuard device_guard(x.device());
int height = x.size(0);
int o_rank = o_a.device().is_meta() ? 0 : o_a.size(1);
q4_attn_2_cuda
(
&tuningParams,
at::cuda::getCurrentCUDABlasHandle(),
(half*) x.data_ptr(),
(half*) attn_output.data_ptr(),
reinterpret_cast<Q4Matrix*>(o_proj),
height,
o_rank ? (half*) o_a.data_ptr() : NULL,
o_rank ? (half*) o_b.data_ptr() : NULL,
o_rank,
lora_temp.device().is_meta() ? NULL : (half*) lora_temp.data_ptr()
);
}
// Llama MLP
void q4_mlp
(
torch::Tensor x, // shape == (height, dim)
torch::Tensor rms_norm_weight, // shape == (x.shape[1],) == (dim,)
float epsilon,
uintptr_t gate,
uintptr_t up,
uintptr_t down,
torch::Tensor gate_a,
torch::Tensor gate_b,
torch::Tensor up_a,
torch::Tensor up_b,
torch::Tensor down_a,
torch::Tensor down_b,
torch::Tensor lora_temp
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(rms_norm_weight, kHalf);
int height = x.size(0);
int dim = x.size(1);
torch::Device device = x.device();
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
int gate_rank = gate_a.device().is_meta() ? 0 : gate_a.size(1);
int up_rank = gate_a.device().is_meta() ? 0 : up_a.size(1);
int down_rank = gate_a.device().is_meta() ? 0 : down_a.size(1);
q4_mlp_cuda
(
&tuningParams,
(half*) x.data_ptr(),
(half*) rms_norm_weight.data_ptr(),
epsilon,
reinterpret_cast<Q4Matrix*>(gate),
reinterpret_cast<Q4Matrix*>(up),
reinterpret_cast<Q4Matrix*>(down),
height,
dim,
gate_rank ? (half*) gate_a.data_ptr() : NULL,
gate_rank ? (half*) gate_b.data_ptr() : NULL,
gate_rank,
up_rank ? (half*) up_a.data_ptr() : NULL,
up_rank ? (half*) up_b.data_ptr() : NULL,
up_rank,
down_rank ? (half*) down_a.data_ptr() : NULL,
down_rank ? (half*) down_b.data_ptr() : NULL,
down_rank,
lora_temp.device().is_meta() ? NULL : (half*) lora_temp.data_ptr(),
at::cuda::getCurrentCUDABlasHandle(),
device_index
);
}
// RMS layernorm
void rms_norm
(
torch::Tensor x,
torch::Tensor w,
torch::Tensor out,
float epsilon
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(w, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 1, w, 0, 1);
TORCH_CHECK_SHAPES(x, 1, w, 0, 1);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK_SHAPES(x, 1, out, 1, 1);
int rows = x.size(0);
int dim = x.size(1);
torch::Device device = x.device();
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
rms_norm_cuda
(
&tuningParams,
(half*) x.data_ptr(),
(half*) w.data_ptr(),
(half*) out.data_ptr(),
epsilon,
rows,
dim,
device_index
);
}
// RoPE rotary positional embeddings
void rope_
(
torch::Tensor x,
torch::Tensor sin,
torch::Tensor cos,
int past_len,
int num_heads,
int head_dim
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(sin, kHalf);
TORCH_CHECK_DTYPE(cos, kHalf);
TORCH_CHECK(head_dim == cos.size(-1), "cos table does not match head_dim");
TORCH_CHECK(head_dim == sin.size(-1), "sin table does not match head_dim");
int bsz = x.size(0);
int rows_per_batch = x.numel() / head_dim / bsz;
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
rope_cuda
(
&tuningParams,
(half*) x.data_ptr(),
(half*) sin.data_ptr(),
(half*) cos.data_ptr(),
bsz,
rows_per_batch,
head_dim,
num_heads,
past_len
);
}
// Repetition penalty (CPU)
void rep_penalty
(
torch::Tensor sequence,
torch::Tensor rep_mask,
float penalty_max,
int sustain,
int decay
)
{
TORCH_CHECK_DTYPE(sequence, kLong);
TORCH_CHECK_DTYPE(rep_mask, kFloat);
int vocab_size = rep_mask.size(0);
int seq_len = sequence.size(-1);
// TODO: Support batch size
rep_penalty_cpu
(
vocab_size,
(uint64_t*) sequence.data_ptr(),
(float*) rep_mask.data_ptr(),
penalty_max,
sustain,
decay,
seq_len
);
}
void apply_rep_penalty
(
torch::Tensor sequence,
float penalty_max,
int sustain,
int decay,
torch::Tensor logits
)
{
TORCH_CHECK_DTYPE(sequence, kLong);
TORCH_CHECK_DTYPE(logits, kFloat);
TORCH_CHECK_SHAPES(sequence, 0, logits, 0, 1);
int vocab_size = logits.size(-1);
int bsz = sequence.size(0);
int seq_len = sequence.size(-1);
for (int i = 0; i < bsz; i++)
{
apply_rep_penalty_cpu
(
vocab_size,
((uint64_t*) sequence.data_ptr()) + i * seq_len,
penalty_max,
sustain,
decay,
seq_len,
((float*) logits.data_ptr()) + i * vocab_size
);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
m.def("q4_matmul_lora", &q4_matmul_lora, "q4_matmul_lora");
m.def("q4_attn", &q4_attn, "q4_attn");
m.def("q4_attn_2", &q4_attn_2, "q4_attn_2");
m.def("q4_mlp", &q4_mlp, "q4_mlp");
m.def("column_remap", &column_remap, "column_remap");
m.def("rms_norm", &rms_norm, "rms_norm");
m.def("rope_", &rope_, "rope_");
m.def("half_matmul", &half_matmul, "half_matmul");
m.def("half_matmul_cublas", &half_matmul_cublas, "half_matmul_cublas");
m.def("rep_penalty", &rep_penalty, "rep_penalty");
m.def("apply_rep_penalty", &apply_rep_penalty, "apply_rep_penalty");
}
|