Spaces:
Sleeping
Sleeping
Deleting prior code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CONTRIBUTING.md +0 -77
- Dockerfile +0 -121
- Dockerfile.rocm +0 -88
- LICENSE +0 -201
- MANIFEST.in +0 -4
- README.md +0 -8
- benchmarks/README.md +0 -8
- benchmarks/benchmark_latency.py +0 -139
- benchmarks/benchmark_serving.py +0 -249
- benchmarks/benchmark_throughput.py +0 -328
- benchmarks/kernels/benchmark_paged_attention.py +0 -196
- benchmarks/launch_tgi_server.sh +0 -16
- csrc/activation_kernels.cu +0 -118
- csrc/attention/attention_dtypes.h +0 -7
- csrc/attention/attention_generic.cuh +0 -64
- csrc/attention/attention_kernels.cu +0 -951
- csrc/attention/attention_utils.cuh +0 -56
- csrc/attention/dtype_bfloat16.cuh +0 -451
- csrc/attention/dtype_float16.cuh +0 -502
- csrc/attention/dtype_float32.cuh +0 -273
- csrc/attention/dtype_fp8_e5m2.cuh +0 -35
- csrc/cache.h +0 -36
- csrc/cache_kernels.cu +0 -474
- csrc/cuda_compat.h +0 -28
- csrc/cuda_utils.h +0 -10
- csrc/cuda_utils_kernels.cu +0 -35
- csrc/custom_all_reduce.cu +0 -148
- csrc/custom_all_reduce.cuh +0 -562
- csrc/custom_all_reduce_test.cu +0 -284
- csrc/dispatch_utils.h +0 -37
- csrc/layernorm_kernels.cu +0 -120
- csrc/moe_align_block_size_kernels.cu +0 -108
- csrc/ops.h +0 -130
- csrc/pos_encoding_kernels.cu +0 -130
- csrc/punica/LICENSE +0 -217
- csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +0 -4
- csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu +0 -4
- csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu +0 -4
- csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu +0 -4
- csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +0 -4
- csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu +0 -4
- csrc/punica/bgmv/bgmv_config.h +0 -59
- csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu +0 -4
- csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu +0 -4
- csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu +0 -4
- csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +0 -4
- csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu +0 -4
- csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +0 -4
- csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +0 -4
- csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu +0 -4
CONTRIBUTING.md
DELETED
@@ -1,77 +0,0 @@
|
|
1 |
-
# Contributing to vLLM
|
2 |
-
|
3 |
-
Thank you for your interest in contributing to vLLM!
|
4 |
-
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
|
5 |
-
There are several ways you can contribute to the project:
|
6 |
-
|
7 |
-
- Identify and report any issues or bugs.
|
8 |
-
- Request or add a new model.
|
9 |
-
- Suggest or implement new features.
|
10 |
-
|
11 |
-
However, remember that contributions aren't just about code.
|
12 |
-
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
|
13 |
-
|
14 |
-
Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
|
15 |
-
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
|
16 |
-
Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
|
17 |
-
|
18 |
-
|
19 |
-
## Setup for development
|
20 |
-
|
21 |
-
### Build from source
|
22 |
-
|
23 |
-
```bash
|
24 |
-
pip install -r requirements.txt
|
25 |
-
pip install -e . # This may take several minutes.
|
26 |
-
```
|
27 |
-
|
28 |
-
### Testing
|
29 |
-
|
30 |
-
```bash
|
31 |
-
pip install -r requirements-dev.txt
|
32 |
-
|
33 |
-
# Static type checking
|
34 |
-
mypy
|
35 |
-
# Unit tests
|
36 |
-
pytest tests/
|
37 |
-
```
|
38 |
-
**Note:** Currently, the repository does not pass the mypy tests.
|
39 |
-
|
40 |
-
|
41 |
-
## Contributing Guidelines
|
42 |
-
|
43 |
-
### Issue Reporting
|
44 |
-
|
45 |
-
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
|
46 |
-
If not, please file a new issue, providing as much relevant information as possible.
|
47 |
-
|
48 |
-
### Coding Style Guide
|
49 |
-
|
50 |
-
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
51 |
-
|
52 |
-
We include a formatting script [`format.sh`](./format.sh) to format the code.
|
53 |
-
|
54 |
-
### Pull Requests
|
55 |
-
|
56 |
-
When submitting a pull request:
|
57 |
-
|
58 |
-
1. Make sure your code has been rebased on top of the latest commit on the main branch.
|
59 |
-
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
|
60 |
-
3. Include a detailed description of the changes in the pull request.
|
61 |
-
Explain why you made the changes you did.
|
62 |
-
If your pull request fixes an open issue, please include a reference to it in the description.
|
63 |
-
|
64 |
-
### Code Reviews
|
65 |
-
|
66 |
-
All submissions, including submissions by project members, require a code review.
|
67 |
-
To make the review process as smooth as possible, please:
|
68 |
-
|
69 |
-
1. Keep your changes as concise as possible.
|
70 |
-
If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
|
71 |
-
2. Respond to all comments within a reasonable time frame.
|
72 |
-
If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
|
73 |
-
|
74 |
-
### Thank You
|
75 |
-
|
76 |
-
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
|
77 |
-
Your contributions make vLLM a great tool for everyone!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
2 |
-
# to run the OpenAI compatible server.
|
3 |
-
|
4 |
-
#################### BASE BUILD IMAGE ####################
|
5 |
-
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
6 |
-
|
7 |
-
RUN apt-get update -y \
|
8 |
-
&& apt-get install -y python3-pip git
|
9 |
-
|
10 |
-
WORKDIR /workspace
|
11 |
-
|
12 |
-
# install build and runtime dependencies
|
13 |
-
COPY requirements.txt requirements.txt
|
14 |
-
RUN --mount=type=cache,target=/root/.cache/pip \
|
15 |
-
pip install -r requirements.txt
|
16 |
-
|
17 |
-
# install development dependencies
|
18 |
-
COPY requirements-dev.txt requirements-dev.txt
|
19 |
-
RUN --mount=type=cache,target=/root/.cache/pip \
|
20 |
-
pip install -r requirements-dev.txt
|
21 |
-
#################### BASE BUILD IMAGE ####################
|
22 |
-
|
23 |
-
|
24 |
-
#################### EXTENSION BUILD IMAGE ####################
|
25 |
-
FROM dev AS build
|
26 |
-
|
27 |
-
# install build dependencies
|
28 |
-
COPY requirements-build.txt requirements-build.txt
|
29 |
-
RUN --mount=type=cache,target=/root/.cache/pip \
|
30 |
-
pip install -r requirements-build.txt
|
31 |
-
|
32 |
-
# copy input files
|
33 |
-
COPY csrc csrc
|
34 |
-
COPY setup.py setup.py
|
35 |
-
COPY requirements.txt requirements.txt
|
36 |
-
COPY pyproject.toml pyproject.toml
|
37 |
-
COPY vllm/__init__.py vllm/__init__.py
|
38 |
-
|
39 |
-
# cuda arch list used by torch
|
40 |
-
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
41 |
-
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
42 |
-
# max jobs used by Ninja to build extensions
|
43 |
-
ARG max_jobs=2
|
44 |
-
ENV MAX_JOBS=${max_jobs}
|
45 |
-
# number of threads used by nvcc
|
46 |
-
ARG nvcc_threads=8
|
47 |
-
ENV NVCC_THREADS=$nvcc_threads
|
48 |
-
# make sure punica kernels are built (for LoRA)
|
49 |
-
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
50 |
-
|
51 |
-
RUN python3 setup.py build_ext --inplace
|
52 |
-
#################### EXTENSION Build IMAGE ####################
|
53 |
-
|
54 |
-
|
55 |
-
#################### TEST IMAGE ####################
|
56 |
-
# image to run unit testing suite
|
57 |
-
FROM dev AS test
|
58 |
-
|
59 |
-
# copy pytorch extensions separately to avoid having to rebuild
|
60 |
-
# when python code changes
|
61 |
-
WORKDIR /vllm-workspace
|
62 |
-
# ADD is used to preserve directory structure
|
63 |
-
ADD . /vllm-workspace/
|
64 |
-
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
|
65 |
-
# ignore build dependencies installation because we are using pre-complied extensions
|
66 |
-
RUN rm pyproject.toml
|
67 |
-
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
|
68 |
-
#################### TEST IMAGE ####################
|
69 |
-
|
70 |
-
|
71 |
-
#################### RUNTIME BASE IMAGE ####################
|
72 |
-
# use CUDA base as CUDA runtime dependencies are already installed via pip
|
73 |
-
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
|
74 |
-
|
75 |
-
# libnccl required for ray
|
76 |
-
RUN apt-get update -y \
|
77 |
-
&& apt-get install -y python3-pip
|
78 |
-
|
79 |
-
WORKDIR /workspace
|
80 |
-
COPY requirements.txt requirements.txt
|
81 |
-
RUN --mount=type=cache,target=/root/.cache/pip \
|
82 |
-
pip install -r requirements.txt
|
83 |
-
#################### RUNTIME BASE IMAGE ####################
|
84 |
-
|
85 |
-
|
86 |
-
#################### OPENAI API SERVER ####################
|
87 |
-
# openai api server alternative
|
88 |
-
FROM vllm-base AS vllm-openai
|
89 |
-
# install additional dependencies for openai api server
|
90 |
-
RUN --mount=type=cache,target=/root/.cache/pip \
|
91 |
-
pip install accelerate
|
92 |
-
|
93 |
-
# Create a non-root user
|
94 |
-
RUN useradd -m appuser
|
95 |
-
|
96 |
-
# Transfer ownership of the /workspace to the new non-root user
|
97 |
-
RUN chown -R appuser:appuser /workspace
|
98 |
-
|
99 |
-
# Create a cache directory within the appuser's home directory and transfer ownership
|
100 |
-
RUN mkdir -p /home/appuser/cache && \
|
101 |
-
chown -R appuser:appuser /home/appuser/cache
|
102 |
-
|
103 |
-
# Switch to the non-root user for subsequent commands and container runtime
|
104 |
-
USER appuser
|
105 |
-
|
106 |
-
# Set the Hugging Face cache directory environment variable
|
107 |
-
ENV TRANSFORMERS_CACHE=/home/appuser/cache
|
108 |
-
|
109 |
-
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
110 |
-
COPY vllm vllm
|
111 |
-
|
112 |
-
CMD ["python3", "-m", \
|
113 |
-
"vllm.entrypoints.openai.api_server", \
|
114 |
-
"--host", "0.0.0.0", \
|
115 |
-
"--port", "7860", \
|
116 |
-
"--served-model-name", "default", \
|
117 |
-
"--model", "facebook/opt-125m", \
|
118 |
-
"--max-model-len", "1024", \
|
119 |
-
"--tensor-parallel-size", "1", \
|
120 |
-
"--max-num-seqs", "16"]
|
121 |
-
#################### OPENAI API SERVER ####################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile.rocm
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
# default base image
|
2 |
-
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
3 |
-
|
4 |
-
FROM $BASE_IMAGE
|
5 |
-
|
6 |
-
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
7 |
-
|
8 |
-
RUN echo "Base image is $BASE_IMAGE"
|
9 |
-
|
10 |
-
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
11 |
-
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
12 |
-
|
13 |
-
# this does not always work for all rocm versions
|
14 |
-
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
|
15 |
-
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
|
16 |
-
|
17 |
-
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
18 |
-
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
19 |
-
|
20 |
-
ARG FA_BRANCH="3d2b6f5"
|
21 |
-
RUN echo "FA_BRANCH is $FA_BRANCH"
|
22 |
-
|
23 |
-
# Install some basic utilities
|
24 |
-
RUN apt-get update && apt-get install python3 python3-pip -y
|
25 |
-
|
26 |
-
# Install some basic utilities
|
27 |
-
RUN apt-get update && apt-get install -y \
|
28 |
-
curl \
|
29 |
-
ca-certificates \
|
30 |
-
sudo \
|
31 |
-
git \
|
32 |
-
bzip2 \
|
33 |
-
libx11-6 \
|
34 |
-
build-essential \
|
35 |
-
wget \
|
36 |
-
unzip \
|
37 |
-
nvidia-cuda-toolkit \
|
38 |
-
tmux \
|
39 |
-
&& rm -rf /var/lib/apt/lists/*
|
40 |
-
|
41 |
-
### Mount Point ###
|
42 |
-
# When launching the container, mount the code directory to /app
|
43 |
-
ARG APP_MOUNT=/app
|
44 |
-
VOLUME [ ${APP_MOUNT} ]
|
45 |
-
WORKDIR ${APP_MOUNT}
|
46 |
-
|
47 |
-
RUN python3 -m pip install --upgrade pip
|
48 |
-
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
49 |
-
|
50 |
-
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
51 |
-
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
52 |
-
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
53 |
-
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
54 |
-
|
55 |
-
# Install ROCm flash-attention
|
56 |
-
RUN mkdir libs \
|
57 |
-
&& cd libs \
|
58 |
-
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
59 |
-
&& cd flash-attention \
|
60 |
-
&& git checkout ${FA_BRANCH} \
|
61 |
-
&& git submodule update --init \
|
62 |
-
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
63 |
-
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
64 |
-
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
65 |
-
&& python3 setup.py install \
|
66 |
-
&& cd ..
|
67 |
-
|
68 |
-
COPY ./ /app/vllm
|
69 |
-
|
70 |
-
RUN python3 -m pip install --upgrade pip
|
71 |
-
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
72 |
-
|
73 |
-
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
74 |
-
# Manually removed it so that later steps of numpy upgrade can continue
|
75 |
-
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
76 |
-
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
77 |
-
|
78 |
-
RUN cd /app \
|
79 |
-
&& cd vllm \
|
80 |
-
&& pip install -U -r requirements-rocm.txt \
|
81 |
-
&& bash patch_xformers.rocm.sh \
|
82 |
-
&& python3 setup.py install \
|
83 |
-
&& cd ..
|
84 |
-
|
85 |
-
RUN python3 -m pip install --upgrade pip
|
86 |
-
RUN python3 -m pip install --no-cache-dir ray[all]
|
87 |
-
|
88 |
-
CMD ["/bin/bash"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [yyyy] [name of copyright owner]
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MANIFEST.in
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
include LICENSE
|
2 |
-
include requirements.txt
|
3 |
-
|
4 |
-
recursive-include csrc *
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: CertifAIer Demo
|
3 |
-
emoji: ✈️
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: blue
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/README.md
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# Benchmarking vLLM
|
2 |
-
|
3 |
-
## Downloading the ShareGPT dataset
|
4 |
-
|
5 |
-
You can download the dataset by running:
|
6 |
-
```bash
|
7 |
-
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
8 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/benchmark_latency.py
DELETED
@@ -1,139 +0,0 @@
|
|
1 |
-
"""Benchmark the latency of processing a single batch of requests."""
|
2 |
-
import argparse
|
3 |
-
import time
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
from tqdm import tqdm
|
10 |
-
|
11 |
-
from vllm import LLM, SamplingParams
|
12 |
-
|
13 |
-
|
14 |
-
def main(args: argparse.Namespace):
|
15 |
-
print(args)
|
16 |
-
|
17 |
-
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
18 |
-
# the engine will automatically process the request in multiple batches.
|
19 |
-
llm = LLM(
|
20 |
-
model=args.model,
|
21 |
-
tokenizer=args.tokenizer,
|
22 |
-
quantization=args.quantization,
|
23 |
-
tensor_parallel_size=args.tensor_parallel_size,
|
24 |
-
trust_remote_code=args.trust_remote_code,
|
25 |
-
dtype=args.dtype,
|
26 |
-
enforce_eager=args.enforce_eager,
|
27 |
-
kv_cache_dtype=args.kv_cache_dtype,
|
28 |
-
)
|
29 |
-
|
30 |
-
sampling_params = SamplingParams(
|
31 |
-
n=args.n,
|
32 |
-
temperature=0.0 if args.use_beam_search else 1.0,
|
33 |
-
top_p=1.0,
|
34 |
-
use_beam_search=args.use_beam_search,
|
35 |
-
ignore_eos=True,
|
36 |
-
max_tokens=args.output_len,
|
37 |
-
)
|
38 |
-
print(sampling_params)
|
39 |
-
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
40 |
-
|
41 |
-
def run_to_completion(profile_dir: Optional[str] = None):
|
42 |
-
if profile_dir:
|
43 |
-
with torch.profiler.profile(
|
44 |
-
activities=[
|
45 |
-
torch.profiler.ProfilerActivity.CPU,
|
46 |
-
torch.profiler.ProfilerActivity.CUDA,
|
47 |
-
],
|
48 |
-
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
49 |
-
str(profile_dir))) as p:
|
50 |
-
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
51 |
-
sampling_params=sampling_params,
|
52 |
-
use_tqdm=False)
|
53 |
-
print(p.key_averages())
|
54 |
-
else:
|
55 |
-
start_time = time.perf_counter()
|
56 |
-
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
57 |
-
sampling_params=sampling_params,
|
58 |
-
use_tqdm=False)
|
59 |
-
end_time = time.perf_counter()
|
60 |
-
latency = end_time - start_time
|
61 |
-
return latency
|
62 |
-
|
63 |
-
print("Warming up...")
|
64 |
-
run_to_completion(profile_dir=None)
|
65 |
-
|
66 |
-
if args.profile:
|
67 |
-
profile_dir = args.profile_result_dir
|
68 |
-
if not profile_dir:
|
69 |
-
profile_dir = Path(
|
70 |
-
"."
|
71 |
-
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
72 |
-
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
73 |
-
run_to_completion(profile_dir=args.profile_result_dir)
|
74 |
-
return
|
75 |
-
|
76 |
-
# Benchmark.
|
77 |
-
latencies = []
|
78 |
-
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
79 |
-
latencies.append(run_to_completion(profile_dir=None))
|
80 |
-
print(f'Avg latency: {np.mean(latencies)} seconds')
|
81 |
-
|
82 |
-
|
83 |
-
if __name__ == '__main__':
|
84 |
-
parser = argparse.ArgumentParser(
|
85 |
-
description='Benchmark the latency of processing a single batch of '
|
86 |
-
'requests till completion.')
|
87 |
-
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
88 |
-
parser.add_argument('--tokenizer', type=str, default=None)
|
89 |
-
parser.add_argument('--quantization',
|
90 |
-
'-q',
|
91 |
-
choices=['awq', 'gptq', 'squeezellm', None],
|
92 |
-
default=None)
|
93 |
-
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
94 |
-
parser.add_argument('--input-len', type=int, default=32)
|
95 |
-
parser.add_argument('--output-len', type=int, default=128)
|
96 |
-
parser.add_argument('--batch-size', type=int, default=8)
|
97 |
-
parser.add_argument('--n',
|
98 |
-
type=int,
|
99 |
-
default=1,
|
100 |
-
help='Number of generated sequences per prompt.')
|
101 |
-
parser.add_argument('--use-beam-search', action='store_true')
|
102 |
-
parser.add_argument('--num-iters',
|
103 |
-
type=int,
|
104 |
-
default=3,
|
105 |
-
help='Number of iterations to run.')
|
106 |
-
parser.add_argument('--trust-remote-code',
|
107 |
-
action='store_true',
|
108 |
-
help='trust remote code from huggingface')
|
109 |
-
parser.add_argument(
|
110 |
-
'--dtype',
|
111 |
-
type=str,
|
112 |
-
default='auto',
|
113 |
-
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
114 |
-
help='data type for model weights and activations. '
|
115 |
-
'The "auto" option will use FP16 precision '
|
116 |
-
'for FP32 and FP16 models, and BF16 precision '
|
117 |
-
'for BF16 models.')
|
118 |
-
parser.add_argument('--enforce-eager',
|
119 |
-
action='store_true',
|
120 |
-
help='enforce eager mode and disable CUDA graph')
|
121 |
-
parser.add_argument(
|
122 |
-
"--kv-cache-dtype",
|
123 |
-
type=str,
|
124 |
-
choices=['auto', 'fp8_e5m2'],
|
125 |
-
default='auto',
|
126 |
-
help=
|
127 |
-
'Data type for kv cache storage. If "auto", will use model data type.')
|
128 |
-
parser.add_argument(
|
129 |
-
'--profile',
|
130 |
-
action='store_true',
|
131 |
-
help='profile the generation process of a single batch')
|
132 |
-
parser.add_argument(
|
133 |
-
'--profile-result-dir',
|
134 |
-
type=str,
|
135 |
-
default=None,
|
136 |
-
help=('path to save the pytorch profiler output. Can be visualized '
|
137 |
-
'with ui.perfetto.dev or Tensorboard.'))
|
138 |
-
args = parser.parse_args()
|
139 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/benchmark_serving.py
DELETED
@@ -1,249 +0,0 @@
|
|
1 |
-
"""Benchmark online serving throughput.
|
2 |
-
|
3 |
-
On the server side, run one of the following commands:
|
4 |
-
(vLLM backend)
|
5 |
-
python -m vllm.entrypoints.api_server \
|
6 |
-
--model <your_model> --swap-space 16 \
|
7 |
-
--disable-log-requests
|
8 |
-
|
9 |
-
(TGI backend)
|
10 |
-
./launch_hf_server.sh <your_model>
|
11 |
-
|
12 |
-
On the client side, run:
|
13 |
-
python benchmarks/benchmark_serving.py \
|
14 |
-
--backend <backend> \
|
15 |
-
--tokenizer <your_model> --dataset <target_dataset> \
|
16 |
-
--request-rate <request_rate>
|
17 |
-
"""
|
18 |
-
import argparse
|
19 |
-
import asyncio
|
20 |
-
import json
|
21 |
-
import random
|
22 |
-
import time
|
23 |
-
from typing import AsyncGenerator, List, Tuple
|
24 |
-
|
25 |
-
import aiohttp
|
26 |
-
import numpy as np
|
27 |
-
from tqdm.asyncio import tqdm
|
28 |
-
from transformers import PreTrainedTokenizerBase
|
29 |
-
from vllm.transformers_utils.tokenizer import get_tokenizer
|
30 |
-
|
31 |
-
# (prompt len, output len, latency)
|
32 |
-
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
33 |
-
|
34 |
-
|
35 |
-
def sample_requests(
|
36 |
-
dataset_path: str,
|
37 |
-
num_requests: int,
|
38 |
-
tokenizer: PreTrainedTokenizerBase,
|
39 |
-
) -> List[Tuple[str, int, int]]:
|
40 |
-
# Load the dataset.
|
41 |
-
with open(dataset_path) as f:
|
42 |
-
dataset = json.load(f)
|
43 |
-
# Filter out the conversations with less than 2 turns.
|
44 |
-
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
45 |
-
# Only keep the first two turns of each conversation.
|
46 |
-
dataset = [(data["conversations"][0]["value"],
|
47 |
-
data["conversations"][1]["value"]) for data in dataset]
|
48 |
-
|
49 |
-
# Tokenize the prompts and completions.
|
50 |
-
prompts = [prompt for prompt, _ in dataset]
|
51 |
-
prompt_token_ids = tokenizer(prompts).input_ids
|
52 |
-
completions = [completion for _, completion in dataset]
|
53 |
-
completion_token_ids = tokenizer(completions).input_ids
|
54 |
-
tokenized_dataset = []
|
55 |
-
for i in range(len(dataset)):
|
56 |
-
output_len = len(completion_token_ids[i])
|
57 |
-
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
58 |
-
|
59 |
-
# Filter out too long sequences.
|
60 |
-
filtered_dataset: List[Tuple[str, int, int]] = []
|
61 |
-
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
62 |
-
prompt_len = len(prompt_token_ids)
|
63 |
-
if prompt_len < 4 or output_len < 4:
|
64 |
-
# Prune too short sequences.
|
65 |
-
# This is because TGI causes errors when the input or output length
|
66 |
-
# is too short.
|
67 |
-
continue
|
68 |
-
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
69 |
-
# Prune too long sequences.
|
70 |
-
continue
|
71 |
-
filtered_dataset.append((prompt, prompt_len, output_len))
|
72 |
-
|
73 |
-
# Sample the requests.
|
74 |
-
sampled_requests = random.sample(filtered_dataset, num_requests)
|
75 |
-
return sampled_requests
|
76 |
-
|
77 |
-
|
78 |
-
async def get_request(
|
79 |
-
input_requests: List[Tuple[str, int, int]],
|
80 |
-
request_rate: float,
|
81 |
-
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
82 |
-
input_requests = iter(input_requests)
|
83 |
-
for request in input_requests:
|
84 |
-
yield request
|
85 |
-
|
86 |
-
if request_rate == float("inf"):
|
87 |
-
# If the request rate is infinity, then we don't need to wait.
|
88 |
-
continue
|
89 |
-
# Sample the request interval from the exponential distribution.
|
90 |
-
interval = np.random.exponential(1.0 / request_rate)
|
91 |
-
# The next request will be sent after the interval.
|
92 |
-
await asyncio.sleep(interval)
|
93 |
-
|
94 |
-
|
95 |
-
async def send_request(backend: str, model: str, api_url: str, prompt: str,
|
96 |
-
prompt_len: int, output_len: int, best_of: int,
|
97 |
-
use_beam_search: bool, pbar: tqdm) -> None:
|
98 |
-
request_start_time = time.perf_counter()
|
99 |
-
|
100 |
-
headers = {"User-Agent": "Benchmark Client"}
|
101 |
-
if backend == "vllm":
|
102 |
-
pload = {
|
103 |
-
"prompt": prompt,
|
104 |
-
"n": 1,
|
105 |
-
"best_of": best_of,
|
106 |
-
"use_beam_search": use_beam_search,
|
107 |
-
"temperature": 0.0 if use_beam_search else 1.0,
|
108 |
-
"top_p": 1.0,
|
109 |
-
"max_tokens": output_len,
|
110 |
-
"ignore_eos": True,
|
111 |
-
"stream": False,
|
112 |
-
}
|
113 |
-
if model is not None:
|
114 |
-
pload["model"] = model
|
115 |
-
elif backend == "tgi":
|
116 |
-
assert not use_beam_search
|
117 |
-
params = {
|
118 |
-
"best_of": best_of,
|
119 |
-
"max_new_tokens": output_len,
|
120 |
-
"do_sample": True,
|
121 |
-
}
|
122 |
-
pload = {
|
123 |
-
"inputs": prompt,
|
124 |
-
"parameters": params,
|
125 |
-
}
|
126 |
-
else:
|
127 |
-
raise ValueError(f"Unknown backend: {backend}")
|
128 |
-
|
129 |
-
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
130 |
-
async with aiohttp.ClientSession(timeout=timeout) as session:
|
131 |
-
while True:
|
132 |
-
async with session.post(api_url, headers=headers,
|
133 |
-
json=pload) as response:
|
134 |
-
chunks = []
|
135 |
-
async for chunk, _ in response.content.iter_chunks():
|
136 |
-
chunks.append(chunk)
|
137 |
-
output = b"".join(chunks).decode("utf-8")
|
138 |
-
output = json.loads(output)
|
139 |
-
|
140 |
-
# Re-send the request if it failed.
|
141 |
-
if "error" not in output:
|
142 |
-
break
|
143 |
-
|
144 |
-
request_end_time = time.perf_counter()
|
145 |
-
request_latency = request_end_time - request_start_time
|
146 |
-
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
147 |
-
pbar.update(1)
|
148 |
-
|
149 |
-
|
150 |
-
async def benchmark(
|
151 |
-
backend: str,
|
152 |
-
model: str,
|
153 |
-
api_url: str,
|
154 |
-
input_requests: List[Tuple[str, int, int]],
|
155 |
-
best_of: int,
|
156 |
-
use_beam_search: bool,
|
157 |
-
request_rate: float,
|
158 |
-
) -> None:
|
159 |
-
tasks: List[asyncio.Task] = []
|
160 |
-
pbar = tqdm(total=len(input_requests))
|
161 |
-
async for request in get_request(input_requests, request_rate):
|
162 |
-
prompt, prompt_len, output_len = request
|
163 |
-
task = asyncio.create_task(
|
164 |
-
send_request(backend, model, api_url, prompt, prompt_len,
|
165 |
-
output_len, best_of, use_beam_search, pbar))
|
166 |
-
tasks.append(task)
|
167 |
-
await asyncio.gather(*tasks)
|
168 |
-
pbar.close()
|
169 |
-
|
170 |
-
|
171 |
-
def main(args: argparse.Namespace):
|
172 |
-
print(args)
|
173 |
-
random.seed(args.seed)
|
174 |
-
np.random.seed(args.seed)
|
175 |
-
|
176 |
-
api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
|
177 |
-
tokenizer = get_tokenizer(args.tokenizer,
|
178 |
-
trust_remote_code=args.trust_remote_code)
|
179 |
-
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
180 |
-
|
181 |
-
benchmark_start_time = time.perf_counter()
|
182 |
-
asyncio.run(
|
183 |
-
benchmark(args.backend, args.model, api_url, input_requests,
|
184 |
-
args.best_of, args.use_beam_search, args.request_rate))
|
185 |
-
benchmark_end_time = time.perf_counter()
|
186 |
-
benchmark_time = benchmark_end_time - benchmark_start_time
|
187 |
-
print(f"Total time: {benchmark_time:.2f} s")
|
188 |
-
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
189 |
-
|
190 |
-
# Compute the latency statistics.
|
191 |
-
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
192 |
-
print(f"Average latency: {avg_latency:.2f} s")
|
193 |
-
avg_per_token_latency = np.mean([
|
194 |
-
latency / (prompt_len + output_len)
|
195 |
-
for prompt_len, output_len, latency in REQUEST_LATENCY
|
196 |
-
])
|
197 |
-
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
198 |
-
avg_per_output_token_latency = np.mean(
|
199 |
-
[latency / output_len for _, output_len, latency in REQUEST_LATENCY])
|
200 |
-
print("Average latency per output token: "
|
201 |
-
f"{avg_per_output_token_latency:.2f} s")
|
202 |
-
|
203 |
-
|
204 |
-
if __name__ == "__main__":
|
205 |
-
parser = argparse.ArgumentParser(
|
206 |
-
description="Benchmark the online serving throughput.")
|
207 |
-
parser.add_argument("--backend",
|
208 |
-
type=str,
|
209 |
-
default="vllm",
|
210 |
-
choices=["vllm", "tgi"])
|
211 |
-
parser.add_argument("--protocol",
|
212 |
-
type=str,
|
213 |
-
default="http",
|
214 |
-
choices=["http", "https"])
|
215 |
-
parser.add_argument("--host", type=str, default="localhost")
|
216 |
-
parser.add_argument("--port", type=int, default=8000)
|
217 |
-
parser.add_argument("--endpoint", type=str, default="/generate")
|
218 |
-
parser.add_argument("--model", type=str, default=None)
|
219 |
-
parser.add_argument("--dataset",
|
220 |
-
type=str,
|
221 |
-
required=True,
|
222 |
-
help="Path to the dataset.")
|
223 |
-
parser.add_argument("--tokenizer",
|
224 |
-
type=str,
|
225 |
-
required=True,
|
226 |
-
help="Name or path of the tokenizer.")
|
227 |
-
parser.add_argument("--best-of",
|
228 |
-
type=int,
|
229 |
-
default=1,
|
230 |
-
help="Generates `best_of` sequences per prompt and "
|
231 |
-
"returns the best one.")
|
232 |
-
parser.add_argument("--use-beam-search", action="store_true")
|
233 |
-
parser.add_argument("--num-prompts",
|
234 |
-
type=int,
|
235 |
-
default=1000,
|
236 |
-
help="Number of prompts to process.")
|
237 |
-
parser.add_argument("--request-rate",
|
238 |
-
type=float,
|
239 |
-
default=float("inf"),
|
240 |
-
help="Number of requests per second. If this is inf, "
|
241 |
-
"then all the requests are sent at time 0. "
|
242 |
-
"Otherwise, we use Poisson process to synthesize "
|
243 |
-
"the request arrival times.")
|
244 |
-
parser.add_argument("--seed", type=int, default=0)
|
245 |
-
parser.add_argument('--trust-remote-code',
|
246 |
-
action='store_true',
|
247 |
-
help='trust remote code from huggingface')
|
248 |
-
args = parser.parse_args()
|
249 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/benchmark_throughput.py
DELETED
@@ -1,328 +0,0 @@
|
|
1 |
-
"""Benchmark offline inference throughput."""
|
2 |
-
import argparse
|
3 |
-
import json
|
4 |
-
import random
|
5 |
-
import time
|
6 |
-
from typing import List, Optional, Tuple
|
7 |
-
|
8 |
-
import torch
|
9 |
-
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
10 |
-
PreTrainedTokenizerBase)
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
|
14 |
-
def sample_requests(
|
15 |
-
dataset_path: str,
|
16 |
-
num_requests: int,
|
17 |
-
tokenizer: PreTrainedTokenizerBase,
|
18 |
-
fixed_output_len: Optional[int],
|
19 |
-
) -> List[Tuple[str, int, int]]:
|
20 |
-
if fixed_output_len is not None and fixed_output_len < 4:
|
21 |
-
raise ValueError("output_len too small")
|
22 |
-
|
23 |
-
# Load the dataset.
|
24 |
-
with open(dataset_path) as f:
|
25 |
-
dataset = json.load(f)
|
26 |
-
# Filter out the conversations with less than 2 turns.
|
27 |
-
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
28 |
-
# Only keep the first two turns of each conversation.
|
29 |
-
dataset = [(data["conversations"][0]["value"],
|
30 |
-
data["conversations"][1]["value"]) for data in dataset]
|
31 |
-
|
32 |
-
# Tokenize the prompts and completions.
|
33 |
-
prompts = [prompt for prompt, _ in dataset]
|
34 |
-
prompt_token_ids = tokenizer(prompts).input_ids
|
35 |
-
completions = [completion for _, completion in dataset]
|
36 |
-
completion_token_ids = tokenizer(completions).input_ids
|
37 |
-
tokenized_dataset = []
|
38 |
-
for i in range(len(dataset)):
|
39 |
-
output_len = len(completion_token_ids[i])
|
40 |
-
if fixed_output_len is not None:
|
41 |
-
output_len = fixed_output_len
|
42 |
-
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
43 |
-
|
44 |
-
# Filter out too long sequences.
|
45 |
-
filtered_dataset: List[Tuple[str, int, int]] = []
|
46 |
-
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
47 |
-
prompt_len = len(prompt_token_ids)
|
48 |
-
if prompt_len < 4 or output_len < 4:
|
49 |
-
# Prune too short sequences.
|
50 |
-
continue
|
51 |
-
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
52 |
-
# Prune too long sequences.
|
53 |
-
continue
|
54 |
-
filtered_dataset.append((prompt, prompt_len, output_len))
|
55 |
-
|
56 |
-
# Sample the requests.
|
57 |
-
sampled_requests = random.sample(filtered_dataset, num_requests)
|
58 |
-
return sampled_requests
|
59 |
-
|
60 |
-
|
61 |
-
def run_vllm(
|
62 |
-
requests: List[Tuple[str, int, int]],
|
63 |
-
model: str,
|
64 |
-
tokenizer: str,
|
65 |
-
quantization: Optional[str],
|
66 |
-
tensor_parallel_size: int,
|
67 |
-
seed: int,
|
68 |
-
n: int,
|
69 |
-
use_beam_search: bool,
|
70 |
-
trust_remote_code: bool,
|
71 |
-
dtype: str,
|
72 |
-
max_model_len: Optional[int],
|
73 |
-
enforce_eager: bool,
|
74 |
-
kv_cache_dtype: str,
|
75 |
-
) -> float:
|
76 |
-
from vllm import LLM, SamplingParams
|
77 |
-
llm = LLM(
|
78 |
-
model=model,
|
79 |
-
tokenizer=tokenizer,
|
80 |
-
quantization=quantization,
|
81 |
-
tensor_parallel_size=tensor_parallel_size,
|
82 |
-
seed=seed,
|
83 |
-
trust_remote_code=trust_remote_code,
|
84 |
-
dtype=dtype,
|
85 |
-
max_model_len=max_model_len,
|
86 |
-
enforce_eager=enforce_eager,
|
87 |
-
kv_cache_dtype=kv_cache_dtype,
|
88 |
-
)
|
89 |
-
|
90 |
-
# Add the requests to the engine.
|
91 |
-
for prompt, _, output_len in requests:
|
92 |
-
sampling_params = SamplingParams(
|
93 |
-
n=n,
|
94 |
-
temperature=0.0 if use_beam_search else 1.0,
|
95 |
-
top_p=1.0,
|
96 |
-
use_beam_search=use_beam_search,
|
97 |
-
ignore_eos=True,
|
98 |
-
max_tokens=output_len,
|
99 |
-
)
|
100 |
-
# FIXME(woosuk): Do not use internal method.
|
101 |
-
llm._add_request(
|
102 |
-
prompt=prompt,
|
103 |
-
prompt_token_ids=None,
|
104 |
-
sampling_params=sampling_params,
|
105 |
-
)
|
106 |
-
|
107 |
-
start = time.perf_counter()
|
108 |
-
# FIXME(woosuk): Do not use internal method.
|
109 |
-
llm._run_engine(use_tqdm=True)
|
110 |
-
end = time.perf_counter()
|
111 |
-
return end - start
|
112 |
-
|
113 |
-
|
114 |
-
def run_hf(
|
115 |
-
requests: List[Tuple[str, int, int]],
|
116 |
-
model: str,
|
117 |
-
tokenizer: PreTrainedTokenizerBase,
|
118 |
-
n: int,
|
119 |
-
use_beam_search: bool,
|
120 |
-
max_batch_size: int,
|
121 |
-
trust_remote_code: bool,
|
122 |
-
) -> float:
|
123 |
-
assert not use_beam_search
|
124 |
-
llm = AutoModelForCausalLM.from_pretrained(
|
125 |
-
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
126 |
-
if llm.config.model_type == "llama":
|
127 |
-
# To enable padding in the HF backend.
|
128 |
-
tokenizer.pad_token = tokenizer.eos_token
|
129 |
-
llm = llm.cuda()
|
130 |
-
|
131 |
-
pbar = tqdm(total=len(requests))
|
132 |
-
start = time.perf_counter()
|
133 |
-
batch: List[str] = []
|
134 |
-
max_prompt_len = 0
|
135 |
-
max_output_len = 0
|
136 |
-
for i in range(len(requests)):
|
137 |
-
prompt, prompt_len, output_len = requests[i]
|
138 |
-
# Add the prompt to the batch.
|
139 |
-
batch.append(prompt)
|
140 |
-
max_prompt_len = max(max_prompt_len, prompt_len)
|
141 |
-
max_output_len = max(max_output_len, output_len)
|
142 |
-
if len(batch) < max_batch_size and i != len(requests) - 1:
|
143 |
-
# Check if we can add more requests to the batch.
|
144 |
-
_, next_prompt_len, next_output_len = requests[i + 1]
|
145 |
-
if (max(max_prompt_len, next_prompt_len) +
|
146 |
-
max(max_output_len, next_output_len)) <= 2048:
|
147 |
-
# We can add more requests to the batch.
|
148 |
-
continue
|
149 |
-
|
150 |
-
# Generate the sequences.
|
151 |
-
input_ids = tokenizer(batch, return_tensors="pt",
|
152 |
-
padding=True).input_ids
|
153 |
-
llm_outputs = llm.generate(
|
154 |
-
input_ids=input_ids.cuda(),
|
155 |
-
do_sample=not use_beam_search,
|
156 |
-
num_return_sequences=n,
|
157 |
-
temperature=1.0,
|
158 |
-
top_p=1.0,
|
159 |
-
use_cache=True,
|
160 |
-
max_new_tokens=max_output_len,
|
161 |
-
)
|
162 |
-
# Include the decoding time.
|
163 |
-
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
164 |
-
pbar.update(len(batch))
|
165 |
-
|
166 |
-
# Clear the batch.
|
167 |
-
batch = []
|
168 |
-
max_prompt_len = 0
|
169 |
-
max_output_len = 0
|
170 |
-
end = time.perf_counter()
|
171 |
-
return end - start
|
172 |
-
|
173 |
-
|
174 |
-
def run_mii(
|
175 |
-
requests: List[Tuple[str, int, int]],
|
176 |
-
model: str,
|
177 |
-
tensor_parallel_size: int,
|
178 |
-
output_len: int,
|
179 |
-
) -> float:
|
180 |
-
from mii import pipeline
|
181 |
-
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
182 |
-
prompts = [prompt for prompt, _, _ in requests]
|
183 |
-
|
184 |
-
start = time.perf_counter()
|
185 |
-
llm(prompts, max_new_tokens=output_len)
|
186 |
-
end = time.perf_counter()
|
187 |
-
return end - start
|
188 |
-
|
189 |
-
|
190 |
-
def main(args: argparse.Namespace):
|
191 |
-
print(args)
|
192 |
-
random.seed(args.seed)
|
193 |
-
|
194 |
-
# Sample the requests.
|
195 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
196 |
-
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
197 |
-
if args.dataset is None:
|
198 |
-
# Synthesize a prompt with the given input length.
|
199 |
-
prompt = "hi" * (args.input_len - 1)
|
200 |
-
requests = [(prompt, args.input_len, args.output_len)
|
201 |
-
for _ in range(args.num_prompts)]
|
202 |
-
else:
|
203 |
-
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
204 |
-
args.output_len)
|
205 |
-
|
206 |
-
if args.backend == "vllm":
|
207 |
-
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
208 |
-
args.quantization, args.tensor_parallel_size,
|
209 |
-
args.seed, args.n, args.use_beam_search,
|
210 |
-
args.trust_remote_code, args.dtype,
|
211 |
-
args.max_model_len, args.enforce_eager,
|
212 |
-
args.kv_cache_dtype)
|
213 |
-
elif args.backend == "hf":
|
214 |
-
assert args.tensor_parallel_size == 1
|
215 |
-
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
216 |
-
args.use_beam_search, args.hf_max_batch_size,
|
217 |
-
args.trust_remote_code)
|
218 |
-
elif args.backend == "mii":
|
219 |
-
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
220 |
-
args.output_len)
|
221 |
-
else:
|
222 |
-
raise ValueError(f"Unknown backend: {args.backend}")
|
223 |
-
total_num_tokens = sum(prompt_len + output_len
|
224 |
-
for _, prompt_len, output_len in requests)
|
225 |
-
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
226 |
-
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
227 |
-
|
228 |
-
|
229 |
-
if __name__ == "__main__":
|
230 |
-
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
231 |
-
parser.add_argument("--backend",
|
232 |
-
type=str,
|
233 |
-
choices=["vllm", "hf", "mii"],
|
234 |
-
default="vllm")
|
235 |
-
parser.add_argument("--dataset",
|
236 |
-
type=str,
|
237 |
-
default=None,
|
238 |
-
help="Path to the dataset.")
|
239 |
-
parser.add_argument("--input-len",
|
240 |
-
type=int,
|
241 |
-
default=None,
|
242 |
-
help="Input prompt length for each request")
|
243 |
-
parser.add_argument("--output-len",
|
244 |
-
type=int,
|
245 |
-
default=None,
|
246 |
-
help="Output length for each request. Overrides the "
|
247 |
-
"output length from the dataset.")
|
248 |
-
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
249 |
-
parser.add_argument("--tokenizer", type=str, default=None)
|
250 |
-
parser.add_argument('--quantization',
|
251 |
-
'-q',
|
252 |
-
choices=['awq', 'gptq', 'squeezellm', None],
|
253 |
-
default=None)
|
254 |
-
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
255 |
-
parser.add_argument("--n",
|
256 |
-
type=int,
|
257 |
-
default=1,
|
258 |
-
help="Number of generated sequences per prompt.")
|
259 |
-
parser.add_argument("--use-beam-search", action="store_true")
|
260 |
-
parser.add_argument("--num-prompts",
|
261 |
-
type=int,
|
262 |
-
default=1000,
|
263 |
-
help="Number of prompts to process.")
|
264 |
-
parser.add_argument("--seed", type=int, default=0)
|
265 |
-
parser.add_argument("--hf-max-batch-size",
|
266 |
-
type=int,
|
267 |
-
default=None,
|
268 |
-
help="Maximum batch size for HF backend.")
|
269 |
-
parser.add_argument('--trust-remote-code',
|
270 |
-
action='store_true',
|
271 |
-
help='trust remote code from huggingface')
|
272 |
-
parser.add_argument(
|
273 |
-
'--max-model-len',
|
274 |
-
type=int,
|
275 |
-
default=None,
|
276 |
-
help='Maximum length of a sequence (including prompt and output). '
|
277 |
-
'If None, will be derived from the model.')
|
278 |
-
parser.add_argument(
|
279 |
-
'--dtype',
|
280 |
-
type=str,
|
281 |
-
default='auto',
|
282 |
-
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
283 |
-
help='data type for model weights and activations. '
|
284 |
-
'The "auto" option will use FP16 precision '
|
285 |
-
'for FP32 and FP16 models, and BF16 precision '
|
286 |
-
'for BF16 models.')
|
287 |
-
parser.add_argument("--enforce-eager",
|
288 |
-
action="store_true",
|
289 |
-
help="enforce eager execution")
|
290 |
-
parser.add_argument(
|
291 |
-
"--kv-cache-dtype",
|
292 |
-
type=str,
|
293 |
-
choices=["auto", "fp8_e5m2"],
|
294 |
-
default="auto",
|
295 |
-
help=
|
296 |
-
'Data type for kv cache storage. If "auto", will use model data type.')
|
297 |
-
args = parser.parse_args()
|
298 |
-
if args.tokenizer is None:
|
299 |
-
args.tokenizer = args.model
|
300 |
-
if args.dataset is None:
|
301 |
-
assert args.input_len is not None
|
302 |
-
assert args.output_len is not None
|
303 |
-
else:
|
304 |
-
assert args.input_len is None
|
305 |
-
|
306 |
-
if args.backend == "vllm":
|
307 |
-
if args.hf_max_batch_size is not None:
|
308 |
-
raise ValueError("HF max batch size is only for HF backend.")
|
309 |
-
elif args.backend == "hf":
|
310 |
-
if args.hf_max_batch_size is None:
|
311 |
-
raise ValueError("HF max batch size is required for HF backend.")
|
312 |
-
if args.quantization is not None:
|
313 |
-
raise ValueError("Quantization is only for vLLM backend.")
|
314 |
-
elif args.backend == "mii":
|
315 |
-
if args.dtype != "auto":
|
316 |
-
raise ValueError("dtype must be auto for MII backend.")
|
317 |
-
if args.n != 1:
|
318 |
-
raise ValueError("n must be 1 for MII backend.")
|
319 |
-
if args.use_beam_search:
|
320 |
-
raise ValueError("Beam search is not supported for MII backend.")
|
321 |
-
if args.quantization is not None:
|
322 |
-
raise ValueError("Quantization is only for vLLM backend.")
|
323 |
-
if args.hf_max_batch_size is not None:
|
324 |
-
raise ValueError("HF max batch size is only for HF backend.")
|
325 |
-
if args.tokenizer != args.model:
|
326 |
-
raise ValueError("Tokenizer must be the same as the model for MII "
|
327 |
-
"backend.")
|
328 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/kernels/benchmark_paged_attention.py
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
import argparse
|
3 |
-
import random
|
4 |
-
import time
|
5 |
-
|
6 |
-
import torch
|
7 |
-
|
8 |
-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
9 |
-
from vllm._C import ops
|
10 |
-
|
11 |
-
NUM_BLOCKS = 1024
|
12 |
-
PARTITION_SIZE = 512
|
13 |
-
|
14 |
-
|
15 |
-
@torch.inference_mode()
|
16 |
-
def main(
|
17 |
-
version: str,
|
18 |
-
num_seqs: int,
|
19 |
-
context_len: int,
|
20 |
-
num_query_heads: int,
|
21 |
-
num_kv_heads: int,
|
22 |
-
head_size: int,
|
23 |
-
use_alibi: bool,
|
24 |
-
block_size: int,
|
25 |
-
dtype: torch.dtype,
|
26 |
-
seed: int,
|
27 |
-
do_profile: bool,
|
28 |
-
kv_cache_dtype: Optional[str] = None,
|
29 |
-
) -> None:
|
30 |
-
random.seed(seed)
|
31 |
-
torch.random.manual_seed(seed)
|
32 |
-
torch.cuda.manual_seed(seed)
|
33 |
-
|
34 |
-
scale = float(1.0 / (head_size**0.5))
|
35 |
-
query = torch.empty(num_seqs,
|
36 |
-
num_query_heads,
|
37 |
-
head_size,
|
38 |
-
dtype=dtype,
|
39 |
-
device="cuda")
|
40 |
-
query.uniform_(-scale, scale)
|
41 |
-
|
42 |
-
assert num_query_heads % num_kv_heads == 0
|
43 |
-
alibi_slopes = None
|
44 |
-
if use_alibi:
|
45 |
-
alibi_slopes = torch.randn(num_query_heads,
|
46 |
-
dtype=torch.float,
|
47 |
-
device="cuda")
|
48 |
-
|
49 |
-
context_lens = [context_len for _ in range(num_seqs)]
|
50 |
-
max_context_len = max(context_lens)
|
51 |
-
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
52 |
-
|
53 |
-
# Create the block tables.
|
54 |
-
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
55 |
-
block_tables = []
|
56 |
-
for _ in range(num_seqs):
|
57 |
-
block_table = [
|
58 |
-
random.randint(0, NUM_BLOCKS - 1)
|
59 |
-
for _ in range(max_num_blocks_per_seq)
|
60 |
-
]
|
61 |
-
block_tables.append(block_table)
|
62 |
-
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
63 |
-
|
64 |
-
# Create the KV cache.
|
65 |
-
key_caches, value_caches = create_kv_caches_with_random(
|
66 |
-
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
|
67 |
-
dtype)
|
68 |
-
key_cache, value_cache = key_caches[0], value_caches[0]
|
69 |
-
|
70 |
-
# Prepare for the paged attention kernel.
|
71 |
-
output = torch.empty_like(query)
|
72 |
-
if version == "v2":
|
73 |
-
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
74 |
-
PARTITION_SIZE)
|
75 |
-
tmp_output = torch.empty(
|
76 |
-
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
77 |
-
dtype=output.dtype,
|
78 |
-
device=output.device,
|
79 |
-
)
|
80 |
-
exp_sums = torch.empty(
|
81 |
-
size=(num_seqs, num_query_heads, num_partitions),
|
82 |
-
dtype=torch.float32,
|
83 |
-
device=output.device,
|
84 |
-
)
|
85 |
-
max_logits = torch.empty_like(exp_sums)
|
86 |
-
|
87 |
-
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
88 |
-
torch.cuda.synchronize()
|
89 |
-
if profile:
|
90 |
-
torch.cuda.cudart().cudaProfilerStart()
|
91 |
-
start_time = time.perf_counter()
|
92 |
-
|
93 |
-
for _ in range(num_iters):
|
94 |
-
if version == "v1":
|
95 |
-
ops.paged_attention_v1(
|
96 |
-
output,
|
97 |
-
query,
|
98 |
-
key_cache,
|
99 |
-
value_cache,
|
100 |
-
num_kv_heads,
|
101 |
-
scale,
|
102 |
-
block_tables,
|
103 |
-
context_lens,
|
104 |
-
block_size,
|
105 |
-
max_context_len,
|
106 |
-
alibi_slopes,
|
107 |
-
kv_cache_dtype,
|
108 |
-
)
|
109 |
-
elif version == "v2":
|
110 |
-
ops.paged_attention_v2(
|
111 |
-
output,
|
112 |
-
exp_sums,
|
113 |
-
max_logits,
|
114 |
-
tmp_output,
|
115 |
-
query,
|
116 |
-
key_cache,
|
117 |
-
value_cache,
|
118 |
-
num_kv_heads,
|
119 |
-
scale,
|
120 |
-
block_tables,
|
121 |
-
context_lens,
|
122 |
-
block_size,
|
123 |
-
max_context_len,
|
124 |
-
alibi_slopes,
|
125 |
-
kv_cache_dtype,
|
126 |
-
)
|
127 |
-
else:
|
128 |
-
raise ValueError(f"Invalid version: {version}")
|
129 |
-
torch.cuda.synchronize()
|
130 |
-
|
131 |
-
end_time = time.perf_counter()
|
132 |
-
if profile:
|
133 |
-
torch.cuda.cudart().cudaProfilerStart()
|
134 |
-
return (end_time - start_time) / num_iters
|
135 |
-
|
136 |
-
# Warmup.
|
137 |
-
print("Warming up...")
|
138 |
-
run_benchmark(num_iters=3, profile=False)
|
139 |
-
|
140 |
-
# Benchmark.
|
141 |
-
if do_profile:
|
142 |
-
latency = run_benchmark(num_iters=1, profile=True)
|
143 |
-
else:
|
144 |
-
latency = run_benchmark(num_iters=100, profile=False)
|
145 |
-
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
146 |
-
|
147 |
-
|
148 |
-
if __name__ == '__main__':
|
149 |
-
parser = argparse.ArgumentParser(
|
150 |
-
description="Benchmark the paged attention kernel.")
|
151 |
-
parser.add_argument("--version",
|
152 |
-
type=str,
|
153 |
-
choices=["v1", "v2"],
|
154 |
-
default="v2")
|
155 |
-
parser.add_argument("--batch-size", type=int, default=8)
|
156 |
-
parser.add_argument("--context-len", type=int, default=4096)
|
157 |
-
parser.add_argument("--num-query-heads", type=int, default=64)
|
158 |
-
parser.add_argument("--num-kv-heads", type=int, default=8)
|
159 |
-
parser.add_argument("--head-size",
|
160 |
-
type=int,
|
161 |
-
choices=[64, 80, 96, 112, 128, 256],
|
162 |
-
default=128)
|
163 |
-
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
164 |
-
parser.add_argument("--use-alibi", action="store_true")
|
165 |
-
parser.add_argument("--dtype",
|
166 |
-
type=str,
|
167 |
-
choices=["half", "bfloat16", "float"],
|
168 |
-
default="half")
|
169 |
-
parser.add_argument("--seed", type=int, default=0)
|
170 |
-
parser.add_argument("--profile", action="store_true")
|
171 |
-
parser.add_argument(
|
172 |
-
"--kv-cache-dtype",
|
173 |
-
type=str,
|
174 |
-
choices=["auto", "fp8_e5m2"],
|
175 |
-
default="auto",
|
176 |
-
help=
|
177 |
-
'Data type for kv cache storage. If "auto", will use model data type.')
|
178 |
-
args = parser.parse_args()
|
179 |
-
print(args)
|
180 |
-
|
181 |
-
if args.num_query_heads % args.num_kv_heads != 0:
|
182 |
-
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
183 |
-
main(
|
184 |
-
version=args.version,
|
185 |
-
num_seqs=args.batch_size,
|
186 |
-
context_len=args.context_len,
|
187 |
-
num_query_heads=args.num_query_heads,
|
188 |
-
num_kv_heads=args.num_kv_heads,
|
189 |
-
head_size=args.head_size,
|
190 |
-
block_size=args.block_size,
|
191 |
-
use_alibi=args.use_alibi,
|
192 |
-
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
193 |
-
seed=args.seed,
|
194 |
-
do_profile=args.profile,
|
195 |
-
kv_cache_dtype=args.kv_cache_dtype,
|
196 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/launch_tgi_server.sh
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
PORT=8000
|
4 |
-
MODEL=$1
|
5 |
-
TOKENS=$2
|
6 |
-
|
7 |
-
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
8 |
-
-v $PWD/data:/data \
|
9 |
-
ghcr.io/huggingface/text-generation-inference:0.8 \
|
10 |
-
--model-id $MODEL \
|
11 |
-
--sharded false \
|
12 |
-
--max-input-length 1024 \
|
13 |
-
--max-total-tokens 2048 \
|
14 |
-
--max-best-of 5 \
|
15 |
-
--max-concurrent-requests 5000 \
|
16 |
-
--max-batch-total-tokens $TOKENS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/activation_kernels.cu
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
#include <ATen/cuda/CUDAContext.h>
|
2 |
-
#include <torch/extension.h>
|
3 |
-
#include <c10/cuda/CUDAGuard.h>
|
4 |
-
|
5 |
-
#include "cuda_compat.h"
|
6 |
-
#include "dispatch_utils.h"
|
7 |
-
|
8 |
-
namespace vllm {
|
9 |
-
|
10 |
-
template<typename T>
|
11 |
-
__device__ __forceinline__ T silu(const T& x) {
|
12 |
-
// x * sigmoid(x)
|
13 |
-
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
14 |
-
}
|
15 |
-
|
16 |
-
template<typename scalar_t>
|
17 |
-
__global__ void silu_and_mul_kernel(
|
18 |
-
scalar_t* __restrict__ out, // [..., d]
|
19 |
-
const scalar_t* __restrict__ input, // [..., 2, d]
|
20 |
-
const int d) {
|
21 |
-
const int64_t token_idx = blockIdx.x;
|
22 |
-
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
23 |
-
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
24 |
-
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
25 |
-
out[token_idx * d + idx] = silu(x) * y;
|
26 |
-
}
|
27 |
-
}
|
28 |
-
|
29 |
-
} // namespace vllm
|
30 |
-
|
31 |
-
void silu_and_mul(
|
32 |
-
torch::Tensor& out, // [..., d]
|
33 |
-
torch::Tensor& input) // [..., 2 * d]
|
34 |
-
{
|
35 |
-
int64_t num_tokens = input.numel() / input.size(-1);
|
36 |
-
int d = input.size(-1) / 2;
|
37 |
-
|
38 |
-
dim3 grid(num_tokens);
|
39 |
-
dim3 block(std::min(d, 1024));
|
40 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
41 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
42 |
-
VLLM_DISPATCH_FLOATING_TYPES(
|
43 |
-
input.scalar_type(),
|
44 |
-
"silu_and_mul_kernel",
|
45 |
-
[&] {
|
46 |
-
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
47 |
-
out.data_ptr<scalar_t>(),
|
48 |
-
input.data_ptr<scalar_t>(),
|
49 |
-
d);
|
50 |
-
});
|
51 |
-
}
|
52 |
-
|
53 |
-
namespace vllm {
|
54 |
-
|
55 |
-
// Element-wise activation kernel template.
|
56 |
-
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
57 |
-
__global__ void activation_kernel(
|
58 |
-
scalar_t* __restrict__ out, // [..., d]
|
59 |
-
const scalar_t* __restrict__ input, // [..., d]
|
60 |
-
const int d) {
|
61 |
-
const int64_t token_idx = blockIdx.x;
|
62 |
-
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
63 |
-
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
64 |
-
out[token_idx * d + idx] = ACT_FN(x);
|
65 |
-
}
|
66 |
-
}
|
67 |
-
|
68 |
-
} // namespace vllm
|
69 |
-
|
70 |
-
// Launch element-wise activation kernel.
|
71 |
-
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
72 |
-
int d = input.size(-1); \
|
73 |
-
int64_t num_tokens = input.numel() / d; \
|
74 |
-
dim3 grid(num_tokens); \
|
75 |
-
dim3 block(std::min(d, 1024)); \
|
76 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
77 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
78 |
-
VLLM_DISPATCH_FLOATING_TYPES( \
|
79 |
-
input.scalar_type(), \
|
80 |
-
"activation_kernel", \
|
81 |
-
[&] { \
|
82 |
-
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
83 |
-
out.data_ptr<scalar_t>(), \
|
84 |
-
input.data_ptr<scalar_t>(), \
|
85 |
-
d); \
|
86 |
-
});
|
87 |
-
|
88 |
-
namespace vllm {
|
89 |
-
|
90 |
-
template<typename T>
|
91 |
-
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
92 |
-
const float x3 = (float) (x * x * x);
|
93 |
-
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
94 |
-
return ((T) 0.5) * x * (((T) 1.0) + t);
|
95 |
-
}
|
96 |
-
|
97 |
-
template<typename T>
|
98 |
-
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
99 |
-
const float f = (float) x;
|
100 |
-
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
101 |
-
return ((T) 0.5) * x * (((T) 1.0) + t);
|
102 |
-
}
|
103 |
-
|
104 |
-
} // namespace vllm
|
105 |
-
|
106 |
-
void gelu_new(
|
107 |
-
torch::Tensor& out, // [..., d]
|
108 |
-
torch::Tensor& input) // [..., d]
|
109 |
-
{
|
110 |
-
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
111 |
-
}
|
112 |
-
|
113 |
-
void gelu_fast(
|
114 |
-
torch::Tensor& out, // [..., d]
|
115 |
-
torch::Tensor& input) // [..., d]
|
116 |
-
{
|
117 |
-
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
118 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/attention_dtypes.h
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include "attention_generic.cuh"
|
4 |
-
#include "dtype_float16.cuh"
|
5 |
-
#include "dtype_float32.cuh"
|
6 |
-
#include "dtype_bfloat16.cuh"
|
7 |
-
#include "dtype_fp8_e5m2.cuh"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/attention_generic.cuh
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
3 |
-
* Copyright (c) 2023, The vLLM team.
|
4 |
-
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
-
*
|
6 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
-
* you may not use this file except in compliance with the License.
|
8 |
-
* You may obtain a copy of the License at
|
9 |
-
*
|
10 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
-
*
|
12 |
-
* Unless required by applicable law or agreed to in writing, software
|
13 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
-
* See the License for the specific language governing permissions and
|
16 |
-
* limitations under the License.
|
17 |
-
*/
|
18 |
-
#pragma once
|
19 |
-
|
20 |
-
#include <stdint.h>
|
21 |
-
|
22 |
-
namespace vllm {
|
23 |
-
|
24 |
-
// A vector type to store Q, K, V elements.
|
25 |
-
template<typename T, int VEC_SIZE>
|
26 |
-
struct Vec {};
|
27 |
-
|
28 |
-
// A vector type to store FP32 accumulators.
|
29 |
-
template<typename T>
|
30 |
-
struct FloatVec {};
|
31 |
-
|
32 |
-
// Template vector operations.
|
33 |
-
template<typename Acc, typename A, typename B>
|
34 |
-
inline __device__ Acc mul(A a, B b);
|
35 |
-
|
36 |
-
template<typename T>
|
37 |
-
inline __device__ float sum(T v);
|
38 |
-
|
39 |
-
template<typename T>
|
40 |
-
inline __device__ float dot(T a, T b) {
|
41 |
-
return sum(mul<T, T, T>(a, b));
|
42 |
-
}
|
43 |
-
|
44 |
-
template<typename A, typename T>
|
45 |
-
inline __device__ float dot(T a, T b) {
|
46 |
-
return sum(mul<A, T, T>(a, b));
|
47 |
-
}
|
48 |
-
|
49 |
-
template<typename T>
|
50 |
-
inline __device__ void zero(T& dst) {
|
51 |
-
constexpr int WORDS = sizeof(T) / 4;
|
52 |
-
union {
|
53 |
-
T raw;
|
54 |
-
uint32_t words[WORDS];
|
55 |
-
} tmp;
|
56 |
-
|
57 |
-
#pragma unroll
|
58 |
-
for (int ii = 0; ii < WORDS; ++ii) {
|
59 |
-
tmp.words[ii] = 0u;
|
60 |
-
}
|
61 |
-
dst = tmp.raw;
|
62 |
-
}
|
63 |
-
|
64 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/attention_kernels.cu
DELETED
@@ -1,951 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
-
* Copyright (c) 2023, The vLLM team.
|
4 |
-
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
-
*
|
6 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
-
* you may not use this file except in compliance with the License.
|
8 |
-
* You may obtain a copy of the License at
|
9 |
-
*
|
10 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
-
*
|
12 |
-
* Unless required by applicable law or agreed to in writing, software
|
13 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
-
* See the License for the specific language governing permissions and
|
16 |
-
* limitations under the License.
|
17 |
-
*/
|
18 |
-
#ifdef USE_ROCM
|
19 |
-
#include <hip/hip_runtime.h>
|
20 |
-
#endif
|
21 |
-
|
22 |
-
#include <torch/extension.h>
|
23 |
-
#include <ATen/cuda/CUDAContext.h>
|
24 |
-
#include <c10/cuda/CUDAGuard.h>
|
25 |
-
|
26 |
-
#include "attention_dtypes.h"
|
27 |
-
#include "attention_utils.cuh"
|
28 |
-
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
29 |
-
|
30 |
-
#include <algorithm>
|
31 |
-
|
32 |
-
#ifndef USE_ROCM
|
33 |
-
#define WARP_SIZE 32
|
34 |
-
#else
|
35 |
-
#define WARP_SIZE warpSize
|
36 |
-
#endif
|
37 |
-
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
38 |
-
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
39 |
-
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
40 |
-
|
41 |
-
namespace vllm {
|
42 |
-
|
43 |
-
// Utility function for attention softmax.
|
44 |
-
template<int NUM_WARPS>
|
45 |
-
inline __device__ float block_sum(float* red_smem, float sum) {
|
46 |
-
// Decompose the thread index into warp / lane.
|
47 |
-
int warp = threadIdx.x / WARP_SIZE;
|
48 |
-
int lane = threadIdx.x % WARP_SIZE;
|
49 |
-
|
50 |
-
// Compute the sum per warp.
|
51 |
-
#pragma unroll
|
52 |
-
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
53 |
-
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
54 |
-
}
|
55 |
-
|
56 |
-
// Warp leaders store the data to shared memory.
|
57 |
-
if (lane == 0) {
|
58 |
-
red_smem[warp] = sum;
|
59 |
-
}
|
60 |
-
|
61 |
-
// Make sure the data is in shared memory.
|
62 |
-
__syncthreads();
|
63 |
-
|
64 |
-
// The warps compute the final sums.
|
65 |
-
if (lane < NUM_WARPS) {
|
66 |
-
sum = red_smem[lane];
|
67 |
-
}
|
68 |
-
|
69 |
-
// Parallel reduction inside the warp.
|
70 |
-
#pragma unroll
|
71 |
-
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
72 |
-
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
73 |
-
}
|
74 |
-
|
75 |
-
// Broadcast to other threads.
|
76 |
-
return VLLM_SHFL_SYNC(sum, 0);
|
77 |
-
}
|
78 |
-
|
79 |
-
// TODO(woosuk): Merge the last two dimensions of the grid.
|
80 |
-
// Grid: (num_heads, num_seqs, max_num_partitions).
|
81 |
-
template<
|
82 |
-
typename scalar_t,
|
83 |
-
typename cache_t,
|
84 |
-
int HEAD_SIZE,
|
85 |
-
int BLOCK_SIZE,
|
86 |
-
int NUM_THREADS,
|
87 |
-
bool IS_FP8_E5M2_KV_CACHE,
|
88 |
-
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
89 |
-
__device__ void paged_attention_kernel(
|
90 |
-
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
91 |
-
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
92 |
-
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
93 |
-
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
94 |
-
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
95 |
-
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
96 |
-
const int num_kv_heads, // [num_heads]
|
97 |
-
const float scale,
|
98 |
-
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
99 |
-
const int* __restrict__ context_lens, // [num_seqs]
|
100 |
-
const int max_num_blocks_per_seq,
|
101 |
-
const float* __restrict__ alibi_slopes, // [num_heads]
|
102 |
-
const int q_stride,
|
103 |
-
const int kv_block_stride,
|
104 |
-
const int kv_head_stride) {
|
105 |
-
const int seq_idx = blockIdx.y;
|
106 |
-
const int partition_idx = blockIdx.z;
|
107 |
-
const int max_num_partitions = gridDim.z;
|
108 |
-
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
109 |
-
const int context_len = context_lens[seq_idx];
|
110 |
-
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
|
111 |
-
// No work to do. Terminate the thread block.
|
112 |
-
return;
|
113 |
-
}
|
114 |
-
|
115 |
-
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
116 |
-
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
117 |
-
|
118 |
-
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
119 |
-
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
120 |
-
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
121 |
-
const int num_blocks = end_block_idx - start_block_idx;
|
122 |
-
|
123 |
-
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
124 |
-
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
125 |
-
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
126 |
-
const int num_tokens = end_token_idx - start_token_idx;
|
127 |
-
|
128 |
-
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
129 |
-
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
130 |
-
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
131 |
-
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
132 |
-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
133 |
-
const int thread_idx = threadIdx.x;
|
134 |
-
const int warp_idx = thread_idx / WARP_SIZE;
|
135 |
-
const int lane = thread_idx % WARP_SIZE;
|
136 |
-
|
137 |
-
const int head_idx = blockIdx.x;
|
138 |
-
const int num_heads = gridDim.x;
|
139 |
-
const int num_queries_per_kv = num_heads / num_kv_heads;
|
140 |
-
const int kv_head_idx = head_idx / num_queries_per_kv;
|
141 |
-
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
142 |
-
|
143 |
-
// A vector type to store a part of a key or a query.
|
144 |
-
// The vector size is configured in such a way that the threads in a thread group
|
145 |
-
// fetch or compute 16 bytes at a time.
|
146 |
-
// For example, if the size of a thread group is 4 and the data type is half,
|
147 |
-
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
148 |
-
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
149 |
-
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
150 |
-
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
151 |
-
#ifdef ENABLE_FP8_E5M2
|
152 |
-
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
153 |
-
#endif
|
154 |
-
|
155 |
-
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
156 |
-
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
157 |
-
|
158 |
-
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
159 |
-
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
160 |
-
|
161 |
-
// Load the query to registers.
|
162 |
-
// Each thread in a thread group has a different part of the query.
|
163 |
-
// For example, if the the thread group size is 4, then the first thread in the group
|
164 |
-
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
165 |
-
// th vectors of the query, and so on.
|
166 |
-
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
167 |
-
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
168 |
-
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
169 |
-
#pragma unroll
|
170 |
-
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
171 |
-
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
172 |
-
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
173 |
-
}
|
174 |
-
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
175 |
-
|
176 |
-
// Memory planning.
|
177 |
-
extern __shared__ char shared_mem[];
|
178 |
-
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
179 |
-
float* logits = reinterpret_cast<float*>(shared_mem);
|
180 |
-
// Workspace for reduction.
|
181 |
-
__shared__ float red_smem[2 * NUM_WARPS];
|
182 |
-
|
183 |
-
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
184 |
-
// Each thread group fetches x elements from the key at a time.
|
185 |
-
constexpr int x = 16 / sizeof(cache_t);
|
186 |
-
float qk_max = -FLT_MAX;
|
187 |
-
|
188 |
-
// Iterate over the key blocks.
|
189 |
-
// Each warp fetches a block of keys for each iteration.
|
190 |
-
// Each thread group in a warp fetches a key from the block, and computes
|
191 |
-
// dot product with the query.
|
192 |
-
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
193 |
-
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
194 |
-
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
195 |
-
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
196 |
-
// (e.g., kv_block_stride).
|
197 |
-
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
198 |
-
|
199 |
-
// Load a key to registers.
|
200 |
-
// Each thread in a thread group has a different part of the key.
|
201 |
-
// For example, if the the thread group size is 4, then the first thread in the group
|
202 |
-
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
203 |
-
// vectors of the key, and so on.
|
204 |
-
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
205 |
-
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
206 |
-
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
207 |
-
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
208 |
-
|
209 |
-
#pragma unroll
|
210 |
-
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
211 |
-
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
212 |
-
+ kv_head_idx * kv_head_stride
|
213 |
-
+ physical_block_offset * x;
|
214 |
-
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
215 |
-
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
216 |
-
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
217 |
-
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
218 |
-
#ifdef ENABLE_FP8_E5M2
|
219 |
-
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
220 |
-
// Vector conversion from Quant_vec to K_vec.
|
221 |
-
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
222 |
-
#else
|
223 |
-
assert(false);
|
224 |
-
#endif
|
225 |
-
} else {
|
226 |
-
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
227 |
-
}
|
228 |
-
}
|
229 |
-
|
230 |
-
// Compute dot product.
|
231 |
-
// This includes a reduction across the threads in the same thread group.
|
232 |
-
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
233 |
-
// Add the ALiBi bias if slopes are given.
|
234 |
-
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
235 |
-
|
236 |
-
if (thread_group_offset == 0) {
|
237 |
-
// Store the partial reductions to shared memory.
|
238 |
-
// NOTE(woosuk): It is required to zero out the masked logits.
|
239 |
-
const bool mask = token_idx >= context_len;
|
240 |
-
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
241 |
-
// Update the max value.
|
242 |
-
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
243 |
-
}
|
244 |
-
}
|
245 |
-
}
|
246 |
-
|
247 |
-
// Perform reduction across the threads in the same warp to get the
|
248 |
-
// max qk value for each "warp" (not across the thread block yet).
|
249 |
-
// The 0-th thread of each thread group already has its max qk value.
|
250 |
-
#pragma unroll
|
251 |
-
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
252 |
-
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
253 |
-
}
|
254 |
-
if (lane == 0) {
|
255 |
-
red_smem[warp_idx] = qk_max;
|
256 |
-
}
|
257 |
-
__syncthreads();
|
258 |
-
|
259 |
-
// TODO(woosuk): Refactor this part.
|
260 |
-
// Get the max qk value for the sequence.
|
261 |
-
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
262 |
-
#pragma unroll
|
263 |
-
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
264 |
-
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
265 |
-
}
|
266 |
-
// Broadcast the max qk value to all threads.
|
267 |
-
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
268 |
-
|
269 |
-
// Get the sum of the exp values.
|
270 |
-
float exp_sum = 0.f;
|
271 |
-
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
272 |
-
float val = __expf(logits[i] - qk_max);
|
273 |
-
logits[i] = val;
|
274 |
-
exp_sum += val;
|
275 |
-
}
|
276 |
-
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
277 |
-
|
278 |
-
// Compute softmax.
|
279 |
-
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
280 |
-
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
281 |
-
logits[i] *= inv_sum;
|
282 |
-
}
|
283 |
-
__syncthreads();
|
284 |
-
|
285 |
-
// If partitioning is enabled, store the max logit and exp_sum.
|
286 |
-
if (USE_PARTITIONING && thread_idx == 0) {
|
287 |
-
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
288 |
-
+ head_idx * max_num_partitions
|
289 |
-
+ partition_idx;
|
290 |
-
*max_logits_ptr = qk_max;
|
291 |
-
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
292 |
-
+ head_idx * max_num_partitions
|
293 |
-
+ partition_idx;
|
294 |
-
*exp_sums_ptr = exp_sum;
|
295 |
-
}
|
296 |
-
|
297 |
-
// Each thread will fetch 16 bytes from the value cache at a time.
|
298 |
-
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
299 |
-
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
300 |
-
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
301 |
-
#ifdef ENABLE_FP8_E5M2
|
302 |
-
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
303 |
-
#endif
|
304 |
-
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
305 |
-
|
306 |
-
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
307 |
-
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
308 |
-
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
309 |
-
|
310 |
-
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
311 |
-
float accs[NUM_ROWS_PER_THREAD];
|
312 |
-
#pragma unroll
|
313 |
-
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
314 |
-
accs[i] = 0.f;
|
315 |
-
}
|
316 |
-
|
317 |
-
scalar_t zero_value;
|
318 |
-
zero(zero_value);
|
319 |
-
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
320 |
-
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
321 |
-
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
322 |
-
// (e.g., kv_block_stride).
|
323 |
-
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
324 |
-
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
325 |
-
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
326 |
-
L_vec logits_vec;
|
327 |
-
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
328 |
-
|
329 |
-
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
330 |
-
+ kv_head_idx * kv_head_stride;
|
331 |
-
#pragma unroll
|
332 |
-
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
333 |
-
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
334 |
-
if (row_idx < HEAD_SIZE) {
|
335 |
-
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
336 |
-
V_vec v_vec;
|
337 |
-
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
338 |
-
#ifdef ENABLE_FP8_E5M2
|
339 |
-
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
340 |
-
// Vector conversion from V_quant_vec to V_vec.
|
341 |
-
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
342 |
-
#else
|
343 |
-
assert(false);
|
344 |
-
#endif
|
345 |
-
} else {
|
346 |
-
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
347 |
-
}
|
348 |
-
if (block_idx == num_context_blocks - 1) {
|
349 |
-
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
350 |
-
// we should explicitly zero out the values since they may contain NaNs.
|
351 |
-
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
352 |
-
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
353 |
-
#pragma unroll
|
354 |
-
for (int j = 0; j < V_VEC_SIZE; j++) {
|
355 |
-
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
356 |
-
}
|
357 |
-
}
|
358 |
-
accs[i] += dot(logits_vec, v_vec);
|
359 |
-
}
|
360 |
-
}
|
361 |
-
}
|
362 |
-
|
363 |
-
// Perform reduction within each warp.
|
364 |
-
#pragma unroll
|
365 |
-
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
366 |
-
float acc = accs[i];
|
367 |
-
#pragma unroll
|
368 |
-
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
369 |
-
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
370 |
-
}
|
371 |
-
accs[i] = acc;
|
372 |
-
}
|
373 |
-
|
374 |
-
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
375 |
-
// is reused for the output.
|
376 |
-
__syncthreads();
|
377 |
-
|
378 |
-
// Perform reduction across warps.
|
379 |
-
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
380 |
-
#pragma unroll
|
381 |
-
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
382 |
-
int mid = i / 2;
|
383 |
-
// Upper warps write to shared memory.
|
384 |
-
if (warp_idx >= mid && warp_idx < i) {
|
385 |
-
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
386 |
-
#pragma unroll
|
387 |
-
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
388 |
-
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
389 |
-
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
390 |
-
dst[row_idx] = accs[i];
|
391 |
-
}
|
392 |
-
}
|
393 |
-
}
|
394 |
-
__syncthreads();
|
395 |
-
|
396 |
-
// Lower warps update the output.
|
397 |
-
if (warp_idx < mid) {
|
398 |
-
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
399 |
-
#pragma unroll
|
400 |
-
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
401 |
-
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
402 |
-
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
403 |
-
accs[i] += src[row_idx];
|
404 |
-
}
|
405 |
-
}
|
406 |
-
}
|
407 |
-
__syncthreads();
|
408 |
-
}
|
409 |
-
|
410 |
-
// Write the final output.
|
411 |
-
if (warp_idx == 0) {
|
412 |
-
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
413 |
-
+ head_idx * max_num_partitions * HEAD_SIZE
|
414 |
-
+ partition_idx * HEAD_SIZE;
|
415 |
-
#pragma unroll
|
416 |
-
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
417 |
-
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
418 |
-
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
419 |
-
from_float(*(out_ptr + row_idx), accs[i]);
|
420 |
-
}
|
421 |
-
}
|
422 |
-
}
|
423 |
-
}
|
424 |
-
|
425 |
-
// Grid: (num_heads, num_seqs, 1).
|
426 |
-
template<
|
427 |
-
typename scalar_t,
|
428 |
-
typename cache_t,
|
429 |
-
int HEAD_SIZE,
|
430 |
-
int BLOCK_SIZE,
|
431 |
-
int NUM_THREADS,
|
432 |
-
bool IS_FP8_E5M2_KV_CACHE>
|
433 |
-
__global__ void paged_attention_v1_kernel(
|
434 |
-
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
435 |
-
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
436 |
-
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
437 |
-
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
438 |
-
const int num_kv_heads, // [num_heads]
|
439 |
-
const float scale,
|
440 |
-
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
441 |
-
const int* __restrict__ context_lens, // [num_seqs]
|
442 |
-
const int max_num_blocks_per_seq,
|
443 |
-
const float* __restrict__ alibi_slopes, // [num_heads]
|
444 |
-
const int q_stride,
|
445 |
-
const int kv_block_stride,
|
446 |
-
const int kv_head_stride) {
|
447 |
-
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
448 |
-
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
449 |
-
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
450 |
-
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
451 |
-
}
|
452 |
-
|
453 |
-
// Grid: (num_heads, num_seqs, max_num_partitions).
|
454 |
-
template<
|
455 |
-
typename scalar_t,
|
456 |
-
typename cache_t,
|
457 |
-
int HEAD_SIZE,
|
458 |
-
int BLOCK_SIZE,
|
459 |
-
int NUM_THREADS,
|
460 |
-
bool IS_FP8_E5M2_KV_CACHE,
|
461 |
-
int PARTITION_SIZE>
|
462 |
-
__global__ void paged_attention_v2_kernel(
|
463 |
-
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
464 |
-
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
465 |
-
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
466 |
-
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
467 |
-
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
468 |
-
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
469 |
-
const int num_kv_heads, // [num_heads]
|
470 |
-
const float scale,
|
471 |
-
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
472 |
-
const int* __restrict__ context_lens, // [num_seqs]
|
473 |
-
const int max_num_blocks_per_seq,
|
474 |
-
const float* __restrict__ alibi_slopes, // [num_heads]
|
475 |
-
const int q_stride,
|
476 |
-
const int kv_block_stride,
|
477 |
-
const int kv_head_stride) {
|
478 |
-
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
479 |
-
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
480 |
-
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
481 |
-
q_stride, kv_block_stride, kv_head_stride);
|
482 |
-
}
|
483 |
-
|
484 |
-
// Grid: (num_heads, num_seqs).
|
485 |
-
template<
|
486 |
-
typename scalar_t,
|
487 |
-
int HEAD_SIZE,
|
488 |
-
int NUM_THREADS,
|
489 |
-
int PARTITION_SIZE>
|
490 |
-
__global__ void paged_attention_v2_reduce_kernel(
|
491 |
-
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
492 |
-
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
493 |
-
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
494 |
-
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
495 |
-
const int* __restrict__ context_lens, // [num_seqs]
|
496 |
-
const int max_num_partitions) {
|
497 |
-
const int num_heads = gridDim.x;
|
498 |
-
const int head_idx = blockIdx.x;
|
499 |
-
const int seq_idx = blockIdx.y;
|
500 |
-
const int context_len = context_lens[seq_idx];
|
501 |
-
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
502 |
-
if (num_partitions == 1) {
|
503 |
-
// No need to reduce. Only copy tmp_out to out.
|
504 |
-
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
505 |
-
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
506 |
-
+ head_idx * max_num_partitions * HEAD_SIZE;
|
507 |
-
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
508 |
-
out_ptr[i] = tmp_out_ptr[i];
|
509 |
-
}
|
510 |
-
// Terminate the thread block.
|
511 |
-
return;
|
512 |
-
}
|
513 |
-
|
514 |
-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
515 |
-
const int warp_idx = threadIdx.x / WARP_SIZE;
|
516 |
-
const int lane = threadIdx.x % WARP_SIZE;
|
517 |
-
|
518 |
-
// Size: 2 * num_partitions.
|
519 |
-
extern __shared__ char shared_mem[];
|
520 |
-
// Workspace for reduction.
|
521 |
-
__shared__ float red_smem[2 * NUM_WARPS];
|
522 |
-
|
523 |
-
// Load max logits to shared memory.
|
524 |
-
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
525 |
-
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
526 |
-
+ head_idx * max_num_partitions;
|
527 |
-
float max_logit = -FLT_MAX;
|
528 |
-
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
529 |
-
const float l = max_logits_ptr[i];
|
530 |
-
shared_max_logits[i] = l;
|
531 |
-
max_logit = fmaxf(max_logit, l);
|
532 |
-
}
|
533 |
-
__syncthreads();
|
534 |
-
|
535 |
-
// Get the global max logit.
|
536 |
-
// Reduce within the warp.
|
537 |
-
#pragma unroll
|
538 |
-
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
539 |
-
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
540 |
-
}
|
541 |
-
if (lane == 0) {
|
542 |
-
red_smem[warp_idx] = max_logit;
|
543 |
-
}
|
544 |
-
__syncthreads();
|
545 |
-
// Reduce across warps.
|
546 |
-
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
547 |
-
#pragma unroll
|
548 |
-
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
549 |
-
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
550 |
-
}
|
551 |
-
// Broadcast the max value to all threads.
|
552 |
-
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
553 |
-
|
554 |
-
// Load rescaled exp sums to shared memory.
|
555 |
-
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
556 |
-
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
557 |
-
+ head_idx * max_num_partitions;
|
558 |
-
float global_exp_sum = 0.0f;
|
559 |
-
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
560 |
-
float l = shared_max_logits[i];
|
561 |
-
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
562 |
-
global_exp_sum += rescaled_exp_sum;
|
563 |
-
shared_exp_sums[i] = rescaled_exp_sum;
|
564 |
-
}
|
565 |
-
__syncthreads();
|
566 |
-
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
567 |
-
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
568 |
-
|
569 |
-
// Aggregate tmp_out to out.
|
570 |
-
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
571 |
-
+ head_idx * max_num_partitions * HEAD_SIZE;
|
572 |
-
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
573 |
-
#pragma unroll
|
574 |
-
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
575 |
-
float acc = 0.0f;
|
576 |
-
for (int j = 0; j < num_partitions; ++j) {
|
577 |
-
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
578 |
-
}
|
579 |
-
from_float(out_ptr[i], acc);
|
580 |
-
}
|
581 |
-
}
|
582 |
-
|
583 |
-
} // namespace vllm
|
584 |
-
|
585 |
-
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
586 |
-
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
587 |
-
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
588 |
-
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
589 |
-
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
590 |
-
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
591 |
-
out_ptr, \
|
592 |
-
query_ptr, \
|
593 |
-
key_cache_ptr, \
|
594 |
-
value_cache_ptr, \
|
595 |
-
num_kv_heads, \
|
596 |
-
scale, \
|
597 |
-
block_tables_ptr, \
|
598 |
-
context_lens_ptr, \
|
599 |
-
max_num_blocks_per_seq, \
|
600 |
-
alibi_slopes_ptr, \
|
601 |
-
q_stride, \
|
602 |
-
kv_block_stride, \
|
603 |
-
kv_head_stride);
|
604 |
-
|
605 |
-
// TODO(woosuk): Tune NUM_THREADS.
|
606 |
-
template<
|
607 |
-
typename T,
|
608 |
-
typename CACHE_T,
|
609 |
-
int BLOCK_SIZE,
|
610 |
-
bool IS_FP8_E5M2_KV_CACHE,
|
611 |
-
int NUM_THREADS = 128>
|
612 |
-
void paged_attention_v1_launcher(
|
613 |
-
torch::Tensor& out,
|
614 |
-
torch::Tensor& query,
|
615 |
-
torch::Tensor& key_cache,
|
616 |
-
torch::Tensor& value_cache,
|
617 |
-
int num_kv_heads,
|
618 |
-
float scale,
|
619 |
-
torch::Tensor& block_tables,
|
620 |
-
torch::Tensor& context_lens,
|
621 |
-
int max_context_len,
|
622 |
-
const c10::optional<torch::Tensor>& alibi_slopes) {
|
623 |
-
int num_seqs = query.size(0);
|
624 |
-
int num_heads = query.size(1);
|
625 |
-
int head_size = query.size(2);
|
626 |
-
int max_num_blocks_per_seq = block_tables.size(1);
|
627 |
-
int q_stride = query.stride(0);
|
628 |
-
int kv_block_stride = key_cache.stride(0);
|
629 |
-
int kv_head_stride = key_cache.stride(1);
|
630 |
-
|
631 |
-
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
632 |
-
assert(head_size % thread_group_size == 0);
|
633 |
-
|
634 |
-
// NOTE: alibi_slopes is optional.
|
635 |
-
const float* alibi_slopes_ptr = alibi_slopes ?
|
636 |
-
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
637 |
-
: nullptr;
|
638 |
-
|
639 |
-
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
640 |
-
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
641 |
-
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
642 |
-
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
643 |
-
int* block_tables_ptr = block_tables.data_ptr<int>();
|
644 |
-
int* context_lens_ptr = context_lens.data_ptr<int>();
|
645 |
-
|
646 |
-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
647 |
-
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
648 |
-
int logits_size = padded_max_context_len * sizeof(float);
|
649 |
-
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
650 |
-
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
651 |
-
// Keep that in sync with the logic here!
|
652 |
-
int shared_mem_size = std::max(logits_size, outputs_size);
|
653 |
-
|
654 |
-
dim3 grid(num_heads, num_seqs, 1);
|
655 |
-
dim3 block(NUM_THREADS);
|
656 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
657 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
658 |
-
switch (head_size) {
|
659 |
-
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
660 |
-
// head sizes that we use in the model. However, we can easily extend this
|
661 |
-
// to support any head size which is a multiple of 16.
|
662 |
-
case 64:
|
663 |
-
LAUNCH_PAGED_ATTENTION_V1(64);
|
664 |
-
break;
|
665 |
-
case 80:
|
666 |
-
LAUNCH_PAGED_ATTENTION_V1(80);
|
667 |
-
break;
|
668 |
-
case 96:
|
669 |
-
LAUNCH_PAGED_ATTENTION_V1(96);
|
670 |
-
break;
|
671 |
-
case 112:
|
672 |
-
LAUNCH_PAGED_ATTENTION_V1(112);
|
673 |
-
break;
|
674 |
-
case 128:
|
675 |
-
LAUNCH_PAGED_ATTENTION_V1(128);
|
676 |
-
break;
|
677 |
-
case 256:
|
678 |
-
LAUNCH_PAGED_ATTENTION_V1(256);
|
679 |
-
break;
|
680 |
-
default:
|
681 |
-
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
682 |
-
break;
|
683 |
-
}
|
684 |
-
}
|
685 |
-
|
686 |
-
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
687 |
-
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
688 |
-
out, \
|
689 |
-
query, \
|
690 |
-
key_cache, \
|
691 |
-
value_cache, \
|
692 |
-
num_kv_heads, \
|
693 |
-
scale, \
|
694 |
-
block_tables, \
|
695 |
-
context_lens, \
|
696 |
-
max_context_len, \
|
697 |
-
alibi_slopes);
|
698 |
-
|
699 |
-
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
700 |
-
// 1, 2, 4, 64, 128, 256.
|
701 |
-
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
702 |
-
switch (block_size) { \
|
703 |
-
case 8: \
|
704 |
-
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
705 |
-
break; \
|
706 |
-
case 16: \
|
707 |
-
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
708 |
-
break; \
|
709 |
-
case 32: \
|
710 |
-
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
711 |
-
break; \
|
712 |
-
default: \
|
713 |
-
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
714 |
-
break; \
|
715 |
-
}
|
716 |
-
|
717 |
-
void paged_attention_v1(
|
718 |
-
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
719 |
-
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
720 |
-
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
721 |
-
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
722 |
-
int num_kv_heads, // [num_heads]
|
723 |
-
float scale,
|
724 |
-
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
725 |
-
torch::Tensor& context_lens, // [num_seqs]
|
726 |
-
int block_size,
|
727 |
-
int max_context_len,
|
728 |
-
const c10::optional<torch::Tensor>& alibi_slopes,
|
729 |
-
const std::string& kv_cache_dtype) {
|
730 |
-
if (kv_cache_dtype == "auto") {
|
731 |
-
if (query.dtype() == at::ScalarType::Float) {
|
732 |
-
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
733 |
-
} else if (query.dtype() == at::ScalarType::Half) {
|
734 |
-
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
735 |
-
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
736 |
-
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
737 |
-
} else {
|
738 |
-
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
739 |
-
}
|
740 |
-
} else if (kv_cache_dtype == "fp8_e5m2") {
|
741 |
-
if (query.dtype() == at::ScalarType::Float) {
|
742 |
-
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
743 |
-
} else if (query.dtype() == at::ScalarType::Half) {
|
744 |
-
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
745 |
-
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
746 |
-
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
747 |
-
} else {
|
748 |
-
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
749 |
-
}
|
750 |
-
} else {
|
751 |
-
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
752 |
-
}
|
753 |
-
}
|
754 |
-
|
755 |
-
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
756 |
-
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
757 |
-
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
758 |
-
<<<grid, block, shared_mem_size, stream>>>( \
|
759 |
-
exp_sums_ptr, \
|
760 |
-
max_logits_ptr, \
|
761 |
-
tmp_out_ptr, \
|
762 |
-
query_ptr, \
|
763 |
-
key_cache_ptr, \
|
764 |
-
value_cache_ptr, \
|
765 |
-
num_kv_heads, \
|
766 |
-
scale, \
|
767 |
-
block_tables_ptr, \
|
768 |
-
context_lens_ptr, \
|
769 |
-
max_num_blocks_per_seq, \
|
770 |
-
alibi_slopes_ptr, \
|
771 |
-
q_stride, \
|
772 |
-
kv_block_stride, \
|
773 |
-
kv_head_stride); \
|
774 |
-
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
775 |
-
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
776 |
-
out_ptr, \
|
777 |
-
exp_sums_ptr, \
|
778 |
-
max_logits_ptr, \
|
779 |
-
tmp_out_ptr, \
|
780 |
-
context_lens_ptr, \
|
781 |
-
max_num_partitions);
|
782 |
-
|
783 |
-
template<
|
784 |
-
typename T,
|
785 |
-
typename CACHE_T,
|
786 |
-
int BLOCK_SIZE,
|
787 |
-
bool IS_FP8_E5M2_KV_CACHE,
|
788 |
-
int NUM_THREADS = 128,
|
789 |
-
int PARTITION_SIZE = 512>
|
790 |
-
void paged_attention_v2_launcher(
|
791 |
-
torch::Tensor& out,
|
792 |
-
torch::Tensor& exp_sums,
|
793 |
-
torch::Tensor& max_logits,
|
794 |
-
torch::Tensor& tmp_out,
|
795 |
-
torch::Tensor& query,
|
796 |
-
torch::Tensor& key_cache,
|
797 |
-
torch::Tensor& value_cache,
|
798 |
-
int num_kv_heads,
|
799 |
-
float scale,
|
800 |
-
torch::Tensor& block_tables,
|
801 |
-
torch::Tensor& context_lens,
|
802 |
-
int max_context_len,
|
803 |
-
const c10::optional<torch::Tensor>& alibi_slopes) {
|
804 |
-
int num_seqs = query.size(0);
|
805 |
-
int num_heads = query.size(1);
|
806 |
-
int head_size = query.size(2);
|
807 |
-
int max_num_blocks_per_seq = block_tables.size(1);
|
808 |
-
int q_stride = query.stride(0);
|
809 |
-
int kv_block_stride = key_cache.stride(0);
|
810 |
-
int kv_head_stride = key_cache.stride(1);
|
811 |
-
|
812 |
-
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
813 |
-
assert(head_size % thread_group_size == 0);
|
814 |
-
|
815 |
-
// NOTE: alibi_slopes is optional.
|
816 |
-
const float* alibi_slopes_ptr = alibi_slopes ?
|
817 |
-
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
818 |
-
: nullptr;
|
819 |
-
|
820 |
-
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
821 |
-
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
822 |
-
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
823 |
-
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
824 |
-
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
825 |
-
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
826 |
-
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
827 |
-
int* block_tables_ptr = block_tables.data_ptr<int>();
|
828 |
-
int* context_lens_ptr = context_lens.data_ptr<int>();
|
829 |
-
|
830 |
-
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
831 |
-
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
832 |
-
int logits_size = PARTITION_SIZE * sizeof(float);
|
833 |
-
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
834 |
-
|
835 |
-
// For paged attention v2 kernel.
|
836 |
-
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
837 |
-
int shared_mem_size = std::max(logits_size, outputs_size);
|
838 |
-
// For paged attention v2 reduce kernel.
|
839 |
-
dim3 reduce_grid(num_heads, num_seqs);
|
840 |
-
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
841 |
-
|
842 |
-
dim3 block(NUM_THREADS);
|
843 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
844 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
845 |
-
switch (head_size) {
|
846 |
-
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
847 |
-
// head sizes that we use in the model. However, we can easily extend this
|
848 |
-
// to support any head size which is a multiple of 16.
|
849 |
-
case 64:
|
850 |
-
LAUNCH_PAGED_ATTENTION_V2(64);
|
851 |
-
break;
|
852 |
-
case 80:
|
853 |
-
LAUNCH_PAGED_ATTENTION_V2(80);
|
854 |
-
break;
|
855 |
-
case 96:
|
856 |
-
LAUNCH_PAGED_ATTENTION_V2(96);
|
857 |
-
break;
|
858 |
-
case 112:
|
859 |
-
LAUNCH_PAGED_ATTENTION_V2(112);
|
860 |
-
break;
|
861 |
-
case 128:
|
862 |
-
LAUNCH_PAGED_ATTENTION_V2(128);
|
863 |
-
break;
|
864 |
-
case 256:
|
865 |
-
LAUNCH_PAGED_ATTENTION_V2(256);
|
866 |
-
break;
|
867 |
-
default:
|
868 |
-
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
869 |
-
break;
|
870 |
-
}
|
871 |
-
}
|
872 |
-
|
873 |
-
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
874 |
-
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
875 |
-
out, \
|
876 |
-
exp_sums, \
|
877 |
-
max_logits, \
|
878 |
-
tmp_out, \
|
879 |
-
query, \
|
880 |
-
key_cache, \
|
881 |
-
value_cache, \
|
882 |
-
num_kv_heads, \
|
883 |
-
scale, \
|
884 |
-
block_tables, \
|
885 |
-
context_lens, \
|
886 |
-
max_context_len, \
|
887 |
-
alibi_slopes);
|
888 |
-
|
889 |
-
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
890 |
-
// 1, 2, 4, 64, 128, 256.
|
891 |
-
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
892 |
-
switch (block_size) { \
|
893 |
-
case 8: \
|
894 |
-
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
895 |
-
break; \
|
896 |
-
case 16: \
|
897 |
-
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
898 |
-
break; \
|
899 |
-
case 32: \
|
900 |
-
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
901 |
-
break; \
|
902 |
-
default: \
|
903 |
-
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
904 |
-
break; \
|
905 |
-
}
|
906 |
-
|
907 |
-
void paged_attention_v2(
|
908 |
-
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
909 |
-
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
910 |
-
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
911 |
-
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
912 |
-
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
913 |
-
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
914 |
-
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
915 |
-
int num_kv_heads, // [num_heads]
|
916 |
-
float scale,
|
917 |
-
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
918 |
-
torch::Tensor& context_lens, // [num_seqs]
|
919 |
-
int block_size,
|
920 |
-
int max_context_len,
|
921 |
-
const c10::optional<torch::Tensor>& alibi_slopes,
|
922 |
-
const std::string& kv_cache_dtype) {
|
923 |
-
if (kv_cache_dtype == "auto") {
|
924 |
-
if (query.dtype() == at::ScalarType::Float) {
|
925 |
-
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
926 |
-
} else if (query.dtype() == at::ScalarType::Half) {
|
927 |
-
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
928 |
-
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
929 |
-
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
930 |
-
} else {
|
931 |
-
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
932 |
-
}
|
933 |
-
} else if (kv_cache_dtype == "fp8_e5m2") {
|
934 |
-
if (query.dtype() == at::ScalarType::Float) {
|
935 |
-
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
936 |
-
} else if (query.dtype() == at::ScalarType::Half) {
|
937 |
-
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
938 |
-
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
939 |
-
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
940 |
-
} else {
|
941 |
-
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
942 |
-
}
|
943 |
-
} else {
|
944 |
-
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
945 |
-
}
|
946 |
-
}
|
947 |
-
|
948 |
-
#undef WARP_SIZE
|
949 |
-
#undef MAX
|
950 |
-
#undef MIN
|
951 |
-
#undef DIVIDE_ROUND_UP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/attention_utils.cuh
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
-
* Copyright (c) 2023, The vLLM team.
|
4 |
-
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
5 |
-
*
|
6 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
-
* you may not use this file except in compliance with the License.
|
8 |
-
* You may obtain a copy of the License at
|
9 |
-
*
|
10 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
11 |
-
*
|
12 |
-
* Unless required by applicable law or agreed to in writing, software
|
13 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
-
* See the License for the specific language governing permissions and
|
16 |
-
* limitations under the License.
|
17 |
-
*/
|
18 |
-
#pragma once
|
19 |
-
|
20 |
-
#include "../cuda_compat.h"
|
21 |
-
#include "attention_dtypes.h"
|
22 |
-
|
23 |
-
#include <float.h>
|
24 |
-
#include <type_traits>
|
25 |
-
|
26 |
-
namespace vllm {
|
27 |
-
|
28 |
-
// Q*K^T operation.
|
29 |
-
template<int THREAD_GROUP_SIZE, typename Vec, int N>
|
30 |
-
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
31 |
-
using A_vec = typename FloatVec<Vec>::Type;
|
32 |
-
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
33 |
-
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
34 |
-
#pragma unroll
|
35 |
-
for (int ii = 1; ii < N; ++ii) {
|
36 |
-
qk_vec = fma(q[ii], k[ii], qk_vec);
|
37 |
-
}
|
38 |
-
|
39 |
-
// Finalize the reduction across lanes.
|
40 |
-
float qk = sum(qk_vec);
|
41 |
-
#pragma unroll
|
42 |
-
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
43 |
-
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
44 |
-
}
|
45 |
-
return qk;
|
46 |
-
}
|
47 |
-
|
48 |
-
template<typename T, int THREAD_GROUP_SIZE>
|
49 |
-
struct Qk_dot {
|
50 |
-
template<typename Vec, int N>
|
51 |
-
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
52 |
-
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
53 |
-
}
|
54 |
-
};
|
55 |
-
|
56 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/dtype_bfloat16.cuh
DELETED
@@ -1,451 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
-
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
-
* Copyright (c) 2023, The vLLM team.
|
5 |
-
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
-
*
|
7 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
-
* you may not use this file except in compliance with the License.
|
9 |
-
* You may obtain a copy of the License at
|
10 |
-
*
|
11 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
-
*
|
13 |
-
* Unless required by applicable law or agreed to in writing, software
|
14 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
-
* See the License for the specific language governing permissions and
|
17 |
-
* limitations under the License.
|
18 |
-
*/
|
19 |
-
#pragma once
|
20 |
-
|
21 |
-
#include "attention_generic.cuh"
|
22 |
-
#include "dtype_float32.cuh"
|
23 |
-
|
24 |
-
#ifndef USE_ROCM
|
25 |
-
#include <cuda_bf16.h>
|
26 |
-
#include <cuda_fp16.h>
|
27 |
-
#else
|
28 |
-
#include <hip/hip_bf16.h>
|
29 |
-
#include <hip/hip_fp16.h>
|
30 |
-
|
31 |
-
typedef __hip_bfloat162 __nv_bfloat162;
|
32 |
-
typedef __hip_bfloat16 __nv_bfloat16;
|
33 |
-
#endif
|
34 |
-
|
35 |
-
#include <stdint.h>
|
36 |
-
|
37 |
-
namespace vllm {
|
38 |
-
|
39 |
-
// Define custom BF16 vector data types.
|
40 |
-
struct bf16_4_t {
|
41 |
-
__nv_bfloat162 x;
|
42 |
-
__nv_bfloat162 y;
|
43 |
-
};
|
44 |
-
|
45 |
-
struct bf16_8_t {
|
46 |
-
__nv_bfloat162 x;
|
47 |
-
__nv_bfloat162 y;
|
48 |
-
__nv_bfloat162 z;
|
49 |
-
__nv_bfloat162 w;
|
50 |
-
};
|
51 |
-
|
52 |
-
// BF16 vector types for Q, K, V.
|
53 |
-
template<>
|
54 |
-
struct Vec<__nv_bfloat16, 1> {
|
55 |
-
using Type = __nv_bfloat16;
|
56 |
-
};
|
57 |
-
template<>
|
58 |
-
struct Vec<__nv_bfloat16, 2> {
|
59 |
-
using Type = __nv_bfloat162;
|
60 |
-
};
|
61 |
-
template<>
|
62 |
-
struct Vec<__nv_bfloat16, 4> {
|
63 |
-
using Type = bf16_4_t;
|
64 |
-
};
|
65 |
-
template<>
|
66 |
-
struct Vec<__nv_bfloat16, 8> {
|
67 |
-
using Type = bf16_8_t;
|
68 |
-
};
|
69 |
-
|
70 |
-
// FP32 accumulator vector types corresponding to Vec.
|
71 |
-
template<>
|
72 |
-
struct FloatVec<__nv_bfloat16> {
|
73 |
-
using Type = float;
|
74 |
-
};
|
75 |
-
template<>
|
76 |
-
struct FloatVec<__nv_bfloat162> {
|
77 |
-
using Type = float2;
|
78 |
-
};
|
79 |
-
template<>
|
80 |
-
struct FloatVec<bf16_4_t> {
|
81 |
-
using Type = Float4_;
|
82 |
-
};
|
83 |
-
template<>
|
84 |
-
struct FloatVec<bf16_8_t> {
|
85 |
-
using Type = Float8_;
|
86 |
-
};
|
87 |
-
|
88 |
-
// Utility functions for type conversions.
|
89 |
-
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
90 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
91 |
-
assert(false);
|
92 |
-
#else
|
93 |
-
return __bfloat1622float2(val);
|
94 |
-
#endif
|
95 |
-
}
|
96 |
-
|
97 |
-
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
98 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
99 |
-
assert(false);
|
100 |
-
#else
|
101 |
-
return __bfloat162bfloat162(val);
|
102 |
-
#endif
|
103 |
-
}
|
104 |
-
|
105 |
-
// Vector addition.
|
106 |
-
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
107 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
108 |
-
assert(false);
|
109 |
-
#else
|
110 |
-
#ifndef USE_ROCM
|
111 |
-
return a + b;
|
112 |
-
#else
|
113 |
-
return __hadd(a, b);
|
114 |
-
#endif
|
115 |
-
#endif
|
116 |
-
}
|
117 |
-
|
118 |
-
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
119 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
120 |
-
assert(false);
|
121 |
-
#else
|
122 |
-
return __hadd2(a, b);
|
123 |
-
#endif
|
124 |
-
}
|
125 |
-
|
126 |
-
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
127 |
-
bf16_4_t c;
|
128 |
-
c.x = add(a.x, b.x);
|
129 |
-
c.y = add(a.y, b.y);
|
130 |
-
return c;
|
131 |
-
}
|
132 |
-
|
133 |
-
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
134 |
-
bf16_8_t c;
|
135 |
-
c.x = add(a.x, b.x);
|
136 |
-
c.y = add(a.y, b.y);
|
137 |
-
c.z = add(a.z, b.z);
|
138 |
-
c.w = add(a.w, b.w);
|
139 |
-
return c;
|
140 |
-
}
|
141 |
-
|
142 |
-
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
|
143 |
-
float2 fa = bf1622float2(a);
|
144 |
-
return add(fa, fb);
|
145 |
-
}
|
146 |
-
|
147 |
-
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
148 |
-
Float4_ fc;
|
149 |
-
fc.x = add(a.x, fb.x);
|
150 |
-
fc.y = add(a.y, fb.y);
|
151 |
-
return fc;
|
152 |
-
}
|
153 |
-
|
154 |
-
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
155 |
-
Float8_ fc;
|
156 |
-
fc.x = add(a.x, fb.x);
|
157 |
-
fc.y = add(a.y, fb.y);
|
158 |
-
fc.z = add(a.z, fb.z);
|
159 |
-
fc.w = add(a.w, fb.w);
|
160 |
-
return fc;
|
161 |
-
}
|
162 |
-
|
163 |
-
// Vector multiplication.
|
164 |
-
template<>
|
165 |
-
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
166 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
167 |
-
assert(false);
|
168 |
-
#else
|
169 |
-
return __hmul(a, b);
|
170 |
-
#endif
|
171 |
-
}
|
172 |
-
|
173 |
-
template<>
|
174 |
-
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
175 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
176 |
-
assert(false);
|
177 |
-
#else
|
178 |
-
return __hmul2(a, b);
|
179 |
-
#endif
|
180 |
-
}
|
181 |
-
|
182 |
-
template<>
|
183 |
-
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
184 |
-
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
185 |
-
}
|
186 |
-
|
187 |
-
template<>
|
188 |
-
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
189 |
-
bf16_4_t c;
|
190 |
-
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
191 |
-
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
192 |
-
return c;
|
193 |
-
}
|
194 |
-
|
195 |
-
template<>
|
196 |
-
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
197 |
-
__nv_bfloat162 s = bf162bf162(a);
|
198 |
-
bf16_4_t c;
|
199 |
-
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
200 |
-
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
201 |
-
return c;
|
202 |
-
}
|
203 |
-
|
204 |
-
template<>
|
205 |
-
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
206 |
-
bf16_8_t c;
|
207 |
-
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
208 |
-
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
209 |
-
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
210 |
-
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
211 |
-
return c;
|
212 |
-
}
|
213 |
-
|
214 |
-
template<>
|
215 |
-
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
216 |
-
__nv_bfloat162 s = bf162bf162(a);
|
217 |
-
bf16_8_t c;
|
218 |
-
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
219 |
-
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
220 |
-
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
221 |
-
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
222 |
-
return c;
|
223 |
-
}
|
224 |
-
|
225 |
-
template<>
|
226 |
-
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
227 |
-
float fa = __bfloat162float(a);
|
228 |
-
float fb = __bfloat162float(b);
|
229 |
-
return fa * fb;
|
230 |
-
}
|
231 |
-
|
232 |
-
template<>
|
233 |
-
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
234 |
-
float2 fa = bf1622float2(a);
|
235 |
-
float2 fb = bf1622float2(b);
|
236 |
-
return mul<float2, float2, float2>(fa, fb);
|
237 |
-
}
|
238 |
-
|
239 |
-
template<>
|
240 |
-
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
241 |
-
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
242 |
-
}
|
243 |
-
|
244 |
-
template<>
|
245 |
-
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
246 |
-
Float4_ fc;
|
247 |
-
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
248 |
-
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
249 |
-
return fc;
|
250 |
-
}
|
251 |
-
|
252 |
-
template<>
|
253 |
-
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
254 |
-
__nv_bfloat162 s = bf162bf162(a);
|
255 |
-
Float4_ fc;
|
256 |
-
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
257 |
-
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
258 |
-
return fc;
|
259 |
-
}
|
260 |
-
|
261 |
-
template<>
|
262 |
-
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
263 |
-
Float8_ fc;
|
264 |
-
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
265 |
-
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
266 |
-
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
267 |
-
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
268 |
-
return fc;
|
269 |
-
}
|
270 |
-
|
271 |
-
template<>
|
272 |
-
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
273 |
-
__nv_bfloat162 s = bf162bf162(a);
|
274 |
-
Float8_ fc;
|
275 |
-
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
276 |
-
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
277 |
-
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
278 |
-
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
279 |
-
return fc;
|
280 |
-
}
|
281 |
-
|
282 |
-
// Vector fused multiply-add.
|
283 |
-
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
284 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
285 |
-
assert(false);
|
286 |
-
#else
|
287 |
-
return __hfma2(a, b, c);
|
288 |
-
#endif
|
289 |
-
}
|
290 |
-
|
291 |
-
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
292 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
293 |
-
assert(false);
|
294 |
-
#else
|
295 |
-
return __hfma2(bf162bf162(a), b, c);
|
296 |
-
#endif
|
297 |
-
}
|
298 |
-
|
299 |
-
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
300 |
-
bf16_4_t d;
|
301 |
-
d.x = fma(a.x, b.x, c.x);
|
302 |
-
d.y = fma(a.y, b.y, c.y);
|
303 |
-
return d;
|
304 |
-
}
|
305 |
-
|
306 |
-
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
307 |
-
__nv_bfloat162 s = bf162bf162(a);
|
308 |
-
bf16_4_t d;
|
309 |
-
d.x = fma(s, b.x, c.x);
|
310 |
-
d.y = fma(s, b.y, c.y);
|
311 |
-
return d;
|
312 |
-
}
|
313 |
-
|
314 |
-
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
315 |
-
bf16_8_t d;
|
316 |
-
d.x = fma(a.x, b.x, c.x);
|
317 |
-
d.y = fma(a.y, b.y, c.y);
|
318 |
-
d.z = fma(a.z, b.z, c.z);
|
319 |
-
d.w = fma(a.w, b.w, c.w);
|
320 |
-
return d;
|
321 |
-
}
|
322 |
-
|
323 |
-
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
324 |
-
__nv_bfloat162 s = bf162bf162(a);
|
325 |
-
bf16_8_t d;
|
326 |
-
d.x = fma(s, b.x, c.x);
|
327 |
-
d.y = fma(s, b.y, c.y);
|
328 |
-
d.z = fma(s, b.z, c.z);
|
329 |
-
d.w = fma(s, b.w, c.w);
|
330 |
-
return d;
|
331 |
-
}
|
332 |
-
|
333 |
-
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
|
334 |
-
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
335 |
-
}
|
336 |
-
|
337 |
-
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
|
338 |
-
float2 fa = bf1622float2(a);
|
339 |
-
float2 fb = bf1622float2(b);
|
340 |
-
return fma(fa, fb, fc);
|
341 |
-
}
|
342 |
-
|
343 |
-
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
|
344 |
-
return fma(bf162bf162(a), b, fc);
|
345 |
-
}
|
346 |
-
|
347 |
-
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
348 |
-
Float4_ fd;
|
349 |
-
fd.x = fma(a.x, b.x, fc.x);
|
350 |
-
fd.y = fma(a.y, b.y, fc.y);
|
351 |
-
return fd;
|
352 |
-
}
|
353 |
-
|
354 |
-
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
355 |
-
__nv_bfloat162 s = bf162bf162(a);
|
356 |
-
Float4_ fd;
|
357 |
-
fd.x = fma(s, b.x, fc.x);
|
358 |
-
fd.y = fma(s, b.y, fc.y);
|
359 |
-
return fd;
|
360 |
-
}
|
361 |
-
|
362 |
-
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
363 |
-
Float8_ fd;
|
364 |
-
fd.x = fma(a.x, b.x, fc.x);
|
365 |
-
fd.y = fma(a.y, b.y, fc.y);
|
366 |
-
fd.z = fma(a.z, b.z, fc.z);
|
367 |
-
fd.w = fma(a.w, b.w, fc.w);
|
368 |
-
return fd;
|
369 |
-
}
|
370 |
-
|
371 |
-
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
372 |
-
__nv_bfloat162 s = bf162bf162(a);
|
373 |
-
Float8_ fd;
|
374 |
-
fd.x = fma(s, b.x, fc.x);
|
375 |
-
fd.y = fma(s, b.y, fc.y);
|
376 |
-
fd.z = fma(s, b.z, fc.z);
|
377 |
-
fd.w = fma(s, b.w, fc.w);
|
378 |
-
return fd;
|
379 |
-
}
|
380 |
-
|
381 |
-
// Vector sum.
|
382 |
-
template<>
|
383 |
-
inline __device__ float sum(__nv_bfloat16 v) {
|
384 |
-
return __bfloat162float(v);
|
385 |
-
}
|
386 |
-
|
387 |
-
template<>
|
388 |
-
inline __device__ float sum(__nv_bfloat162 v) {
|
389 |
-
float2 vf = bf1622float2(v);
|
390 |
-
return vf.x + vf.y;
|
391 |
-
}
|
392 |
-
|
393 |
-
template<>
|
394 |
-
inline __device__ float sum(bf16_4_t v) {
|
395 |
-
return sum(v.x) + sum(v.y);
|
396 |
-
}
|
397 |
-
|
398 |
-
template<>
|
399 |
-
inline __device__ float sum(bf16_8_t v) {
|
400 |
-
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
401 |
-
}
|
402 |
-
|
403 |
-
// From float32 to bfloat16.
|
404 |
-
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
405 |
-
dst = __float2bfloat16(src);
|
406 |
-
}
|
407 |
-
|
408 |
-
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
409 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
410 |
-
assert(false);
|
411 |
-
#else
|
412 |
-
dst = __float22bfloat162_rn(src);
|
413 |
-
#endif
|
414 |
-
}
|
415 |
-
|
416 |
-
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
417 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
418 |
-
assert(false);
|
419 |
-
#else
|
420 |
-
dst.x = __float22bfloat162_rn(src.x);
|
421 |
-
dst.y = __float22bfloat162_rn(src.y);
|
422 |
-
#endif
|
423 |
-
}
|
424 |
-
|
425 |
-
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
426 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
427 |
-
assert(false);
|
428 |
-
#else
|
429 |
-
dst.x = __float22bfloat162_rn(src.x);
|
430 |
-
dst.y = __float22bfloat162_rn(src.y);
|
431 |
-
dst.z = __float22bfloat162_rn(src.z);
|
432 |
-
dst.w = __float22bfloat162_rn(src.w);
|
433 |
-
#endif
|
434 |
-
}
|
435 |
-
|
436 |
-
// From bfloat16 to float32.
|
437 |
-
inline __device__ float to_float(__nv_bfloat16 u) {
|
438 |
-
return __bfloat162float(u);
|
439 |
-
}
|
440 |
-
|
441 |
-
// Zero-out a variable.
|
442 |
-
inline __device__ void zero(__nv_bfloat16& dst) {
|
443 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
444 |
-
assert(false);
|
445 |
-
#else
|
446 |
-
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
447 |
-
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
448 |
-
#endif
|
449 |
-
}
|
450 |
-
|
451 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/dtype_float16.cuh
DELETED
@@ -1,502 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
-
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
-
* Copyright (c) 2023, The vLLM team.
|
5 |
-
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
-
*
|
7 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
-
* you may not use this file except in compliance with the License.
|
9 |
-
* You may obtain a copy of the License at
|
10 |
-
*
|
11 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
-
*
|
13 |
-
* Unless required by applicable law or agreed to in writing, software
|
14 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
-
* See the License for the specific language governing permissions and
|
17 |
-
* limitations under the License.
|
18 |
-
*/
|
19 |
-
#pragma once
|
20 |
-
|
21 |
-
#include "attention_generic.cuh"
|
22 |
-
#include "dtype_float32.cuh"
|
23 |
-
|
24 |
-
#ifdef USE_ROCM
|
25 |
-
#include <hip/hip_fp16.h>
|
26 |
-
#endif
|
27 |
-
|
28 |
-
#include <stdint.h>
|
29 |
-
|
30 |
-
namespace vllm {
|
31 |
-
|
32 |
-
// FP16 vector types for Q, K, V.
|
33 |
-
template<>
|
34 |
-
struct Vec<uint16_t, 1> {
|
35 |
-
using Type = uint16_t;
|
36 |
-
};
|
37 |
-
template<>
|
38 |
-
struct Vec<uint16_t, 2> {
|
39 |
-
using Type = uint32_t;
|
40 |
-
};
|
41 |
-
template<>
|
42 |
-
struct Vec<uint16_t, 4> {
|
43 |
-
using Type = uint2;
|
44 |
-
};
|
45 |
-
template<>
|
46 |
-
struct Vec<uint16_t, 8> {
|
47 |
-
using Type = uint4;
|
48 |
-
};
|
49 |
-
|
50 |
-
// FP32 accumulator vector types corresponding to Vec.
|
51 |
-
template<>
|
52 |
-
struct FloatVec<uint16_t> {
|
53 |
-
using Type = float;
|
54 |
-
};
|
55 |
-
template<>
|
56 |
-
struct FloatVec<uint32_t> {
|
57 |
-
using Type = float2;
|
58 |
-
};
|
59 |
-
template<>
|
60 |
-
struct FloatVec<uint2> {
|
61 |
-
using Type = Float4_;
|
62 |
-
};
|
63 |
-
template<>
|
64 |
-
struct FloatVec<uint4> {
|
65 |
-
using Type = Float8_;
|
66 |
-
};
|
67 |
-
|
68 |
-
// Utility functions for type conversions.
|
69 |
-
inline __device__ uint32_t h0_h0(uint16_t a) {
|
70 |
-
#ifndef USE_ROCM
|
71 |
-
uint32_t b;
|
72 |
-
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
73 |
-
return b;
|
74 |
-
#else
|
75 |
-
union {
|
76 |
-
uint32_t u32;
|
77 |
-
uint16_t u16[2];
|
78 |
-
} tmp;
|
79 |
-
tmp.u16[0] = a;
|
80 |
-
tmp.u16[1] = a;
|
81 |
-
return tmp.u32;
|
82 |
-
#endif
|
83 |
-
}
|
84 |
-
|
85 |
-
inline __device__ float half_to_float(uint16_t h) {
|
86 |
-
float f;
|
87 |
-
#ifndef USE_ROCM
|
88 |
-
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
89 |
-
#else
|
90 |
-
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
91 |
-
#endif
|
92 |
-
return f;
|
93 |
-
}
|
94 |
-
|
95 |
-
inline __device__ float2 half2_to_float2(uint32_t v) {
|
96 |
-
#ifndef USE_ROCM
|
97 |
-
uint16_t lo, hi;
|
98 |
-
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
99 |
-
return make_float2(half_to_float(lo), half_to_float(hi));
|
100 |
-
#else
|
101 |
-
union {
|
102 |
-
uint32_t u32;
|
103 |
-
uint16_t u16[2];
|
104 |
-
} tmp;
|
105 |
-
tmp.u32 = v;
|
106 |
-
float2 ret;
|
107 |
-
ret.x = half_to_float(tmp.u16[0]);
|
108 |
-
ret.y = half_to_float(tmp.u16[1]);
|
109 |
-
return ret;
|
110 |
-
#endif
|
111 |
-
}
|
112 |
-
|
113 |
-
inline __device__ uint16_t float_to_half(float f) {
|
114 |
-
union {
|
115 |
-
uint32_t u32;
|
116 |
-
uint16_t u16[2];
|
117 |
-
} tmp;
|
118 |
-
#ifndef USE_ROCM
|
119 |
-
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
120 |
-
#else
|
121 |
-
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
122 |
-
#endif
|
123 |
-
return tmp.u16[0];
|
124 |
-
}
|
125 |
-
|
126 |
-
inline __device__ uint32_t float2_to_half2(float2 f) {
|
127 |
-
union {
|
128 |
-
uint32_t u32;
|
129 |
-
uint16_t u16[2];
|
130 |
-
} tmp;
|
131 |
-
#ifndef USE_ROCM
|
132 |
-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
133 |
-
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
134 |
-
#else
|
135 |
-
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
136 |
-
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
137 |
-
#endif
|
138 |
-
#else
|
139 |
-
tmp.u16[0] = float_to_half(f.x);
|
140 |
-
tmp.u16[1] = float_to_half(f.y);
|
141 |
-
#endif
|
142 |
-
return tmp.u32;
|
143 |
-
}
|
144 |
-
|
145 |
-
// Vector addition.
|
146 |
-
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
147 |
-
uint16_t c;
|
148 |
-
#ifndef USE_ROCM
|
149 |
-
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
150 |
-
#else
|
151 |
-
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
152 |
-
#endif
|
153 |
-
return c;
|
154 |
-
}
|
155 |
-
|
156 |
-
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
157 |
-
uint32_t c;
|
158 |
-
#ifndef USE_ROCM
|
159 |
-
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
160 |
-
#else
|
161 |
-
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
162 |
-
#endif
|
163 |
-
return c;
|
164 |
-
}
|
165 |
-
|
166 |
-
inline __device__ uint2 add(uint2 a, uint2 b) {
|
167 |
-
uint2 c;
|
168 |
-
c.x = add(a.x, b.x);
|
169 |
-
c.y = add(a.y, b.y);
|
170 |
-
return c;
|
171 |
-
}
|
172 |
-
|
173 |
-
inline __device__ uint4 add(uint4 a, uint4 b) {
|
174 |
-
uint4 c;
|
175 |
-
c.x = add(a.x, b.x);
|
176 |
-
c.y = add(a.y, b.y);
|
177 |
-
c.z = add(a.z, b.z);
|
178 |
-
c.w = add(a.w, b.w);
|
179 |
-
return c;
|
180 |
-
}
|
181 |
-
|
182 |
-
inline __device__ float2 add(uint32_t a, float2 fb) {
|
183 |
-
float2 fa = half2_to_float2(a);
|
184 |
-
return add(fa, fb);
|
185 |
-
}
|
186 |
-
|
187 |
-
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
|
188 |
-
Float4_ fc;
|
189 |
-
fc.x = add(a.x, fb.x);
|
190 |
-
fc.y = add(a.y, fb.y);
|
191 |
-
return fc;
|
192 |
-
}
|
193 |
-
|
194 |
-
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
195 |
-
Float8_ fc;
|
196 |
-
fc.x = add(a.x, fb.x);
|
197 |
-
fc.y = add(a.y, fb.y);
|
198 |
-
fc.z = add(a.z, fb.z);
|
199 |
-
fc.w = add(a.w, fb.w);
|
200 |
-
return fc;
|
201 |
-
}
|
202 |
-
|
203 |
-
// Vector multiplication.
|
204 |
-
template<>
|
205 |
-
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
206 |
-
uint16_t c;
|
207 |
-
#ifndef USE_ROCM
|
208 |
-
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
209 |
-
#else
|
210 |
-
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
211 |
-
#endif
|
212 |
-
return c;
|
213 |
-
}
|
214 |
-
|
215 |
-
template<>
|
216 |
-
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
217 |
-
uint32_t c;
|
218 |
-
#ifndef USE_ROCM
|
219 |
-
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
220 |
-
#else
|
221 |
-
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
222 |
-
#endif
|
223 |
-
return c;
|
224 |
-
}
|
225 |
-
|
226 |
-
template<>
|
227 |
-
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
228 |
-
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
229 |
-
}
|
230 |
-
|
231 |
-
template<>
|
232 |
-
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
233 |
-
uint2 c;
|
234 |
-
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
235 |
-
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
236 |
-
return c;
|
237 |
-
}
|
238 |
-
|
239 |
-
template<>
|
240 |
-
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
241 |
-
uint32_t s = h0_h0(a);
|
242 |
-
uint2 c;
|
243 |
-
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
244 |
-
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
245 |
-
return c;
|
246 |
-
}
|
247 |
-
|
248 |
-
template<>
|
249 |
-
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
250 |
-
uint4 c;
|
251 |
-
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
252 |
-
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
253 |
-
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
254 |
-
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
255 |
-
return c;
|
256 |
-
}
|
257 |
-
|
258 |
-
template<>
|
259 |
-
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
260 |
-
uint32_t s = h0_h0(a);
|
261 |
-
uint4 c;
|
262 |
-
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
263 |
-
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
264 |
-
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
265 |
-
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
266 |
-
return c;
|
267 |
-
}
|
268 |
-
|
269 |
-
template<>
|
270 |
-
inline __device__ float mul(uint16_t a, uint16_t b) {
|
271 |
-
float fa = half_to_float(a);
|
272 |
-
float fb = half_to_float(b);
|
273 |
-
return fa * fb;
|
274 |
-
}
|
275 |
-
|
276 |
-
template<>
|
277 |
-
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
278 |
-
float2 fa = half2_to_float2(a);
|
279 |
-
float2 fb = half2_to_float2(b);
|
280 |
-
return mul<float2, float2, float2>(fa, fb);
|
281 |
-
}
|
282 |
-
|
283 |
-
template<>
|
284 |
-
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
285 |
-
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
286 |
-
}
|
287 |
-
|
288 |
-
template<>
|
289 |
-
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
290 |
-
Float4_ fc;
|
291 |
-
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
292 |
-
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
293 |
-
return fc;
|
294 |
-
}
|
295 |
-
|
296 |
-
template<>
|
297 |
-
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
298 |
-
uint32_t s = h0_h0(a);
|
299 |
-
Float4_ fc;
|
300 |
-
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
301 |
-
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
302 |
-
return fc;
|
303 |
-
}
|
304 |
-
|
305 |
-
template<>
|
306 |
-
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
307 |
-
Float8_ fc;
|
308 |
-
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
309 |
-
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
310 |
-
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
311 |
-
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
312 |
-
return fc;
|
313 |
-
}
|
314 |
-
|
315 |
-
template<>
|
316 |
-
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
317 |
-
uint32_t s = h0_h0(a);
|
318 |
-
Float8_ fc;
|
319 |
-
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
320 |
-
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
321 |
-
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
322 |
-
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
323 |
-
return fc;
|
324 |
-
}
|
325 |
-
|
326 |
-
// Vector fused multiply-add.
|
327 |
-
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
328 |
-
uint32_t d;
|
329 |
-
#ifndef USE_ROCM
|
330 |
-
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
331 |
-
#else
|
332 |
-
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
333 |
-
#endif
|
334 |
-
return d;
|
335 |
-
}
|
336 |
-
|
337 |
-
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
|
338 |
-
return fma(h0_h0(a), b, c);
|
339 |
-
}
|
340 |
-
|
341 |
-
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
|
342 |
-
uint2 d;
|
343 |
-
d.x = fma(a.x, b.x, c.x);
|
344 |
-
d.y = fma(a.y, b.y, c.y);
|
345 |
-
return d;
|
346 |
-
}
|
347 |
-
|
348 |
-
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
|
349 |
-
uint32_t s = h0_h0(a);
|
350 |
-
uint2 d;
|
351 |
-
d.x = fma(s, b.x, c.x);
|
352 |
-
d.y = fma(s, b.y, c.y);
|
353 |
-
return d;
|
354 |
-
}
|
355 |
-
|
356 |
-
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
|
357 |
-
uint4 d;
|
358 |
-
d.x = fma(a.x, b.x, c.x);
|
359 |
-
d.y = fma(a.y, b.y, c.y);
|
360 |
-
d.z = fma(a.z, b.z, c.z);
|
361 |
-
d.w = fma(a.w, b.w, c.w);
|
362 |
-
return d;
|
363 |
-
}
|
364 |
-
|
365 |
-
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
|
366 |
-
uint32_t s = h0_h0(a);
|
367 |
-
uint4 d;
|
368 |
-
d.x = fma(s, b.x, c.x);
|
369 |
-
d.y = fma(s, b.y, c.y);
|
370 |
-
d.z = fma(s, b.z, c.z);
|
371 |
-
d.w = fma(s, b.w, c.w);
|
372 |
-
return d;
|
373 |
-
}
|
374 |
-
|
375 |
-
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
|
376 |
-
float fa = half_to_float(a);
|
377 |
-
float fb = half_to_float(b);
|
378 |
-
return fa * fb + fc;
|
379 |
-
}
|
380 |
-
|
381 |
-
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
|
382 |
-
float2 fa = half2_to_float2(a);
|
383 |
-
float2 fb = half2_to_float2(b);
|
384 |
-
return fma(fa, fb, fc);
|
385 |
-
}
|
386 |
-
|
387 |
-
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
|
388 |
-
return fma(h0_h0(a), b, fc);
|
389 |
-
}
|
390 |
-
|
391 |
-
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
|
392 |
-
Float4_ fd;
|
393 |
-
fd.x = fma(a.x, b.x, fc.x);
|
394 |
-
fd.y = fma(a.y, b.y, fc.y);
|
395 |
-
return fd;
|
396 |
-
}
|
397 |
-
|
398 |
-
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
|
399 |
-
uint32_t s = h0_h0(a);
|
400 |
-
Float4_ fd;
|
401 |
-
fd.x = fma(s, b.x, fc.x);
|
402 |
-
fd.y = fma(s, b.y, fc.y);
|
403 |
-
return fd;
|
404 |
-
}
|
405 |
-
|
406 |
-
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
|
407 |
-
Float8_ fd;
|
408 |
-
fd.x = fma(a.x, b.x, fc.x);
|
409 |
-
fd.y = fma(a.y, b.y, fc.y);
|
410 |
-
fd.z = fma(a.z, b.z, fc.z);
|
411 |
-
fd.w = fma(a.w, b.w, fc.w);
|
412 |
-
return fd;
|
413 |
-
}
|
414 |
-
|
415 |
-
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
416 |
-
uint32_t s = h0_h0(a);
|
417 |
-
Float8_ fd;
|
418 |
-
fd.x = fma(s, b.x, fc.x);
|
419 |
-
fd.y = fma(s, b.y, fc.y);
|
420 |
-
fd.z = fma(s, b.z, fc.z);
|
421 |
-
fd.w = fma(s, b.w, fc.w);
|
422 |
-
return fd;
|
423 |
-
}
|
424 |
-
|
425 |
-
// Vector sum.
|
426 |
-
template<>
|
427 |
-
inline __device__ float sum(uint16_t v) {
|
428 |
-
return half_to_float(v);
|
429 |
-
}
|
430 |
-
|
431 |
-
template<>
|
432 |
-
inline __device__ float sum(uint32_t v) {
|
433 |
-
float2 tmp = half2_to_float2(v);
|
434 |
-
return tmp.x + tmp.y;
|
435 |
-
}
|
436 |
-
|
437 |
-
template<>
|
438 |
-
inline __device__ float sum(uint2 v) {
|
439 |
-
uint32_t c = add(v.x, v.y);
|
440 |
-
return sum(c);
|
441 |
-
}
|
442 |
-
|
443 |
-
template<>
|
444 |
-
inline __device__ float sum(uint4 v) {
|
445 |
-
uint32_t c = add(v.x, v.y);
|
446 |
-
c = add(c, v.z);
|
447 |
-
c = add(c, v.w);
|
448 |
-
return sum(c);
|
449 |
-
}
|
450 |
-
|
451 |
-
// From float32 to float16.
|
452 |
-
inline __device__ void from_float(uint16_t& dst, float src) {
|
453 |
-
dst = float_to_half(src);
|
454 |
-
}
|
455 |
-
|
456 |
-
inline __device__ void from_float(uint32_t& dst, float2 src) {
|
457 |
-
dst = float2_to_half2(src);
|
458 |
-
}
|
459 |
-
|
460 |
-
inline __device__ void from_float(uint2& dst, Float4_ src) {
|
461 |
-
dst.x = float2_to_half2(src.x);
|
462 |
-
dst.y = float2_to_half2(src.y);
|
463 |
-
}
|
464 |
-
|
465 |
-
inline __device__ void from_float(uint4& dst, Float8_ src) {
|
466 |
-
dst.x = float2_to_half2(src.x);
|
467 |
-
dst.y = float2_to_half2(src.y);
|
468 |
-
dst.z = float2_to_half2(src.z);
|
469 |
-
dst.w = float2_to_half2(src.w);
|
470 |
-
}
|
471 |
-
|
472 |
-
// From float16 to float32.
|
473 |
-
inline __device__ float to_float(uint16_t u) {
|
474 |
-
return half_to_float(u);
|
475 |
-
}
|
476 |
-
|
477 |
-
inline __device__ float2 to_float(uint32_t u) {
|
478 |
-
return half2_to_float2(u);
|
479 |
-
}
|
480 |
-
|
481 |
-
inline __device__ Float4_ to_float(uint2 u) {
|
482 |
-
Float4_ tmp;
|
483 |
-
tmp.x = half2_to_float2(u.x);
|
484 |
-
tmp.y = half2_to_float2(u.y);
|
485 |
-
return tmp;
|
486 |
-
}
|
487 |
-
|
488 |
-
inline __device__ Float8_ to_float(uint4 u) {
|
489 |
-
Float8_ tmp;
|
490 |
-
tmp.x = half2_to_float2(u.x);
|
491 |
-
tmp.y = half2_to_float2(u.y);
|
492 |
-
tmp.z = half2_to_float2(u.z);
|
493 |
-
tmp.w = half2_to_float2(u.w);
|
494 |
-
return tmp;
|
495 |
-
}
|
496 |
-
|
497 |
-
// Zero-out a variable.
|
498 |
-
inline __device__ void zero(uint16_t& dst) {
|
499 |
-
dst = uint16_t(0);
|
500 |
-
}
|
501 |
-
|
502 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/dtype_float32.cuh
DELETED
@@ -1,273 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
3 |
-
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
4 |
-
* Copyright (c) 2023, The vLLM team.
|
5 |
-
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
6 |
-
*
|
7 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
-
* you may not use this file except in compliance with the License.
|
9 |
-
* You may obtain a copy of the License at
|
10 |
-
*
|
11 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
12 |
-
*
|
13 |
-
* Unless required by applicable law or agreed to in writing, software
|
14 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
-
* See the License for the specific language governing permissions and
|
17 |
-
* limitations under the License.
|
18 |
-
*/
|
19 |
-
#pragma once
|
20 |
-
|
21 |
-
#include "attention_generic.cuh"
|
22 |
-
|
23 |
-
#include <stdint.h>
|
24 |
-
|
25 |
-
namespace vllm {
|
26 |
-
|
27 |
-
// Define custom FP32 vector data types.
|
28 |
-
struct Float4_ {
|
29 |
-
float2 x;
|
30 |
-
float2 y;
|
31 |
-
};
|
32 |
-
|
33 |
-
struct Float8_ {
|
34 |
-
float2 x;
|
35 |
-
float2 y;
|
36 |
-
float2 z;
|
37 |
-
float2 w;
|
38 |
-
};
|
39 |
-
|
40 |
-
// FP32 vector types for Q, K, V.
|
41 |
-
template<>
|
42 |
-
struct Vec<float, 1> {
|
43 |
-
using Type = float;
|
44 |
-
};
|
45 |
-
template<>
|
46 |
-
struct Vec<float, 2> {
|
47 |
-
using Type = float2;
|
48 |
-
};
|
49 |
-
template<>
|
50 |
-
struct Vec<float, 4> {
|
51 |
-
using Type = float4;
|
52 |
-
};
|
53 |
-
|
54 |
-
// FP32 accumulator vector types corresponding to Vec.
|
55 |
-
template<>
|
56 |
-
struct FloatVec<float> {
|
57 |
-
using Type = float;
|
58 |
-
};
|
59 |
-
template<>
|
60 |
-
struct FloatVec<float2> {
|
61 |
-
using Type = float2;
|
62 |
-
};
|
63 |
-
template<>
|
64 |
-
struct FloatVec<float4> {
|
65 |
-
using Type = float4;
|
66 |
-
};
|
67 |
-
|
68 |
-
// Vector addition.
|
69 |
-
inline __device__ float add(float a, float b) {
|
70 |
-
return a + b;
|
71 |
-
}
|
72 |
-
|
73 |
-
inline __device__ float2 add(float2 a, float2 b) {
|
74 |
-
float2 c;
|
75 |
-
c.x = add(a.x, b.x);
|
76 |
-
c.y = add(a.y, b.y);
|
77 |
-
return c;
|
78 |
-
}
|
79 |
-
|
80 |
-
inline __device__ float4 add(float4 a, float4 b) {
|
81 |
-
float4 c;
|
82 |
-
c.x = add(a.x, b.x);
|
83 |
-
c.y = add(a.y, b.y);
|
84 |
-
c.z = add(a.z, b.z);
|
85 |
-
c.w = add(a.w, b.w);
|
86 |
-
return c;
|
87 |
-
}
|
88 |
-
|
89 |
-
// Vector multiplication.
|
90 |
-
template<>
|
91 |
-
inline __device__ float mul<float, float>(float a, float b) {
|
92 |
-
return a * b;
|
93 |
-
}
|
94 |
-
|
95 |
-
template<>
|
96 |
-
inline __device__ float2 mul(float2 a, float2 b) {
|
97 |
-
float2 c;
|
98 |
-
c.x = a.x * b.x;
|
99 |
-
c.y = a.y * b.y;
|
100 |
-
return c;
|
101 |
-
}
|
102 |
-
|
103 |
-
template<>
|
104 |
-
inline __device__ float2 mul(float a, float2 b) {
|
105 |
-
float2 c;
|
106 |
-
c.x = a * b.x;
|
107 |
-
c.y = a * b.y;
|
108 |
-
return c;
|
109 |
-
}
|
110 |
-
|
111 |
-
template<>
|
112 |
-
inline __device__ float4 mul(float4 a, float4 b) {
|
113 |
-
float4 c;
|
114 |
-
c.x = a.x * b.x;
|
115 |
-
c.y = a.y * b.y;
|
116 |
-
c.z = a.z * b.z;
|
117 |
-
c.w = a.w * b.w;
|
118 |
-
return c;
|
119 |
-
}
|
120 |
-
|
121 |
-
template<>
|
122 |
-
inline __device__ float4 mul(float a, float4 b) {
|
123 |
-
float4 c;
|
124 |
-
c.x = a * b.x;
|
125 |
-
c.y = a * b.y;
|
126 |
-
c.z = a * b.z;
|
127 |
-
c.w = a * b.w;
|
128 |
-
return c;
|
129 |
-
}
|
130 |
-
|
131 |
-
// Vector fused multiply-add.
|
132 |
-
inline __device__ float fma(float a, float b, float c) {
|
133 |
-
return a * b + c;
|
134 |
-
}
|
135 |
-
|
136 |
-
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
137 |
-
float2 d;
|
138 |
-
d.x = fma(a.x, b.x, c.x);
|
139 |
-
d.y = fma(a.y, b.y, c.y);
|
140 |
-
return d;
|
141 |
-
}
|
142 |
-
|
143 |
-
inline __device__ float2 fma(float a, float2 b, float2 c) {
|
144 |
-
float2 d;
|
145 |
-
d.x = fma(a, b.x, c.x);
|
146 |
-
d.y = fma(a, b.y, c.y);
|
147 |
-
return d;
|
148 |
-
}
|
149 |
-
|
150 |
-
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
|
151 |
-
float4 d;
|
152 |
-
d.x = fma(a.x, b.x, c.x);
|
153 |
-
d.y = fma(a.y, b.y, c.y);
|
154 |
-
d.z = fma(a.z, b.z, c.z);
|
155 |
-
d.w = fma(a.w, b.w, c.w);
|
156 |
-
return d;
|
157 |
-
}
|
158 |
-
|
159 |
-
inline __device__ float4 fma(float a, float4 b, float4 c) {
|
160 |
-
float4 d;
|
161 |
-
d.x = fma(a, b.x, c.x);
|
162 |
-
d.y = fma(a, b.y, c.y);
|
163 |
-
d.z = fma(a, b.z, c.z);
|
164 |
-
d.w = fma(a, b.w, c.w);
|
165 |
-
return d;
|
166 |
-
}
|
167 |
-
|
168 |
-
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
|
169 |
-
Float4_ d;
|
170 |
-
d.x = fma(a, b.x, c.x);
|
171 |
-
d.y = fma(a, b.y, c.y);
|
172 |
-
return d;
|
173 |
-
}
|
174 |
-
|
175 |
-
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
176 |
-
Float8_ d;
|
177 |
-
d.x = fma(a, b.x, c.x);
|
178 |
-
d.y = fma(a, b.y, c.y);
|
179 |
-
d.z = fma(a, b.z, c.z);
|
180 |
-
d.w = fma(a, b.w, c.w);
|
181 |
-
return d;
|
182 |
-
}
|
183 |
-
|
184 |
-
// Vector sum.
|
185 |
-
template<>
|
186 |
-
inline __device__ float sum(float v) {
|
187 |
-
return v;
|
188 |
-
}
|
189 |
-
|
190 |
-
template<>
|
191 |
-
inline __device__ float sum(float2 v) {
|
192 |
-
return v.x + v.y;
|
193 |
-
}
|
194 |
-
|
195 |
-
template<>
|
196 |
-
inline __device__ float sum(float4 v) {
|
197 |
-
return v.x + v.y + v.z + v.w;
|
198 |
-
}
|
199 |
-
|
200 |
-
template<>
|
201 |
-
inline __device__ float sum(Float4_ v) {
|
202 |
-
return v.x.x + v.x.y + v.y.x + v.y.y;
|
203 |
-
}
|
204 |
-
|
205 |
-
template<>
|
206 |
-
inline __device__ float sum(Float8_ v) {
|
207 |
-
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
208 |
-
}
|
209 |
-
|
210 |
-
// Vector dot product.
|
211 |
-
inline __device__ float dot(float a, float b) {
|
212 |
-
return a * b;
|
213 |
-
}
|
214 |
-
|
215 |
-
inline __device__ float dot(float2 a, float2 b) {
|
216 |
-
float2 c = mul<float2, float2, float2>(a, b);
|
217 |
-
return c.x + c.y;
|
218 |
-
}
|
219 |
-
|
220 |
-
inline __device__ float dot(Float4_ a, Float4_ b) {
|
221 |
-
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
222 |
-
acc = fma(a.y, b.y, acc);
|
223 |
-
return acc.x + acc.y;
|
224 |
-
}
|
225 |
-
|
226 |
-
inline __device__ float dot(Float8_ a, Float8_ b) {
|
227 |
-
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
228 |
-
acc = fma(a.y, b.y, acc);
|
229 |
-
acc = fma(a.z, b.z, acc);
|
230 |
-
acc = fma(a.w, b.w, acc);
|
231 |
-
return acc.x + acc.y;
|
232 |
-
}
|
233 |
-
|
234 |
-
// From float to float.
|
235 |
-
inline __device__ void from_float(float& dst, float src) {
|
236 |
-
dst = src;
|
237 |
-
}
|
238 |
-
|
239 |
-
inline __device__ void from_float(float2& dst, float2 src) {
|
240 |
-
dst = src;
|
241 |
-
}
|
242 |
-
|
243 |
-
inline __device__ void from_float(float4& dst, float4 src) {
|
244 |
-
dst = src;
|
245 |
-
}
|
246 |
-
|
247 |
-
// From float to float.
|
248 |
-
inline __device__ float to_float(float u) {
|
249 |
-
return u;
|
250 |
-
}
|
251 |
-
|
252 |
-
inline __device__ float2 to_float(float2 u) {
|
253 |
-
return u;
|
254 |
-
}
|
255 |
-
|
256 |
-
inline __device__ float4 to_float(float4 u) {
|
257 |
-
return u;
|
258 |
-
}
|
259 |
-
|
260 |
-
inline __device__ Float4_ to_float(Float4_ u) {
|
261 |
-
return u;
|
262 |
-
}
|
263 |
-
|
264 |
-
inline __device__ Float8_ to_float(Float8_ u) {
|
265 |
-
return u;
|
266 |
-
}
|
267 |
-
|
268 |
-
// Zero-out a variable.
|
269 |
-
inline __device__ void zero(float& dst) {
|
270 |
-
dst = 0.f;
|
271 |
-
}
|
272 |
-
|
273 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/attention/dtype_fp8_e5m2.cuh
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include "attention_generic.cuh"
|
4 |
-
|
5 |
-
#include <stdint.h>
|
6 |
-
#ifdef ENABLE_FP8_E5M2
|
7 |
-
#include <cuda_fp8.h>
|
8 |
-
#endif
|
9 |
-
|
10 |
-
namespace vllm {
|
11 |
-
#ifdef ENABLE_FP8_E5M2
|
12 |
-
// fp8 vector types for quantization of kv cache
|
13 |
-
|
14 |
-
template<>
|
15 |
-
struct Vec<uint8_t, 1> {
|
16 |
-
using Type = uint8_t;
|
17 |
-
};
|
18 |
-
|
19 |
-
template<>
|
20 |
-
struct Vec<uint8_t, 2> {
|
21 |
-
using Type = uint16_t;
|
22 |
-
};
|
23 |
-
|
24 |
-
template<>
|
25 |
-
struct Vec<uint8_t, 4> {
|
26 |
-
using Type = uint32_t;
|
27 |
-
};
|
28 |
-
|
29 |
-
template<>
|
30 |
-
struct Vec<uint8_t, 8> {
|
31 |
-
using Type = uint2;
|
32 |
-
};
|
33 |
-
#endif // ENABLE_FP8_E5M2
|
34 |
-
|
35 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/cache.h
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <torch/extension.h>
|
4 |
-
|
5 |
-
#include <map>
|
6 |
-
#include <vector>
|
7 |
-
|
8 |
-
void swap_blocks(
|
9 |
-
torch::Tensor& src,
|
10 |
-
torch::Tensor& dst,
|
11 |
-
const std::map<int64_t, int64_t>& block_mapping);
|
12 |
-
|
13 |
-
void copy_blocks(
|
14 |
-
std::vector<torch::Tensor>& key_caches,
|
15 |
-
std::vector<torch::Tensor>& value_caches,
|
16 |
-
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
17 |
-
|
18 |
-
void reshape_and_cache(
|
19 |
-
torch::Tensor& key,
|
20 |
-
torch::Tensor& value,
|
21 |
-
torch::Tensor& key_cache,
|
22 |
-
torch::Tensor& value_cache,
|
23 |
-
torch::Tensor& slot_mapping,
|
24 |
-
const std::string& kv_cache_dtype);
|
25 |
-
|
26 |
-
void gather_cached_kv(
|
27 |
-
torch::Tensor& key,
|
28 |
-
torch::Tensor& value,
|
29 |
-
torch::Tensor& key_cache,
|
30 |
-
torch::Tensor& value_cache,
|
31 |
-
torch::Tensor& slot_mapping);
|
32 |
-
|
33 |
-
// Just for unittest
|
34 |
-
void convert_fp8_e5m2(
|
35 |
-
torch::Tensor& src_cache,
|
36 |
-
torch::Tensor& dst_cache);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/cache_kernels.cu
DELETED
@@ -1,474 +0,0 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include <ATen/cuda/CUDAContext.h>
|
3 |
-
#include <c10/cuda/CUDAGuard.h>
|
4 |
-
|
5 |
-
#include "cuda_compat.h"
|
6 |
-
#include "dispatch_utils.h"
|
7 |
-
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
8 |
-
|
9 |
-
#include <algorithm>
|
10 |
-
#include <cassert>
|
11 |
-
#include <map>
|
12 |
-
#include <vector>
|
13 |
-
|
14 |
-
void swap_blocks(
|
15 |
-
torch::Tensor& src,
|
16 |
-
torch::Tensor& dst,
|
17 |
-
const std::map<int64_t, int64_t>& block_mapping) {
|
18 |
-
torch::Device src_device = src.device();
|
19 |
-
torch::Device dst_device = dst.device();
|
20 |
-
cudaMemcpyKind memcpy_type;
|
21 |
-
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
22 |
-
TORCH_CHECK(
|
23 |
-
src_device.index() == dst_device.index(),
|
24 |
-
"src and dst must be on the same GPU");
|
25 |
-
memcpy_type = cudaMemcpyDeviceToDevice;
|
26 |
-
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
27 |
-
memcpy_type = cudaMemcpyDeviceToHost;
|
28 |
-
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
29 |
-
memcpy_type = cudaMemcpyHostToDevice;
|
30 |
-
} else {
|
31 |
-
TORCH_CHECK(false, "Invalid device combination");
|
32 |
-
}
|
33 |
-
|
34 |
-
char *src_ptr = static_cast<char*>(src.data_ptr());
|
35 |
-
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
36 |
-
|
37 |
-
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
38 |
-
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
39 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
40 |
-
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
41 |
-
for (const auto& pair : block_mapping) {
|
42 |
-
int64_t src_block_number = pair.first;
|
43 |
-
int64_t dst_block_number = pair.second;
|
44 |
-
int64_t src_offset = src_block_number * block_size_in_bytes;
|
45 |
-
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
46 |
-
cudaMemcpyAsync(
|
47 |
-
dst_ptr + dst_offset,
|
48 |
-
src_ptr + src_offset,
|
49 |
-
block_size_in_bytes,
|
50 |
-
memcpy_type,
|
51 |
-
stream);
|
52 |
-
}
|
53 |
-
}
|
54 |
-
|
55 |
-
namespace vllm {
|
56 |
-
|
57 |
-
// Grid: (num_layers, num_pairs)
|
58 |
-
template<typename scalar_t>
|
59 |
-
__global__ void copy_blocks_kernel(
|
60 |
-
int64_t* key_cache_ptrs,
|
61 |
-
int64_t* value_cache_ptrs,
|
62 |
-
const int64_t* __restrict__ block_mapping,
|
63 |
-
const int numel_per_block) {
|
64 |
-
const int layer_idx = blockIdx.x;
|
65 |
-
const int pair_idx = blockIdx.y;
|
66 |
-
|
67 |
-
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
68 |
-
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
69 |
-
int64_t src_block_number = block_mapping[2 * pair_idx];
|
70 |
-
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
71 |
-
|
72 |
-
const int64_t src_block_offset = src_block_number * numel_per_block;
|
73 |
-
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
74 |
-
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
75 |
-
int64_t src_offset = src_block_offset + i;
|
76 |
-
int64_t dst_offset = dst_block_offset + i;
|
77 |
-
key_cache[dst_offset] = key_cache[src_offset];
|
78 |
-
}
|
79 |
-
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
80 |
-
int64_t src_offset = src_block_offset + i;
|
81 |
-
int64_t dst_offset = dst_block_offset + i;
|
82 |
-
value_cache[dst_offset] = value_cache[src_offset];
|
83 |
-
}
|
84 |
-
}
|
85 |
-
|
86 |
-
} // namespace vllm
|
87 |
-
|
88 |
-
void copy_blocks(
|
89 |
-
std::vector<torch::Tensor>& key_caches,
|
90 |
-
std::vector<torch::Tensor>& value_caches,
|
91 |
-
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
92 |
-
int num_layers = key_caches.size();
|
93 |
-
TORCH_CHECK(num_layers == value_caches.size());
|
94 |
-
if (num_layers == 0) {
|
95 |
-
return;
|
96 |
-
}
|
97 |
-
torch::Device cache_device = key_caches[0].device();
|
98 |
-
TORCH_CHECK(cache_device.is_cuda());
|
99 |
-
|
100 |
-
// Create data structures for the kernel.
|
101 |
-
// Create an array of pointers to the key and value caches.
|
102 |
-
int64_t key_cache_ptrs[num_layers];
|
103 |
-
int64_t value_cache_ptrs[num_layers];
|
104 |
-
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
105 |
-
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
106 |
-
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
107 |
-
}
|
108 |
-
// Create block mapping array.
|
109 |
-
std::vector<int64_t> block_mapping_vec;
|
110 |
-
for (const auto& pair : block_mapping) {
|
111 |
-
int64_t src_block_number = pair.first;
|
112 |
-
for (int64_t dst_block_number : pair.second) {
|
113 |
-
block_mapping_vec.push_back(src_block_number);
|
114 |
-
block_mapping_vec.push_back(dst_block_number);
|
115 |
-
}
|
116 |
-
}
|
117 |
-
int64_t* block_mapping_array = block_mapping_vec.data();
|
118 |
-
int num_pairs = block_mapping_vec.size() / 2;
|
119 |
-
|
120 |
-
// Move the data structures to the GPU.
|
121 |
-
// NOTE: This synchronizes the CPU and GPU.
|
122 |
-
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
123 |
-
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
124 |
-
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
125 |
-
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
126 |
-
torch::Tensor block_mapping_tensor = torch::from_blob(
|
127 |
-
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
128 |
-
|
129 |
-
// Launch the kernel.
|
130 |
-
const int numel_per_block = key_caches[0][0].numel();
|
131 |
-
dim3 grid(num_layers, num_pairs);
|
132 |
-
dim3 block(std::min(1024, numel_per_block));
|
133 |
-
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
134 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
135 |
-
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
136 |
-
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
137 |
-
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
138 |
-
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
139 |
-
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
140 |
-
block_mapping_tensor.data_ptr<int64_t>(),
|
141 |
-
numel_per_block);
|
142 |
-
}));
|
143 |
-
}
|
144 |
-
|
145 |
-
namespace vllm {
|
146 |
-
|
147 |
-
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
148 |
-
__global__ void reshape_and_cache_kernel(
|
149 |
-
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
150 |
-
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
151 |
-
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
152 |
-
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
153 |
-
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
154 |
-
const int key_stride,
|
155 |
-
const int value_stride,
|
156 |
-
const int num_heads,
|
157 |
-
const int head_size,
|
158 |
-
const int block_size,
|
159 |
-
const int x) {
|
160 |
-
const int64_t token_idx = blockIdx.x;
|
161 |
-
const int64_t slot_idx = slot_mapping[token_idx];
|
162 |
-
if (slot_idx < 0) {
|
163 |
-
// Padding token that should be ignored.
|
164 |
-
return;
|
165 |
-
}
|
166 |
-
|
167 |
-
const int64_t block_idx = slot_idx / block_size;
|
168 |
-
const int64_t block_offset = slot_idx % block_size;
|
169 |
-
|
170 |
-
const int n = num_heads * head_size;
|
171 |
-
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
172 |
-
const int64_t src_key_idx = token_idx * key_stride + i;
|
173 |
-
const int64_t src_value_idx = token_idx * value_stride + i;
|
174 |
-
|
175 |
-
const int head_idx = i / head_size;
|
176 |
-
const int head_offset = i % head_size;
|
177 |
-
const int x_idx = head_offset / x;
|
178 |
-
const int x_offset = head_offset % x;
|
179 |
-
|
180 |
-
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
181 |
-
+ head_idx * (head_size / x) * block_size * x
|
182 |
-
+ x_idx * block_size * x
|
183 |
-
+ block_offset * x
|
184 |
-
+ x_offset;
|
185 |
-
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
186 |
-
+ head_idx * head_size * block_size
|
187 |
-
+ head_offset * block_size
|
188 |
-
+ block_offset;
|
189 |
-
scalar_t tgt_key = key[src_key_idx];
|
190 |
-
scalar_t tgt_value = value[src_value_idx];
|
191 |
-
if constexpr (is_fp8_e5m2_kv_cache) {
|
192 |
-
#ifdef ENABLE_FP8_E5M2
|
193 |
-
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
194 |
-
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
195 |
-
#else
|
196 |
-
assert(false);
|
197 |
-
#endif
|
198 |
-
} else {
|
199 |
-
key_cache[tgt_key_idx] = tgt_key;
|
200 |
-
value_cache[tgt_value_idx] = tgt_value;
|
201 |
-
}
|
202 |
-
}
|
203 |
-
}
|
204 |
-
|
205 |
-
} // namespace vllm
|
206 |
-
|
207 |
-
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
208 |
-
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
209 |
-
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
210 |
-
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
211 |
-
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
212 |
-
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
213 |
-
slot_mapping.data_ptr<int64_t>(), \
|
214 |
-
key_stride, \
|
215 |
-
value_stride, \
|
216 |
-
num_heads, \
|
217 |
-
head_size, \
|
218 |
-
block_size, \
|
219 |
-
x);
|
220 |
-
|
221 |
-
void reshape_and_cache(
|
222 |
-
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
223 |
-
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
224 |
-
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
225 |
-
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
226 |
-
torch::Tensor& slot_mapping, // [num_tokens]
|
227 |
-
const std::string& kv_cache_dtype)
|
228 |
-
{
|
229 |
-
int num_tokens = key.size(0);
|
230 |
-
int num_heads = key.size(1);
|
231 |
-
int head_size = key.size(2);
|
232 |
-
int block_size = key_cache.size(3);
|
233 |
-
int x = key_cache.size(4);
|
234 |
-
|
235 |
-
int key_stride = key.stride(0);
|
236 |
-
int value_stride = value.stride(0);
|
237 |
-
|
238 |
-
dim3 grid(num_tokens);
|
239 |
-
dim3 block(std::min(num_heads * head_size, 512));
|
240 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
241 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
242 |
-
if (kv_cache_dtype == "auto") {
|
243 |
-
if (key.dtype() == at::ScalarType::Float) {
|
244 |
-
CALL_RESHAPE_AND_CACHE(float, float, false);
|
245 |
-
} else if (key.dtype() == at::ScalarType::Half) {
|
246 |
-
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
247 |
-
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
248 |
-
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
249 |
-
}
|
250 |
-
} else if (kv_cache_dtype == "fp8_e5m2") {
|
251 |
-
if (key.dtype() == at::ScalarType::Float) {
|
252 |
-
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
253 |
-
} else if (key.dtype() == at::ScalarType::Half) {
|
254 |
-
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
255 |
-
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
256 |
-
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
257 |
-
}
|
258 |
-
} else {
|
259 |
-
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
260 |
-
}
|
261 |
-
}
|
262 |
-
|
263 |
-
namespace vllm {
|
264 |
-
|
265 |
-
// Grid: (num_blocks, block_size).
|
266 |
-
template<typename scalar_t>
|
267 |
-
__global__ void gather_cached_kv_kernel(
|
268 |
-
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
269 |
-
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
270 |
-
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
271 |
-
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
272 |
-
const int* __restrict__ slot_mapping, // [num_tokens]
|
273 |
-
const int key_stride,
|
274 |
-
const int value_stride,
|
275 |
-
const int num_heads,
|
276 |
-
const int head_size,
|
277 |
-
const int block_size,
|
278 |
-
const int x) {
|
279 |
-
const int token_idx = blockIdx.x;
|
280 |
-
const int slot_idx = slot_mapping[token_idx];
|
281 |
-
const int block_idx = slot_idx / block_size;
|
282 |
-
const int block_offset = slot_idx % block_size;
|
283 |
-
|
284 |
-
const int num_tokens = num_heads * head_size;
|
285 |
-
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
|
286 |
-
const int tgt_key_idx = token_idx * key_stride + i;
|
287 |
-
const int tgt_value_idx = token_idx * value_stride + i;
|
288 |
-
|
289 |
-
const int head_idx = i / head_size;
|
290 |
-
const int head_offset = i % head_size;
|
291 |
-
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
|
292 |
-
const int x_offset = head_offset % x;
|
293 |
-
|
294 |
-
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
295 |
-
+ head_idx * (head_size / x) * block_size * x
|
296 |
-
+ x_idx * block_size * x
|
297 |
-
+ block_offset * x
|
298 |
-
+ x_offset;
|
299 |
-
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
300 |
-
+ head_idx * head_size * block_size
|
301 |
-
+ head_offset * block_size
|
302 |
-
+ block_offset;
|
303 |
-
|
304 |
-
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
305 |
-
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
306 |
-
}
|
307 |
-
}
|
308 |
-
|
309 |
-
template <typename scalar_t>
|
310 |
-
__global__ void gather_cached_kv_kernel_optimized(
|
311 |
-
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
312 |
-
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
313 |
-
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
314 |
-
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
315 |
-
const int *__restrict__ slot_mapping, // [num_tokens]
|
316 |
-
const int key_stride,
|
317 |
-
const int value_stride,
|
318 |
-
const int num_heads,
|
319 |
-
const int head_size,
|
320 |
-
const int block_size,
|
321 |
-
const int x)
|
322 |
-
{
|
323 |
-
const int token_idx = blockIdx.x;
|
324 |
-
const int slot_idx = slot_mapping[token_idx];
|
325 |
-
const int block_idx = slot_idx / block_size;
|
326 |
-
const int block_offset = slot_idx % block_size;
|
327 |
-
|
328 |
-
const int dim = num_heads * head_size;
|
329 |
-
assert(dim % 4 == 0); // this is true for known use cases
|
330 |
-
const int unroll_factor = 4;
|
331 |
-
const int unrolled_dim = dim / unroll_factor;
|
332 |
-
|
333 |
-
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
|
334 |
-
{
|
335 |
-
int tgt_key_indices[unroll_factor];
|
336 |
-
int tgt_value_indices[unroll_factor];
|
337 |
-
int src_key_indices[unroll_factor];
|
338 |
-
int src_value_indices[unroll_factor];
|
339 |
-
scalar_t keys_to_store[unroll_factor];
|
340 |
-
scalar_t values_to_store[unroll_factor];
|
341 |
-
|
342 |
-
#pragma unroll
|
343 |
-
for (int j = 0; j < unroll_factor; ++j)
|
344 |
-
{
|
345 |
-
int index = i + j * unrolled_dim;
|
346 |
-
|
347 |
-
const int tgt_key_idx = token_idx * key_stride + index;
|
348 |
-
const int tgt_value_idx = token_idx * value_stride + index;
|
349 |
-
|
350 |
-
const int head_idx = index / head_size;
|
351 |
-
const int head_offset = index % head_size;
|
352 |
-
const int x_idx = head_offset / x;
|
353 |
-
const int x_offset = head_offset % x;
|
354 |
-
|
355 |
-
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
356 |
-
+ head_idx * (head_size / x) * block_size * x
|
357 |
-
+ x_idx * block_size * x
|
358 |
-
+ block_offset * x
|
359 |
-
+ x_offset;
|
360 |
-
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
361 |
-
+ head_idx * head_size * block_size
|
362 |
-
+ head_offset * block_size
|
363 |
-
+ block_offset;
|
364 |
-
|
365 |
-
tgt_key_indices[j] = tgt_key_idx;
|
366 |
-
tgt_value_indices[j] = tgt_value_idx;
|
367 |
-
src_key_indices[j] = src_key_idx;
|
368 |
-
src_value_indices[j] = src_value_idx;
|
369 |
-
|
370 |
-
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
371 |
-
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
372 |
-
}
|
373 |
-
|
374 |
-
#pragma unroll
|
375 |
-
for (int j = 0; j < unroll_factor; ++j)
|
376 |
-
{
|
377 |
-
key[tgt_key_indices[j]] = keys_to_store[j];
|
378 |
-
value[tgt_value_indices[j]] = values_to_store[j];
|
379 |
-
}
|
380 |
-
}
|
381 |
-
}
|
382 |
-
|
383 |
-
} // namespace vllm
|
384 |
-
|
385 |
-
void gather_cached_kv(
|
386 |
-
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
387 |
-
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
388 |
-
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
|
389 |
-
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
|
390 |
-
torch::Tensor& slot_mapping) // [in] [num_tokens]
|
391 |
-
{
|
392 |
-
int num_tokens = key.size(0);
|
393 |
-
int num_heads = key.size(1);
|
394 |
-
int head_size = key.size(2);
|
395 |
-
int block_size = key_cache.size(3);
|
396 |
-
int x = key_cache.size(4);
|
397 |
-
|
398 |
-
int key_stride = key.stride(0);
|
399 |
-
int value_stride = value.stride(0);
|
400 |
-
|
401 |
-
dim3 grid(num_tokens);
|
402 |
-
dim3 block(std::min(num_heads * head_size, 512));
|
403 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
404 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
405 |
-
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
406 |
-
key.scalar_type(),
|
407 |
-
"gather_cached_kv_kernel_optimized",
|
408 |
-
[&] {
|
409 |
-
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
410 |
-
key.data_ptr<scalar_t>(),
|
411 |
-
value.data_ptr<scalar_t>(),
|
412 |
-
key_cache.data_ptr<scalar_t>(),
|
413 |
-
value_cache.data_ptr<scalar_t>(),
|
414 |
-
slot_mapping.data_ptr<int>(),
|
415 |
-
key_stride,
|
416 |
-
value_stride,
|
417 |
-
num_heads,
|
418 |
-
head_size,
|
419 |
-
block_size,
|
420 |
-
x);
|
421 |
-
});
|
422 |
-
}
|
423 |
-
|
424 |
-
namespace vllm {
|
425 |
-
|
426 |
-
template<typename Tout, typename Tin>
|
427 |
-
__global__ void convert_fp8_e5m2_kernel(
|
428 |
-
const Tin* __restrict__ src_cache,
|
429 |
-
Tout* __restrict__ dst_cache,
|
430 |
-
const int64_t block_stride) {
|
431 |
-
const int64_t block_idx = blockIdx.x;
|
432 |
-
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
433 |
-
int64_t idx = block_idx * block_stride + i;
|
434 |
-
#ifdef ENABLE_FP8_E5M2
|
435 |
-
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
436 |
-
#else
|
437 |
-
assert(false);
|
438 |
-
#endif
|
439 |
-
}
|
440 |
-
}
|
441 |
-
|
442 |
-
} // namespace vllm
|
443 |
-
|
444 |
-
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
445 |
-
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
446 |
-
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
447 |
-
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
448 |
-
block_stride);
|
449 |
-
|
450 |
-
void convert_fp8_e5m2(
|
451 |
-
torch::Tensor& src_cache,
|
452 |
-
torch::Tensor& dst_cache)
|
453 |
-
{
|
454 |
-
int64_t num_blocks = src_cache.size(0);
|
455 |
-
int64_t block_stride = src_cache.stride(0);
|
456 |
-
|
457 |
-
dim3 grid(num_blocks);
|
458 |
-
dim3 block(std::min(block_stride, int64_t(512)));
|
459 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
460 |
-
|
461 |
-
if (src_cache.dtype() == at::ScalarType::Float) {
|
462 |
-
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
463 |
-
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
464 |
-
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
465 |
-
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
466 |
-
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
467 |
-
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
468 |
-
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
469 |
-
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
470 |
-
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
471 |
-
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
472 |
-
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
473 |
-
}
|
474 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/cuda_compat.h
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#ifndef USE_ROCM
|
4 |
-
#define VLLM_LDG(arg) __ldg(arg)
|
5 |
-
#else
|
6 |
-
#define VLLM_LDG(arg) *(arg)
|
7 |
-
#endif
|
8 |
-
|
9 |
-
#ifndef USE_ROCM
|
10 |
-
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
11 |
-
#else
|
12 |
-
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
13 |
-
#endif
|
14 |
-
|
15 |
-
#ifndef USE_ROCM
|
16 |
-
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
17 |
-
#else
|
18 |
-
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
19 |
-
#endif
|
20 |
-
|
21 |
-
#ifndef USE_ROCM
|
22 |
-
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
23 |
-
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
24 |
-
#else
|
25 |
-
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
26 |
-
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
27 |
-
#endif
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/cuda_utils.h
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <torch/extension.h>
|
4 |
-
|
5 |
-
int get_device_attribute(
|
6 |
-
int attribute,
|
7 |
-
int device_id);
|
8 |
-
|
9 |
-
int get_max_shared_memory_per_block_device_attribute(
|
10 |
-
int device_id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/cuda_utils_kernels.cu
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
#ifdef USE_ROCM
|
2 |
-
#include <hip/hip_runtime.h>
|
3 |
-
#include <hip/hip_runtime_api.h>
|
4 |
-
#endif
|
5 |
-
int get_device_attribute(
|
6 |
-
int attribute,
|
7 |
-
int device_id)
|
8 |
-
{
|
9 |
-
int device, value;
|
10 |
-
if (device_id < 0) {
|
11 |
-
cudaGetDevice(&device);
|
12 |
-
}
|
13 |
-
else {
|
14 |
-
device = device_id;
|
15 |
-
}
|
16 |
-
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
17 |
-
return value;
|
18 |
-
}
|
19 |
-
|
20 |
-
|
21 |
-
int get_max_shared_memory_per_block_device_attribute(
|
22 |
-
int device_id)
|
23 |
-
{
|
24 |
-
int attribute;
|
25 |
-
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
26 |
-
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
27 |
-
|
28 |
-
#ifdef USE_ROCM
|
29 |
-
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
30 |
-
#else
|
31 |
-
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
32 |
-
#endif
|
33 |
-
|
34 |
-
return get_device_attribute(attribute, device_id);
|
35 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/custom_all_reduce.cu
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
#include <ATen/cuda/Exceptions.h>
|
2 |
-
#include <c10/cuda/CUDAGuard.h>
|
3 |
-
#include <c10/cuda/CUDAStream.h>
|
4 |
-
#include <torch/extension.h>
|
5 |
-
|
6 |
-
#include "custom_all_reduce.cuh"
|
7 |
-
|
8 |
-
// fake pointer type
|
9 |
-
using fptr_t = uint64_t;
|
10 |
-
static_assert(sizeof(void *) == sizeof(fptr_t));
|
11 |
-
|
12 |
-
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
13 |
-
const std::vector<std::string> &handles,
|
14 |
-
const std::vector<int64_t> &offsets, int rank,
|
15 |
-
bool full_nvlink) {
|
16 |
-
int world_size = offsets.size();
|
17 |
-
if (world_size > 8)
|
18 |
-
throw std::invalid_argument("world size > 8 is not supported");
|
19 |
-
if (world_size % 2 != 0)
|
20 |
-
throw std::invalid_argument("Odd num gpus is not supported for now");
|
21 |
-
if (world_size != handles.size())
|
22 |
-
throw std::invalid_argument(
|
23 |
-
"handles length should equal to offsets length");
|
24 |
-
if (rank < 0 || rank >= world_size)
|
25 |
-
throw std::invalid_argument("invalid rank passed in");
|
26 |
-
|
27 |
-
cudaIpcMemHandle_t ipc_handles[8];
|
28 |
-
for (int i = 0; i < world_size; i++) {
|
29 |
-
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
30 |
-
}
|
31 |
-
return (fptr_t) new vllm::CustomAllreduce(
|
32 |
-
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
|
33 |
-
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
34 |
-
}
|
35 |
-
|
36 |
-
/**
|
37 |
-
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
38 |
-
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
39 |
-
* because it allows transpose of contiguous slice (i.e. slicing the first
|
40 |
-
* dimension). Currently, we require this because stride information is not
|
41 |
-
* passed into the kernels and we treat input tensors as flat.
|
42 |
-
*
|
43 |
-
* Examples
|
44 |
-
* A = torch.zeros(3, 3, 3)
|
45 |
-
* 1. A: OK
|
46 |
-
* 2. A[1:]: OK
|
47 |
-
* 3. A.permute(2, 0, 1): OK
|
48 |
-
* 4. A[1:].permute(2, 0, 1): OK
|
49 |
-
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
50 |
-
* 6. A[:, 1:, 1:]: Not OK
|
51 |
-
*/
|
52 |
-
bool _is_weak_contiguous(torch::Tensor &t) {
|
53 |
-
return t.is_contiguous() ||
|
54 |
-
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
55 |
-
t.numel() * t.element_size());
|
56 |
-
}
|
57 |
-
|
58 |
-
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
59 |
-
bool full_nvlink) {
|
60 |
-
auto inp_size = inp.numel() * inp.element_size();
|
61 |
-
// custom allreduce requires input byte size to be multiples of 16
|
62 |
-
if (inp_size % 16 != 0) return false;
|
63 |
-
if (!_is_weak_contiguous(inp)) return false;
|
64 |
-
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
65 |
-
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
|
66 |
-
// <= 512k
|
67 |
-
return world_size <= 4 && inp_size <= 512 * 1024;
|
68 |
-
}
|
69 |
-
|
70 |
-
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
71 |
-
cudaStream_t stream) {
|
72 |
-
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
73 |
-
TORCH_CHECK(_is_weak_contiguous(out));
|
74 |
-
switch (out.scalar_type()) {
|
75 |
-
case at::ScalarType::Float: {
|
76 |
-
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
77 |
-
reinterpret_cast<float *>(out.data_ptr()),
|
78 |
-
out.numel());
|
79 |
-
break;
|
80 |
-
}
|
81 |
-
case at::ScalarType::Half: {
|
82 |
-
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
83 |
-
reinterpret_cast<half *>(out.data_ptr()),
|
84 |
-
out.numel());
|
85 |
-
break;
|
86 |
-
}
|
87 |
-
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
88 |
-
case at::ScalarType::BFloat16: {
|
89 |
-
fa->allreduce<nv_bfloat16>(
|
90 |
-
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
91 |
-
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
92 |
-
break;
|
93 |
-
}
|
94 |
-
#endif
|
95 |
-
default:
|
96 |
-
throw std::runtime_error(
|
97 |
-
"custom allreduce only supports float32, float16 and bfloat16");
|
98 |
-
}
|
99 |
-
}
|
100 |
-
|
101 |
-
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
102 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
103 |
-
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
104 |
-
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
105 |
-
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
106 |
-
_all_reduce(_fa, inp, out, stream);
|
107 |
-
}
|
108 |
-
|
109 |
-
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
110 |
-
torch::Tensor &out) {
|
111 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
112 |
-
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
113 |
-
|
114 |
-
auto input_size = inp.numel() * inp.element_size();
|
115 |
-
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
116 |
-
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
117 |
-
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
118 |
-
"registered buffer is too small to contain the input");
|
119 |
-
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
120 |
-
input_size, cudaMemcpyDeviceToDevice, stream));
|
121 |
-
_all_reduce(_fa, reg_buffer, out, stream);
|
122 |
-
}
|
123 |
-
|
124 |
-
void dispose(fptr_t _fa) {
|
125 |
-
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
126 |
-
delete fa;
|
127 |
-
}
|
128 |
-
|
129 |
-
int meta_size() { return sizeof(vllm::Metadata); }
|
130 |
-
|
131 |
-
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
132 |
-
const std::vector<std::string> &handles,
|
133 |
-
const std::vector<int64_t> &offsets) {
|
134 |
-
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
135 |
-
fa->register_buffer(handles, offsets, t.data_ptr());
|
136 |
-
}
|
137 |
-
|
138 |
-
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
139 |
-
fptr_t _fa) {
|
140 |
-
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
141 |
-
return fa->get_graph_buffer_ipc_meta();
|
142 |
-
}
|
143 |
-
|
144 |
-
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
145 |
-
const std::vector<std::vector<int64_t>> &offsets) {
|
146 |
-
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
147 |
-
fa->register_graph_buffers(handles, offsets);
|
148 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/custom_all_reduce.cuh
DELETED
@@ -1,562 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <cuda.h>
|
4 |
-
#include <cuda_bf16.h>
|
5 |
-
#include <cuda_fp16.h>
|
6 |
-
#include <cuda_runtime.h>
|
7 |
-
|
8 |
-
#include <iostream>
|
9 |
-
#include <limits>
|
10 |
-
#include <map>
|
11 |
-
#include <unordered_map>
|
12 |
-
#include <vector>
|
13 |
-
|
14 |
-
#define CUDACHECK(cmd) \
|
15 |
-
do { \
|
16 |
-
cudaError_t e = cmd; \
|
17 |
-
if (e != cudaSuccess) { \
|
18 |
-
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
19 |
-
cudaGetErrorString(e)); \
|
20 |
-
exit(EXIT_FAILURE); \
|
21 |
-
} \
|
22 |
-
} while (0)
|
23 |
-
|
24 |
-
namespace vllm {
|
25 |
-
|
26 |
-
struct Signal {
|
27 |
-
alignas(64) union {
|
28 |
-
uint64_t flag;
|
29 |
-
unsigned char data[8];
|
30 |
-
} start;
|
31 |
-
alignas(64) union {
|
32 |
-
uint64_t flag;
|
33 |
-
unsigned char data[8];
|
34 |
-
} end;
|
35 |
-
};
|
36 |
-
|
37 |
-
struct Metadata {
|
38 |
-
alignas(128) Signal sg;
|
39 |
-
alignas(128) int counter;
|
40 |
-
};
|
41 |
-
static_assert(offsetof(Metadata, counter) == 128);
|
42 |
-
static_assert(sizeof(Metadata) == 256);
|
43 |
-
|
44 |
-
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
45 |
-
|
46 |
-
struct RankSignals {
|
47 |
-
volatile Signal *signals[8];
|
48 |
-
};
|
49 |
-
|
50 |
-
// like std::array, but aligned
|
51 |
-
template <typename T, int sz>
|
52 |
-
struct __align__(alignof(T) * sz) array_t {
|
53 |
-
T data[sz];
|
54 |
-
using type = T;
|
55 |
-
static constexpr int size = sz;
|
56 |
-
};
|
57 |
-
|
58 |
-
// use packed type to maximize memory efficiency
|
59 |
-
// goal: generate ld.128 and st.128 instructions
|
60 |
-
template <typename T>
|
61 |
-
struct packed_t {
|
62 |
-
// the (P)acked type for load/store
|
63 |
-
using P = array_t<T, 16 / sizeof(T)>;
|
64 |
-
// the (A)ccumulator type for reduction
|
65 |
-
using A = array_t<float, 16 / sizeof(T)>;
|
66 |
-
};
|
67 |
-
|
68 |
-
#define DINLINE __device__ __forceinline__
|
69 |
-
|
70 |
-
// scalar cast functions
|
71 |
-
DINLINE float upcast_s(half val) { return __half2float(val); }
|
72 |
-
|
73 |
-
template <typename T>
|
74 |
-
DINLINE T downcast_s(float val);
|
75 |
-
template <>
|
76 |
-
DINLINE half downcast_s(float val) {
|
77 |
-
return __float2half(val);
|
78 |
-
}
|
79 |
-
|
80 |
-
// scalar add functions
|
81 |
-
// for some reason when compiling with Pytorch, the + operator for half and
|
82 |
-
// bfloat is disabled so we call the intrinsics directly
|
83 |
-
DINLINE half &assign_add(half &a, half b) {
|
84 |
-
a = __hadd(a, b);
|
85 |
-
return a;
|
86 |
-
}
|
87 |
-
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
88 |
-
|
89 |
-
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
90 |
-
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
91 |
-
template <>
|
92 |
-
DINLINE nv_bfloat16 downcast_s(float val) {
|
93 |
-
return __float2bfloat16(val);
|
94 |
-
}
|
95 |
-
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
96 |
-
a = __hadd(a, b);
|
97 |
-
return a;
|
98 |
-
}
|
99 |
-
#endif
|
100 |
-
|
101 |
-
template <typename T, int N>
|
102 |
-
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
103 |
-
#pragma unroll
|
104 |
-
for (int i = 0; i < N; i++) {
|
105 |
-
assign_add(a.data[i], b.data[i]);
|
106 |
-
}
|
107 |
-
return a;
|
108 |
-
}
|
109 |
-
|
110 |
-
template <typename T, int N>
|
111 |
-
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
112 |
-
if constexpr (std::is_same<T, float>::value) {
|
113 |
-
return val;
|
114 |
-
} else {
|
115 |
-
array_t<float, N> out;
|
116 |
-
#pragma unroll
|
117 |
-
for (int i = 0; i < N; i++) {
|
118 |
-
out.data[i] = upcast_s(val.data[i]);
|
119 |
-
}
|
120 |
-
return out;
|
121 |
-
}
|
122 |
-
}
|
123 |
-
|
124 |
-
template <typename O>
|
125 |
-
DINLINE O downcast(array_t<float, O::size> val) {
|
126 |
-
if constexpr (std::is_same<typename O::type, float>::value) {
|
127 |
-
return val;
|
128 |
-
} else {
|
129 |
-
O out;
|
130 |
-
#pragma unroll
|
131 |
-
for (int i = 0; i < O::size; i++) {
|
132 |
-
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
133 |
-
}
|
134 |
-
return out;
|
135 |
-
}
|
136 |
-
}
|
137 |
-
|
138 |
-
// compute flag at compile time
|
139 |
-
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
|
140 |
-
auto m = std::numeric_limits<uint64_t>::max();
|
141 |
-
return m >> ((8 - ngpus) * 8);
|
142 |
-
}
|
143 |
-
|
144 |
-
template <int ngpus>
|
145 |
-
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
|
146 |
-
int rank) {
|
147 |
-
constexpr auto FLAG = compute_flag(ngpus);
|
148 |
-
if (blockIdx.x == 0) {
|
149 |
-
if (threadIdx.x < ngpus)
|
150 |
-
// simultaneously write to the corresponding byte to all other ranks.
|
151 |
-
// Latency = 1 p2p write
|
152 |
-
sg.signals[threadIdx.x]->start.data[rank] = 255;
|
153 |
-
else if (threadIdx.x == 32)
|
154 |
-
// reset
|
155 |
-
meta->sg.end.flag = 0;
|
156 |
-
}
|
157 |
-
if (threadIdx.x == 0) {
|
158 |
-
while (meta->sg.start.flag != FLAG)
|
159 |
-
;
|
160 |
-
}
|
161 |
-
__syncthreads();
|
162 |
-
}
|
163 |
-
|
164 |
-
template <int ngpus, bool final_sync = false>
|
165 |
-
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
|
166 |
-
int rank) {
|
167 |
-
constexpr auto FLAG = compute_flag(ngpus);
|
168 |
-
__syncthreads();
|
169 |
-
__shared__ int num;
|
170 |
-
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
|
171 |
-
__syncthreads();
|
172 |
-
|
173 |
-
// Only the last completing block can perform the end synchronization
|
174 |
-
// This can ensures when the final busy wait ends, all ranks must have
|
175 |
-
// finished reading each other's buffer.
|
176 |
-
if (num == gridDim.x - 1) {
|
177 |
-
if (threadIdx.x == 32) {
|
178 |
-
// reset in a different warp
|
179 |
-
meta->counter = 0;
|
180 |
-
meta->sg.start.flag = 0;
|
181 |
-
} else if (threadIdx.x < ngpus) {
|
182 |
-
// simultaneously write to the corresponding byte to all other ranks.
|
183 |
-
// Latency = 1 p2p write
|
184 |
-
sg.signals[threadIdx.x]->end.data[rank] = 255;
|
185 |
-
}
|
186 |
-
// if this is the final sync, only one block needs it
|
187 |
-
// because kernel exit can serve as sync
|
188 |
-
if constexpr (final_sync) {
|
189 |
-
if (threadIdx.x == 0) {
|
190 |
-
while (meta->sg.end.flag != FLAG)
|
191 |
-
;
|
192 |
-
}
|
193 |
-
}
|
194 |
-
}
|
195 |
-
if constexpr (!final_sync) {
|
196 |
-
if (threadIdx.x == 0) {
|
197 |
-
while (meta->sg.end.flag != FLAG)
|
198 |
-
;
|
199 |
-
}
|
200 |
-
__syncthreads();
|
201 |
-
}
|
202 |
-
}
|
203 |
-
|
204 |
-
template <typename P, int ngpus, typename A>
|
205 |
-
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
206 |
-
A tmp = upcast(ptrs[0][idx]);
|
207 |
-
#pragma unroll
|
208 |
-
for (int i = 1; i < ngpus; i++) {
|
209 |
-
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
210 |
-
}
|
211 |
-
return downcast<P>(tmp);
|
212 |
-
}
|
213 |
-
|
214 |
-
template <typename T, int ngpus>
|
215 |
-
__global__ void __launch_bounds__(512, 1)
|
216 |
-
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
217 |
-
volatile Metadata *meta, T *__restrict__ result,
|
218 |
-
int rank, int size) {
|
219 |
-
using P = typename packed_t<T>::P;
|
220 |
-
using A = typename packed_t<T>::A;
|
221 |
-
// note: we don't reorder the address so the accumulation order is the same
|
222 |
-
// for all ranks, ensuring bitwise identical results
|
223 |
-
auto dp = *_dp;
|
224 |
-
start_sync<ngpus>(sg, meta, rank);
|
225 |
-
// do the actual reduction
|
226 |
-
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
227 |
-
idx += gridDim.x * blockDim.x) {
|
228 |
-
((P *)result)[idx] =
|
229 |
-
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
230 |
-
}
|
231 |
-
end_sync<ngpus, true>(sg, meta, rank);
|
232 |
-
}
|
233 |
-
|
234 |
-
template <typename P>
|
235 |
-
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
236 |
-
return (P *)(((Metadata *)sg) + 1);
|
237 |
-
}
|
238 |
-
|
239 |
-
template <typename T, int ngpus>
|
240 |
-
__global__ void __launch_bounds__(512, 1)
|
241 |
-
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
242 |
-
volatile Metadata *meta, T *__restrict__ result,
|
243 |
-
int rank, int size) {
|
244 |
-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
245 |
-
int stride = gridDim.x * blockDim.x;
|
246 |
-
using P = typename packed_t<T>::P;
|
247 |
-
using A = typename packed_t<T>::A;
|
248 |
-
int part = size / ngpus;
|
249 |
-
int start = rank * part;
|
250 |
-
int end = rank == ngpus - 1 ? size : start + part;
|
251 |
-
const P *ptrs[ngpus];
|
252 |
-
P *tmps[ngpus];
|
253 |
-
#pragma unroll
|
254 |
-
for (int i = 0; i < ngpus; i++) {
|
255 |
-
int target = (rank + i) % ngpus;
|
256 |
-
ptrs[i] = (const P *)_dp->ptrs[target];
|
257 |
-
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
258 |
-
}
|
259 |
-
auto tmp_out = tmps[0];
|
260 |
-
start_sync<ngpus>(sg, meta, rank);
|
261 |
-
// stage 1: reduce scatter
|
262 |
-
for (int idx = start + tid; idx < end; idx += stride) {
|
263 |
-
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
264 |
-
}
|
265 |
-
// Maybe TODO: replace this with per-block release-acquire
|
266 |
-
// can save about 1-2us (not a lot though)
|
267 |
-
end_sync<ngpus>(sg, meta, rank);
|
268 |
-
|
269 |
-
// stage 2: allgather
|
270 |
-
for (int idx = tid; idx < part; idx += stride) {
|
271 |
-
#pragma unroll
|
272 |
-
for (int i = 0; i < ngpus; i++) {
|
273 |
-
int dst_idx = ((rank + i) % ngpus) * part + idx;
|
274 |
-
((P *)result)[dst_idx] = tmps[i][idx];
|
275 |
-
}
|
276 |
-
}
|
277 |
-
// process the last larger partition
|
278 |
-
int remaining = size - part * ngpus;
|
279 |
-
if (tid < remaining) {
|
280 |
-
int dst_idx = tid + part * ngpus;
|
281 |
-
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
|
282 |
-
}
|
283 |
-
|
284 |
-
// faster than this
|
285 |
-
// for (int idx = tid; idx < size; idx += stride) {
|
286 |
-
// int target_rank = idx / part;
|
287 |
-
// if (target_rank == ngpus) target_rank -= 1;
|
288 |
-
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
|
289 |
-
// }
|
290 |
-
}
|
291 |
-
|
292 |
-
template <typename T, int ngpus>
|
293 |
-
__global__ void __launch_bounds__(512, 1)
|
294 |
-
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
|
295 |
-
volatile Metadata *meta,
|
296 |
-
T *__restrict__ result, int rank,
|
297 |
-
int size) {
|
298 |
-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
299 |
-
int stride = gridDim.x * blockDim.x;
|
300 |
-
using P = typename packed_t<T>::P;
|
301 |
-
using A = typename packed_t<T>::A;
|
302 |
-
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
|
303 |
-
constexpr int hg = ngpus / 2;
|
304 |
-
// Actually not quite half butterfly.
|
305 |
-
// This is an all-to-all within each group containing half of the ranks
|
306 |
-
// followed by cross-group add. Equivalent to half butterfly when there
|
307 |
-
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
|
308 |
-
const P *ptrs[hg];
|
309 |
-
{
|
310 |
-
int start = rank - rank % hg;
|
311 |
-
#pragma unroll
|
312 |
-
for (int i = 0; i < hg; i++) {
|
313 |
-
ptrs[i] = (const P *)_dp->ptrs[i + start];
|
314 |
-
}
|
315 |
-
}
|
316 |
-
start_sync<ngpus>(sg, meta, rank);
|
317 |
-
for (int idx = tid; idx < size; idx += stride) {
|
318 |
-
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
|
319 |
-
}
|
320 |
-
end_sync<ngpus>(sg, meta, rank);
|
321 |
-
|
322 |
-
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
|
323 |
-
// do the cross group reduction
|
324 |
-
for (int idx = tid; idx < size; idx += stride) {
|
325 |
-
auto tmp = tmp_out[idx];
|
326 |
-
packed_assign_add(tmp, src[idx]);
|
327 |
-
((P *)result)[idx] = tmp;
|
328 |
-
}
|
329 |
-
}
|
330 |
-
|
331 |
-
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
332 |
-
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
333 |
-
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
334 |
-
|
335 |
-
class CustomAllreduce {
|
336 |
-
public:
|
337 |
-
int rank_;
|
338 |
-
int world_size_;
|
339 |
-
bool full_nvlink_;
|
340 |
-
|
341 |
-
// below are device pointers
|
342 |
-
RankSignals sg_;
|
343 |
-
std::unordered_map<void *, RankData *> buffers_;
|
344 |
-
Metadata *meta_;
|
345 |
-
|
346 |
-
// stores the registered device pointers from all ranks
|
347 |
-
RankData *d_rank_data_base_, *d_rank_data_end_;
|
348 |
-
std::vector<void *> graph_unreg_buffers_;
|
349 |
-
// a map from IPC handles to opened IPC pointers
|
350 |
-
std::map<IPC_KEY, char *> ipc_handles_;
|
351 |
-
|
352 |
-
/**
|
353 |
-
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
354 |
-
*
|
355 |
-
* There's a total of sizeof(Metadata) of prefix before the actual data,
|
356 |
-
* so meta + 1 points to actual temporary buffer.
|
357 |
-
*
|
358 |
-
* note: this class does not own any device memory. Any required buffers
|
359 |
-
* are passed in from the constructor
|
360 |
-
*/
|
361 |
-
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
|
362 |
-
const cudaIpcMemHandle_t *handles,
|
363 |
-
const std::vector<int64_t> &offsets, int rank,
|
364 |
-
bool full_nvlink = true)
|
365 |
-
: rank_(rank),
|
366 |
-
world_size_(offsets.size()),
|
367 |
-
full_nvlink_(full_nvlink),
|
368 |
-
meta_(meta),
|
369 |
-
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
370 |
-
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
371 |
-
for (int i = 0; i < world_size_; i++) {
|
372 |
-
Metadata *rank_meta;
|
373 |
-
if (i != rank_) {
|
374 |
-
char *handle = open_ipc_handle(&handles[i]);
|
375 |
-
handle += offsets[i];
|
376 |
-
rank_meta = (Metadata *)handle;
|
377 |
-
} else {
|
378 |
-
rank_meta = meta_;
|
379 |
-
}
|
380 |
-
sg_.signals[i] = &rank_meta->sg;
|
381 |
-
}
|
382 |
-
}
|
383 |
-
|
384 |
-
char *open_ipc_handle(const void *ipc_handle) {
|
385 |
-
auto [it, new_handle] =
|
386 |
-
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
387 |
-
if (new_handle) {
|
388 |
-
char *ipc_ptr;
|
389 |
-
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
390 |
-
*((const cudaIpcMemHandle_t *)ipc_handle),
|
391 |
-
cudaIpcMemLazyEnablePeerAccess));
|
392 |
-
it->second = ipc_ptr;
|
393 |
-
}
|
394 |
-
return it->second;
|
395 |
-
}
|
396 |
-
|
397 |
-
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
398 |
-
get_graph_buffer_ipc_meta() {
|
399 |
-
auto num_buffers = graph_unreg_buffers_.size();
|
400 |
-
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
401 |
-
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
402 |
-
std::vector<int64_t> offsets(num_buffers);
|
403 |
-
for (int i = 0; i < num_buffers; i++) {
|
404 |
-
auto ptr = graph_unreg_buffers_[i];
|
405 |
-
void *base_ptr;
|
406 |
-
// note: must share the base address of each allocation, or we get wrong
|
407 |
-
// address
|
408 |
-
if (cuPointerGetAttribute(&base_ptr,
|
409 |
-
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
410 |
-
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
411 |
-
throw std::runtime_error("failed to get pointer attr");
|
412 |
-
CUDACHECK(cudaIpcGetMemHandle(
|
413 |
-
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
414 |
-
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
415 |
-
}
|
416 |
-
return std::make_pair(handles, offsets);
|
417 |
-
}
|
418 |
-
|
419 |
-
void check_rank_data_capacity(size_t num = 1) {
|
420 |
-
if (d_rank_data_base_ + num > d_rank_data_end_)
|
421 |
-
throw std::runtime_error(
|
422 |
-
"Rank data buffer is overflowed by " +
|
423 |
-
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
424 |
-
}
|
425 |
-
|
426 |
-
void register_buffer(const std::vector<std::string> &handles,
|
427 |
-
const std::vector<int64_t> &offsets, void *self) {
|
428 |
-
check_rank_data_capacity();
|
429 |
-
RankData data;
|
430 |
-
for (int i = 0; i < world_size_; i++) {
|
431 |
-
if (i != rank_) {
|
432 |
-
char *handle = open_ipc_handle(handles[i].data());
|
433 |
-
handle += offsets[i];
|
434 |
-
data.ptrs[i] = handle;
|
435 |
-
} else {
|
436 |
-
data.ptrs[i] = self;
|
437 |
-
}
|
438 |
-
}
|
439 |
-
auto d_data = d_rank_data_base_++;
|
440 |
-
CUDACHECK(
|
441 |
-
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
442 |
-
buffers_[self] = d_data;
|
443 |
-
}
|
444 |
-
|
445 |
-
// note: when registering graph buffers, we intentionally choose to not
|
446 |
-
// deduplicate the addresses. That means if the allocator reuses some
|
447 |
-
// addresses, they will be registered again. This is to account for the remote
|
448 |
-
// possibility of different allocation patterns between ranks. For example,
|
449 |
-
// rank 1 may get the same input address for the second allreduce, but rank 2
|
450 |
-
// got a different address. IPC handles have internal reference counting
|
451 |
-
// mechanism so overhead should be small.
|
452 |
-
void register_graph_buffers(
|
453 |
-
const std::vector<std::string> &handles,
|
454 |
-
const std::vector<std::vector<int64_t>> &offsets) {
|
455 |
-
auto num_buffers = graph_unreg_buffers_.size();
|
456 |
-
check_rank_data_capacity(num_buffers);
|
457 |
-
std::vector<RankData> rank_data(num_buffers);
|
458 |
-
for (int i = 0; i < num_buffers; i++) {
|
459 |
-
auto self_ptr = graph_unreg_buffers_[i];
|
460 |
-
auto &rd = rank_data[i];
|
461 |
-
for (int j = 0; j < world_size_; j++) {
|
462 |
-
if (j != rank_) {
|
463 |
-
char *handle =
|
464 |
-
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
465 |
-
handle += offsets[j][i];
|
466 |
-
rd.ptrs[j] = handle;
|
467 |
-
} else {
|
468 |
-
rd.ptrs[j] = self_ptr;
|
469 |
-
}
|
470 |
-
}
|
471 |
-
}
|
472 |
-
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
473 |
-
sizeof(RankData) * num_buffers,
|
474 |
-
cudaMemcpyHostToDevice));
|
475 |
-
d_rank_data_base_ += num_buffers;
|
476 |
-
graph_unreg_buffers_.clear();
|
477 |
-
}
|
478 |
-
|
479 |
-
/**
|
480 |
-
* This is the result after careful grid search. Using 36 blocks give the best
|
481 |
-
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
482 |
-
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
483 |
-
* Not quite sure the underlying reason, but my guess is that too many SMs
|
484 |
-
* will cause contention on NVLink bus.
|
485 |
-
*/
|
486 |
-
template <typename T>
|
487 |
-
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
488 |
-
int threads = 512, int block_limit = 36) {
|
489 |
-
auto d = packed_t<T>::P::size;
|
490 |
-
if (size % d != 0)
|
491 |
-
throw std::runtime_error(
|
492 |
-
"custom allreduce currently requires input length to be multiple "
|
493 |
-
"of " +
|
494 |
-
std::to_string(d));
|
495 |
-
|
496 |
-
RankData *ptrs;
|
497 |
-
cudaStreamCaptureStatus status;
|
498 |
-
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
499 |
-
if (status == cudaStreamCaptureStatusActive) {
|
500 |
-
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
501 |
-
graph_unreg_buffers_.push_back(input);
|
502 |
-
} else {
|
503 |
-
auto it = buffers_.find(input);
|
504 |
-
if (it == buffers_.end())
|
505 |
-
throw std::runtime_error(
|
506 |
-
"buffer address " +
|
507 |
-
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
508 |
-
" is not registered!");
|
509 |
-
ptrs = it->second;
|
510 |
-
}
|
511 |
-
|
512 |
-
size /= d;
|
513 |
-
auto bytes = size * sizeof(typename packed_t<T>::P);
|
514 |
-
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
515 |
-
#define KL(ngpus, name) \
|
516 |
-
name<T, ngpus> \
|
517 |
-
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
|
518 |
-
#define REDUCE_CASE(ngpus) \
|
519 |
-
case ngpus: { \
|
520 |
-
if (world_size_ == 2) { \
|
521 |
-
KL(ngpus, cross_device_reduce_1stage); \
|
522 |
-
} else if (full_nvlink_) { \
|
523 |
-
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
524 |
-
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
525 |
-
KL(ngpus, cross_device_reduce_1stage); \
|
526 |
-
} else { \
|
527 |
-
KL(ngpus, cross_device_reduce_2stage); \
|
528 |
-
} \
|
529 |
-
} else { \
|
530 |
-
KL(ngpus, cross_device_reduce_half_butterfly); \
|
531 |
-
} \
|
532 |
-
break; \
|
533 |
-
}
|
534 |
-
|
535 |
-
switch (world_size_) {
|
536 |
-
REDUCE_CASE(2)
|
537 |
-
REDUCE_CASE(4)
|
538 |
-
REDUCE_CASE(6)
|
539 |
-
REDUCE_CASE(8)
|
540 |
-
default:
|
541 |
-
throw std::runtime_error(
|
542 |
-
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
543 |
-
"gpus = " +
|
544 |
-
std::to_string(world_size_));
|
545 |
-
}
|
546 |
-
#undef REDUCE_CASE
|
547 |
-
#undef KL
|
548 |
-
}
|
549 |
-
|
550 |
-
~CustomAllreduce() {
|
551 |
-
for (auto [_, ptr] : ipc_handles_) {
|
552 |
-
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
553 |
-
}
|
554 |
-
}
|
555 |
-
};
|
556 |
-
/**
|
557 |
-
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
558 |
-
a template instantiation:
|
559 |
-
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
|
560 |
-
int, int, int);
|
561 |
-
*/
|
562 |
-
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/custom_all_reduce_test.cu
DELETED
@@ -1,284 +0,0 @@
|
|
1 |
-
/**
|
2 |
-
* This is a standalone test for custom allreduce.
|
3 |
-
* To compile, make sure you have MPI and NCCL installed in your system.
|
4 |
-
* export MPI_HOME=XXX
|
5 |
-
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
6 |
-
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
7 |
-
*
|
8 |
-
* Warning: this C++ test is not designed to be very readable and was used
|
9 |
-
* during the rapid prototyping process.
|
10 |
-
*
|
11 |
-
* To run:
|
12 |
-
* mpirun -np 8 ./custom_all_reduce_test
|
13 |
-
*/
|
14 |
-
#include <cuda.h>
|
15 |
-
#include <curand_kernel.h>
|
16 |
-
#include <stdio.h>
|
17 |
-
#include <stdlib.h>
|
18 |
-
|
19 |
-
#include <limits>
|
20 |
-
#include <vector>
|
21 |
-
|
22 |
-
#include "cuda_profiler_api.h"
|
23 |
-
#include "custom_all_reduce.cuh"
|
24 |
-
#include "mpi.h"
|
25 |
-
#include "nccl.h"
|
26 |
-
|
27 |
-
#define MPICHECK(cmd) \
|
28 |
-
do { \
|
29 |
-
int e = cmd; \
|
30 |
-
if (e != MPI_SUCCESS) { \
|
31 |
-
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
32 |
-
exit(EXIT_FAILURE); \
|
33 |
-
} \
|
34 |
-
} while (0)
|
35 |
-
|
36 |
-
#define NCCLCHECK(cmd) \
|
37 |
-
do { \
|
38 |
-
ncclResult_t r = cmd; \
|
39 |
-
if (r != ncclSuccess) { \
|
40 |
-
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
41 |
-
ncclGetErrorString(r)); \
|
42 |
-
exit(EXIT_FAILURE); \
|
43 |
-
} \
|
44 |
-
} while (0)
|
45 |
-
|
46 |
-
__global__ void dummy_kernel() {
|
47 |
-
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
48 |
-
}
|
49 |
-
|
50 |
-
template <typename T>
|
51 |
-
__global__ void set_data(T *data, int size, int myRank) {
|
52 |
-
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
53 |
-
idx += gridDim.x * blockDim.x) {
|
54 |
-
data[idx] = myRank * 0.11f;
|
55 |
-
}
|
56 |
-
}
|
57 |
-
|
58 |
-
template <typename T>
|
59 |
-
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
60 |
-
double *fdata2, int size) {
|
61 |
-
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
62 |
-
idx += gridDim.x * blockDim.x) {
|
63 |
-
fdata1[idx] = data1[idx];
|
64 |
-
fdata2[idx] = data2[idx];
|
65 |
-
}
|
66 |
-
}
|
67 |
-
|
68 |
-
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
69 |
-
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
70 |
-
idx += gridDim.x * blockDim.x) {
|
71 |
-
for (int i = 0; i < nRanks; i++) {
|
72 |
-
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
73 |
-
}
|
74 |
-
}
|
75 |
-
}
|
76 |
-
|
77 |
-
template <typename T>
|
78 |
-
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
79 |
-
int myRank, int nRanks, int size) {
|
80 |
-
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
81 |
-
idx += gridDim.x * blockDim.x) {
|
82 |
-
double sum = 0.0;
|
83 |
-
for (int i = 0; i < nRanks; i++) {
|
84 |
-
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
85 |
-
T hval = val; // downcast first
|
86 |
-
sum += static_cast<double>(hval);
|
87 |
-
if (i == myRank) data[idx] = hval;
|
88 |
-
}
|
89 |
-
ground_truth[idx] = sum;
|
90 |
-
}
|
91 |
-
}
|
92 |
-
|
93 |
-
template <typename T>
|
94 |
-
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
95 |
-
int data_size) {
|
96 |
-
T *result;
|
97 |
-
cudaStream_t stream;
|
98 |
-
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
99 |
-
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
100 |
-
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
|
101 |
-
|
102 |
-
cudaIpcMemHandle_t self_data_handle;
|
103 |
-
cudaIpcMemHandle_t data_handles[8];
|
104 |
-
vllm::Metadata *buffer;
|
105 |
-
T *self_data_copy;
|
106 |
-
/**
|
107 |
-
* Allocate IPC buffer
|
108 |
-
*
|
109 |
-
* The first section is a temporary buffer for storing intermediate allreduce
|
110 |
-
* results, if a particular algorithm requires it. The second section is for
|
111 |
-
* the input to the allreduce. The actual API takes the input pointer as an
|
112 |
-
* argument (that is, they can and usually should be allocated separately).
|
113 |
-
* But since the input pointers and the temporary buffer all require IPC
|
114 |
-
* registration, they are allocated and registered together in the test for
|
115 |
-
* convenience.
|
116 |
-
*/
|
117 |
-
CUDACHECK(
|
118 |
-
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
119 |
-
CUDACHECK(cudaMemset(buffer, 0,
|
120 |
-
2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
121 |
-
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
122 |
-
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
123 |
-
|
124 |
-
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
|
125 |
-
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
126 |
-
MPI_BYTE, MPI_COMM_WORLD));
|
127 |
-
|
128 |
-
void *rank_data;
|
129 |
-
size_t rank_data_sz = 16 * 1024 * 1024;
|
130 |
-
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
131 |
-
std::vector<int64_t> offsets(nRanks, 0);
|
132 |
-
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
133 |
-
offsets, myRank);
|
134 |
-
auto *self_data =
|
135 |
-
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
136 |
-
sizeof(vllm::Metadata) + data_size * sizeof(T));
|
137 |
-
// hack buffer registration
|
138 |
-
{
|
139 |
-
std::vector<std::string> handles;
|
140 |
-
handles.reserve(nRanks);
|
141 |
-
for (int i = 0; i < nRanks; i++) {
|
142 |
-
char *begin = (char *)&data_handles[i];
|
143 |
-
char *end = (char *)&data_handles[i + 1];
|
144 |
-
handles.emplace_back(begin, end);
|
145 |
-
}
|
146 |
-
std::vector<int64_t> offsets(
|
147 |
-
nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
|
148 |
-
fa.register_buffer(handles, offsets, self_data);
|
149 |
-
}
|
150 |
-
|
151 |
-
double *ground_truth;
|
152 |
-
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
153 |
-
curandState_t *states;
|
154 |
-
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
155 |
-
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
156 |
-
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
157 |
-
nRanks, data_size);
|
158 |
-
CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
159 |
-
cudaMemcpyDeviceToDevice, stream));
|
160 |
-
cudaEvent_t start, stop;
|
161 |
-
CUDACHECK(cudaEventCreate(&start));
|
162 |
-
CUDACHECK(cudaEventCreate(&stop));
|
163 |
-
|
164 |
-
ncclDataType_t ncclDtype;
|
165 |
-
if (std::is_same<T, half>::value) {
|
166 |
-
ncclDtype = ncclFloat16;
|
167 |
-
} else if (std::is_same<T, nv_bfloat16>::value) {
|
168 |
-
ncclDtype = ncclBfloat16;
|
169 |
-
} else {
|
170 |
-
ncclDtype = ncclFloat;
|
171 |
-
}
|
172 |
-
|
173 |
-
dummy_kernel<<<1, 1, 0, stream>>>();
|
174 |
-
constexpr int warmup_iters = 5;
|
175 |
-
constexpr int num_iters = 25;
|
176 |
-
// warmup
|
177 |
-
for (int i = 0; i < warmup_iters; i++) {
|
178 |
-
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
179 |
-
stream));
|
180 |
-
}
|
181 |
-
CUDACHECK(cudaEventRecord(start, stream));
|
182 |
-
for (int i = 0; i < num_iters; i++) {
|
183 |
-
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
184 |
-
stream));
|
185 |
-
}
|
186 |
-
CUDACHECK(cudaEventRecord(stop, stream));
|
187 |
-
CUDACHECK(cudaStreamSynchronize(stream));
|
188 |
-
float allreduce_ms = 0;
|
189 |
-
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
190 |
-
|
191 |
-
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
|
192 |
-
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
|
193 |
-
|
194 |
-
dummy_kernel<<<1, 1, 0, stream>>>();
|
195 |
-
// warm up
|
196 |
-
for (int i = 0; i < warmup_iters; i++) {
|
197 |
-
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
198 |
-
}
|
199 |
-
CUDACHECK(cudaEventRecord(start, stream));
|
200 |
-
for (int i = 0; i < num_iters; i++) {
|
201 |
-
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
202 |
-
}
|
203 |
-
CUDACHECK(cudaEventRecord(stop, stream));
|
204 |
-
CUDACHECK(cudaStreamSynchronize(stream));
|
205 |
-
|
206 |
-
float duration_ms = 0;
|
207 |
-
cudaEventElapsedTime(&duration_ms, start, stop);
|
208 |
-
if (myRank == 0)
|
209 |
-
printf(
|
210 |
-
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
211 |
-
"time:%.2fus\n",
|
212 |
-
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
213 |
-
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
214 |
-
|
215 |
-
// And wait for all the queued up work to complete
|
216 |
-
CUDACHECK(cudaStreamSynchronize(stream));
|
217 |
-
|
218 |
-
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
219 |
-
ncclSum, comm, stream));
|
220 |
-
|
221 |
-
double *nccl_result, *my_result;
|
222 |
-
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
223 |
-
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
224 |
-
|
225 |
-
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
226 |
-
my_result, data_size);
|
227 |
-
CUDACHECK(cudaStreamSynchronize(stream));
|
228 |
-
|
229 |
-
for (unsigned long j = 0; j < data_size; j++) {
|
230 |
-
auto diff = abs(nccl_result[j] - my_result[j]);
|
231 |
-
if (diff >= 1e-2) {
|
232 |
-
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
233 |
-
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
234 |
-
break;
|
235 |
-
}
|
236 |
-
}
|
237 |
-
|
238 |
-
long double nccl_diffs = 0.0;
|
239 |
-
long double my_diffs = 0.0;
|
240 |
-
for (int j = 0; j < data_size; j++) {
|
241 |
-
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
242 |
-
my_diffs += abs(my_result[j] - ground_truth[j]);
|
243 |
-
}
|
244 |
-
if (myRank == 0)
|
245 |
-
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
246 |
-
<< " me: " << my_diffs / data_size << std::endl;
|
247 |
-
|
248 |
-
CUDACHECK(cudaFree(result));
|
249 |
-
CUDACHECK(cudaFree(self_data_copy));
|
250 |
-
CUDACHECK(cudaFree(rank_data));
|
251 |
-
CUDACHECK(cudaFree(buffer));
|
252 |
-
CUDACHECK(cudaFree(states));
|
253 |
-
CUDACHECK(cudaFreeHost(ground_truth));
|
254 |
-
CUDACHECK(cudaFreeHost(nccl_result));
|
255 |
-
CUDACHECK(cudaFreeHost(my_result));
|
256 |
-
CUDACHECK(cudaStreamDestroy(stream));
|
257 |
-
}
|
258 |
-
|
259 |
-
int main(int argc, char **argv) {
|
260 |
-
int nRanks, myRank;
|
261 |
-
MPICHECK(MPI_Init(&argc, &argv));
|
262 |
-
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
263 |
-
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
264 |
-
CUDACHECK(cudaSetDevice(myRank));
|
265 |
-
ncclUniqueId id;
|
266 |
-
ncclComm_t comm;
|
267 |
-
if (myRank == 0) ncclGetUniqueId(&id);
|
268 |
-
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
269 |
-
MPI_COMM_WORLD));
|
270 |
-
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
271 |
-
|
272 |
-
cudaProfilerStart();
|
273 |
-
// for (int threads : {256, 512}) {
|
274 |
-
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
275 |
-
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
276 |
-
// }
|
277 |
-
// }
|
278 |
-
for (int sz = 512; sz <= (32 << 20); sz *= 2) {
|
279 |
-
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
|
280 |
-
}
|
281 |
-
|
282 |
-
cudaProfilerStop();
|
283 |
-
return EXIT_SUCCESS;
|
284 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/dispatch_utils.h
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
/*
|
2 |
-
* Adapted from
|
3 |
-
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
4 |
-
*/
|
5 |
-
#pragma once
|
6 |
-
|
7 |
-
#include <torch/extension.h>
|
8 |
-
|
9 |
-
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
-
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
-
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
-
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
-
|
14 |
-
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
-
AT_DISPATCH_SWITCH( \
|
16 |
-
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
17 |
-
|
18 |
-
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
19 |
-
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
20 |
-
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
21 |
-
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
22 |
-
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
23 |
-
|
24 |
-
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
25 |
-
AT_DISPATCH_SWITCH( \
|
26 |
-
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
27 |
-
|
28 |
-
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
29 |
-
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
30 |
-
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
31 |
-
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
32 |
-
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
33 |
-
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
34 |
-
|
35 |
-
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
36 |
-
AT_DISPATCH_SWITCH( \
|
37 |
-
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/layernorm_kernels.cu
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include <ATen/cuda/CUDAContext.h>
|
3 |
-
#include <c10/cuda/CUDAGuard.h>
|
4 |
-
|
5 |
-
#include "dispatch_utils.h"
|
6 |
-
#include "reduction_utils.cuh"
|
7 |
-
|
8 |
-
namespace vllm {
|
9 |
-
|
10 |
-
// TODO(woosuk): Further optimize this kernel.
|
11 |
-
template<typename scalar_t>
|
12 |
-
__global__ void rms_norm_kernel(
|
13 |
-
scalar_t* __restrict__ out, // [..., hidden_size]
|
14 |
-
const scalar_t* __restrict__ input, // [..., hidden_size]
|
15 |
-
const scalar_t* __restrict__ weight, // [hidden_size]
|
16 |
-
const float epsilon,
|
17 |
-
const int num_tokens,
|
18 |
-
const int hidden_size) {
|
19 |
-
__shared__ float s_variance;
|
20 |
-
float variance = 0.0f;
|
21 |
-
|
22 |
-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
23 |
-
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
24 |
-
variance += x * x;
|
25 |
-
}
|
26 |
-
variance = blockReduceSum<float>(variance);
|
27 |
-
if (threadIdx.x == 0) {
|
28 |
-
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
29 |
-
}
|
30 |
-
__syncthreads();
|
31 |
-
|
32 |
-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
33 |
-
float x = (float) input[blockIdx.x * hidden_size + idx];
|
34 |
-
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
35 |
-
}
|
36 |
-
}
|
37 |
-
|
38 |
-
// TODO: Further optimize this kernel.
|
39 |
-
template<typename scalar_t>
|
40 |
-
__global__ void fused_add_rms_norm_kernel(
|
41 |
-
scalar_t* __restrict__ input, // [..., hidden_size]
|
42 |
-
scalar_t* __restrict__ residual, // [..., hidden_size]
|
43 |
-
const scalar_t* __restrict__ weight, // [hidden_size]
|
44 |
-
const float epsilon,
|
45 |
-
const int num_tokens,
|
46 |
-
const int hidden_size) {
|
47 |
-
__shared__ float s_variance;
|
48 |
-
float variance = 0.0f;
|
49 |
-
|
50 |
-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
51 |
-
float x = (float) input[blockIdx.x * hidden_size + idx];
|
52 |
-
x += (float) residual[blockIdx.x * hidden_size + idx];
|
53 |
-
variance += x * x;
|
54 |
-
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
|
55 |
-
}
|
56 |
-
variance = blockReduceSum<float>(variance);
|
57 |
-
if (threadIdx.x == 0) {
|
58 |
-
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
59 |
-
}
|
60 |
-
__syncthreads();
|
61 |
-
|
62 |
-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
63 |
-
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
64 |
-
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
65 |
-
}
|
66 |
-
}
|
67 |
-
|
68 |
-
} // namespace vllm
|
69 |
-
|
70 |
-
void rms_norm(
|
71 |
-
torch::Tensor& out, // [..., hidden_size]
|
72 |
-
torch::Tensor& input, // [..., hidden_size]
|
73 |
-
torch::Tensor& weight, // [hidden_size]
|
74 |
-
float epsilon) {
|
75 |
-
int hidden_size = input.size(-1);
|
76 |
-
int num_tokens = input.numel() / hidden_size;
|
77 |
-
|
78 |
-
dim3 grid(num_tokens);
|
79 |
-
dim3 block(std::min(hidden_size, 1024));
|
80 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
81 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
82 |
-
VLLM_DISPATCH_FLOATING_TYPES(
|
83 |
-
input.scalar_type(),
|
84 |
-
"rms_norm_kernel",
|
85 |
-
[&] {
|
86 |
-
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
87 |
-
out.data_ptr<scalar_t>(),
|
88 |
-
input.data_ptr<scalar_t>(),
|
89 |
-
weight.data_ptr<scalar_t>(),
|
90 |
-
epsilon,
|
91 |
-
num_tokens,
|
92 |
-
hidden_size);
|
93 |
-
});
|
94 |
-
}
|
95 |
-
|
96 |
-
void fused_add_rms_norm(
|
97 |
-
torch::Tensor& input, // [..., hidden_size]
|
98 |
-
torch::Tensor& residual, // [..., hidden_size]
|
99 |
-
torch::Tensor& weight, // [hidden_size]
|
100 |
-
float epsilon) {
|
101 |
-
int hidden_size = input.size(-1);
|
102 |
-
int num_tokens = input.numel() / hidden_size;
|
103 |
-
|
104 |
-
dim3 grid(num_tokens);
|
105 |
-
dim3 block(std::min(hidden_size, 1024));
|
106 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
107 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
108 |
-
VLLM_DISPATCH_FLOATING_TYPES(
|
109 |
-
input.scalar_type(),
|
110 |
-
"fused_add_rms_norm_kernel",
|
111 |
-
[&] {
|
112 |
-
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
113 |
-
input.data_ptr<scalar_t>(),
|
114 |
-
residual.data_ptr<scalar_t>(),
|
115 |
-
weight.data_ptr<scalar_t>(),
|
116 |
-
epsilon,
|
117 |
-
num_tokens,
|
118 |
-
hidden_size);
|
119 |
-
});
|
120 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/moe_align_block_size_kernels.cu
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include <ATen/cuda/CUDAContext.h>
|
3 |
-
|
4 |
-
#include <ATen/ATen.h>
|
5 |
-
#include <THC/THCAtomics.cuh>
|
6 |
-
|
7 |
-
#include "cuda_compat.h"
|
8 |
-
#include "dispatch_utils.h"
|
9 |
-
|
10 |
-
const static size_t NUM_MAX_EXPERTS = 64;
|
11 |
-
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
12 |
-
|
13 |
-
namespace vllm {
|
14 |
-
template <typename scalar_t>
|
15 |
-
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
16 |
-
int32_t *sorted_token_ids,
|
17 |
-
int32_t *expert_ids,
|
18 |
-
int32_t *total_tokens_post_pad,
|
19 |
-
int32_t num_experts,
|
20 |
-
int32_t block_size,
|
21 |
-
size_t numel) {
|
22 |
-
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
23 |
-
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
24 |
-
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
|
25 |
-
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
|
26 |
-
for (int i = 0; i < num_experts; ++i) {
|
27 |
-
tokens_cnts[threadIdx.x + 1][i] = 0;
|
28 |
-
}
|
29 |
-
|
30 |
-
/**
|
31 |
-
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
32 |
-
* which counts how many tokens in the token shard of thread_index are assigned
|
33 |
-
* to expert expert_index.
|
34 |
-
*/
|
35 |
-
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
36 |
-
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
|
37 |
-
}
|
38 |
-
|
39 |
-
__syncthreads();
|
40 |
-
|
41 |
-
// For each expert we accumulate the token counts from the different threads.
|
42 |
-
tokens_cnts[0][threadIdx.x] = 0;
|
43 |
-
for (int i = 1; i <= blockDim.x; ++i) {
|
44 |
-
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
|
45 |
-
}
|
46 |
-
|
47 |
-
__syncthreads();
|
48 |
-
|
49 |
-
// We accumulate the token counts of all experts in thread 0.
|
50 |
-
if (threadIdx.x == 0) {
|
51 |
-
cumsum[0] = 0;
|
52 |
-
for (int i = 1; i <= num_experts; ++i) {
|
53 |
-
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
|
54 |
-
}
|
55 |
-
*total_tokens_post_pad = cumsum[num_experts];
|
56 |
-
}
|
57 |
-
|
58 |
-
__syncthreads();
|
59 |
-
|
60 |
-
/**
|
61 |
-
* For each expert, each thread processes the tokens of the corresponding blocks
|
62 |
-
* and stores the corresponding expert_id for each block.
|
63 |
-
*/
|
64 |
-
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
65 |
-
expert_ids[i / block_size] = threadIdx.x;
|
66 |
-
}
|
67 |
-
|
68 |
-
/**
|
69 |
-
* Each thread processes a token shard, calculating the index of each token after
|
70 |
-
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
71 |
-
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
72 |
-
* where * represents a padding value(preset in python).
|
73 |
-
*/
|
74 |
-
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
75 |
-
int32_t expert_id = topk_ids[i];
|
76 |
-
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
77 |
-
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
78 |
-
* stores the indices of the tokens processed by the expert with expert_id within
|
79 |
-
* the current thread's token shard.
|
80 |
-
*/
|
81 |
-
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
|
82 |
-
sorted_token_ids[rank_post_pad] = i;
|
83 |
-
++tokens_cnts[threadIdx.x][expert_id];
|
84 |
-
}
|
85 |
-
}
|
86 |
-
}
|
87 |
-
|
88 |
-
void moe_align_block_size(
|
89 |
-
torch::Tensor topk_ids,
|
90 |
-
int num_experts,
|
91 |
-
int block_size,
|
92 |
-
torch::Tensor sorted_token_ids,
|
93 |
-
torch::Tensor experts_ids,
|
94 |
-
torch::Tensor num_tokens_post_pad) {
|
95 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
96 |
-
assert(num_experts <= NUM_MAX_EXPERTS);
|
97 |
-
VLLM_DISPATCH_INTEGRAL_TYPES(
|
98 |
-
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
99 |
-
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
|
100 |
-
topk_ids.data_ptr<scalar_t>(),
|
101 |
-
sorted_token_ids.data_ptr<int32_t>(),
|
102 |
-
experts_ids.data_ptr<int32_t>(),
|
103 |
-
num_tokens_post_pad.data_ptr<int32_t>(),
|
104 |
-
num_experts,
|
105 |
-
block_size,
|
106 |
-
topk_ids.numel());
|
107 |
-
});
|
108 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/ops.h
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <torch/extension.h>
|
4 |
-
|
5 |
-
void paged_attention_v1(
|
6 |
-
torch::Tensor& out,
|
7 |
-
torch::Tensor& query,
|
8 |
-
torch::Tensor& key_cache,
|
9 |
-
torch::Tensor& value_cache,
|
10 |
-
int num_kv_heads,
|
11 |
-
float scale,
|
12 |
-
torch::Tensor& block_tables,
|
13 |
-
torch::Tensor& context_lens,
|
14 |
-
int block_size,
|
15 |
-
int max_context_len,
|
16 |
-
const c10::optional<torch::Tensor>& alibi_slopes,
|
17 |
-
const std::string& kv_cache_dtype);
|
18 |
-
|
19 |
-
void paged_attention_v2(
|
20 |
-
torch::Tensor& out,
|
21 |
-
torch::Tensor& exp_sums,
|
22 |
-
torch::Tensor& max_logits,
|
23 |
-
torch::Tensor& tmp_out,
|
24 |
-
torch::Tensor& query,
|
25 |
-
torch::Tensor& key_cache,
|
26 |
-
torch::Tensor& value_cache,
|
27 |
-
int num_kv_heads,
|
28 |
-
float scale,
|
29 |
-
torch::Tensor& block_tables,
|
30 |
-
torch::Tensor& context_lens,
|
31 |
-
int block_size,
|
32 |
-
int max_context_len,
|
33 |
-
const c10::optional<torch::Tensor>& alibi_slopes,
|
34 |
-
const std::string& kv_cache_dtype);
|
35 |
-
|
36 |
-
void rms_norm(
|
37 |
-
torch::Tensor& out,
|
38 |
-
torch::Tensor& input,
|
39 |
-
torch::Tensor& weight,
|
40 |
-
float epsilon);
|
41 |
-
|
42 |
-
void fused_add_rms_norm(
|
43 |
-
torch::Tensor& input,
|
44 |
-
torch::Tensor& residual,
|
45 |
-
torch::Tensor& weight,
|
46 |
-
float epsilon);
|
47 |
-
|
48 |
-
void rotary_embedding(
|
49 |
-
torch::Tensor& positions,
|
50 |
-
torch::Tensor& query,
|
51 |
-
torch::Tensor& key,
|
52 |
-
int head_size,
|
53 |
-
torch::Tensor& cos_sin_cache,
|
54 |
-
bool is_neox);
|
55 |
-
|
56 |
-
void silu_and_mul(
|
57 |
-
torch::Tensor& out,
|
58 |
-
torch::Tensor& input);
|
59 |
-
|
60 |
-
void gelu_new(
|
61 |
-
torch::Tensor& out,
|
62 |
-
torch::Tensor& input);
|
63 |
-
|
64 |
-
void gelu_fast(
|
65 |
-
torch::Tensor& out,
|
66 |
-
torch::Tensor& input);
|
67 |
-
|
68 |
-
#ifndef USE_ROCM
|
69 |
-
torch::Tensor awq_gemm(
|
70 |
-
torch::Tensor _in_feats,
|
71 |
-
torch::Tensor _kernel,
|
72 |
-
torch::Tensor _scaling_factors,
|
73 |
-
torch::Tensor _zeros,
|
74 |
-
int split_k_iters);
|
75 |
-
|
76 |
-
torch::Tensor awq_dequantize(
|
77 |
-
torch::Tensor _kernel,
|
78 |
-
torch::Tensor _scaling_factors,
|
79 |
-
torch::Tensor _zeros,
|
80 |
-
int split_k_iters,
|
81 |
-
int thx,
|
82 |
-
int thy);
|
83 |
-
#endif
|
84 |
-
|
85 |
-
void squeezellm_gemm(
|
86 |
-
torch::Tensor vec,
|
87 |
-
torch::Tensor mat,
|
88 |
-
torch::Tensor mul,
|
89 |
-
torch::Tensor lookup_table);
|
90 |
-
|
91 |
-
torch::Tensor gptq_gemm(
|
92 |
-
torch::Tensor a,
|
93 |
-
torch::Tensor b_q_weight,
|
94 |
-
torch::Tensor b_gptq_qzeros,
|
95 |
-
torch::Tensor b_gptq_scales,
|
96 |
-
torch::Tensor b_g_idx,
|
97 |
-
bool use_exllama);
|
98 |
-
|
99 |
-
void gptq_shuffle(
|
100 |
-
torch::Tensor q_weight,
|
101 |
-
torch::Tensor q_perm);
|
102 |
-
|
103 |
-
void moe_align_block_size(
|
104 |
-
torch::Tensor topk_ids,
|
105 |
-
int num_experts,
|
106 |
-
int block_size,
|
107 |
-
torch::Tensor sorted_token_ids,
|
108 |
-
torch::Tensor experts_ids,
|
109 |
-
torch::Tensor num_tokens_post_pad);
|
110 |
-
|
111 |
-
#ifndef USE_ROCM
|
112 |
-
using fptr_t = uint64_t;
|
113 |
-
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
114 |
-
const std::vector<std::string> &handles,
|
115 |
-
const std::vector<int64_t> &offsets, int rank,
|
116 |
-
bool full_nvlink);
|
117 |
-
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
118 |
-
bool full_nvlink);
|
119 |
-
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
120 |
-
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
121 |
-
torch::Tensor &out);
|
122 |
-
void dispose(fptr_t _fa);
|
123 |
-
int meta_size();
|
124 |
-
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
125 |
-
const std::vector<std::string> &handles,
|
126 |
-
const std::vector<int64_t> &offsets);
|
127 |
-
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
128 |
-
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
129 |
-
const std::vector<std::vector<int64_t>> &offsets);
|
130 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/pos_encoding_kernels.cu
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
#include <ATen/cuda/CUDAContext.h>
|
3 |
-
#include <c10/cuda/CUDAGuard.h>
|
4 |
-
|
5 |
-
#include "cuda_compat.h"
|
6 |
-
#include "dispatch_utils.h"
|
7 |
-
|
8 |
-
namespace vllm {
|
9 |
-
|
10 |
-
template<typename scalar_t, bool IS_NEOX>
|
11 |
-
inline __device__ void apply_rotary_embedding(
|
12 |
-
scalar_t* __restrict__ arr,
|
13 |
-
const scalar_t* __restrict__ cos_ptr,
|
14 |
-
const scalar_t* __restrict__ sin_ptr,
|
15 |
-
int rot_offset,
|
16 |
-
int embed_dim)
|
17 |
-
{
|
18 |
-
int x_index, y_index;
|
19 |
-
scalar_t cos, sin;
|
20 |
-
if (IS_NEOX) {
|
21 |
-
// GPT-NeoX style rotary embedding.
|
22 |
-
x_index = rot_offset;
|
23 |
-
y_index = embed_dim + rot_offset;
|
24 |
-
cos = VLLM_LDG(cos_ptr + x_index);
|
25 |
-
sin = VLLM_LDG(sin_ptr + x_index);
|
26 |
-
} else {
|
27 |
-
// GPT-J style rotary embedding.
|
28 |
-
x_index = 2 * rot_offset;
|
29 |
-
y_index = 2 * rot_offset + 1;
|
30 |
-
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
31 |
-
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
32 |
-
}
|
33 |
-
|
34 |
-
const scalar_t x = arr[x_index];
|
35 |
-
const scalar_t y = arr[y_index];
|
36 |
-
arr[x_index] = x * cos - y * sin;
|
37 |
-
arr[y_index] = y * cos + x * sin;
|
38 |
-
}
|
39 |
-
|
40 |
-
template<typename scalar_t, bool IS_NEOX>
|
41 |
-
__global__ void rotary_embedding_kernel(
|
42 |
-
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
43 |
-
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
44 |
-
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
45 |
-
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
46 |
-
const int rot_dim,
|
47 |
-
const int64_t query_stride,
|
48 |
-
const int64_t key_stride,
|
49 |
-
const int num_heads,
|
50 |
-
const int num_kv_heads,
|
51 |
-
const int head_size) {
|
52 |
-
// Each thread block is responsible for one token.
|
53 |
-
const int token_idx = blockIdx.x;
|
54 |
-
int64_t pos = positions[token_idx];
|
55 |
-
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
56 |
-
|
57 |
-
const int embed_dim = rot_dim / 2;
|
58 |
-
const scalar_t* cos_ptr = cache_ptr;
|
59 |
-
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
60 |
-
|
61 |
-
const int nq = num_heads * embed_dim;
|
62 |
-
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
63 |
-
const int head_idx = i / embed_dim;
|
64 |
-
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
65 |
-
const int rot_offset = i % embed_dim;
|
66 |
-
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
67 |
-
sin_ptr, rot_offset, embed_dim);
|
68 |
-
}
|
69 |
-
|
70 |
-
const int nk = num_kv_heads * embed_dim;
|
71 |
-
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
72 |
-
const int head_idx = i / embed_dim;
|
73 |
-
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
74 |
-
const int rot_offset = i % embed_dim;
|
75 |
-
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
76 |
-
sin_ptr, rot_offset, embed_dim);
|
77 |
-
}
|
78 |
-
}
|
79 |
-
|
80 |
-
} // namespace vllm
|
81 |
-
|
82 |
-
void rotary_embedding(
|
83 |
-
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
84 |
-
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
85 |
-
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
86 |
-
int head_size,
|
87 |
-
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
88 |
-
bool is_neox) {
|
89 |
-
int64_t num_tokens = query.numel() / query.size(-1);
|
90 |
-
int rot_dim = cos_sin_cache.size(1);
|
91 |
-
int num_heads = query.size(-1) / head_size;
|
92 |
-
int num_kv_heads = key.size(-1) / head_size;
|
93 |
-
int64_t query_stride = query.stride(-2);
|
94 |
-
int64_t key_stride = key.stride(-2);
|
95 |
-
|
96 |
-
dim3 grid(num_tokens);
|
97 |
-
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
98 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
99 |
-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
100 |
-
VLLM_DISPATCH_FLOATING_TYPES(
|
101 |
-
query.scalar_type(),
|
102 |
-
"rotary_embedding",
|
103 |
-
[&] {
|
104 |
-
if (is_neox) {
|
105 |
-
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
106 |
-
positions.data_ptr<int64_t>(),
|
107 |
-
query.data_ptr<scalar_t>(),
|
108 |
-
key.data_ptr<scalar_t>(),
|
109 |
-
cos_sin_cache.data_ptr<scalar_t>(),
|
110 |
-
rot_dim,
|
111 |
-
query_stride,
|
112 |
-
key_stride,
|
113 |
-
num_heads,
|
114 |
-
num_kv_heads,
|
115 |
-
head_size);
|
116 |
-
} else {
|
117 |
-
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
118 |
-
positions.data_ptr<int64_t>(),
|
119 |
-
query.data_ptr<scalar_t>(),
|
120 |
-
key.data_ptr<scalar_t>(),
|
121 |
-
cos_sin_cache.data_ptr<scalar_t>(),
|
122 |
-
rot_dim,
|
123 |
-
query_stride,
|
124 |
-
key_stride,
|
125 |
-
num_heads,
|
126 |
-
num_kv_heads,
|
127 |
-
head_size);
|
128 |
-
}
|
129 |
-
});
|
130 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/punica/LICENSE
DELETED
@@ -1,217 +0,0 @@
|
|
1 |
-
Contains code from https://github.com/punica-ai/punica
|
2 |
-
|
3 |
-
Apache License
|
4 |
-
Version 2.0, January 2004
|
5 |
-
http://www.apache.org/licenses/
|
6 |
-
|
7 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
-
|
9 |
-
1. Definitions.
|
10 |
-
|
11 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
-
|
14 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
-
the copyright owner that is granting the License.
|
16 |
-
|
17 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
-
other entities that control, are controlled by, or are under common
|
19 |
-
control with that entity. For the purposes of this definition,
|
20 |
-
"control" means (i) the power, direct or indirect, to cause the
|
21 |
-
direction or management of such entity, whether by contract or
|
22 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
-
|
25 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
-
exercising permissions granted by this License.
|
27 |
-
|
28 |
-
"Source" form shall mean the preferred form for making modifications,
|
29 |
-
including but not limited to software source code, documentation
|
30 |
-
source, and configuration files.
|
31 |
-
|
32 |
-
"Object" form shall mean any form resulting from mechanical
|
33 |
-
transformation or translation of a Source form, including but
|
34 |
-
not limited to compiled object code, generated documentation,
|
35 |
-
and conversions to other media types.
|
36 |
-
|
37 |
-
"Work" shall mean the work of authorship, whether in Source or
|
38 |
-
Object form, made available under the License, as indicated by a
|
39 |
-
copyright notice that is included in or attached to the work
|
40 |
-
(an example is provided in the Appendix below).
|
41 |
-
|
42 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
-
form, that is based on (or derived from) the Work and for which the
|
44 |
-
editorial revisions, annotations, elaborations, or other modifications
|
45 |
-
represent, as a whole, an original work of authorship. For the purposes
|
46 |
-
of this License, Derivative Works shall not include works that remain
|
47 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
-
the Work and Derivative Works thereof.
|
49 |
-
|
50 |
-
"Contribution" shall mean any work of authorship, including
|
51 |
-
the original version of the Work and any modifications or additions
|
52 |
-
to that Work or Derivative Works thereof, that is intentionally
|
53 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
-
means any form of electronic, verbal, or written communication sent
|
57 |
-
to the Licensor or its representatives, including but not limited to
|
58 |
-
communication on electronic mailing lists, source code control systems,
|
59 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
-
Licensor for the purpose of discussing and improving the Work, but
|
61 |
-
excluding communication that is conspicuously marked or otherwise
|
62 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
-
|
64 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
-
on behalf of whom a Contribution has been received by Licensor and
|
66 |
-
subsequently incorporated within the Work.
|
67 |
-
|
68 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
-
this License, each Contributor hereby grants to You a perpetual,
|
70 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
-
copyright license to reproduce, prepare Derivative Works of,
|
72 |
-
publicly display, publicly perform, sublicense, and distribute the
|
73 |
-
Work and such Derivative Works in Source or Object form.
|
74 |
-
|
75 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
-
this License, each Contributor hereby grants to You a perpetual,
|
77 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
-
(except as stated in this section) patent license to make, have made,
|
79 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
-
where such license applies only to those patent claims licensable
|
81 |
-
by such Contributor that are necessarily infringed by their
|
82 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
-
with the Work to which such Contribution(s) was submitted. If You
|
84 |
-
institute patent litigation against any entity (including a
|
85 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
-
or a Contribution incorporated within the Work constitutes direct
|
87 |
-
or contributory patent infringement, then any patent licenses
|
88 |
-
granted to You under this License for that Work shall terminate
|
89 |
-
as of the date such litigation is filed.
|
90 |
-
|
91 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
-
Work or Derivative Works thereof in any medium, with or without
|
93 |
-
modifications, and in Source or Object form, provided that You
|
94 |
-
meet the following conditions:
|
95 |
-
|
96 |
-
(a) You must give any other recipients of the Work or
|
97 |
-
Derivative Works a copy of this License; and
|
98 |
-
|
99 |
-
(b) You must cause any modified files to carry prominent notices
|
100 |
-
stating that You changed the files; and
|
101 |
-
|
102 |
-
(c) You must retain, in the Source form of any Derivative Works
|
103 |
-
that You distribute, all copyright, patent, trademark, and
|
104 |
-
attribution notices from the Source form of the Work,
|
105 |
-
excluding those notices that do not pertain to any part of
|
106 |
-
the Derivative Works; and
|
107 |
-
|
108 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
-
distribution, then any Derivative Works that You distribute must
|
110 |
-
include a readable copy of the attribution notices contained
|
111 |
-
within such NOTICE file, excluding those notices that do not
|
112 |
-
pertain to any part of the Derivative Works, in at least one
|
113 |
-
of the following places: within a NOTICE text file distributed
|
114 |
-
as part of the Derivative Works; within the Source form or
|
115 |
-
documentation, if provided along with the Derivative Works; or,
|
116 |
-
within a display generated by the Derivative Works, if and
|
117 |
-
wherever such third-party notices normally appear. The contents
|
118 |
-
of the NOTICE file are for informational purposes only and
|
119 |
-
do not modify the License. You may add Your own attribution
|
120 |
-
notices within Derivative Works that You distribute, alongside
|
121 |
-
or as an addendum to the NOTICE text from the Work, provided
|
122 |
-
that such additional attribution notices cannot be construed
|
123 |
-
as modifying the License.
|
124 |
-
|
125 |
-
You may add Your own copyright statement to Your modifications and
|
126 |
-
may provide additional or different license terms and conditions
|
127 |
-
for use, reproduction, or distribution of Your modifications, or
|
128 |
-
for any such Derivative Works as a whole, provided Your use,
|
129 |
-
reproduction, and distribution of the Work otherwise complies with
|
130 |
-
the conditions stated in this License.
|
131 |
-
|
132 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
-
any Contribution intentionally submitted for inclusion in the Work
|
134 |
-
by You to the Licensor shall be under the terms and conditions of
|
135 |
-
this License, without any additional terms or conditions.
|
136 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
-
the terms of any separate license agreement you may have executed
|
138 |
-
with Licensor regarding such Contributions.
|
139 |
-
|
140 |
-
6. Trademarks. This License does not grant permission to use the trade
|
141 |
-
names, trademarks, service marks, or product names of the Licensor,
|
142 |
-
except as required for reasonable and customary use in describing the
|
143 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
-
|
145 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
-
agreed to in writing, Licensor provides the Work (and each
|
147 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
-
implied, including, without limitation, any warranties or conditions
|
150 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
-
appropriateness of using or redistributing the Work and assume any
|
153 |
-
risks associated with Your exercise of permissions under this License.
|
154 |
-
|
155 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
-
whether in tort (including negligence), contract, or otherwise,
|
157 |
-
unless required by applicable law (such as deliberate and grossly
|
158 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
-
liable to You for damages, including any direct, indirect, special,
|
160 |
-
incidental, or consequential damages of any character arising as a
|
161 |
-
result of this License or out of the use or inability to use the
|
162 |
-
Work (including but not limited to damages for loss of goodwill,
|
163 |
-
work stoppage, computer failure or malfunction, or any and all
|
164 |
-
other commercial damages or losses), even if such Contributor
|
165 |
-
has been advised of the possibility of such damages.
|
166 |
-
|
167 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
-
or other liability obligations and/or rights consistent with this
|
171 |
-
License. However, in accepting such obligations, You may act only
|
172 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
-
of any other Contributor, and only if You agree to indemnify,
|
174 |
-
defend, and hold each Contributor harmless for any liability
|
175 |
-
incurred by, or claims asserted against, such Contributor by reason
|
176 |
-
of your accepting any such warranty or additional liability.
|
177 |
-
|
178 |
-
END OF TERMS AND CONDITIONS
|
179 |
-
|
180 |
-
APPENDIX: How to apply the Apache License to your work.
|
181 |
-
|
182 |
-
To apply the Apache License to your work, attach the following
|
183 |
-
boilerplate notice, with the fields enclosed by brackets "{}"
|
184 |
-
replaced with your own identifying information. (Don't include
|
185 |
-
the brackets!) The text should be enclosed in the appropriate
|
186 |
-
comment syntax for the file format. We also recommend that a
|
187 |
-
file or class name and description of purpose be included on the
|
188 |
-
same "printed page" as the copyright notice for easier
|
189 |
-
identification within third-party archives.
|
190 |
-
|
191 |
-
Copyright {yyyy} {name of copyright owner}
|
192 |
-
|
193 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
-
you may not use this file except in compliance with the License.
|
195 |
-
You may obtain a copy of the License at
|
196 |
-
|
197 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
-
|
199 |
-
Unless required by applicable law or agreed to in writing, software
|
200 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
-
See the License for the specific language governing permissions and
|
203 |
-
limitations under the License.
|
204 |
-
|
205 |
-
------------------------------------------------------------------------------------
|
206 |
-
|
207 |
-
This product bundles various third-party components under other open source licenses.
|
208 |
-
This section summarizes those components and their licenses. See licenses/
|
209 |
-
for text of these licenses.
|
210 |
-
|
211 |
-
|
212 |
-
Apache-2.0
|
213 |
-
* third_party/nvbench (with LLVM exception)
|
214 |
-
* third_party/flashinfer
|
215 |
-
|
216 |
-
BSD-3-Clause:
|
217 |
-
* third_party/cutlass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_config.h
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
4 |
-
typename W_T>
|
5 |
-
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
6 |
-
const W_T *__restrict__ W,
|
7 |
-
const int64_t *__restrict__ indicies, int64_t y_offset,
|
8 |
-
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
9 |
-
int64_t layer_idx, float scale);
|
10 |
-
|
11 |
-
// clang-format off
|
12 |
-
|
13 |
-
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
14 |
-
f(in_T, out_T, W_T, narrow, 128) \
|
15 |
-
f(in_T, out_T, W_T, narrow, 256) \
|
16 |
-
f(in_T, out_T, W_T, narrow, 512) \
|
17 |
-
f(in_T, out_T, W_T, narrow, 1024) \
|
18 |
-
f(in_T, out_T, W_T, narrow, 1280) \
|
19 |
-
f(in_T, out_T, W_T, narrow, 1728) \
|
20 |
-
f(in_T, out_T, W_T, narrow, 1792) \
|
21 |
-
f(in_T, out_T, W_T, narrow, 2048) \
|
22 |
-
f(in_T, out_T, W_T, narrow, 2560) \
|
23 |
-
f(in_T, out_T, W_T, narrow, 2752) \
|
24 |
-
f(in_T, out_T, W_T, narrow, 3072) \
|
25 |
-
f(in_T, out_T, W_T, narrow, 3456) \
|
26 |
-
f(in_T, out_T, W_T, narrow, 3584) \
|
27 |
-
f(in_T, out_T, W_T, narrow, 4096) \
|
28 |
-
f(in_T, out_T, W_T, narrow, 5120) \
|
29 |
-
f(in_T, out_T, W_T, narrow, 5504) \
|
30 |
-
f(in_T, out_T, W_T, narrow, 5632) \
|
31 |
-
f(in_T, out_T, W_T, narrow, 6912) \
|
32 |
-
f(in_T, out_T, W_T, narrow, 7168) \
|
33 |
-
f(in_T, out_T, W_T, narrow, 8192) \
|
34 |
-
f(in_T, out_T, W_T, narrow, 9216) \
|
35 |
-
f(in_T, out_T, W_T, narrow, 10240) \
|
36 |
-
f(in_T, out_T, W_T, narrow, 11008) \
|
37 |
-
f(in_T, out_T, W_T, narrow, 12288) \
|
38 |
-
f(in_T, out_T, W_T, narrow, 13824) \
|
39 |
-
f(in_T, out_T, W_T, narrow, 14336) \
|
40 |
-
f(in_T, out_T, W_T, narrow, 16384) \
|
41 |
-
f(in_T, out_T, W_T, narrow, 20480) \
|
42 |
-
f(in_T, out_T, W_T, narrow, 28672) \
|
43 |
-
f(in_T, out_T, W_T, narrow, 32000) \
|
44 |
-
f(in_T, out_T, W_T, narrow, 32256) \
|
45 |
-
f(in_T, out_T, W_T, narrow, 32512) \
|
46 |
-
f(in_T, out_T, W_T, narrow, 32768) \
|
47 |
-
f(in_T, out_T, W_T, narrow, 33024) \
|
48 |
-
f(in_T, out_T, W_T, narrow, 36864) \
|
49 |
-
f(in_T, out_T, W_T, narrow, 49152) \
|
50 |
-
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
51 |
-
|
52 |
-
// Keep this in sync with vllm/config::LoRAConfig
|
53 |
-
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
54 |
-
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
55 |
-
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
56 |
-
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
57 |
-
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
58 |
-
|
59 |
-
// clang-format on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
|
|
|
|
|
|
|
|
|
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
#include "bgmv_config.h"
|
2 |
-
#include "bgmv_impl.cuh"
|
3 |
-
|
4 |
-
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
|
|
|
|
|
|
|
|
|