bsmit1659 commited on
Commit
b457808
·
1 Parent(s): e2d4dfc

Deleting prior code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CONTRIBUTING.md +0 -77
  2. Dockerfile +0 -121
  3. Dockerfile.rocm +0 -88
  4. LICENSE +0 -201
  5. MANIFEST.in +0 -4
  6. README.md +0 -8
  7. benchmarks/README.md +0 -8
  8. benchmarks/benchmark_latency.py +0 -139
  9. benchmarks/benchmark_serving.py +0 -249
  10. benchmarks/benchmark_throughput.py +0 -328
  11. benchmarks/kernels/benchmark_paged_attention.py +0 -196
  12. benchmarks/launch_tgi_server.sh +0 -16
  13. csrc/activation_kernels.cu +0 -118
  14. csrc/attention/attention_dtypes.h +0 -7
  15. csrc/attention/attention_generic.cuh +0 -64
  16. csrc/attention/attention_kernels.cu +0 -951
  17. csrc/attention/attention_utils.cuh +0 -56
  18. csrc/attention/dtype_bfloat16.cuh +0 -451
  19. csrc/attention/dtype_float16.cuh +0 -502
  20. csrc/attention/dtype_float32.cuh +0 -273
  21. csrc/attention/dtype_fp8_e5m2.cuh +0 -35
  22. csrc/cache.h +0 -36
  23. csrc/cache_kernels.cu +0 -474
  24. csrc/cuda_compat.h +0 -28
  25. csrc/cuda_utils.h +0 -10
  26. csrc/cuda_utils_kernels.cu +0 -35
  27. csrc/custom_all_reduce.cu +0 -148
  28. csrc/custom_all_reduce.cuh +0 -562
  29. csrc/custom_all_reduce_test.cu +0 -284
  30. csrc/dispatch_utils.h +0 -37
  31. csrc/layernorm_kernels.cu +0 -120
  32. csrc/moe_align_block_size_kernels.cu +0 -108
  33. csrc/ops.h +0 -130
  34. csrc/pos_encoding_kernels.cu +0 -130
  35. csrc/punica/LICENSE +0 -217
  36. csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +0 -4
  37. csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu +0 -4
  38. csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu +0 -4
  39. csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu +0 -4
  40. csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +0 -4
  41. csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu +0 -4
  42. csrc/punica/bgmv/bgmv_config.h +0 -59
  43. csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu +0 -4
  44. csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu +0 -4
  45. csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu +0 -4
  46. csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +0 -4
  47. csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu +0 -4
  48. csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +0 -4
  49. csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +0 -4
  50. csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu +0 -4
CONTRIBUTING.md DELETED
@@ -1,77 +0,0 @@
1
- # Contributing to vLLM
2
-
3
- Thank you for your interest in contributing to vLLM!
4
- Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
5
- There are several ways you can contribute to the project:
6
-
7
- - Identify and report any issues or bugs.
8
- - Request or add a new model.
9
- - Suggest or implement new features.
10
-
11
- However, remember that contributions aren't just about code.
12
- We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
13
-
14
- Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
15
- Talk about it in your blog posts, highlighting how it's driving your incredible projects.
16
- Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
17
-
18
-
19
- ## Setup for development
20
-
21
- ### Build from source
22
-
23
- ```bash
24
- pip install -r requirements.txt
25
- pip install -e . # This may take several minutes.
26
- ```
27
-
28
- ### Testing
29
-
30
- ```bash
31
- pip install -r requirements-dev.txt
32
-
33
- # Static type checking
34
- mypy
35
- # Unit tests
36
- pytest tests/
37
- ```
38
- **Note:** Currently, the repository does not pass the mypy tests.
39
-
40
-
41
- ## Contributing Guidelines
42
-
43
- ### Issue Reporting
44
-
45
- If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
46
- If not, please file a new issue, providing as much relevant information as possible.
47
-
48
- ### Coding Style Guide
49
-
50
- In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
51
-
52
- We include a formatting script [`format.sh`](./format.sh) to format the code.
53
-
54
- ### Pull Requests
55
-
56
- When submitting a pull request:
57
-
58
- 1. Make sure your code has been rebased on top of the latest commit on the main branch.
59
- 2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
60
- 3. Include a detailed description of the changes in the pull request.
61
- Explain why you made the changes you did.
62
- If your pull request fixes an open issue, please include a reference to it in the description.
63
-
64
- ### Code Reviews
65
-
66
- All submissions, including submissions by project members, require a code review.
67
- To make the review process as smooth as possible, please:
68
-
69
- 1. Keep your changes as concise as possible.
70
- If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
71
- 2. Respond to all comments within a reasonable time frame.
72
- If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
73
-
74
- ### Thank You
75
-
76
- Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
77
- Your contributions make vLLM a great tool for everyone!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile DELETED
@@ -1,121 +0,0 @@
1
- # The vLLM Dockerfile is used to construct vLLM image that can be directly used
2
- # to run the OpenAI compatible server.
3
-
4
- #################### BASE BUILD IMAGE ####################
5
- FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
6
-
7
- RUN apt-get update -y \
8
- && apt-get install -y python3-pip git
9
-
10
- WORKDIR /workspace
11
-
12
- # install build and runtime dependencies
13
- COPY requirements.txt requirements.txt
14
- RUN --mount=type=cache,target=/root/.cache/pip \
15
- pip install -r requirements.txt
16
-
17
- # install development dependencies
18
- COPY requirements-dev.txt requirements-dev.txt
19
- RUN --mount=type=cache,target=/root/.cache/pip \
20
- pip install -r requirements-dev.txt
21
- #################### BASE BUILD IMAGE ####################
22
-
23
-
24
- #################### EXTENSION BUILD IMAGE ####################
25
- FROM dev AS build
26
-
27
- # install build dependencies
28
- COPY requirements-build.txt requirements-build.txt
29
- RUN --mount=type=cache,target=/root/.cache/pip \
30
- pip install -r requirements-build.txt
31
-
32
- # copy input files
33
- COPY csrc csrc
34
- COPY setup.py setup.py
35
- COPY requirements.txt requirements.txt
36
- COPY pyproject.toml pyproject.toml
37
- COPY vllm/__init__.py vllm/__init__.py
38
-
39
- # cuda arch list used by torch
40
- ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
41
- ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
42
- # max jobs used by Ninja to build extensions
43
- ARG max_jobs=2
44
- ENV MAX_JOBS=${max_jobs}
45
- # number of threads used by nvcc
46
- ARG nvcc_threads=8
47
- ENV NVCC_THREADS=$nvcc_threads
48
- # make sure punica kernels are built (for LoRA)
49
- ENV VLLM_INSTALL_PUNICA_KERNELS=1
50
-
51
- RUN python3 setup.py build_ext --inplace
52
- #################### EXTENSION Build IMAGE ####################
53
-
54
-
55
- #################### TEST IMAGE ####################
56
- # image to run unit testing suite
57
- FROM dev AS test
58
-
59
- # copy pytorch extensions separately to avoid having to rebuild
60
- # when python code changes
61
- WORKDIR /vllm-workspace
62
- # ADD is used to preserve directory structure
63
- ADD . /vllm-workspace/
64
- COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
65
- # ignore build dependencies installation because we are using pre-complied extensions
66
- RUN rm pyproject.toml
67
- RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
68
- #################### TEST IMAGE ####################
69
-
70
-
71
- #################### RUNTIME BASE IMAGE ####################
72
- # use CUDA base as CUDA runtime dependencies are already installed via pip
73
- FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
74
-
75
- # libnccl required for ray
76
- RUN apt-get update -y \
77
- && apt-get install -y python3-pip
78
-
79
- WORKDIR /workspace
80
- COPY requirements.txt requirements.txt
81
- RUN --mount=type=cache,target=/root/.cache/pip \
82
- pip install -r requirements.txt
83
- #################### RUNTIME BASE IMAGE ####################
84
-
85
-
86
- #################### OPENAI API SERVER ####################
87
- # openai api server alternative
88
- FROM vllm-base AS vllm-openai
89
- # install additional dependencies for openai api server
90
- RUN --mount=type=cache,target=/root/.cache/pip \
91
- pip install accelerate
92
-
93
- # Create a non-root user
94
- RUN useradd -m appuser
95
-
96
- # Transfer ownership of the /workspace to the new non-root user
97
- RUN chown -R appuser:appuser /workspace
98
-
99
- # Create a cache directory within the appuser's home directory and transfer ownership
100
- RUN mkdir -p /home/appuser/cache && \
101
- chown -R appuser:appuser /home/appuser/cache
102
-
103
- # Switch to the non-root user for subsequent commands and container runtime
104
- USER appuser
105
-
106
- # Set the Hugging Face cache directory environment variable
107
- ENV TRANSFORMERS_CACHE=/home/appuser/cache
108
-
109
- COPY --from=build /workspace/vllm/*.so /workspace/vllm/
110
- COPY vllm vllm
111
-
112
- CMD ["python3", "-m", \
113
- "vllm.entrypoints.openai.api_server", \
114
- "--host", "0.0.0.0", \
115
- "--port", "7860", \
116
- "--served-model-name", "default", \
117
- "--model", "facebook/opt-125m", \
118
- "--max-model-len", "1024", \
119
- "--tensor-parallel-size", "1", \
120
- "--max-num-seqs", "16"]
121
- #################### OPENAI API SERVER ####################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile.rocm DELETED
@@ -1,88 +0,0 @@
1
- # default base image
2
- ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
3
-
4
- FROM $BASE_IMAGE
5
-
6
- ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7
-
8
- RUN echo "Base image is $BASE_IMAGE"
9
-
10
- # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
11
- # BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
12
-
13
- # this does not always work for all rocm versions
14
- RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
15
- echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
16
-
17
- ARG FA_GFX_ARCHS="gfx90a;gfx942"
18
- RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
19
-
20
- ARG FA_BRANCH="3d2b6f5"
21
- RUN echo "FA_BRANCH is $FA_BRANCH"
22
-
23
- # Install some basic utilities
24
- RUN apt-get update && apt-get install python3 python3-pip -y
25
-
26
- # Install some basic utilities
27
- RUN apt-get update && apt-get install -y \
28
- curl \
29
- ca-certificates \
30
- sudo \
31
- git \
32
- bzip2 \
33
- libx11-6 \
34
- build-essential \
35
- wget \
36
- unzip \
37
- nvidia-cuda-toolkit \
38
- tmux \
39
- && rm -rf /var/lib/apt/lists/*
40
-
41
- ### Mount Point ###
42
- # When launching the container, mount the code directory to /app
43
- ARG APP_MOUNT=/app
44
- VOLUME [ ${APP_MOUNT} ]
45
- WORKDIR ${APP_MOUNT}
46
-
47
- RUN python3 -m pip install --upgrade pip
48
- RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
49
-
50
- ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
51
- ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
52
- ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
53
- ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
54
-
55
- # Install ROCm flash-attention
56
- RUN mkdir libs \
57
- && cd libs \
58
- && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
59
- && cd flash-attention \
60
- && git checkout ${FA_BRANCH} \
61
- && git submodule update --init \
62
- && export GPU_ARCHS=${FA_GFX_ARCHS} \
63
- && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
64
- patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
65
- && python3 setup.py install \
66
- && cd ..
67
-
68
- COPY ./ /app/vllm
69
-
70
- RUN python3 -m pip install --upgrade pip
71
- RUN python3 -m pip install xformers==0.0.23 --no-deps
72
-
73
- # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
74
- # Manually removed it so that later steps of numpy upgrade can continue
75
- RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
76
- rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
77
-
78
- RUN cd /app \
79
- && cd vllm \
80
- && pip install -U -r requirements-rocm.txt \
81
- && bash patch_xformers.rocm.sh \
82
- && python3 setup.py install \
83
- && cd ..
84
-
85
- RUN python3 -m pip install --upgrade pip
86
- RUN python3 -m pip install --no-cache-dir ray[all]
87
-
88
- CMD ["/bin/bash"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MANIFEST.in DELETED
@@ -1,4 +0,0 @@
1
- include LICENSE
2
- include requirements.txt
3
-
4
- recursive-include csrc *
 
 
 
 
 
README.md DELETED
@@ -1,8 +0,0 @@
1
- ---
2
- title: CertifAIer Demo
3
- emoji: ✈️
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- ---
 
 
 
 
 
 
 
 
 
benchmarks/README.md DELETED
@@ -1,8 +0,0 @@
1
- # Benchmarking vLLM
2
-
3
- ## Downloading the ShareGPT dataset
4
-
5
- You can download the dataset by running:
6
- ```bash
7
- wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
8
- ```
 
 
 
 
 
 
 
 
 
benchmarks/benchmark_latency.py DELETED
@@ -1,139 +0,0 @@
1
- """Benchmark the latency of processing a single batch of requests."""
2
- import argparse
3
- import time
4
- from pathlib import Path
5
- from typing import Optional
6
-
7
- import numpy as np
8
- import torch
9
- from tqdm import tqdm
10
-
11
- from vllm import LLM, SamplingParams
12
-
13
-
14
- def main(args: argparse.Namespace):
15
- print(args)
16
-
17
- # NOTE(woosuk): If the request cannot be processed in a single batch,
18
- # the engine will automatically process the request in multiple batches.
19
- llm = LLM(
20
- model=args.model,
21
- tokenizer=args.tokenizer,
22
- quantization=args.quantization,
23
- tensor_parallel_size=args.tensor_parallel_size,
24
- trust_remote_code=args.trust_remote_code,
25
- dtype=args.dtype,
26
- enforce_eager=args.enforce_eager,
27
- kv_cache_dtype=args.kv_cache_dtype,
28
- )
29
-
30
- sampling_params = SamplingParams(
31
- n=args.n,
32
- temperature=0.0 if args.use_beam_search else 1.0,
33
- top_p=1.0,
34
- use_beam_search=args.use_beam_search,
35
- ignore_eos=True,
36
- max_tokens=args.output_len,
37
- )
38
- print(sampling_params)
39
- dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
40
-
41
- def run_to_completion(profile_dir: Optional[str] = None):
42
- if profile_dir:
43
- with torch.profiler.profile(
44
- activities=[
45
- torch.profiler.ProfilerActivity.CPU,
46
- torch.profiler.ProfilerActivity.CUDA,
47
- ],
48
- on_trace_ready=torch.profiler.tensorboard_trace_handler(
49
- str(profile_dir))) as p:
50
- llm.generate(prompt_token_ids=dummy_prompt_token_ids,
51
- sampling_params=sampling_params,
52
- use_tqdm=False)
53
- print(p.key_averages())
54
- else:
55
- start_time = time.perf_counter()
56
- llm.generate(prompt_token_ids=dummy_prompt_token_ids,
57
- sampling_params=sampling_params,
58
- use_tqdm=False)
59
- end_time = time.perf_counter()
60
- latency = end_time - start_time
61
- return latency
62
-
63
- print("Warming up...")
64
- run_to_completion(profile_dir=None)
65
-
66
- if args.profile:
67
- profile_dir = args.profile_result_dir
68
- if not profile_dir:
69
- profile_dir = Path(
70
- "."
71
- ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
72
- print(f"Profiling (results will be saved to '{profile_dir}')...")
73
- run_to_completion(profile_dir=args.profile_result_dir)
74
- return
75
-
76
- # Benchmark.
77
- latencies = []
78
- for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
79
- latencies.append(run_to_completion(profile_dir=None))
80
- print(f'Avg latency: {np.mean(latencies)} seconds')
81
-
82
-
83
- if __name__ == '__main__':
84
- parser = argparse.ArgumentParser(
85
- description='Benchmark the latency of processing a single batch of '
86
- 'requests till completion.')
87
- parser.add_argument('--model', type=str, default='facebook/opt-125m')
88
- parser.add_argument('--tokenizer', type=str, default=None)
89
- parser.add_argument('--quantization',
90
- '-q',
91
- choices=['awq', 'gptq', 'squeezellm', None],
92
- default=None)
93
- parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
94
- parser.add_argument('--input-len', type=int, default=32)
95
- parser.add_argument('--output-len', type=int, default=128)
96
- parser.add_argument('--batch-size', type=int, default=8)
97
- parser.add_argument('--n',
98
- type=int,
99
- default=1,
100
- help='Number of generated sequences per prompt.')
101
- parser.add_argument('--use-beam-search', action='store_true')
102
- parser.add_argument('--num-iters',
103
- type=int,
104
- default=3,
105
- help='Number of iterations to run.')
106
- parser.add_argument('--trust-remote-code',
107
- action='store_true',
108
- help='trust remote code from huggingface')
109
- parser.add_argument(
110
- '--dtype',
111
- type=str,
112
- default='auto',
113
- choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
114
- help='data type for model weights and activations. '
115
- 'The "auto" option will use FP16 precision '
116
- 'for FP32 and FP16 models, and BF16 precision '
117
- 'for BF16 models.')
118
- parser.add_argument('--enforce-eager',
119
- action='store_true',
120
- help='enforce eager mode and disable CUDA graph')
121
- parser.add_argument(
122
- "--kv-cache-dtype",
123
- type=str,
124
- choices=['auto', 'fp8_e5m2'],
125
- default='auto',
126
- help=
127
- 'Data type for kv cache storage. If "auto", will use model data type.')
128
- parser.add_argument(
129
- '--profile',
130
- action='store_true',
131
- help='profile the generation process of a single batch')
132
- parser.add_argument(
133
- '--profile-result-dir',
134
- type=str,
135
- default=None,
136
- help=('path to save the pytorch profiler output. Can be visualized '
137
- 'with ui.perfetto.dev or Tensorboard.'))
138
- args = parser.parse_args()
139
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/benchmark_serving.py DELETED
@@ -1,249 +0,0 @@
1
- """Benchmark online serving throughput.
2
-
3
- On the server side, run one of the following commands:
4
- (vLLM backend)
5
- python -m vllm.entrypoints.api_server \
6
- --model <your_model> --swap-space 16 \
7
- --disable-log-requests
8
-
9
- (TGI backend)
10
- ./launch_hf_server.sh <your_model>
11
-
12
- On the client side, run:
13
- python benchmarks/benchmark_serving.py \
14
- --backend <backend> \
15
- --tokenizer <your_model> --dataset <target_dataset> \
16
- --request-rate <request_rate>
17
- """
18
- import argparse
19
- import asyncio
20
- import json
21
- import random
22
- import time
23
- from typing import AsyncGenerator, List, Tuple
24
-
25
- import aiohttp
26
- import numpy as np
27
- from tqdm.asyncio import tqdm
28
- from transformers import PreTrainedTokenizerBase
29
- from vllm.transformers_utils.tokenizer import get_tokenizer
30
-
31
- # (prompt len, output len, latency)
32
- REQUEST_LATENCY: List[Tuple[int, int, float]] = []
33
-
34
-
35
- def sample_requests(
36
- dataset_path: str,
37
- num_requests: int,
38
- tokenizer: PreTrainedTokenizerBase,
39
- ) -> List[Tuple[str, int, int]]:
40
- # Load the dataset.
41
- with open(dataset_path) as f:
42
- dataset = json.load(f)
43
- # Filter out the conversations with less than 2 turns.
44
- dataset = [data for data in dataset if len(data["conversations"]) >= 2]
45
- # Only keep the first two turns of each conversation.
46
- dataset = [(data["conversations"][0]["value"],
47
- data["conversations"][1]["value"]) for data in dataset]
48
-
49
- # Tokenize the prompts and completions.
50
- prompts = [prompt for prompt, _ in dataset]
51
- prompt_token_ids = tokenizer(prompts).input_ids
52
- completions = [completion for _, completion in dataset]
53
- completion_token_ids = tokenizer(completions).input_ids
54
- tokenized_dataset = []
55
- for i in range(len(dataset)):
56
- output_len = len(completion_token_ids[i])
57
- tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
58
-
59
- # Filter out too long sequences.
60
- filtered_dataset: List[Tuple[str, int, int]] = []
61
- for prompt, prompt_token_ids, output_len in tokenized_dataset:
62
- prompt_len = len(prompt_token_ids)
63
- if prompt_len < 4 or output_len < 4:
64
- # Prune too short sequences.
65
- # This is because TGI causes errors when the input or output length
66
- # is too short.
67
- continue
68
- if prompt_len > 1024 or prompt_len + output_len > 2048:
69
- # Prune too long sequences.
70
- continue
71
- filtered_dataset.append((prompt, prompt_len, output_len))
72
-
73
- # Sample the requests.
74
- sampled_requests = random.sample(filtered_dataset, num_requests)
75
- return sampled_requests
76
-
77
-
78
- async def get_request(
79
- input_requests: List[Tuple[str, int, int]],
80
- request_rate: float,
81
- ) -> AsyncGenerator[Tuple[str, int, int], None]:
82
- input_requests = iter(input_requests)
83
- for request in input_requests:
84
- yield request
85
-
86
- if request_rate == float("inf"):
87
- # If the request rate is infinity, then we don't need to wait.
88
- continue
89
- # Sample the request interval from the exponential distribution.
90
- interval = np.random.exponential(1.0 / request_rate)
91
- # The next request will be sent after the interval.
92
- await asyncio.sleep(interval)
93
-
94
-
95
- async def send_request(backend: str, model: str, api_url: str, prompt: str,
96
- prompt_len: int, output_len: int, best_of: int,
97
- use_beam_search: bool, pbar: tqdm) -> None:
98
- request_start_time = time.perf_counter()
99
-
100
- headers = {"User-Agent": "Benchmark Client"}
101
- if backend == "vllm":
102
- pload = {
103
- "prompt": prompt,
104
- "n": 1,
105
- "best_of": best_of,
106
- "use_beam_search": use_beam_search,
107
- "temperature": 0.0 if use_beam_search else 1.0,
108
- "top_p": 1.0,
109
- "max_tokens": output_len,
110
- "ignore_eos": True,
111
- "stream": False,
112
- }
113
- if model is not None:
114
- pload["model"] = model
115
- elif backend == "tgi":
116
- assert not use_beam_search
117
- params = {
118
- "best_of": best_of,
119
- "max_new_tokens": output_len,
120
- "do_sample": True,
121
- }
122
- pload = {
123
- "inputs": prompt,
124
- "parameters": params,
125
- }
126
- else:
127
- raise ValueError(f"Unknown backend: {backend}")
128
-
129
- timeout = aiohttp.ClientTimeout(total=3 * 3600)
130
- async with aiohttp.ClientSession(timeout=timeout) as session:
131
- while True:
132
- async with session.post(api_url, headers=headers,
133
- json=pload) as response:
134
- chunks = []
135
- async for chunk, _ in response.content.iter_chunks():
136
- chunks.append(chunk)
137
- output = b"".join(chunks).decode("utf-8")
138
- output = json.loads(output)
139
-
140
- # Re-send the request if it failed.
141
- if "error" not in output:
142
- break
143
-
144
- request_end_time = time.perf_counter()
145
- request_latency = request_end_time - request_start_time
146
- REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
147
- pbar.update(1)
148
-
149
-
150
- async def benchmark(
151
- backend: str,
152
- model: str,
153
- api_url: str,
154
- input_requests: List[Tuple[str, int, int]],
155
- best_of: int,
156
- use_beam_search: bool,
157
- request_rate: float,
158
- ) -> None:
159
- tasks: List[asyncio.Task] = []
160
- pbar = tqdm(total=len(input_requests))
161
- async for request in get_request(input_requests, request_rate):
162
- prompt, prompt_len, output_len = request
163
- task = asyncio.create_task(
164
- send_request(backend, model, api_url, prompt, prompt_len,
165
- output_len, best_of, use_beam_search, pbar))
166
- tasks.append(task)
167
- await asyncio.gather(*tasks)
168
- pbar.close()
169
-
170
-
171
- def main(args: argparse.Namespace):
172
- print(args)
173
- random.seed(args.seed)
174
- np.random.seed(args.seed)
175
-
176
- api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
177
- tokenizer = get_tokenizer(args.tokenizer,
178
- trust_remote_code=args.trust_remote_code)
179
- input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
180
-
181
- benchmark_start_time = time.perf_counter()
182
- asyncio.run(
183
- benchmark(args.backend, args.model, api_url, input_requests,
184
- args.best_of, args.use_beam_search, args.request_rate))
185
- benchmark_end_time = time.perf_counter()
186
- benchmark_time = benchmark_end_time - benchmark_start_time
187
- print(f"Total time: {benchmark_time:.2f} s")
188
- print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
189
-
190
- # Compute the latency statistics.
191
- avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
192
- print(f"Average latency: {avg_latency:.2f} s")
193
- avg_per_token_latency = np.mean([
194
- latency / (prompt_len + output_len)
195
- for prompt_len, output_len, latency in REQUEST_LATENCY
196
- ])
197
- print(f"Average latency per token: {avg_per_token_latency:.2f} s")
198
- avg_per_output_token_latency = np.mean(
199
- [latency / output_len for _, output_len, latency in REQUEST_LATENCY])
200
- print("Average latency per output token: "
201
- f"{avg_per_output_token_latency:.2f} s")
202
-
203
-
204
- if __name__ == "__main__":
205
- parser = argparse.ArgumentParser(
206
- description="Benchmark the online serving throughput.")
207
- parser.add_argument("--backend",
208
- type=str,
209
- default="vllm",
210
- choices=["vllm", "tgi"])
211
- parser.add_argument("--protocol",
212
- type=str,
213
- default="http",
214
- choices=["http", "https"])
215
- parser.add_argument("--host", type=str, default="localhost")
216
- parser.add_argument("--port", type=int, default=8000)
217
- parser.add_argument("--endpoint", type=str, default="/generate")
218
- parser.add_argument("--model", type=str, default=None)
219
- parser.add_argument("--dataset",
220
- type=str,
221
- required=True,
222
- help="Path to the dataset.")
223
- parser.add_argument("--tokenizer",
224
- type=str,
225
- required=True,
226
- help="Name or path of the tokenizer.")
227
- parser.add_argument("--best-of",
228
- type=int,
229
- default=1,
230
- help="Generates `best_of` sequences per prompt and "
231
- "returns the best one.")
232
- parser.add_argument("--use-beam-search", action="store_true")
233
- parser.add_argument("--num-prompts",
234
- type=int,
235
- default=1000,
236
- help="Number of prompts to process.")
237
- parser.add_argument("--request-rate",
238
- type=float,
239
- default=float("inf"),
240
- help="Number of requests per second. If this is inf, "
241
- "then all the requests are sent at time 0. "
242
- "Otherwise, we use Poisson process to synthesize "
243
- "the request arrival times.")
244
- parser.add_argument("--seed", type=int, default=0)
245
- parser.add_argument('--trust-remote-code',
246
- action='store_true',
247
- help='trust remote code from huggingface')
248
- args = parser.parse_args()
249
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/benchmark_throughput.py DELETED
@@ -1,328 +0,0 @@
1
- """Benchmark offline inference throughput."""
2
- import argparse
3
- import json
4
- import random
5
- import time
6
- from typing import List, Optional, Tuple
7
-
8
- import torch
9
- from transformers import (AutoModelForCausalLM, AutoTokenizer,
10
- PreTrainedTokenizerBase)
11
- from tqdm import tqdm
12
-
13
-
14
- def sample_requests(
15
- dataset_path: str,
16
- num_requests: int,
17
- tokenizer: PreTrainedTokenizerBase,
18
- fixed_output_len: Optional[int],
19
- ) -> List[Tuple[str, int, int]]:
20
- if fixed_output_len is not None and fixed_output_len < 4:
21
- raise ValueError("output_len too small")
22
-
23
- # Load the dataset.
24
- with open(dataset_path) as f:
25
- dataset = json.load(f)
26
- # Filter out the conversations with less than 2 turns.
27
- dataset = [data for data in dataset if len(data["conversations"]) >= 2]
28
- # Only keep the first two turns of each conversation.
29
- dataset = [(data["conversations"][0]["value"],
30
- data["conversations"][1]["value"]) for data in dataset]
31
-
32
- # Tokenize the prompts and completions.
33
- prompts = [prompt for prompt, _ in dataset]
34
- prompt_token_ids = tokenizer(prompts).input_ids
35
- completions = [completion for _, completion in dataset]
36
- completion_token_ids = tokenizer(completions).input_ids
37
- tokenized_dataset = []
38
- for i in range(len(dataset)):
39
- output_len = len(completion_token_ids[i])
40
- if fixed_output_len is not None:
41
- output_len = fixed_output_len
42
- tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
43
-
44
- # Filter out too long sequences.
45
- filtered_dataset: List[Tuple[str, int, int]] = []
46
- for prompt, prompt_token_ids, output_len in tokenized_dataset:
47
- prompt_len = len(prompt_token_ids)
48
- if prompt_len < 4 or output_len < 4:
49
- # Prune too short sequences.
50
- continue
51
- if prompt_len > 1024 or prompt_len + output_len > 2048:
52
- # Prune too long sequences.
53
- continue
54
- filtered_dataset.append((prompt, prompt_len, output_len))
55
-
56
- # Sample the requests.
57
- sampled_requests = random.sample(filtered_dataset, num_requests)
58
- return sampled_requests
59
-
60
-
61
- def run_vllm(
62
- requests: List[Tuple[str, int, int]],
63
- model: str,
64
- tokenizer: str,
65
- quantization: Optional[str],
66
- tensor_parallel_size: int,
67
- seed: int,
68
- n: int,
69
- use_beam_search: bool,
70
- trust_remote_code: bool,
71
- dtype: str,
72
- max_model_len: Optional[int],
73
- enforce_eager: bool,
74
- kv_cache_dtype: str,
75
- ) -> float:
76
- from vllm import LLM, SamplingParams
77
- llm = LLM(
78
- model=model,
79
- tokenizer=tokenizer,
80
- quantization=quantization,
81
- tensor_parallel_size=tensor_parallel_size,
82
- seed=seed,
83
- trust_remote_code=trust_remote_code,
84
- dtype=dtype,
85
- max_model_len=max_model_len,
86
- enforce_eager=enforce_eager,
87
- kv_cache_dtype=kv_cache_dtype,
88
- )
89
-
90
- # Add the requests to the engine.
91
- for prompt, _, output_len in requests:
92
- sampling_params = SamplingParams(
93
- n=n,
94
- temperature=0.0 if use_beam_search else 1.0,
95
- top_p=1.0,
96
- use_beam_search=use_beam_search,
97
- ignore_eos=True,
98
- max_tokens=output_len,
99
- )
100
- # FIXME(woosuk): Do not use internal method.
101
- llm._add_request(
102
- prompt=prompt,
103
- prompt_token_ids=None,
104
- sampling_params=sampling_params,
105
- )
106
-
107
- start = time.perf_counter()
108
- # FIXME(woosuk): Do not use internal method.
109
- llm._run_engine(use_tqdm=True)
110
- end = time.perf_counter()
111
- return end - start
112
-
113
-
114
- def run_hf(
115
- requests: List[Tuple[str, int, int]],
116
- model: str,
117
- tokenizer: PreTrainedTokenizerBase,
118
- n: int,
119
- use_beam_search: bool,
120
- max_batch_size: int,
121
- trust_remote_code: bool,
122
- ) -> float:
123
- assert not use_beam_search
124
- llm = AutoModelForCausalLM.from_pretrained(
125
- model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
126
- if llm.config.model_type == "llama":
127
- # To enable padding in the HF backend.
128
- tokenizer.pad_token = tokenizer.eos_token
129
- llm = llm.cuda()
130
-
131
- pbar = tqdm(total=len(requests))
132
- start = time.perf_counter()
133
- batch: List[str] = []
134
- max_prompt_len = 0
135
- max_output_len = 0
136
- for i in range(len(requests)):
137
- prompt, prompt_len, output_len = requests[i]
138
- # Add the prompt to the batch.
139
- batch.append(prompt)
140
- max_prompt_len = max(max_prompt_len, prompt_len)
141
- max_output_len = max(max_output_len, output_len)
142
- if len(batch) < max_batch_size and i != len(requests) - 1:
143
- # Check if we can add more requests to the batch.
144
- _, next_prompt_len, next_output_len = requests[i + 1]
145
- if (max(max_prompt_len, next_prompt_len) +
146
- max(max_output_len, next_output_len)) <= 2048:
147
- # We can add more requests to the batch.
148
- continue
149
-
150
- # Generate the sequences.
151
- input_ids = tokenizer(batch, return_tensors="pt",
152
- padding=True).input_ids
153
- llm_outputs = llm.generate(
154
- input_ids=input_ids.cuda(),
155
- do_sample=not use_beam_search,
156
- num_return_sequences=n,
157
- temperature=1.0,
158
- top_p=1.0,
159
- use_cache=True,
160
- max_new_tokens=max_output_len,
161
- )
162
- # Include the decoding time.
163
- tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
164
- pbar.update(len(batch))
165
-
166
- # Clear the batch.
167
- batch = []
168
- max_prompt_len = 0
169
- max_output_len = 0
170
- end = time.perf_counter()
171
- return end - start
172
-
173
-
174
- def run_mii(
175
- requests: List[Tuple[str, int, int]],
176
- model: str,
177
- tensor_parallel_size: int,
178
- output_len: int,
179
- ) -> float:
180
- from mii import pipeline
181
- llm = pipeline(model, tensor_parallel=tensor_parallel_size)
182
- prompts = [prompt for prompt, _, _ in requests]
183
-
184
- start = time.perf_counter()
185
- llm(prompts, max_new_tokens=output_len)
186
- end = time.perf_counter()
187
- return end - start
188
-
189
-
190
- def main(args: argparse.Namespace):
191
- print(args)
192
- random.seed(args.seed)
193
-
194
- # Sample the requests.
195
- tokenizer = AutoTokenizer.from_pretrained(
196
- args.tokenizer, trust_remote_code=args.trust_remote_code)
197
- if args.dataset is None:
198
- # Synthesize a prompt with the given input length.
199
- prompt = "hi" * (args.input_len - 1)
200
- requests = [(prompt, args.input_len, args.output_len)
201
- for _ in range(args.num_prompts)]
202
- else:
203
- requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
204
- args.output_len)
205
-
206
- if args.backend == "vllm":
207
- elapsed_time = run_vllm(requests, args.model, args.tokenizer,
208
- args.quantization, args.tensor_parallel_size,
209
- args.seed, args.n, args.use_beam_search,
210
- args.trust_remote_code, args.dtype,
211
- args.max_model_len, args.enforce_eager,
212
- args.kv_cache_dtype)
213
- elif args.backend == "hf":
214
- assert args.tensor_parallel_size == 1
215
- elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
216
- args.use_beam_search, args.hf_max_batch_size,
217
- args.trust_remote_code)
218
- elif args.backend == "mii":
219
- elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
220
- args.output_len)
221
- else:
222
- raise ValueError(f"Unknown backend: {args.backend}")
223
- total_num_tokens = sum(prompt_len + output_len
224
- for _, prompt_len, output_len in requests)
225
- print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
226
- f"{total_num_tokens / elapsed_time:.2f} tokens/s")
227
-
228
-
229
- if __name__ == "__main__":
230
- parser = argparse.ArgumentParser(description="Benchmark the throughput.")
231
- parser.add_argument("--backend",
232
- type=str,
233
- choices=["vllm", "hf", "mii"],
234
- default="vllm")
235
- parser.add_argument("--dataset",
236
- type=str,
237
- default=None,
238
- help="Path to the dataset.")
239
- parser.add_argument("--input-len",
240
- type=int,
241
- default=None,
242
- help="Input prompt length for each request")
243
- parser.add_argument("--output-len",
244
- type=int,
245
- default=None,
246
- help="Output length for each request. Overrides the "
247
- "output length from the dataset.")
248
- parser.add_argument("--model", type=str, default="facebook/opt-125m")
249
- parser.add_argument("--tokenizer", type=str, default=None)
250
- parser.add_argument('--quantization',
251
- '-q',
252
- choices=['awq', 'gptq', 'squeezellm', None],
253
- default=None)
254
- parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
255
- parser.add_argument("--n",
256
- type=int,
257
- default=1,
258
- help="Number of generated sequences per prompt.")
259
- parser.add_argument("--use-beam-search", action="store_true")
260
- parser.add_argument("--num-prompts",
261
- type=int,
262
- default=1000,
263
- help="Number of prompts to process.")
264
- parser.add_argument("--seed", type=int, default=0)
265
- parser.add_argument("--hf-max-batch-size",
266
- type=int,
267
- default=None,
268
- help="Maximum batch size for HF backend.")
269
- parser.add_argument('--trust-remote-code',
270
- action='store_true',
271
- help='trust remote code from huggingface')
272
- parser.add_argument(
273
- '--max-model-len',
274
- type=int,
275
- default=None,
276
- help='Maximum length of a sequence (including prompt and output). '
277
- 'If None, will be derived from the model.')
278
- parser.add_argument(
279
- '--dtype',
280
- type=str,
281
- default='auto',
282
- choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
283
- help='data type for model weights and activations. '
284
- 'The "auto" option will use FP16 precision '
285
- 'for FP32 and FP16 models, and BF16 precision '
286
- 'for BF16 models.')
287
- parser.add_argument("--enforce-eager",
288
- action="store_true",
289
- help="enforce eager execution")
290
- parser.add_argument(
291
- "--kv-cache-dtype",
292
- type=str,
293
- choices=["auto", "fp8_e5m2"],
294
- default="auto",
295
- help=
296
- 'Data type for kv cache storage. If "auto", will use model data type.')
297
- args = parser.parse_args()
298
- if args.tokenizer is None:
299
- args.tokenizer = args.model
300
- if args.dataset is None:
301
- assert args.input_len is not None
302
- assert args.output_len is not None
303
- else:
304
- assert args.input_len is None
305
-
306
- if args.backend == "vllm":
307
- if args.hf_max_batch_size is not None:
308
- raise ValueError("HF max batch size is only for HF backend.")
309
- elif args.backend == "hf":
310
- if args.hf_max_batch_size is None:
311
- raise ValueError("HF max batch size is required for HF backend.")
312
- if args.quantization is not None:
313
- raise ValueError("Quantization is only for vLLM backend.")
314
- elif args.backend == "mii":
315
- if args.dtype != "auto":
316
- raise ValueError("dtype must be auto for MII backend.")
317
- if args.n != 1:
318
- raise ValueError("n must be 1 for MII backend.")
319
- if args.use_beam_search:
320
- raise ValueError("Beam search is not supported for MII backend.")
321
- if args.quantization is not None:
322
- raise ValueError("Quantization is only for vLLM backend.")
323
- if args.hf_max_batch_size is not None:
324
- raise ValueError("HF max batch size is only for HF backend.")
325
- if args.tokenizer != args.model:
326
- raise ValueError("Tokenizer must be the same as the model for MII "
327
- "backend.")
328
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/kernels/benchmark_paged_attention.py DELETED
@@ -1,196 +0,0 @@
1
- from typing import Optional
2
- import argparse
3
- import random
4
- import time
5
-
6
- import torch
7
-
8
- from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
9
- from vllm._C import ops
10
-
11
- NUM_BLOCKS = 1024
12
- PARTITION_SIZE = 512
13
-
14
-
15
- @torch.inference_mode()
16
- def main(
17
- version: str,
18
- num_seqs: int,
19
- context_len: int,
20
- num_query_heads: int,
21
- num_kv_heads: int,
22
- head_size: int,
23
- use_alibi: bool,
24
- block_size: int,
25
- dtype: torch.dtype,
26
- seed: int,
27
- do_profile: bool,
28
- kv_cache_dtype: Optional[str] = None,
29
- ) -> None:
30
- random.seed(seed)
31
- torch.random.manual_seed(seed)
32
- torch.cuda.manual_seed(seed)
33
-
34
- scale = float(1.0 / (head_size**0.5))
35
- query = torch.empty(num_seqs,
36
- num_query_heads,
37
- head_size,
38
- dtype=dtype,
39
- device="cuda")
40
- query.uniform_(-scale, scale)
41
-
42
- assert num_query_heads % num_kv_heads == 0
43
- alibi_slopes = None
44
- if use_alibi:
45
- alibi_slopes = torch.randn(num_query_heads,
46
- dtype=torch.float,
47
- device="cuda")
48
-
49
- context_lens = [context_len for _ in range(num_seqs)]
50
- max_context_len = max(context_lens)
51
- context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
52
-
53
- # Create the block tables.
54
- max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
55
- block_tables = []
56
- for _ in range(num_seqs):
57
- block_table = [
58
- random.randint(0, NUM_BLOCKS - 1)
59
- for _ in range(max_num_blocks_per_seq)
60
- ]
61
- block_tables.append(block_table)
62
- block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
63
-
64
- # Create the KV cache.
65
- key_caches, value_caches = create_kv_caches_with_random(
66
- NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
67
- dtype)
68
- key_cache, value_cache = key_caches[0], value_caches[0]
69
-
70
- # Prepare for the paged attention kernel.
71
- output = torch.empty_like(query)
72
- if version == "v2":
73
- num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
74
- PARTITION_SIZE)
75
- tmp_output = torch.empty(
76
- size=(num_seqs, num_query_heads, num_partitions, head_size),
77
- dtype=output.dtype,
78
- device=output.device,
79
- )
80
- exp_sums = torch.empty(
81
- size=(num_seqs, num_query_heads, num_partitions),
82
- dtype=torch.float32,
83
- device=output.device,
84
- )
85
- max_logits = torch.empty_like(exp_sums)
86
-
87
- def run_benchmark(num_iters: int, profile: bool = False) -> float:
88
- torch.cuda.synchronize()
89
- if profile:
90
- torch.cuda.cudart().cudaProfilerStart()
91
- start_time = time.perf_counter()
92
-
93
- for _ in range(num_iters):
94
- if version == "v1":
95
- ops.paged_attention_v1(
96
- output,
97
- query,
98
- key_cache,
99
- value_cache,
100
- num_kv_heads,
101
- scale,
102
- block_tables,
103
- context_lens,
104
- block_size,
105
- max_context_len,
106
- alibi_slopes,
107
- kv_cache_dtype,
108
- )
109
- elif version == "v2":
110
- ops.paged_attention_v2(
111
- output,
112
- exp_sums,
113
- max_logits,
114
- tmp_output,
115
- query,
116
- key_cache,
117
- value_cache,
118
- num_kv_heads,
119
- scale,
120
- block_tables,
121
- context_lens,
122
- block_size,
123
- max_context_len,
124
- alibi_slopes,
125
- kv_cache_dtype,
126
- )
127
- else:
128
- raise ValueError(f"Invalid version: {version}")
129
- torch.cuda.synchronize()
130
-
131
- end_time = time.perf_counter()
132
- if profile:
133
- torch.cuda.cudart().cudaProfilerStart()
134
- return (end_time - start_time) / num_iters
135
-
136
- # Warmup.
137
- print("Warming up...")
138
- run_benchmark(num_iters=3, profile=False)
139
-
140
- # Benchmark.
141
- if do_profile:
142
- latency = run_benchmark(num_iters=1, profile=True)
143
- else:
144
- latency = run_benchmark(num_iters=100, profile=False)
145
- print(f"Kernel running time: {latency * 1000000:.3f} us")
146
-
147
-
148
- if __name__ == '__main__':
149
- parser = argparse.ArgumentParser(
150
- description="Benchmark the paged attention kernel.")
151
- parser.add_argument("--version",
152
- type=str,
153
- choices=["v1", "v2"],
154
- default="v2")
155
- parser.add_argument("--batch-size", type=int, default=8)
156
- parser.add_argument("--context-len", type=int, default=4096)
157
- parser.add_argument("--num-query-heads", type=int, default=64)
158
- parser.add_argument("--num-kv-heads", type=int, default=8)
159
- parser.add_argument("--head-size",
160
- type=int,
161
- choices=[64, 80, 96, 112, 128, 256],
162
- default=128)
163
- parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
164
- parser.add_argument("--use-alibi", action="store_true")
165
- parser.add_argument("--dtype",
166
- type=str,
167
- choices=["half", "bfloat16", "float"],
168
- default="half")
169
- parser.add_argument("--seed", type=int, default=0)
170
- parser.add_argument("--profile", action="store_true")
171
- parser.add_argument(
172
- "--kv-cache-dtype",
173
- type=str,
174
- choices=["auto", "fp8_e5m2"],
175
- default="auto",
176
- help=
177
- 'Data type for kv cache storage. If "auto", will use model data type.')
178
- args = parser.parse_args()
179
- print(args)
180
-
181
- if args.num_query_heads % args.num_kv_heads != 0:
182
- raise ValueError("num_query_heads must be divisible by num_kv_heads")
183
- main(
184
- version=args.version,
185
- num_seqs=args.batch_size,
186
- context_len=args.context_len,
187
- num_query_heads=args.num_query_heads,
188
- num_kv_heads=args.num_kv_heads,
189
- head_size=args.head_size,
190
- block_size=args.block_size,
191
- use_alibi=args.use_alibi,
192
- dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
193
- seed=args.seed,
194
- do_profile=args.profile,
195
- kv_cache_dtype=args.kv_cache_dtype,
196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/launch_tgi_server.sh DELETED
@@ -1,16 +0,0 @@
1
- #!/bin/bash
2
-
3
- PORT=8000
4
- MODEL=$1
5
- TOKENS=$2
6
-
7
- docker run --gpus all --shm-size 1g -p $PORT:80 \
8
- -v $PWD/data:/data \
9
- ghcr.io/huggingface/text-generation-inference:0.8 \
10
- --model-id $MODEL \
11
- --sharded false \
12
- --max-input-length 1024 \
13
- --max-total-tokens 2048 \
14
- --max-best-of 5 \
15
- --max-concurrent-requests 5000 \
16
- --max-batch-total-tokens $TOKENS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/activation_kernels.cu DELETED
@@ -1,118 +0,0 @@
1
- #include <ATen/cuda/CUDAContext.h>
2
- #include <torch/extension.h>
3
- #include <c10/cuda/CUDAGuard.h>
4
-
5
- #include "cuda_compat.h"
6
- #include "dispatch_utils.h"
7
-
8
- namespace vllm {
9
-
10
- template<typename T>
11
- __device__ __forceinline__ T silu(const T& x) {
12
- // x * sigmoid(x)
13
- return (T) (((float) x) / (1.0f + expf((float) -x)));
14
- }
15
-
16
- template<typename scalar_t>
17
- __global__ void silu_and_mul_kernel(
18
- scalar_t* __restrict__ out, // [..., d]
19
- const scalar_t* __restrict__ input, // [..., 2, d]
20
- const int d) {
21
- const int64_t token_idx = blockIdx.x;
22
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
23
- const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
24
- const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
25
- out[token_idx * d + idx] = silu(x) * y;
26
- }
27
- }
28
-
29
- } // namespace vllm
30
-
31
- void silu_and_mul(
32
- torch::Tensor& out, // [..., d]
33
- torch::Tensor& input) // [..., 2 * d]
34
- {
35
- int64_t num_tokens = input.numel() / input.size(-1);
36
- int d = input.size(-1) / 2;
37
-
38
- dim3 grid(num_tokens);
39
- dim3 block(std::min(d, 1024));
40
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
41
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
42
- VLLM_DISPATCH_FLOATING_TYPES(
43
- input.scalar_type(),
44
- "silu_and_mul_kernel",
45
- [&] {
46
- vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
47
- out.data_ptr<scalar_t>(),
48
- input.data_ptr<scalar_t>(),
49
- d);
50
- });
51
- }
52
-
53
- namespace vllm {
54
-
55
- // Element-wise activation kernel template.
56
- template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
57
- __global__ void activation_kernel(
58
- scalar_t* __restrict__ out, // [..., d]
59
- const scalar_t* __restrict__ input, // [..., d]
60
- const int d) {
61
- const int64_t token_idx = blockIdx.x;
62
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
63
- const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
64
- out[token_idx * d + idx] = ACT_FN(x);
65
- }
66
- }
67
-
68
- } // namespace vllm
69
-
70
- // Launch element-wise activation kernel.
71
- #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
72
- int d = input.size(-1); \
73
- int64_t num_tokens = input.numel() / d; \
74
- dim3 grid(num_tokens); \
75
- dim3 block(std::min(d, 1024)); \
76
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
77
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
78
- VLLM_DISPATCH_FLOATING_TYPES( \
79
- input.scalar_type(), \
80
- "activation_kernel", \
81
- [&] { \
82
- vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
83
- out.data_ptr<scalar_t>(), \
84
- input.data_ptr<scalar_t>(), \
85
- d); \
86
- });
87
-
88
- namespace vllm {
89
-
90
- template<typename T>
91
- __device__ __forceinline__ T gelu_new_kernel(const T& x) {
92
- const float x3 = (float) (x * x * x);
93
- const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
94
- return ((T) 0.5) * x * (((T) 1.0) + t);
95
- }
96
-
97
- template<typename T>
98
- __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
99
- const float f = (float) x;
100
- const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
101
- return ((T) 0.5) * x * (((T) 1.0) + t);
102
- }
103
-
104
- } // namespace vllm
105
-
106
- void gelu_new(
107
- torch::Tensor& out, // [..., d]
108
- torch::Tensor& input) // [..., d]
109
- {
110
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
111
- }
112
-
113
- void gelu_fast(
114
- torch::Tensor& out, // [..., d]
115
- torch::Tensor& input) // [..., d]
116
- {
117
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
118
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/attention/attention_dtypes.h DELETED
@@ -1,7 +0,0 @@
1
- #pragma once
2
-
3
- #include "attention_generic.cuh"
4
- #include "dtype_float16.cuh"
5
- #include "dtype_float32.cuh"
6
- #include "dtype_bfloat16.cuh"
7
- #include "dtype_fp8_e5m2.cuh"
 
 
 
 
 
 
 
 
csrc/attention/attention_generic.cuh DELETED
@@ -1,64 +0,0 @@
1
- /*
2
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
3
- * Copyright (c) 2023, The vLLM team.
4
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5
- *
6
- * Licensed under the Apache License, Version 2.0 (the "License");
7
- * you may not use this file except in compliance with the License.
8
- * You may obtain a copy of the License at
9
- *
10
- * http://www.apache.org/licenses/LICENSE-2.0
11
- *
12
- * Unless required by applicable law or agreed to in writing, software
13
- * distributed under the License is distributed on an "AS IS" BASIS,
14
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- * See the License for the specific language governing permissions and
16
- * limitations under the License.
17
- */
18
- #pragma once
19
-
20
- #include <stdint.h>
21
-
22
- namespace vllm {
23
-
24
- // A vector type to store Q, K, V elements.
25
- template<typename T, int VEC_SIZE>
26
- struct Vec {};
27
-
28
- // A vector type to store FP32 accumulators.
29
- template<typename T>
30
- struct FloatVec {};
31
-
32
- // Template vector operations.
33
- template<typename Acc, typename A, typename B>
34
- inline __device__ Acc mul(A a, B b);
35
-
36
- template<typename T>
37
- inline __device__ float sum(T v);
38
-
39
- template<typename T>
40
- inline __device__ float dot(T a, T b) {
41
- return sum(mul<T, T, T>(a, b));
42
- }
43
-
44
- template<typename A, typename T>
45
- inline __device__ float dot(T a, T b) {
46
- return sum(mul<A, T, T>(a, b));
47
- }
48
-
49
- template<typename T>
50
- inline __device__ void zero(T& dst) {
51
- constexpr int WORDS = sizeof(T) / 4;
52
- union {
53
- T raw;
54
- uint32_t words[WORDS];
55
- } tmp;
56
-
57
- #pragma unroll
58
- for (int ii = 0; ii < WORDS; ++ii) {
59
- tmp.words[ii] = 0u;
60
- }
61
- dst = tmp.raw;
62
- }
63
-
64
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/attention/attention_kernels.cu DELETED
@@ -1,951 +0,0 @@
1
- /*
2
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
- * Copyright (c) 2023, The vLLM team.
4
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5
- *
6
- * Licensed under the Apache License, Version 2.0 (the "License");
7
- * you may not use this file except in compliance with the License.
8
- * You may obtain a copy of the License at
9
- *
10
- * http://www.apache.org/licenses/LICENSE-2.0
11
- *
12
- * Unless required by applicable law or agreed to in writing, software
13
- * distributed under the License is distributed on an "AS IS" BASIS,
14
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- * See the License for the specific language governing permissions and
16
- * limitations under the License.
17
- */
18
- #ifdef USE_ROCM
19
- #include <hip/hip_runtime.h>
20
- #endif
21
-
22
- #include <torch/extension.h>
23
- #include <ATen/cuda/CUDAContext.h>
24
- #include <c10/cuda/CUDAGuard.h>
25
-
26
- #include "attention_dtypes.h"
27
- #include "attention_utils.cuh"
28
- #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
29
-
30
- #include <algorithm>
31
-
32
- #ifndef USE_ROCM
33
- #define WARP_SIZE 32
34
- #else
35
- #define WARP_SIZE warpSize
36
- #endif
37
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
38
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
39
- #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
40
-
41
- namespace vllm {
42
-
43
- // Utility function for attention softmax.
44
- template<int NUM_WARPS>
45
- inline __device__ float block_sum(float* red_smem, float sum) {
46
- // Decompose the thread index into warp / lane.
47
- int warp = threadIdx.x / WARP_SIZE;
48
- int lane = threadIdx.x % WARP_SIZE;
49
-
50
- // Compute the sum per warp.
51
- #pragma unroll
52
- for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
53
- sum += VLLM_SHFL_XOR_SYNC(sum, mask);
54
- }
55
-
56
- // Warp leaders store the data to shared memory.
57
- if (lane == 0) {
58
- red_smem[warp] = sum;
59
- }
60
-
61
- // Make sure the data is in shared memory.
62
- __syncthreads();
63
-
64
- // The warps compute the final sums.
65
- if (lane < NUM_WARPS) {
66
- sum = red_smem[lane];
67
- }
68
-
69
- // Parallel reduction inside the warp.
70
- #pragma unroll
71
- for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
72
- sum += VLLM_SHFL_XOR_SYNC(sum, mask);
73
- }
74
-
75
- // Broadcast to other threads.
76
- return VLLM_SHFL_SYNC(sum, 0);
77
- }
78
-
79
- // TODO(woosuk): Merge the last two dimensions of the grid.
80
- // Grid: (num_heads, num_seqs, max_num_partitions).
81
- template<
82
- typename scalar_t,
83
- typename cache_t,
84
- int HEAD_SIZE,
85
- int BLOCK_SIZE,
86
- int NUM_THREADS,
87
- bool IS_FP8_E5M2_KV_CACHE,
88
- int PARTITION_SIZE = 0> // Zero means no partitioning.
89
- __device__ void paged_attention_kernel(
90
- float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
91
- float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
92
- scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
93
- const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
94
- const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
95
- const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
96
- const int num_kv_heads, // [num_heads]
97
- const float scale,
98
- const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
99
- const int* __restrict__ context_lens, // [num_seqs]
100
- const int max_num_blocks_per_seq,
101
- const float* __restrict__ alibi_slopes, // [num_heads]
102
- const int q_stride,
103
- const int kv_block_stride,
104
- const int kv_head_stride) {
105
- const int seq_idx = blockIdx.y;
106
- const int partition_idx = blockIdx.z;
107
- const int max_num_partitions = gridDim.z;
108
- constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
109
- const int context_len = context_lens[seq_idx];
110
- if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
111
- // No work to do. Terminate the thread block.
112
- return;
113
- }
114
-
115
- const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
116
- const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
117
-
118
- // [start_block_idx, end_block_idx) is the range of blocks to process.
119
- const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
120
- const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
121
- const int num_blocks = end_block_idx - start_block_idx;
122
-
123
- // [start_token_idx, end_token_idx) is the range of tokens to process.
124
- const int start_token_idx = start_block_idx * BLOCK_SIZE;
125
- const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
126
- const int num_tokens = end_token_idx - start_token_idx;
127
-
128
- constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
129
- constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
130
- assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
131
- constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
132
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
133
- const int thread_idx = threadIdx.x;
134
- const int warp_idx = thread_idx / WARP_SIZE;
135
- const int lane = thread_idx % WARP_SIZE;
136
-
137
- const int head_idx = blockIdx.x;
138
- const int num_heads = gridDim.x;
139
- const int num_queries_per_kv = num_heads / num_kv_heads;
140
- const int kv_head_idx = head_idx / num_queries_per_kv;
141
- const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
142
-
143
- // A vector type to store a part of a key or a query.
144
- // The vector size is configured in such a way that the threads in a thread group
145
- // fetch or compute 16 bytes at a time.
146
- // For example, if the size of a thread group is 4 and the data type is half,
147
- // then the vector size is 16 / (4 * sizeof(half)) == 2.
148
- constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
149
- using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
150
- using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
151
- #ifdef ENABLE_FP8_E5M2
152
- using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
153
- #endif
154
-
155
- constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
156
- constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
157
-
158
- const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
159
- const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
160
-
161
- // Load the query to registers.
162
- // Each thread in a thread group has a different part of the query.
163
- // For example, if the the thread group size is 4, then the first thread in the group
164
- // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
165
- // th vectors of the query, and so on.
166
- // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
167
- const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
168
- __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
169
- #pragma unroll
170
- for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
171
- const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
172
- q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
173
- }
174
- __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
175
-
176
- // Memory planning.
177
- extern __shared__ char shared_mem[];
178
- // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
179
- float* logits = reinterpret_cast<float*>(shared_mem);
180
- // Workspace for reduction.
181
- __shared__ float red_smem[2 * NUM_WARPS];
182
-
183
- // x == THREAD_GROUP_SIZE * VEC_SIZE
184
- // Each thread group fetches x elements from the key at a time.
185
- constexpr int x = 16 / sizeof(cache_t);
186
- float qk_max = -FLT_MAX;
187
-
188
- // Iterate over the key blocks.
189
- // Each warp fetches a block of keys for each iteration.
190
- // Each thread group in a warp fetches a key from the block, and computes
191
- // dot product with the query.
192
- const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
193
- for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
194
- // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
195
- // because int32 can lead to overflow when this variable is multiplied by large numbers
196
- // (e.g., kv_block_stride).
197
- const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
198
-
199
- // Load a key to registers.
200
- // Each thread in a thread group has a different part of the key.
201
- // For example, if the the thread group size is 4, then the first thread in the group
202
- // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
203
- // vectors of the key, and so on.
204
- for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
205
- const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
206
- const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
207
- K_vec k_vecs[NUM_VECS_PER_THREAD];
208
-
209
- #pragma unroll
210
- for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
211
- const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
212
- + kv_head_idx * kv_head_stride
213
- + physical_block_offset * x;
214
- const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
215
- const int offset1 = (vec_idx * VEC_SIZE) / x;
216
- const int offset2 = (vec_idx * VEC_SIZE) % x;
217
- if constexpr (IS_FP8_E5M2_KV_CACHE) {
218
- #ifdef ENABLE_FP8_E5M2
219
- Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
220
- // Vector conversion from Quant_vec to K_vec.
221
- k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
222
- #else
223
- assert(false);
224
- #endif
225
- } else {
226
- k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
227
- }
228
- }
229
-
230
- // Compute dot product.
231
- // This includes a reduction across the threads in the same thread group.
232
- float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
233
- // Add the ALiBi bias if slopes are given.
234
- qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
235
-
236
- if (thread_group_offset == 0) {
237
- // Store the partial reductions to shared memory.
238
- // NOTE(woosuk): It is required to zero out the masked logits.
239
- const bool mask = token_idx >= context_len;
240
- logits[token_idx - start_token_idx] = mask ? 0.f : qk;
241
- // Update the max value.
242
- qk_max = mask ? qk_max : fmaxf(qk_max, qk);
243
- }
244
- }
245
- }
246
-
247
- // Perform reduction across the threads in the same warp to get the
248
- // max qk value for each "warp" (not across the thread block yet).
249
- // The 0-th thread of each thread group already has its max qk value.
250
- #pragma unroll
251
- for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
252
- qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
253
- }
254
- if (lane == 0) {
255
- red_smem[warp_idx] = qk_max;
256
- }
257
- __syncthreads();
258
-
259
- // TODO(woosuk): Refactor this part.
260
- // Get the max qk value for the sequence.
261
- qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
262
- #pragma unroll
263
- for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
264
- qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
265
- }
266
- // Broadcast the max qk value to all threads.
267
- qk_max = VLLM_SHFL_SYNC(qk_max, 0);
268
-
269
- // Get the sum of the exp values.
270
- float exp_sum = 0.f;
271
- for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
272
- float val = __expf(logits[i] - qk_max);
273
- logits[i] = val;
274
- exp_sum += val;
275
- }
276
- exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
277
-
278
- // Compute softmax.
279
- const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
280
- for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
281
- logits[i] *= inv_sum;
282
- }
283
- __syncthreads();
284
-
285
- // If partitioning is enabled, store the max logit and exp_sum.
286
- if (USE_PARTITIONING && thread_idx == 0) {
287
- float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
288
- + head_idx * max_num_partitions
289
- + partition_idx;
290
- *max_logits_ptr = qk_max;
291
- float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
292
- + head_idx * max_num_partitions
293
- + partition_idx;
294
- *exp_sums_ptr = exp_sum;
295
- }
296
-
297
- // Each thread will fetch 16 bytes from the value cache at a time.
298
- constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
299
- using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
300
- using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
301
- #ifdef ENABLE_FP8_E5M2
302
- using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
303
- #endif
304
- using Float_L_vec = typename FloatVec<L_vec>::Type;
305
-
306
- constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
307
- constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
308
- constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
309
-
310
- // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
311
- float accs[NUM_ROWS_PER_THREAD];
312
- #pragma unroll
313
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
314
- accs[i] = 0.f;
315
- }
316
-
317
- scalar_t zero_value;
318
- zero(zero_value);
319
- for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
320
- // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
321
- // because int32 can lead to overflow when this variable is multiplied by large numbers
322
- // (e.g., kv_block_stride).
323
- const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
324
- const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
325
- const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
326
- L_vec logits_vec;
327
- from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
328
-
329
- const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
330
- + kv_head_idx * kv_head_stride;
331
- #pragma unroll
332
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
333
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
334
- if (row_idx < HEAD_SIZE) {
335
- const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
336
- V_vec v_vec;
337
- if constexpr (IS_FP8_E5M2_KV_CACHE) {
338
- #ifdef ENABLE_FP8_E5M2
339
- V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
340
- // Vector conversion from V_quant_vec to V_vec.
341
- v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
342
- #else
343
- assert(false);
344
- #endif
345
- } else {
346
- v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
347
- }
348
- if (block_idx == num_context_blocks - 1) {
349
- // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
350
- // we should explicitly zero out the values since they may contain NaNs.
351
- // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
352
- scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
353
- #pragma unroll
354
- for (int j = 0; j < V_VEC_SIZE; j++) {
355
- v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
356
- }
357
- }
358
- accs[i] += dot(logits_vec, v_vec);
359
- }
360
- }
361
- }
362
-
363
- // Perform reduction within each warp.
364
- #pragma unroll
365
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
366
- float acc = accs[i];
367
- #pragma unroll
368
- for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
369
- acc += VLLM_SHFL_XOR_SYNC(acc, mask);
370
- }
371
- accs[i] = acc;
372
- }
373
-
374
- // NOTE(woosuk): A barrier is required because the shared memory space for logits
375
- // is reused for the output.
376
- __syncthreads();
377
-
378
- // Perform reduction across warps.
379
- float* out_smem = reinterpret_cast<float*>(shared_mem);
380
- #pragma unroll
381
- for (int i = NUM_WARPS; i > 1; i /= 2) {
382
- int mid = i / 2;
383
- // Upper warps write to shared memory.
384
- if (warp_idx >= mid && warp_idx < i) {
385
- float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
386
- #pragma unroll
387
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
388
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
389
- if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
390
- dst[row_idx] = accs[i];
391
- }
392
- }
393
- }
394
- __syncthreads();
395
-
396
- // Lower warps update the output.
397
- if (warp_idx < mid) {
398
- const float* src = &out_smem[warp_idx * HEAD_SIZE];
399
- #pragma unroll
400
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
401
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
402
- if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
403
- accs[i] += src[row_idx];
404
- }
405
- }
406
- }
407
- __syncthreads();
408
- }
409
-
410
- // Write the final output.
411
- if (warp_idx == 0) {
412
- scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
413
- + head_idx * max_num_partitions * HEAD_SIZE
414
- + partition_idx * HEAD_SIZE;
415
- #pragma unroll
416
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
417
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
418
- if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
419
- from_float(*(out_ptr + row_idx), accs[i]);
420
- }
421
- }
422
- }
423
- }
424
-
425
- // Grid: (num_heads, num_seqs, 1).
426
- template<
427
- typename scalar_t,
428
- typename cache_t,
429
- int HEAD_SIZE,
430
- int BLOCK_SIZE,
431
- int NUM_THREADS,
432
- bool IS_FP8_E5M2_KV_CACHE>
433
- __global__ void paged_attention_v1_kernel(
434
- scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
435
- const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
436
- const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
437
- const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
438
- const int num_kv_heads, // [num_heads]
439
- const float scale,
440
- const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
441
- const int* __restrict__ context_lens, // [num_seqs]
442
- const int max_num_blocks_per_seq,
443
- const float* __restrict__ alibi_slopes, // [num_heads]
444
- const int q_stride,
445
- const int kv_block_stride,
446
- const int kv_head_stride) {
447
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
448
- /* exp_sums */ nullptr, /* max_logits */ nullptr,
449
- out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
450
- max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
451
- }
452
-
453
- // Grid: (num_heads, num_seqs, max_num_partitions).
454
- template<
455
- typename scalar_t,
456
- typename cache_t,
457
- int HEAD_SIZE,
458
- int BLOCK_SIZE,
459
- int NUM_THREADS,
460
- bool IS_FP8_E5M2_KV_CACHE,
461
- int PARTITION_SIZE>
462
- __global__ void paged_attention_v2_kernel(
463
- float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
464
- float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
465
- scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
466
- const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
467
- const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
468
- const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
469
- const int num_kv_heads, // [num_heads]
470
- const float scale,
471
- const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
472
- const int* __restrict__ context_lens, // [num_seqs]
473
- const int max_num_blocks_per_seq,
474
- const float* __restrict__ alibi_slopes, // [num_heads]
475
- const int q_stride,
476
- const int kv_block_stride,
477
- const int kv_head_stride) {
478
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
479
- exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
480
- block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
481
- q_stride, kv_block_stride, kv_head_stride);
482
- }
483
-
484
- // Grid: (num_heads, num_seqs).
485
- template<
486
- typename scalar_t,
487
- int HEAD_SIZE,
488
- int NUM_THREADS,
489
- int PARTITION_SIZE>
490
- __global__ void paged_attention_v2_reduce_kernel(
491
- scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
492
- const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
493
- const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
494
- const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
495
- const int* __restrict__ context_lens, // [num_seqs]
496
- const int max_num_partitions) {
497
- const int num_heads = gridDim.x;
498
- const int head_idx = blockIdx.x;
499
- const int seq_idx = blockIdx.y;
500
- const int context_len = context_lens[seq_idx];
501
- const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
502
- if (num_partitions == 1) {
503
- // No need to reduce. Only copy tmp_out to out.
504
- scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
505
- const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
506
- + head_idx * max_num_partitions * HEAD_SIZE;
507
- for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
508
- out_ptr[i] = tmp_out_ptr[i];
509
- }
510
- // Terminate the thread block.
511
- return;
512
- }
513
-
514
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
515
- const int warp_idx = threadIdx.x / WARP_SIZE;
516
- const int lane = threadIdx.x % WARP_SIZE;
517
-
518
- // Size: 2 * num_partitions.
519
- extern __shared__ char shared_mem[];
520
- // Workspace for reduction.
521
- __shared__ float red_smem[2 * NUM_WARPS];
522
-
523
- // Load max logits to shared memory.
524
- float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
525
- const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
526
- + head_idx * max_num_partitions;
527
- float max_logit = -FLT_MAX;
528
- for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
529
- const float l = max_logits_ptr[i];
530
- shared_max_logits[i] = l;
531
- max_logit = fmaxf(max_logit, l);
532
- }
533
- __syncthreads();
534
-
535
- // Get the global max logit.
536
- // Reduce within the warp.
537
- #pragma unroll
538
- for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
539
- max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
540
- }
541
- if (lane == 0) {
542
- red_smem[warp_idx] = max_logit;
543
- }
544
- __syncthreads();
545
- // Reduce across warps.
546
- max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
547
- #pragma unroll
548
- for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
549
- max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
550
- }
551
- // Broadcast the max value to all threads.
552
- max_logit = VLLM_SHFL_SYNC(max_logit, 0);
553
-
554
- // Load rescaled exp sums to shared memory.
555
- float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
556
- const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
557
- + head_idx * max_num_partitions;
558
- float global_exp_sum = 0.0f;
559
- for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
560
- float l = shared_max_logits[i];
561
- float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
562
- global_exp_sum += rescaled_exp_sum;
563
- shared_exp_sums[i] = rescaled_exp_sum;
564
- }
565
- __syncthreads();
566
- global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
567
- const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
568
-
569
- // Aggregate tmp_out to out.
570
- const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
571
- + head_idx * max_num_partitions * HEAD_SIZE;
572
- scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
573
- #pragma unroll
574
- for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
575
- float acc = 0.0f;
576
- for (int j = 0; j < num_partitions; ++j) {
577
- acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
578
- }
579
- from_float(out_ptr[i], acc);
580
- }
581
- }
582
-
583
- } // namespace vllm
584
-
585
- #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
586
- VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
587
- ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
588
- IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
589
- vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
590
- IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
591
- out_ptr, \
592
- query_ptr, \
593
- key_cache_ptr, \
594
- value_cache_ptr, \
595
- num_kv_heads, \
596
- scale, \
597
- block_tables_ptr, \
598
- context_lens_ptr, \
599
- max_num_blocks_per_seq, \
600
- alibi_slopes_ptr, \
601
- q_stride, \
602
- kv_block_stride, \
603
- kv_head_stride);
604
-
605
- // TODO(woosuk): Tune NUM_THREADS.
606
- template<
607
- typename T,
608
- typename CACHE_T,
609
- int BLOCK_SIZE,
610
- bool IS_FP8_E5M2_KV_CACHE,
611
- int NUM_THREADS = 128>
612
- void paged_attention_v1_launcher(
613
- torch::Tensor& out,
614
- torch::Tensor& query,
615
- torch::Tensor& key_cache,
616
- torch::Tensor& value_cache,
617
- int num_kv_heads,
618
- float scale,
619
- torch::Tensor& block_tables,
620
- torch::Tensor& context_lens,
621
- int max_context_len,
622
- const c10::optional<torch::Tensor>& alibi_slopes) {
623
- int num_seqs = query.size(0);
624
- int num_heads = query.size(1);
625
- int head_size = query.size(2);
626
- int max_num_blocks_per_seq = block_tables.size(1);
627
- int q_stride = query.stride(0);
628
- int kv_block_stride = key_cache.stride(0);
629
- int kv_head_stride = key_cache.stride(1);
630
-
631
- int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
632
- assert(head_size % thread_group_size == 0);
633
-
634
- // NOTE: alibi_slopes is optional.
635
- const float* alibi_slopes_ptr = alibi_slopes ?
636
- reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
637
- : nullptr;
638
-
639
- T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
640
- T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
641
- CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
642
- CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
643
- int* block_tables_ptr = block_tables.data_ptr<int>();
644
- int* context_lens_ptr = context_lens.data_ptr<int>();
645
-
646
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
647
- int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
648
- int logits_size = padded_max_context_len * sizeof(float);
649
- int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
650
- // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
651
- // Keep that in sync with the logic here!
652
- int shared_mem_size = std::max(logits_size, outputs_size);
653
-
654
- dim3 grid(num_heads, num_seqs, 1);
655
- dim3 block(NUM_THREADS);
656
- const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
657
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
658
- switch (head_size) {
659
- // NOTE(woosuk): To reduce the compilation time, we only compile for the
660
- // head sizes that we use in the model. However, we can easily extend this
661
- // to support any head size which is a multiple of 16.
662
- case 64:
663
- LAUNCH_PAGED_ATTENTION_V1(64);
664
- break;
665
- case 80:
666
- LAUNCH_PAGED_ATTENTION_V1(80);
667
- break;
668
- case 96:
669
- LAUNCH_PAGED_ATTENTION_V1(96);
670
- break;
671
- case 112:
672
- LAUNCH_PAGED_ATTENTION_V1(112);
673
- break;
674
- case 128:
675
- LAUNCH_PAGED_ATTENTION_V1(128);
676
- break;
677
- case 256:
678
- LAUNCH_PAGED_ATTENTION_V1(256);
679
- break;
680
- default:
681
- TORCH_CHECK(false, "Unsupported head size: ", head_size);
682
- break;
683
- }
684
- }
685
-
686
- #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
687
- paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
688
- out, \
689
- query, \
690
- key_cache, \
691
- value_cache, \
692
- num_kv_heads, \
693
- scale, \
694
- block_tables, \
695
- context_lens, \
696
- max_context_len, \
697
- alibi_slopes);
698
-
699
- // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
700
- // 1, 2, 4, 64, 128, 256.
701
- #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
702
- switch (block_size) { \
703
- case 8: \
704
- CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
705
- break; \
706
- case 16: \
707
- CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
708
- break; \
709
- case 32: \
710
- CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
711
- break; \
712
- default: \
713
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
714
- break; \
715
- }
716
-
717
- void paged_attention_v1(
718
- torch::Tensor& out, // [num_seqs, num_heads, head_size]
719
- torch::Tensor& query, // [num_seqs, num_heads, head_size]
720
- torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
721
- torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
722
- int num_kv_heads, // [num_heads]
723
- float scale,
724
- torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
725
- torch::Tensor& context_lens, // [num_seqs]
726
- int block_size,
727
- int max_context_len,
728
- const c10::optional<torch::Tensor>& alibi_slopes,
729
- const std::string& kv_cache_dtype) {
730
- if (kv_cache_dtype == "auto") {
731
- if (query.dtype() == at::ScalarType::Float) {
732
- CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
733
- } else if (query.dtype() == at::ScalarType::Half) {
734
- CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
735
- } else if (query.dtype() == at::ScalarType::BFloat16) {
736
- CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
737
- } else {
738
- TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
739
- }
740
- } else if (kv_cache_dtype == "fp8_e5m2") {
741
- if (query.dtype() == at::ScalarType::Float) {
742
- CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
743
- } else if (query.dtype() == at::ScalarType::Half) {
744
- CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
745
- } else if (query.dtype() == at::ScalarType::BFloat16) {
746
- CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
747
- } else {
748
- TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
749
- }
750
- } else {
751
- TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
752
- }
753
- }
754
-
755
- #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
756
- vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
757
- IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
758
- <<<grid, block, shared_mem_size, stream>>>( \
759
- exp_sums_ptr, \
760
- max_logits_ptr, \
761
- tmp_out_ptr, \
762
- query_ptr, \
763
- key_cache_ptr, \
764
- value_cache_ptr, \
765
- num_kv_heads, \
766
- scale, \
767
- block_tables_ptr, \
768
- context_lens_ptr, \
769
- max_num_blocks_per_seq, \
770
- alibi_slopes_ptr, \
771
- q_stride, \
772
- kv_block_stride, \
773
- kv_head_stride); \
774
- vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
775
- <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
776
- out_ptr, \
777
- exp_sums_ptr, \
778
- max_logits_ptr, \
779
- tmp_out_ptr, \
780
- context_lens_ptr, \
781
- max_num_partitions);
782
-
783
- template<
784
- typename T,
785
- typename CACHE_T,
786
- int BLOCK_SIZE,
787
- bool IS_FP8_E5M2_KV_CACHE,
788
- int NUM_THREADS = 128,
789
- int PARTITION_SIZE = 512>
790
- void paged_attention_v2_launcher(
791
- torch::Tensor& out,
792
- torch::Tensor& exp_sums,
793
- torch::Tensor& max_logits,
794
- torch::Tensor& tmp_out,
795
- torch::Tensor& query,
796
- torch::Tensor& key_cache,
797
- torch::Tensor& value_cache,
798
- int num_kv_heads,
799
- float scale,
800
- torch::Tensor& block_tables,
801
- torch::Tensor& context_lens,
802
- int max_context_len,
803
- const c10::optional<torch::Tensor>& alibi_slopes) {
804
- int num_seqs = query.size(0);
805
- int num_heads = query.size(1);
806
- int head_size = query.size(2);
807
- int max_num_blocks_per_seq = block_tables.size(1);
808
- int q_stride = query.stride(0);
809
- int kv_block_stride = key_cache.stride(0);
810
- int kv_head_stride = key_cache.stride(1);
811
-
812
- int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
813
- assert(head_size % thread_group_size == 0);
814
-
815
- // NOTE: alibi_slopes is optional.
816
- const float* alibi_slopes_ptr = alibi_slopes ?
817
- reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
818
- : nullptr;
819
-
820
- T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
821
- float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
822
- float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
823
- T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
824
- T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
825
- CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
826
- CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
827
- int* block_tables_ptr = block_tables.data_ptr<int>();
828
- int* context_lens_ptr = context_lens.data_ptr<int>();
829
-
830
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
831
- int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
832
- int logits_size = PARTITION_SIZE * sizeof(float);
833
- int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
834
-
835
- // For paged attention v2 kernel.
836
- dim3 grid(num_heads, num_seqs, max_num_partitions);
837
- int shared_mem_size = std::max(logits_size, outputs_size);
838
- // For paged attention v2 reduce kernel.
839
- dim3 reduce_grid(num_heads, num_seqs);
840
- int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
841
-
842
- dim3 block(NUM_THREADS);
843
- const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
844
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
845
- switch (head_size) {
846
- // NOTE(woosuk): To reduce the compilation time, we only compile for the
847
- // head sizes that we use in the model. However, we can easily extend this
848
- // to support any head size which is a multiple of 16.
849
- case 64:
850
- LAUNCH_PAGED_ATTENTION_V2(64);
851
- break;
852
- case 80:
853
- LAUNCH_PAGED_ATTENTION_V2(80);
854
- break;
855
- case 96:
856
- LAUNCH_PAGED_ATTENTION_V2(96);
857
- break;
858
- case 112:
859
- LAUNCH_PAGED_ATTENTION_V2(112);
860
- break;
861
- case 128:
862
- LAUNCH_PAGED_ATTENTION_V2(128);
863
- break;
864
- case 256:
865
- LAUNCH_PAGED_ATTENTION_V2(256);
866
- break;
867
- default:
868
- TORCH_CHECK(false, "Unsupported head size: ", head_size);
869
- break;
870
- }
871
- }
872
-
873
- #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
874
- paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
875
- out, \
876
- exp_sums, \
877
- max_logits, \
878
- tmp_out, \
879
- query, \
880
- key_cache, \
881
- value_cache, \
882
- num_kv_heads, \
883
- scale, \
884
- block_tables, \
885
- context_lens, \
886
- max_context_len, \
887
- alibi_slopes);
888
-
889
- // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
890
- // 1, 2, 4, 64, 128, 256.
891
- #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
892
- switch (block_size) { \
893
- case 8: \
894
- CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
895
- break; \
896
- case 16: \
897
- CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
898
- break; \
899
- case 32: \
900
- CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
901
- break; \
902
- default: \
903
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
904
- break; \
905
- }
906
-
907
- void paged_attention_v2(
908
- torch::Tensor& out, // [num_seqs, num_heads, head_size]
909
- torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
910
- torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
911
- torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
912
- torch::Tensor& query, // [num_seqs, num_heads, head_size]
913
- torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
914
- torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
915
- int num_kv_heads, // [num_heads]
916
- float scale,
917
- torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
918
- torch::Tensor& context_lens, // [num_seqs]
919
- int block_size,
920
- int max_context_len,
921
- const c10::optional<torch::Tensor>& alibi_slopes,
922
- const std::string& kv_cache_dtype) {
923
- if (kv_cache_dtype == "auto") {
924
- if (query.dtype() == at::ScalarType::Float) {
925
- CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
926
- } else if (query.dtype() == at::ScalarType::Half) {
927
- CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
928
- } else if (query.dtype() == at::ScalarType::BFloat16) {
929
- CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
930
- } else {
931
- TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
932
- }
933
- } else if (kv_cache_dtype == "fp8_e5m2") {
934
- if (query.dtype() == at::ScalarType::Float) {
935
- CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
936
- } else if (query.dtype() == at::ScalarType::Half) {
937
- CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
938
- } else if (query.dtype() == at::ScalarType::BFloat16) {
939
- CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
940
- } else {
941
- TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
942
- }
943
- } else {
944
- TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
945
- }
946
- }
947
-
948
- #undef WARP_SIZE
949
- #undef MAX
950
- #undef MIN
951
- #undef DIVIDE_ROUND_UP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/attention/attention_utils.cuh DELETED
@@ -1,56 +0,0 @@
1
- /*
2
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
- * Copyright (c) 2023, The vLLM team.
4
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5
- *
6
- * Licensed under the Apache License, Version 2.0 (the "License");
7
- * you may not use this file except in compliance with the License.
8
- * You may obtain a copy of the License at
9
- *
10
- * http://www.apache.org/licenses/LICENSE-2.0
11
- *
12
- * Unless required by applicable law or agreed to in writing, software
13
- * distributed under the License is distributed on an "AS IS" BASIS,
14
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- * See the License for the specific language governing permissions and
16
- * limitations under the License.
17
- */
18
- #pragma once
19
-
20
- #include "../cuda_compat.h"
21
- #include "attention_dtypes.h"
22
-
23
- #include <float.h>
24
- #include <type_traits>
25
-
26
- namespace vllm {
27
-
28
- // Q*K^T operation.
29
- template<int THREAD_GROUP_SIZE, typename Vec, int N>
30
- inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
31
- using A_vec = typename FloatVec<Vec>::Type;
32
- // Compute the parallel products for Q*K^T (treat vector lanes separately).
33
- A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
34
- #pragma unroll
35
- for (int ii = 1; ii < N; ++ii) {
36
- qk_vec = fma(q[ii], k[ii], qk_vec);
37
- }
38
-
39
- // Finalize the reduction across lanes.
40
- float qk = sum(qk_vec);
41
- #pragma unroll
42
- for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
43
- qk += VLLM_SHFL_XOR_SYNC(qk, mask);
44
- }
45
- return qk;
46
- }
47
-
48
- template<typename T, int THREAD_GROUP_SIZE>
49
- struct Qk_dot {
50
- template<typename Vec, int N>
51
- static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
52
- return qk_dot_<THREAD_GROUP_SIZE>(q, k);
53
- }
54
- };
55
-
56
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/attention/dtype_bfloat16.cuh DELETED
@@ -1,451 +0,0 @@
1
- /*
2
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
- * Copyright (c) 2023, The vLLM team.
5
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
- *
7
- * Licensed under the Apache License, Version 2.0 (the "License");
8
- * you may not use this file except in compliance with the License.
9
- * You may obtain a copy of the License at
10
- *
11
- * http://www.apache.org/licenses/LICENSE-2.0
12
- *
13
- * Unless required by applicable law or agreed to in writing, software
14
- * distributed under the License is distributed on an "AS IS" BASIS,
15
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
- * See the License for the specific language governing permissions and
17
- * limitations under the License.
18
- */
19
- #pragma once
20
-
21
- #include "attention_generic.cuh"
22
- #include "dtype_float32.cuh"
23
-
24
- #ifndef USE_ROCM
25
- #include <cuda_bf16.h>
26
- #include <cuda_fp16.h>
27
- #else
28
- #include <hip/hip_bf16.h>
29
- #include <hip/hip_fp16.h>
30
-
31
- typedef __hip_bfloat162 __nv_bfloat162;
32
- typedef __hip_bfloat16 __nv_bfloat16;
33
- #endif
34
-
35
- #include <stdint.h>
36
-
37
- namespace vllm {
38
-
39
- // Define custom BF16 vector data types.
40
- struct bf16_4_t {
41
- __nv_bfloat162 x;
42
- __nv_bfloat162 y;
43
- };
44
-
45
- struct bf16_8_t {
46
- __nv_bfloat162 x;
47
- __nv_bfloat162 y;
48
- __nv_bfloat162 z;
49
- __nv_bfloat162 w;
50
- };
51
-
52
- // BF16 vector types for Q, K, V.
53
- template<>
54
- struct Vec<__nv_bfloat16, 1> {
55
- using Type = __nv_bfloat16;
56
- };
57
- template<>
58
- struct Vec<__nv_bfloat16, 2> {
59
- using Type = __nv_bfloat162;
60
- };
61
- template<>
62
- struct Vec<__nv_bfloat16, 4> {
63
- using Type = bf16_4_t;
64
- };
65
- template<>
66
- struct Vec<__nv_bfloat16, 8> {
67
- using Type = bf16_8_t;
68
- };
69
-
70
- // FP32 accumulator vector types corresponding to Vec.
71
- template<>
72
- struct FloatVec<__nv_bfloat16> {
73
- using Type = float;
74
- };
75
- template<>
76
- struct FloatVec<__nv_bfloat162> {
77
- using Type = float2;
78
- };
79
- template<>
80
- struct FloatVec<bf16_4_t> {
81
- using Type = Float4_;
82
- };
83
- template<>
84
- struct FloatVec<bf16_8_t> {
85
- using Type = Float8_;
86
- };
87
-
88
- // Utility functions for type conversions.
89
- inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
90
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
91
- assert(false);
92
- #else
93
- return __bfloat1622float2(val);
94
- #endif
95
- }
96
-
97
- inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
98
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
99
- assert(false);
100
- #else
101
- return __bfloat162bfloat162(val);
102
- #endif
103
- }
104
-
105
- // Vector addition.
106
- inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
107
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
108
- assert(false);
109
- #else
110
- #ifndef USE_ROCM
111
- return a + b;
112
- #else
113
- return __hadd(a, b);
114
- #endif
115
- #endif
116
- }
117
-
118
- inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
119
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
120
- assert(false);
121
- #else
122
- return __hadd2(a, b);
123
- #endif
124
- }
125
-
126
- inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
127
- bf16_4_t c;
128
- c.x = add(a.x, b.x);
129
- c.y = add(a.y, b.y);
130
- return c;
131
- }
132
-
133
- inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
134
- bf16_8_t c;
135
- c.x = add(a.x, b.x);
136
- c.y = add(a.y, b.y);
137
- c.z = add(a.z, b.z);
138
- c.w = add(a.w, b.w);
139
- return c;
140
- }
141
-
142
- inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
143
- float2 fa = bf1622float2(a);
144
- return add(fa, fb);
145
- }
146
-
147
- inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
148
- Float4_ fc;
149
- fc.x = add(a.x, fb.x);
150
- fc.y = add(a.y, fb.y);
151
- return fc;
152
- }
153
-
154
- inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
155
- Float8_ fc;
156
- fc.x = add(a.x, fb.x);
157
- fc.y = add(a.y, fb.y);
158
- fc.z = add(a.z, fb.z);
159
- fc.w = add(a.w, fb.w);
160
- return fc;
161
- }
162
-
163
- // Vector multiplication.
164
- template<>
165
- inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
166
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
167
- assert(false);
168
- #else
169
- return __hmul(a, b);
170
- #endif
171
- }
172
-
173
- template<>
174
- inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
175
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
176
- assert(false);
177
- #else
178
- return __hmul2(a, b);
179
- #endif
180
- }
181
-
182
- template<>
183
- inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
184
- return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
185
- }
186
-
187
- template<>
188
- inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
189
- bf16_4_t c;
190
- c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
191
- c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
192
- return c;
193
- }
194
-
195
- template<>
196
- inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
197
- __nv_bfloat162 s = bf162bf162(a);
198
- bf16_4_t c;
199
- c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
200
- c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
201
- return c;
202
- }
203
-
204
- template<>
205
- inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
206
- bf16_8_t c;
207
- c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
208
- c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
209
- c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
210
- c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
211
- return c;
212
- }
213
-
214
- template<>
215
- inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
216
- __nv_bfloat162 s = bf162bf162(a);
217
- bf16_8_t c;
218
- c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
219
- c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
220
- c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
221
- c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
222
- return c;
223
- }
224
-
225
- template<>
226
- inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
227
- float fa = __bfloat162float(a);
228
- float fb = __bfloat162float(b);
229
- return fa * fb;
230
- }
231
-
232
- template<>
233
- inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
234
- float2 fa = bf1622float2(a);
235
- float2 fb = bf1622float2(b);
236
- return mul<float2, float2, float2>(fa, fb);
237
- }
238
-
239
- template<>
240
- inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
241
- return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
242
- }
243
-
244
- template<>
245
- inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
246
- Float4_ fc;
247
- fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
248
- fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
249
- return fc;
250
- }
251
-
252
- template<>
253
- inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
254
- __nv_bfloat162 s = bf162bf162(a);
255
- Float4_ fc;
256
- fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
257
- fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
258
- return fc;
259
- }
260
-
261
- template<>
262
- inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
263
- Float8_ fc;
264
- fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
265
- fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
266
- fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
267
- fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
268
- return fc;
269
- }
270
-
271
- template<>
272
- inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
273
- __nv_bfloat162 s = bf162bf162(a);
274
- Float8_ fc;
275
- fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
276
- fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
277
- fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
278
- fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
279
- return fc;
280
- }
281
-
282
- // Vector fused multiply-add.
283
- inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
284
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
285
- assert(false);
286
- #else
287
- return __hfma2(a, b, c);
288
- #endif
289
- }
290
-
291
- inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
292
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
293
- assert(false);
294
- #else
295
- return __hfma2(bf162bf162(a), b, c);
296
- #endif
297
- }
298
-
299
- inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
300
- bf16_4_t d;
301
- d.x = fma(a.x, b.x, c.x);
302
- d.y = fma(a.y, b.y, c.y);
303
- return d;
304
- }
305
-
306
- inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
307
- __nv_bfloat162 s = bf162bf162(a);
308
- bf16_4_t d;
309
- d.x = fma(s, b.x, c.x);
310
- d.y = fma(s, b.y, c.y);
311
- return d;
312
- }
313
-
314
- inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
315
- bf16_8_t d;
316
- d.x = fma(a.x, b.x, c.x);
317
- d.y = fma(a.y, b.y, c.y);
318
- d.z = fma(a.z, b.z, c.z);
319
- d.w = fma(a.w, b.w, c.w);
320
- return d;
321
- }
322
-
323
- inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
324
- __nv_bfloat162 s = bf162bf162(a);
325
- bf16_8_t d;
326
- d.x = fma(s, b.x, c.x);
327
- d.y = fma(s, b.y, c.y);
328
- d.z = fma(s, b.z, c.z);
329
- d.w = fma(s, b.w, c.w);
330
- return d;
331
- }
332
-
333
- inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
334
- return __bfloat162float(a) * __bfloat162float(b) + fc;
335
- }
336
-
337
- inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
338
- float2 fa = bf1622float2(a);
339
- float2 fb = bf1622float2(b);
340
- return fma(fa, fb, fc);
341
- }
342
-
343
- inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
344
- return fma(bf162bf162(a), b, fc);
345
- }
346
-
347
- inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
348
- Float4_ fd;
349
- fd.x = fma(a.x, b.x, fc.x);
350
- fd.y = fma(a.y, b.y, fc.y);
351
- return fd;
352
- }
353
-
354
- inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
355
- __nv_bfloat162 s = bf162bf162(a);
356
- Float4_ fd;
357
- fd.x = fma(s, b.x, fc.x);
358
- fd.y = fma(s, b.y, fc.y);
359
- return fd;
360
- }
361
-
362
- inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
363
- Float8_ fd;
364
- fd.x = fma(a.x, b.x, fc.x);
365
- fd.y = fma(a.y, b.y, fc.y);
366
- fd.z = fma(a.z, b.z, fc.z);
367
- fd.w = fma(a.w, b.w, fc.w);
368
- return fd;
369
- }
370
-
371
- inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
372
- __nv_bfloat162 s = bf162bf162(a);
373
- Float8_ fd;
374
- fd.x = fma(s, b.x, fc.x);
375
- fd.y = fma(s, b.y, fc.y);
376
- fd.z = fma(s, b.z, fc.z);
377
- fd.w = fma(s, b.w, fc.w);
378
- return fd;
379
- }
380
-
381
- // Vector sum.
382
- template<>
383
- inline __device__ float sum(__nv_bfloat16 v) {
384
- return __bfloat162float(v);
385
- }
386
-
387
- template<>
388
- inline __device__ float sum(__nv_bfloat162 v) {
389
- float2 vf = bf1622float2(v);
390
- return vf.x + vf.y;
391
- }
392
-
393
- template<>
394
- inline __device__ float sum(bf16_4_t v) {
395
- return sum(v.x) + sum(v.y);
396
- }
397
-
398
- template<>
399
- inline __device__ float sum(bf16_8_t v) {
400
- return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
401
- }
402
-
403
- // From float32 to bfloat16.
404
- inline __device__ void from_float(__nv_bfloat16& dst, float src) {
405
- dst = __float2bfloat16(src);
406
- }
407
-
408
- inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
409
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
410
- assert(false);
411
- #else
412
- dst = __float22bfloat162_rn(src);
413
- #endif
414
- }
415
-
416
- inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
417
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
418
- assert(false);
419
- #else
420
- dst.x = __float22bfloat162_rn(src.x);
421
- dst.y = __float22bfloat162_rn(src.y);
422
- #endif
423
- }
424
-
425
- inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
426
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
427
- assert(false);
428
- #else
429
- dst.x = __float22bfloat162_rn(src.x);
430
- dst.y = __float22bfloat162_rn(src.y);
431
- dst.z = __float22bfloat162_rn(src.z);
432
- dst.w = __float22bfloat162_rn(src.w);
433
- #endif
434
- }
435
-
436
- // From bfloat16 to float32.
437
- inline __device__ float to_float(__nv_bfloat16 u) {
438
- return __bfloat162float(u);
439
- }
440
-
441
- // Zero-out a variable.
442
- inline __device__ void zero(__nv_bfloat16& dst) {
443
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
444
- assert(false);
445
- #else
446
- // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
447
- dst = __ushort_as_bfloat16((unsigned short)0x0000U);
448
- #endif
449
- }
450
-
451
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/attention/dtype_float16.cuh DELETED
@@ -1,502 +0,0 @@
1
- /*
2
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
- * Copyright (c) 2023, The vLLM team.
5
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
- *
7
- * Licensed under the Apache License, Version 2.0 (the "License");
8
- * you may not use this file except in compliance with the License.
9
- * You may obtain a copy of the License at
10
- *
11
- * http://www.apache.org/licenses/LICENSE-2.0
12
- *
13
- * Unless required by applicable law or agreed to in writing, software
14
- * distributed under the License is distributed on an "AS IS" BASIS,
15
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
- * See the License for the specific language governing permissions and
17
- * limitations under the License.
18
- */
19
- #pragma once
20
-
21
- #include "attention_generic.cuh"
22
- #include "dtype_float32.cuh"
23
-
24
- #ifdef USE_ROCM
25
- #include <hip/hip_fp16.h>
26
- #endif
27
-
28
- #include <stdint.h>
29
-
30
- namespace vllm {
31
-
32
- // FP16 vector types for Q, K, V.
33
- template<>
34
- struct Vec<uint16_t, 1> {
35
- using Type = uint16_t;
36
- };
37
- template<>
38
- struct Vec<uint16_t, 2> {
39
- using Type = uint32_t;
40
- };
41
- template<>
42
- struct Vec<uint16_t, 4> {
43
- using Type = uint2;
44
- };
45
- template<>
46
- struct Vec<uint16_t, 8> {
47
- using Type = uint4;
48
- };
49
-
50
- // FP32 accumulator vector types corresponding to Vec.
51
- template<>
52
- struct FloatVec<uint16_t> {
53
- using Type = float;
54
- };
55
- template<>
56
- struct FloatVec<uint32_t> {
57
- using Type = float2;
58
- };
59
- template<>
60
- struct FloatVec<uint2> {
61
- using Type = Float4_;
62
- };
63
- template<>
64
- struct FloatVec<uint4> {
65
- using Type = Float8_;
66
- };
67
-
68
- // Utility functions for type conversions.
69
- inline __device__ uint32_t h0_h0(uint16_t a) {
70
- #ifndef USE_ROCM
71
- uint32_t b;
72
- asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
73
- return b;
74
- #else
75
- union {
76
- uint32_t u32;
77
- uint16_t u16[2];
78
- } tmp;
79
- tmp.u16[0] = a;
80
- tmp.u16[1] = a;
81
- return tmp.u32;
82
- #endif
83
- }
84
-
85
- inline __device__ float half_to_float(uint16_t h) {
86
- float f;
87
- #ifndef USE_ROCM
88
- asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
89
- #else
90
- asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
91
- #endif
92
- return f;
93
- }
94
-
95
- inline __device__ float2 half2_to_float2(uint32_t v) {
96
- #ifndef USE_ROCM
97
- uint16_t lo, hi;
98
- asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
99
- return make_float2(half_to_float(lo), half_to_float(hi));
100
- #else
101
- union {
102
- uint32_t u32;
103
- uint16_t u16[2];
104
- } tmp;
105
- tmp.u32 = v;
106
- float2 ret;
107
- ret.x = half_to_float(tmp.u16[0]);
108
- ret.y = half_to_float(tmp.u16[1]);
109
- return ret;
110
- #endif
111
- }
112
-
113
- inline __device__ uint16_t float_to_half(float f) {
114
- union {
115
- uint32_t u32;
116
- uint16_t u16[2];
117
- } tmp;
118
- #ifndef USE_ROCM
119
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
120
- #else
121
- asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
122
- #endif
123
- return tmp.u16[0];
124
- }
125
-
126
- inline __device__ uint32_t float2_to_half2(float2 f) {
127
- union {
128
- uint32_t u32;
129
- uint16_t u16[2];
130
- } tmp;
131
- #ifndef USE_ROCM
132
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
133
- asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
134
- #else
135
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
136
- asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
137
- #endif
138
- #else
139
- tmp.u16[0] = float_to_half(f.x);
140
- tmp.u16[1] = float_to_half(f.y);
141
- #endif
142
- return tmp.u32;
143
- }
144
-
145
- // Vector addition.
146
- inline __device__ uint16_t add(uint16_t a, uint16_t b) {
147
- uint16_t c;
148
- #ifndef USE_ROCM
149
- asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
150
- #else
151
- asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
152
- #endif
153
- return c;
154
- }
155
-
156
- inline __device__ uint32_t add(uint32_t a, uint32_t b) {
157
- uint32_t c;
158
- #ifndef USE_ROCM
159
- asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
160
- #else
161
- asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
162
- #endif
163
- return c;
164
- }
165
-
166
- inline __device__ uint2 add(uint2 a, uint2 b) {
167
- uint2 c;
168
- c.x = add(a.x, b.x);
169
- c.y = add(a.y, b.y);
170
- return c;
171
- }
172
-
173
- inline __device__ uint4 add(uint4 a, uint4 b) {
174
- uint4 c;
175
- c.x = add(a.x, b.x);
176
- c.y = add(a.y, b.y);
177
- c.z = add(a.z, b.z);
178
- c.w = add(a.w, b.w);
179
- return c;
180
- }
181
-
182
- inline __device__ float2 add(uint32_t a, float2 fb) {
183
- float2 fa = half2_to_float2(a);
184
- return add(fa, fb);
185
- }
186
-
187
- inline __device__ Float4_ add(uint2 a, Float4_ fb) {
188
- Float4_ fc;
189
- fc.x = add(a.x, fb.x);
190
- fc.y = add(a.y, fb.y);
191
- return fc;
192
- }
193
-
194
- inline __device__ Float8_ add(uint4 a, Float8_ fb) {
195
- Float8_ fc;
196
- fc.x = add(a.x, fb.x);
197
- fc.y = add(a.y, fb.y);
198
- fc.z = add(a.z, fb.z);
199
- fc.w = add(a.w, fb.w);
200
- return fc;
201
- }
202
-
203
- // Vector multiplication.
204
- template<>
205
- inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
206
- uint16_t c;
207
- #ifndef USE_ROCM
208
- asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
209
- #else
210
- asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
211
- #endif
212
- return c;
213
- }
214
-
215
- template<>
216
- inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
217
- uint32_t c;
218
- #ifndef USE_ROCM
219
- asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
220
- #else
221
- asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
222
- #endif
223
- return c;
224
- }
225
-
226
- template<>
227
- inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
228
- return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
229
- }
230
-
231
- template<>
232
- inline __device__ uint2 mul(uint2 a, uint2 b) {
233
- uint2 c;
234
- c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
235
- c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
236
- return c;
237
- }
238
-
239
- template<>
240
- inline __device__ uint2 mul(uint16_t a, uint2 b) {
241
- uint32_t s = h0_h0(a);
242
- uint2 c;
243
- c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
244
- c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
245
- return c;
246
- }
247
-
248
- template<>
249
- inline __device__ uint4 mul(uint4 a, uint4 b) {
250
- uint4 c;
251
- c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
252
- c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
253
- c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
254
- c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
255
- return c;
256
- }
257
-
258
- template<>
259
- inline __device__ uint4 mul(uint16_t a, uint4 b) {
260
- uint32_t s = h0_h0(a);
261
- uint4 c;
262
- c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
263
- c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
264
- c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
265
- c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
266
- return c;
267
- }
268
-
269
- template<>
270
- inline __device__ float mul(uint16_t a, uint16_t b) {
271
- float fa = half_to_float(a);
272
- float fb = half_to_float(b);
273
- return fa * fb;
274
- }
275
-
276
- template<>
277
- inline __device__ float2 mul(uint32_t a, uint32_t b) {
278
- float2 fa = half2_to_float2(a);
279
- float2 fb = half2_to_float2(b);
280
- return mul<float2, float2, float2>(fa, fb);
281
- }
282
-
283
- template<>
284
- inline __device__ float2 mul(uint16_t a, uint32_t b) {
285
- return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
286
- }
287
-
288
- template<>
289
- inline __device__ Float4_ mul(uint2 a, uint2 b) {
290
- Float4_ fc;
291
- fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
292
- fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
293
- return fc;
294
- }
295
-
296
- template<>
297
- inline __device__ Float4_ mul(uint16_t a, uint2 b) {
298
- uint32_t s = h0_h0(a);
299
- Float4_ fc;
300
- fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
301
- fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
302
- return fc;
303
- }
304
-
305
- template<>
306
- inline __device__ Float8_ mul(uint4 a, uint4 b) {
307
- Float8_ fc;
308
- fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
309
- fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
310
- fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
311
- fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
312
- return fc;
313
- }
314
-
315
- template<>
316
- inline __device__ Float8_ mul(uint16_t a, uint4 b) {
317
- uint32_t s = h0_h0(a);
318
- Float8_ fc;
319
- fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
320
- fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
321
- fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
322
- fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
323
- return fc;
324
- }
325
-
326
- // Vector fused multiply-add.
327
- inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
328
- uint32_t d;
329
- #ifndef USE_ROCM
330
- asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
331
- #else
332
- asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
333
- #endif
334
- return d;
335
- }
336
-
337
- inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
338
- return fma(h0_h0(a), b, c);
339
- }
340
-
341
- inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
342
- uint2 d;
343
- d.x = fma(a.x, b.x, c.x);
344
- d.y = fma(a.y, b.y, c.y);
345
- return d;
346
- }
347
-
348
- inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
349
- uint32_t s = h0_h0(a);
350
- uint2 d;
351
- d.x = fma(s, b.x, c.x);
352
- d.y = fma(s, b.y, c.y);
353
- return d;
354
- }
355
-
356
- inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
357
- uint4 d;
358
- d.x = fma(a.x, b.x, c.x);
359
- d.y = fma(a.y, b.y, c.y);
360
- d.z = fma(a.z, b.z, c.z);
361
- d.w = fma(a.w, b.w, c.w);
362
- return d;
363
- }
364
-
365
- inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
366
- uint32_t s = h0_h0(a);
367
- uint4 d;
368
- d.x = fma(s, b.x, c.x);
369
- d.y = fma(s, b.y, c.y);
370
- d.z = fma(s, b.z, c.z);
371
- d.w = fma(s, b.w, c.w);
372
- return d;
373
- }
374
-
375
- inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
376
- float fa = half_to_float(a);
377
- float fb = half_to_float(b);
378
- return fa * fb + fc;
379
- }
380
-
381
- inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
382
- float2 fa = half2_to_float2(a);
383
- float2 fb = half2_to_float2(b);
384
- return fma(fa, fb, fc);
385
- }
386
-
387
- inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
388
- return fma(h0_h0(a), b, fc);
389
- }
390
-
391
- inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
392
- Float4_ fd;
393
- fd.x = fma(a.x, b.x, fc.x);
394
- fd.y = fma(a.y, b.y, fc.y);
395
- return fd;
396
- }
397
-
398
- inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
399
- uint32_t s = h0_h0(a);
400
- Float4_ fd;
401
- fd.x = fma(s, b.x, fc.x);
402
- fd.y = fma(s, b.y, fc.y);
403
- return fd;
404
- }
405
-
406
- inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
407
- Float8_ fd;
408
- fd.x = fma(a.x, b.x, fc.x);
409
- fd.y = fma(a.y, b.y, fc.y);
410
- fd.z = fma(a.z, b.z, fc.z);
411
- fd.w = fma(a.w, b.w, fc.w);
412
- return fd;
413
- }
414
-
415
- inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
416
- uint32_t s = h0_h0(a);
417
- Float8_ fd;
418
- fd.x = fma(s, b.x, fc.x);
419
- fd.y = fma(s, b.y, fc.y);
420
- fd.z = fma(s, b.z, fc.z);
421
- fd.w = fma(s, b.w, fc.w);
422
- return fd;
423
- }
424
-
425
- // Vector sum.
426
- template<>
427
- inline __device__ float sum(uint16_t v) {
428
- return half_to_float(v);
429
- }
430
-
431
- template<>
432
- inline __device__ float sum(uint32_t v) {
433
- float2 tmp = half2_to_float2(v);
434
- return tmp.x + tmp.y;
435
- }
436
-
437
- template<>
438
- inline __device__ float sum(uint2 v) {
439
- uint32_t c = add(v.x, v.y);
440
- return sum(c);
441
- }
442
-
443
- template<>
444
- inline __device__ float sum(uint4 v) {
445
- uint32_t c = add(v.x, v.y);
446
- c = add(c, v.z);
447
- c = add(c, v.w);
448
- return sum(c);
449
- }
450
-
451
- // From float32 to float16.
452
- inline __device__ void from_float(uint16_t& dst, float src) {
453
- dst = float_to_half(src);
454
- }
455
-
456
- inline __device__ void from_float(uint32_t& dst, float2 src) {
457
- dst = float2_to_half2(src);
458
- }
459
-
460
- inline __device__ void from_float(uint2& dst, Float4_ src) {
461
- dst.x = float2_to_half2(src.x);
462
- dst.y = float2_to_half2(src.y);
463
- }
464
-
465
- inline __device__ void from_float(uint4& dst, Float8_ src) {
466
- dst.x = float2_to_half2(src.x);
467
- dst.y = float2_to_half2(src.y);
468
- dst.z = float2_to_half2(src.z);
469
- dst.w = float2_to_half2(src.w);
470
- }
471
-
472
- // From float16 to float32.
473
- inline __device__ float to_float(uint16_t u) {
474
- return half_to_float(u);
475
- }
476
-
477
- inline __device__ float2 to_float(uint32_t u) {
478
- return half2_to_float2(u);
479
- }
480
-
481
- inline __device__ Float4_ to_float(uint2 u) {
482
- Float4_ tmp;
483
- tmp.x = half2_to_float2(u.x);
484
- tmp.y = half2_to_float2(u.y);
485
- return tmp;
486
- }
487
-
488
- inline __device__ Float8_ to_float(uint4 u) {
489
- Float8_ tmp;
490
- tmp.x = half2_to_float2(u.x);
491
- tmp.y = half2_to_float2(u.y);
492
- tmp.z = half2_to_float2(u.z);
493
- tmp.w = half2_to_float2(u.w);
494
- return tmp;
495
- }
496
-
497
- // Zero-out a variable.
498
- inline __device__ void zero(uint16_t& dst) {
499
- dst = uint16_t(0);
500
- }
501
-
502
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/attention/dtype_float32.cuh DELETED
@@ -1,273 +0,0 @@
1
- /*
2
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
- * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
- * Copyright (c) 2023, The vLLM team.
5
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
- *
7
- * Licensed under the Apache License, Version 2.0 (the "License");
8
- * you may not use this file except in compliance with the License.
9
- * You may obtain a copy of the License at
10
- *
11
- * http://www.apache.org/licenses/LICENSE-2.0
12
- *
13
- * Unless required by applicable law or agreed to in writing, software
14
- * distributed under the License is distributed on an "AS IS" BASIS,
15
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
- * See the License for the specific language governing permissions and
17
- * limitations under the License.
18
- */
19
- #pragma once
20
-
21
- #include "attention_generic.cuh"
22
-
23
- #include <stdint.h>
24
-
25
- namespace vllm {
26
-
27
- // Define custom FP32 vector data types.
28
- struct Float4_ {
29
- float2 x;
30
- float2 y;
31
- };
32
-
33
- struct Float8_ {
34
- float2 x;
35
- float2 y;
36
- float2 z;
37
- float2 w;
38
- };
39
-
40
- // FP32 vector types for Q, K, V.
41
- template<>
42
- struct Vec<float, 1> {
43
- using Type = float;
44
- };
45
- template<>
46
- struct Vec<float, 2> {
47
- using Type = float2;
48
- };
49
- template<>
50
- struct Vec<float, 4> {
51
- using Type = float4;
52
- };
53
-
54
- // FP32 accumulator vector types corresponding to Vec.
55
- template<>
56
- struct FloatVec<float> {
57
- using Type = float;
58
- };
59
- template<>
60
- struct FloatVec<float2> {
61
- using Type = float2;
62
- };
63
- template<>
64
- struct FloatVec<float4> {
65
- using Type = float4;
66
- };
67
-
68
- // Vector addition.
69
- inline __device__ float add(float a, float b) {
70
- return a + b;
71
- }
72
-
73
- inline __device__ float2 add(float2 a, float2 b) {
74
- float2 c;
75
- c.x = add(a.x, b.x);
76
- c.y = add(a.y, b.y);
77
- return c;
78
- }
79
-
80
- inline __device__ float4 add(float4 a, float4 b) {
81
- float4 c;
82
- c.x = add(a.x, b.x);
83
- c.y = add(a.y, b.y);
84
- c.z = add(a.z, b.z);
85
- c.w = add(a.w, b.w);
86
- return c;
87
- }
88
-
89
- // Vector multiplication.
90
- template<>
91
- inline __device__ float mul<float, float>(float a, float b) {
92
- return a * b;
93
- }
94
-
95
- template<>
96
- inline __device__ float2 mul(float2 a, float2 b) {
97
- float2 c;
98
- c.x = a.x * b.x;
99
- c.y = a.y * b.y;
100
- return c;
101
- }
102
-
103
- template<>
104
- inline __device__ float2 mul(float a, float2 b) {
105
- float2 c;
106
- c.x = a * b.x;
107
- c.y = a * b.y;
108
- return c;
109
- }
110
-
111
- template<>
112
- inline __device__ float4 mul(float4 a, float4 b) {
113
- float4 c;
114
- c.x = a.x * b.x;
115
- c.y = a.y * b.y;
116
- c.z = a.z * b.z;
117
- c.w = a.w * b.w;
118
- return c;
119
- }
120
-
121
- template<>
122
- inline __device__ float4 mul(float a, float4 b) {
123
- float4 c;
124
- c.x = a * b.x;
125
- c.y = a * b.y;
126
- c.z = a * b.z;
127
- c.w = a * b.w;
128
- return c;
129
- }
130
-
131
- // Vector fused multiply-add.
132
- inline __device__ float fma(float a, float b, float c) {
133
- return a * b + c;
134
- }
135
-
136
- inline __device__ float2 fma(float2 a, float2 b, float2 c) {
137
- float2 d;
138
- d.x = fma(a.x, b.x, c.x);
139
- d.y = fma(a.y, b.y, c.y);
140
- return d;
141
- }
142
-
143
- inline __device__ float2 fma(float a, float2 b, float2 c) {
144
- float2 d;
145
- d.x = fma(a, b.x, c.x);
146
- d.y = fma(a, b.y, c.y);
147
- return d;
148
- }
149
-
150
- inline __device__ float4 fma(float4 a, float4 b, float4 c) {
151
- float4 d;
152
- d.x = fma(a.x, b.x, c.x);
153
- d.y = fma(a.y, b.y, c.y);
154
- d.z = fma(a.z, b.z, c.z);
155
- d.w = fma(a.w, b.w, c.w);
156
- return d;
157
- }
158
-
159
- inline __device__ float4 fma(float a, float4 b, float4 c) {
160
- float4 d;
161
- d.x = fma(a, b.x, c.x);
162
- d.y = fma(a, b.y, c.y);
163
- d.z = fma(a, b.z, c.z);
164
- d.w = fma(a, b.w, c.w);
165
- return d;
166
- }
167
-
168
- inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
169
- Float4_ d;
170
- d.x = fma(a, b.x, c.x);
171
- d.y = fma(a, b.y, c.y);
172
- return d;
173
- }
174
-
175
- inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
176
- Float8_ d;
177
- d.x = fma(a, b.x, c.x);
178
- d.y = fma(a, b.y, c.y);
179
- d.z = fma(a, b.z, c.z);
180
- d.w = fma(a, b.w, c.w);
181
- return d;
182
- }
183
-
184
- // Vector sum.
185
- template<>
186
- inline __device__ float sum(float v) {
187
- return v;
188
- }
189
-
190
- template<>
191
- inline __device__ float sum(float2 v) {
192
- return v.x + v.y;
193
- }
194
-
195
- template<>
196
- inline __device__ float sum(float4 v) {
197
- return v.x + v.y + v.z + v.w;
198
- }
199
-
200
- template<>
201
- inline __device__ float sum(Float4_ v) {
202
- return v.x.x + v.x.y + v.y.x + v.y.y;
203
- }
204
-
205
- template<>
206
- inline __device__ float sum(Float8_ v) {
207
- return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
208
- }
209
-
210
- // Vector dot product.
211
- inline __device__ float dot(float a, float b) {
212
- return a * b;
213
- }
214
-
215
- inline __device__ float dot(float2 a, float2 b) {
216
- float2 c = mul<float2, float2, float2>(a, b);
217
- return c.x + c.y;
218
- }
219
-
220
- inline __device__ float dot(Float4_ a, Float4_ b) {
221
- float2 acc = mul<float2, float2, float2>(a.x, b.x);
222
- acc = fma(a.y, b.y, acc);
223
- return acc.x + acc.y;
224
- }
225
-
226
- inline __device__ float dot(Float8_ a, Float8_ b) {
227
- float2 acc = mul<float2, float2, float2>(a.x, b.x);
228
- acc = fma(a.y, b.y, acc);
229
- acc = fma(a.z, b.z, acc);
230
- acc = fma(a.w, b.w, acc);
231
- return acc.x + acc.y;
232
- }
233
-
234
- // From float to float.
235
- inline __device__ void from_float(float& dst, float src) {
236
- dst = src;
237
- }
238
-
239
- inline __device__ void from_float(float2& dst, float2 src) {
240
- dst = src;
241
- }
242
-
243
- inline __device__ void from_float(float4& dst, float4 src) {
244
- dst = src;
245
- }
246
-
247
- // From float to float.
248
- inline __device__ float to_float(float u) {
249
- return u;
250
- }
251
-
252
- inline __device__ float2 to_float(float2 u) {
253
- return u;
254
- }
255
-
256
- inline __device__ float4 to_float(float4 u) {
257
- return u;
258
- }
259
-
260
- inline __device__ Float4_ to_float(Float4_ u) {
261
- return u;
262
- }
263
-
264
- inline __device__ Float8_ to_float(Float8_ u) {
265
- return u;
266
- }
267
-
268
- // Zero-out a variable.
269
- inline __device__ void zero(float& dst) {
270
- dst = 0.f;
271
- }
272
-
273
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/attention/dtype_fp8_e5m2.cuh DELETED
@@ -1,35 +0,0 @@
1
- #pragma once
2
-
3
- #include "attention_generic.cuh"
4
-
5
- #include <stdint.h>
6
- #ifdef ENABLE_FP8_E5M2
7
- #include <cuda_fp8.h>
8
- #endif
9
-
10
- namespace vllm {
11
- #ifdef ENABLE_FP8_E5M2
12
- // fp8 vector types for quantization of kv cache
13
-
14
- template<>
15
- struct Vec<uint8_t, 1> {
16
- using Type = uint8_t;
17
- };
18
-
19
- template<>
20
- struct Vec<uint8_t, 2> {
21
- using Type = uint16_t;
22
- };
23
-
24
- template<>
25
- struct Vec<uint8_t, 4> {
26
- using Type = uint32_t;
27
- };
28
-
29
- template<>
30
- struct Vec<uint8_t, 8> {
31
- using Type = uint2;
32
- };
33
- #endif // ENABLE_FP8_E5M2
34
-
35
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/cache.h DELETED
@@ -1,36 +0,0 @@
1
- #pragma once
2
-
3
- #include <torch/extension.h>
4
-
5
- #include <map>
6
- #include <vector>
7
-
8
- void swap_blocks(
9
- torch::Tensor& src,
10
- torch::Tensor& dst,
11
- const std::map<int64_t, int64_t>& block_mapping);
12
-
13
- void copy_blocks(
14
- std::vector<torch::Tensor>& key_caches,
15
- std::vector<torch::Tensor>& value_caches,
16
- const std::map<int64_t, std::vector<int64_t>>& block_mapping);
17
-
18
- void reshape_and_cache(
19
- torch::Tensor& key,
20
- torch::Tensor& value,
21
- torch::Tensor& key_cache,
22
- torch::Tensor& value_cache,
23
- torch::Tensor& slot_mapping,
24
- const std::string& kv_cache_dtype);
25
-
26
- void gather_cached_kv(
27
- torch::Tensor& key,
28
- torch::Tensor& value,
29
- torch::Tensor& key_cache,
30
- torch::Tensor& value_cache,
31
- torch::Tensor& slot_mapping);
32
-
33
- // Just for unittest
34
- void convert_fp8_e5m2(
35
- torch::Tensor& src_cache,
36
- torch::Tensor& dst_cache);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/cache_kernels.cu DELETED
@@ -1,474 +0,0 @@
1
- #include <torch/extension.h>
2
- #include <ATen/cuda/CUDAContext.h>
3
- #include <c10/cuda/CUDAGuard.h>
4
-
5
- #include "cuda_compat.h"
6
- #include "dispatch_utils.h"
7
- #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
8
-
9
- #include <algorithm>
10
- #include <cassert>
11
- #include <map>
12
- #include <vector>
13
-
14
- void swap_blocks(
15
- torch::Tensor& src,
16
- torch::Tensor& dst,
17
- const std::map<int64_t, int64_t>& block_mapping) {
18
- torch::Device src_device = src.device();
19
- torch::Device dst_device = dst.device();
20
- cudaMemcpyKind memcpy_type;
21
- if (src_device.is_cuda() && dst_device.is_cuda()) {
22
- TORCH_CHECK(
23
- src_device.index() == dst_device.index(),
24
- "src and dst must be on the same GPU");
25
- memcpy_type = cudaMemcpyDeviceToDevice;
26
- } else if (src_device.is_cuda() && dst_device.is_cpu()) {
27
- memcpy_type = cudaMemcpyDeviceToHost;
28
- } else if (src_device.is_cpu() && dst_device.is_cuda()) {
29
- memcpy_type = cudaMemcpyHostToDevice;
30
- } else {
31
- TORCH_CHECK(false, "Invalid device combination");
32
- }
33
-
34
- char *src_ptr = static_cast<char*>(src.data_ptr());
35
- char *dst_ptr = static_cast<char*>(dst.data_ptr());
36
-
37
- const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
38
- const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
39
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
40
- // NOTE(woosuk): This can be slow if the number of blocks is large.
41
- for (const auto& pair : block_mapping) {
42
- int64_t src_block_number = pair.first;
43
- int64_t dst_block_number = pair.second;
44
- int64_t src_offset = src_block_number * block_size_in_bytes;
45
- int64_t dst_offset = dst_block_number * block_size_in_bytes;
46
- cudaMemcpyAsync(
47
- dst_ptr + dst_offset,
48
- src_ptr + src_offset,
49
- block_size_in_bytes,
50
- memcpy_type,
51
- stream);
52
- }
53
- }
54
-
55
- namespace vllm {
56
-
57
- // Grid: (num_layers, num_pairs)
58
- template<typename scalar_t>
59
- __global__ void copy_blocks_kernel(
60
- int64_t* key_cache_ptrs,
61
- int64_t* value_cache_ptrs,
62
- const int64_t* __restrict__ block_mapping,
63
- const int numel_per_block) {
64
- const int layer_idx = blockIdx.x;
65
- const int pair_idx = blockIdx.y;
66
-
67
- scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
68
- scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
69
- int64_t src_block_number = block_mapping[2 * pair_idx];
70
- int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
71
-
72
- const int64_t src_block_offset = src_block_number * numel_per_block;
73
- const int64_t dst_block_offset = dst_block_number * numel_per_block;
74
- for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
75
- int64_t src_offset = src_block_offset + i;
76
- int64_t dst_offset = dst_block_offset + i;
77
- key_cache[dst_offset] = key_cache[src_offset];
78
- }
79
- for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
80
- int64_t src_offset = src_block_offset + i;
81
- int64_t dst_offset = dst_block_offset + i;
82
- value_cache[dst_offset] = value_cache[src_offset];
83
- }
84
- }
85
-
86
- } // namespace vllm
87
-
88
- void copy_blocks(
89
- std::vector<torch::Tensor>& key_caches,
90
- std::vector<torch::Tensor>& value_caches,
91
- const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
92
- int num_layers = key_caches.size();
93
- TORCH_CHECK(num_layers == value_caches.size());
94
- if (num_layers == 0) {
95
- return;
96
- }
97
- torch::Device cache_device = key_caches[0].device();
98
- TORCH_CHECK(cache_device.is_cuda());
99
-
100
- // Create data structures for the kernel.
101
- // Create an array of pointers to the key and value caches.
102
- int64_t key_cache_ptrs[num_layers];
103
- int64_t value_cache_ptrs[num_layers];
104
- for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
105
- key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
106
- value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
107
- }
108
- // Create block mapping array.
109
- std::vector<int64_t> block_mapping_vec;
110
- for (const auto& pair : block_mapping) {
111
- int64_t src_block_number = pair.first;
112
- for (int64_t dst_block_number : pair.second) {
113
- block_mapping_vec.push_back(src_block_number);
114
- block_mapping_vec.push_back(dst_block_number);
115
- }
116
- }
117
- int64_t* block_mapping_array = block_mapping_vec.data();
118
- int num_pairs = block_mapping_vec.size() / 2;
119
-
120
- // Move the data structures to the GPU.
121
- // NOTE: This synchronizes the CPU and GPU.
122
- torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
123
- key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
124
- torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
125
- value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
126
- torch::Tensor block_mapping_tensor = torch::from_blob(
127
- block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
128
-
129
- // Launch the kernel.
130
- const int numel_per_block = key_caches[0][0].numel();
131
- dim3 grid(num_layers, num_pairs);
132
- dim3 block(std::min(1024, numel_per_block));
133
- const at::cuda::OptionalCUDAGuard device_guard(cache_device);
134
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
135
- VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
136
- key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
137
- vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
138
- key_cache_ptrs_tensor.data_ptr<int64_t>(),
139
- value_cache_ptrs_tensor.data_ptr<int64_t>(),
140
- block_mapping_tensor.data_ptr<int64_t>(),
141
- numel_per_block);
142
- }));
143
- }
144
-
145
- namespace vllm {
146
-
147
- template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
148
- __global__ void reshape_and_cache_kernel(
149
- const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
150
- const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
151
- cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
152
- cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
153
- const int64_t* __restrict__ slot_mapping, // [num_tokens]
154
- const int key_stride,
155
- const int value_stride,
156
- const int num_heads,
157
- const int head_size,
158
- const int block_size,
159
- const int x) {
160
- const int64_t token_idx = blockIdx.x;
161
- const int64_t slot_idx = slot_mapping[token_idx];
162
- if (slot_idx < 0) {
163
- // Padding token that should be ignored.
164
- return;
165
- }
166
-
167
- const int64_t block_idx = slot_idx / block_size;
168
- const int64_t block_offset = slot_idx % block_size;
169
-
170
- const int n = num_heads * head_size;
171
- for (int i = threadIdx.x; i < n; i += blockDim.x) {
172
- const int64_t src_key_idx = token_idx * key_stride + i;
173
- const int64_t src_value_idx = token_idx * value_stride + i;
174
-
175
- const int head_idx = i / head_size;
176
- const int head_offset = i % head_size;
177
- const int x_idx = head_offset / x;
178
- const int x_offset = head_offset % x;
179
-
180
- const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
181
- + head_idx * (head_size / x) * block_size * x
182
- + x_idx * block_size * x
183
- + block_offset * x
184
- + x_offset;
185
- const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
186
- + head_idx * head_size * block_size
187
- + head_offset * block_size
188
- + block_offset;
189
- scalar_t tgt_key = key[src_key_idx];
190
- scalar_t tgt_value = value[src_value_idx];
191
- if constexpr (is_fp8_e5m2_kv_cache) {
192
- #ifdef ENABLE_FP8_E5M2
193
- key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
194
- value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
195
- #else
196
- assert(false);
197
- #endif
198
- } else {
199
- key_cache[tgt_key_idx] = tgt_key;
200
- value_cache[tgt_value_idx] = tgt_value;
201
- }
202
- }
203
- }
204
-
205
- } // namespace vllm
206
-
207
- #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
208
- vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
209
- reinterpret_cast<KV_T*>(key.data_ptr()), \
210
- reinterpret_cast<KV_T*>(value.data_ptr()), \
211
- reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
212
- reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
213
- slot_mapping.data_ptr<int64_t>(), \
214
- key_stride, \
215
- value_stride, \
216
- num_heads, \
217
- head_size, \
218
- block_size, \
219
- x);
220
-
221
- void reshape_and_cache(
222
- torch::Tensor& key, // [num_tokens, num_heads, head_size]
223
- torch::Tensor& value, // [num_tokens, num_heads, head_size]
224
- torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
225
- torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
226
- torch::Tensor& slot_mapping, // [num_tokens]
227
- const std::string& kv_cache_dtype)
228
- {
229
- int num_tokens = key.size(0);
230
- int num_heads = key.size(1);
231
- int head_size = key.size(2);
232
- int block_size = key_cache.size(3);
233
- int x = key_cache.size(4);
234
-
235
- int key_stride = key.stride(0);
236
- int value_stride = value.stride(0);
237
-
238
- dim3 grid(num_tokens);
239
- dim3 block(std::min(num_heads * head_size, 512));
240
- const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
241
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
242
- if (kv_cache_dtype == "auto") {
243
- if (key.dtype() == at::ScalarType::Float) {
244
- CALL_RESHAPE_AND_CACHE(float, float, false);
245
- } else if (key.dtype() == at::ScalarType::Half) {
246
- CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
247
- } else if (key.dtype() == at::ScalarType::BFloat16) {
248
- CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
249
- }
250
- } else if (kv_cache_dtype == "fp8_e5m2") {
251
- if (key.dtype() == at::ScalarType::Float) {
252
- CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
253
- } else if (key.dtype() == at::ScalarType::Half) {
254
- CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
255
- } else if (key.dtype() == at::ScalarType::BFloat16) {
256
- CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
257
- }
258
- } else {
259
- TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
260
- }
261
- }
262
-
263
- namespace vllm {
264
-
265
- // Grid: (num_blocks, block_size).
266
- template<typename scalar_t>
267
- __global__ void gather_cached_kv_kernel(
268
- scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
269
- scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
270
- const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
271
- const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
272
- const int* __restrict__ slot_mapping, // [num_tokens]
273
- const int key_stride,
274
- const int value_stride,
275
- const int num_heads,
276
- const int head_size,
277
- const int block_size,
278
- const int x) {
279
- const int token_idx = blockIdx.x;
280
- const int slot_idx = slot_mapping[token_idx];
281
- const int block_idx = slot_idx / block_size;
282
- const int block_offset = slot_idx % block_size;
283
-
284
- const int num_tokens = num_heads * head_size;
285
- for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
286
- const int tgt_key_idx = token_idx * key_stride + i;
287
- const int tgt_value_idx = token_idx * value_stride + i;
288
-
289
- const int head_idx = i / head_size;
290
- const int head_offset = i % head_size;
291
- const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
292
- const int x_offset = head_offset % x;
293
-
294
- const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
295
- + head_idx * (head_size / x) * block_size * x
296
- + x_idx * block_size * x
297
- + block_offset * x
298
- + x_offset;
299
- const int src_value_idx = block_idx * num_heads * head_size * block_size
300
- + head_idx * head_size * block_size
301
- + head_offset * block_size
302
- + block_offset;
303
-
304
- key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
305
- value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
306
- }
307
- }
308
-
309
- template <typename scalar_t>
310
- __global__ void gather_cached_kv_kernel_optimized(
311
- scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
312
- scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
313
- const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
314
- const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
315
- const int *__restrict__ slot_mapping, // [num_tokens]
316
- const int key_stride,
317
- const int value_stride,
318
- const int num_heads,
319
- const int head_size,
320
- const int block_size,
321
- const int x)
322
- {
323
- const int token_idx = blockIdx.x;
324
- const int slot_idx = slot_mapping[token_idx];
325
- const int block_idx = slot_idx / block_size;
326
- const int block_offset = slot_idx % block_size;
327
-
328
- const int dim = num_heads * head_size;
329
- assert(dim % 4 == 0); // this is true for known use cases
330
- const int unroll_factor = 4;
331
- const int unrolled_dim = dim / unroll_factor;
332
-
333
- for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
334
- {
335
- int tgt_key_indices[unroll_factor];
336
- int tgt_value_indices[unroll_factor];
337
- int src_key_indices[unroll_factor];
338
- int src_value_indices[unroll_factor];
339
- scalar_t keys_to_store[unroll_factor];
340
- scalar_t values_to_store[unroll_factor];
341
-
342
- #pragma unroll
343
- for (int j = 0; j < unroll_factor; ++j)
344
- {
345
- int index = i + j * unrolled_dim;
346
-
347
- const int tgt_key_idx = token_idx * key_stride + index;
348
- const int tgt_value_idx = token_idx * value_stride + index;
349
-
350
- const int head_idx = index / head_size;
351
- const int head_offset = index % head_size;
352
- const int x_idx = head_offset / x;
353
- const int x_offset = head_offset % x;
354
-
355
- const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
356
- + head_idx * (head_size / x) * block_size * x
357
- + x_idx * block_size * x
358
- + block_offset * x
359
- + x_offset;
360
- const int src_value_idx = block_idx * num_heads * head_size * block_size
361
- + head_idx * head_size * block_size
362
- + head_offset * block_size
363
- + block_offset;
364
-
365
- tgt_key_indices[j] = tgt_key_idx;
366
- tgt_value_indices[j] = tgt_value_idx;
367
- src_key_indices[j] = src_key_idx;
368
- src_value_indices[j] = src_value_idx;
369
-
370
- keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
371
- values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
372
- }
373
-
374
- #pragma unroll
375
- for (int j = 0; j < unroll_factor; ++j)
376
- {
377
- key[tgt_key_indices[j]] = keys_to_store[j];
378
- value[tgt_value_indices[j]] = values_to_store[j];
379
- }
380
- }
381
- }
382
-
383
- } // namespace vllm
384
-
385
- void gather_cached_kv(
386
- torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
387
- torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
388
- torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
389
- torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
390
- torch::Tensor& slot_mapping) // [in] [num_tokens]
391
- {
392
- int num_tokens = key.size(0);
393
- int num_heads = key.size(1);
394
- int head_size = key.size(2);
395
- int block_size = key_cache.size(3);
396
- int x = key_cache.size(4);
397
-
398
- int key_stride = key.stride(0);
399
- int value_stride = value.stride(0);
400
-
401
- dim3 grid(num_tokens);
402
- dim3 block(std::min(num_heads * head_size, 512));
403
- const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
404
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
405
- VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
406
- key.scalar_type(),
407
- "gather_cached_kv_kernel_optimized",
408
- [&] {
409
- vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
410
- key.data_ptr<scalar_t>(),
411
- value.data_ptr<scalar_t>(),
412
- key_cache.data_ptr<scalar_t>(),
413
- value_cache.data_ptr<scalar_t>(),
414
- slot_mapping.data_ptr<int>(),
415
- key_stride,
416
- value_stride,
417
- num_heads,
418
- head_size,
419
- block_size,
420
- x);
421
- });
422
- }
423
-
424
- namespace vllm {
425
-
426
- template<typename Tout, typename Tin>
427
- __global__ void convert_fp8_e5m2_kernel(
428
- const Tin* __restrict__ src_cache,
429
- Tout* __restrict__ dst_cache,
430
- const int64_t block_stride) {
431
- const int64_t block_idx = blockIdx.x;
432
- for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
433
- int64_t idx = block_idx * block_stride + i;
434
- #ifdef ENABLE_FP8_E5M2
435
- dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
436
- #else
437
- assert(false);
438
- #endif
439
- }
440
- }
441
-
442
- } // namespace vllm
443
-
444
- #define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
445
- vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
446
- reinterpret_cast<Tin*>(src_cache.data_ptr()), \
447
- reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
448
- block_stride);
449
-
450
- void convert_fp8_e5m2(
451
- torch::Tensor& src_cache,
452
- torch::Tensor& dst_cache)
453
- {
454
- int64_t num_blocks = src_cache.size(0);
455
- int64_t block_stride = src_cache.stride(0);
456
-
457
- dim3 grid(num_blocks);
458
- dim3 block(std::min(block_stride, int64_t(512)));
459
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
460
-
461
- if (src_cache.dtype() == at::ScalarType::Float) {
462
- CALL_CONVERT_FP8_E5M2(uint8_t, float);
463
- } else if (src_cache.dtype() == at::ScalarType::Half) {
464
- CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
465
- } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
466
- CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
467
- } else if (dst_cache.dtype() == at::ScalarType::Float) {
468
- CALL_CONVERT_FP8_E5M2(float, uint8_t);
469
- } else if (dst_cache.dtype() == at::ScalarType::Half) {
470
- CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
471
- } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
472
- CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
473
- }
474
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/cuda_compat.h DELETED
@@ -1,28 +0,0 @@
1
- #pragma once
2
-
3
- #ifndef USE_ROCM
4
- #define VLLM_LDG(arg) __ldg(arg)
5
- #else
6
- #define VLLM_LDG(arg) *(arg)
7
- #endif
8
-
9
- #ifndef USE_ROCM
10
- #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
11
- #else
12
- #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
13
- #endif
14
-
15
- #ifndef USE_ROCM
16
- #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
17
- #else
18
- #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
19
- #endif
20
-
21
- #ifndef USE_ROCM
22
- #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
23
- cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
24
- #else
25
- #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
26
- hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
27
- #endif
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/cuda_utils.h DELETED
@@ -1,10 +0,0 @@
1
- #pragma once
2
-
3
- #include <torch/extension.h>
4
-
5
- int get_device_attribute(
6
- int attribute,
7
- int device_id);
8
-
9
- int get_max_shared_memory_per_block_device_attribute(
10
- int device_id);
 
 
 
 
 
 
 
 
 
 
 
csrc/cuda_utils_kernels.cu DELETED
@@ -1,35 +0,0 @@
1
- #ifdef USE_ROCM
2
- #include <hip/hip_runtime.h>
3
- #include <hip/hip_runtime_api.h>
4
- #endif
5
- int get_device_attribute(
6
- int attribute,
7
- int device_id)
8
- {
9
- int device, value;
10
- if (device_id < 0) {
11
- cudaGetDevice(&device);
12
- }
13
- else {
14
- device = device_id;
15
- }
16
- cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
17
- return value;
18
- }
19
-
20
-
21
- int get_max_shared_memory_per_block_device_attribute(
22
- int device_id)
23
- {
24
- int attribute;
25
- // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
26
- // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
27
-
28
- #ifdef USE_ROCM
29
- attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
30
- #else
31
- attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
32
- #endif
33
-
34
- return get_device_attribute(attribute, device_id);
35
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/custom_all_reduce.cu DELETED
@@ -1,148 +0,0 @@
1
- #include <ATen/cuda/Exceptions.h>
2
- #include <c10/cuda/CUDAGuard.h>
3
- #include <c10/cuda/CUDAStream.h>
4
- #include <torch/extension.h>
5
-
6
- #include "custom_all_reduce.cuh"
7
-
8
- // fake pointer type
9
- using fptr_t = uint64_t;
10
- static_assert(sizeof(void *) == sizeof(fptr_t));
11
-
12
- fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
13
- const std::vector<std::string> &handles,
14
- const std::vector<int64_t> &offsets, int rank,
15
- bool full_nvlink) {
16
- int world_size = offsets.size();
17
- if (world_size > 8)
18
- throw std::invalid_argument("world size > 8 is not supported");
19
- if (world_size % 2 != 0)
20
- throw std::invalid_argument("Odd num gpus is not supported for now");
21
- if (world_size != handles.size())
22
- throw std::invalid_argument(
23
- "handles length should equal to offsets length");
24
- if (rank < 0 || rank >= world_size)
25
- throw std::invalid_argument("invalid rank passed in");
26
-
27
- cudaIpcMemHandle_t ipc_handles[8];
28
- for (int i = 0; i < world_size; i++) {
29
- std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
30
- }
31
- return (fptr_t) new vllm::CustomAllreduce(
32
- reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
33
- rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
34
- }
35
-
36
- /**
37
- * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
38
- * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
39
- * because it allows transpose of contiguous slice (i.e. slicing the first
40
- * dimension). Currently, we require this because stride information is not
41
- * passed into the kernels and we treat input tensors as flat.
42
- *
43
- * Examples
44
- * A = torch.zeros(3, 3, 3)
45
- * 1. A: OK
46
- * 2. A[1:]: OK
47
- * 3. A.permute(2, 0, 1): OK
48
- * 4. A[1:].permute(2, 0, 1): OK
49
- * 5. A[None].expand(2, -1, -1, -1): Not OK
50
- * 6. A[:, 1:, 1:]: Not OK
51
- */
52
- bool _is_weak_contiguous(torch::Tensor &t) {
53
- return t.is_contiguous() ||
54
- (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
55
- t.numel() * t.element_size());
56
- }
57
-
58
- bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
59
- bool full_nvlink) {
60
- auto inp_size = inp.numel() * inp.element_size();
61
- // custom allreduce requires input byte size to be multiples of 16
62
- if (inp_size % 16 != 0) return false;
63
- if (!_is_weak_contiguous(inp)) return false;
64
- if (world_size == 2 || full_nvlink) return inp_size <= max_size;
65
- // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
66
- // <= 512k
67
- return world_size <= 4 && inp_size <= 512 * 1024;
68
- }
69
-
70
- void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
71
- cudaStream_t stream) {
72
- auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
73
- TORCH_CHECK(_is_weak_contiguous(out));
74
- switch (out.scalar_type()) {
75
- case at::ScalarType::Float: {
76
- fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
77
- reinterpret_cast<float *>(out.data_ptr()),
78
- out.numel());
79
- break;
80
- }
81
- case at::ScalarType::Half: {
82
- fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
83
- reinterpret_cast<half *>(out.data_ptr()),
84
- out.numel());
85
- break;
86
- }
87
- #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
88
- case at::ScalarType::BFloat16: {
89
- fa->allreduce<nv_bfloat16>(
90
- stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
91
- reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
92
- break;
93
- }
94
- #endif
95
- default:
96
- throw std::runtime_error(
97
- "custom allreduce only supports float32, float16 and bfloat16");
98
- }
99
- }
100
-
101
- void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
102
- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
103
- auto stream = c10::cuda::getCurrentCUDAStream().stream();
104
- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
105
- TORCH_CHECK_EQ(inp.numel(), out.numel());
106
- _all_reduce(_fa, inp, out, stream);
107
- }
108
-
109
- void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
110
- torch::Tensor &out) {
111
- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
112
- auto stream = c10::cuda::getCurrentCUDAStream().stream();
113
-
114
- auto input_size = inp.numel() * inp.element_size();
115
- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
116
- TORCH_CHECK_EQ(inp.numel(), out.numel());
117
- TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
118
- "registered buffer is too small to contain the input");
119
- AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
120
- input_size, cudaMemcpyDeviceToDevice, stream));
121
- _all_reduce(_fa, reg_buffer, out, stream);
122
- }
123
-
124
- void dispose(fptr_t _fa) {
125
- auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
126
- delete fa;
127
- }
128
-
129
- int meta_size() { return sizeof(vllm::Metadata); }
130
-
131
- void register_buffer(fptr_t _fa, torch::Tensor &t,
132
- const std::vector<std::string> &handles,
133
- const std::vector<int64_t> &offsets) {
134
- auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
135
- fa->register_buffer(handles, offsets, t.data_ptr());
136
- }
137
-
138
- std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
139
- fptr_t _fa) {
140
- auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
141
- return fa->get_graph_buffer_ipc_meta();
142
- }
143
-
144
- void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
145
- const std::vector<std::vector<int64_t>> &offsets) {
146
- auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
147
- fa->register_graph_buffers(handles, offsets);
148
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/custom_all_reduce.cuh DELETED
@@ -1,562 +0,0 @@
1
- #pragma once
2
-
3
- #include <cuda.h>
4
- #include <cuda_bf16.h>
5
- #include <cuda_fp16.h>
6
- #include <cuda_runtime.h>
7
-
8
- #include <iostream>
9
- #include <limits>
10
- #include <map>
11
- #include <unordered_map>
12
- #include <vector>
13
-
14
- #define CUDACHECK(cmd) \
15
- do { \
16
- cudaError_t e = cmd; \
17
- if (e != cudaSuccess) { \
18
- printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
19
- cudaGetErrorString(e)); \
20
- exit(EXIT_FAILURE); \
21
- } \
22
- } while (0)
23
-
24
- namespace vllm {
25
-
26
- struct Signal {
27
- alignas(64) union {
28
- uint64_t flag;
29
- unsigned char data[8];
30
- } start;
31
- alignas(64) union {
32
- uint64_t flag;
33
- unsigned char data[8];
34
- } end;
35
- };
36
-
37
- struct Metadata {
38
- alignas(128) Signal sg;
39
- alignas(128) int counter;
40
- };
41
- static_assert(offsetof(Metadata, counter) == 128);
42
- static_assert(sizeof(Metadata) == 256);
43
-
44
- struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
45
-
46
- struct RankSignals {
47
- volatile Signal *signals[8];
48
- };
49
-
50
- // like std::array, but aligned
51
- template <typename T, int sz>
52
- struct __align__(alignof(T) * sz) array_t {
53
- T data[sz];
54
- using type = T;
55
- static constexpr int size = sz;
56
- };
57
-
58
- // use packed type to maximize memory efficiency
59
- // goal: generate ld.128 and st.128 instructions
60
- template <typename T>
61
- struct packed_t {
62
- // the (P)acked type for load/store
63
- using P = array_t<T, 16 / sizeof(T)>;
64
- // the (A)ccumulator type for reduction
65
- using A = array_t<float, 16 / sizeof(T)>;
66
- };
67
-
68
- #define DINLINE __device__ __forceinline__
69
-
70
- // scalar cast functions
71
- DINLINE float upcast_s(half val) { return __half2float(val); }
72
-
73
- template <typename T>
74
- DINLINE T downcast_s(float val);
75
- template <>
76
- DINLINE half downcast_s(float val) {
77
- return __float2half(val);
78
- }
79
-
80
- // scalar add functions
81
- // for some reason when compiling with Pytorch, the + operator for half and
82
- // bfloat is disabled so we call the intrinsics directly
83
- DINLINE half &assign_add(half &a, half b) {
84
- a = __hadd(a, b);
85
- return a;
86
- }
87
- DINLINE float &assign_add(float &a, float b) { return a += b; }
88
-
89
- #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
90
- DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
91
- template <>
92
- DINLINE nv_bfloat16 downcast_s(float val) {
93
- return __float2bfloat16(val);
94
- }
95
- DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
96
- a = __hadd(a, b);
97
- return a;
98
- }
99
- #endif
100
-
101
- template <typename T, int N>
102
- DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
103
- #pragma unroll
104
- for (int i = 0; i < N; i++) {
105
- assign_add(a.data[i], b.data[i]);
106
- }
107
- return a;
108
- }
109
-
110
- template <typename T, int N>
111
- DINLINE array_t<float, N> upcast(array_t<T, N> val) {
112
- if constexpr (std::is_same<T, float>::value) {
113
- return val;
114
- } else {
115
- array_t<float, N> out;
116
- #pragma unroll
117
- for (int i = 0; i < N; i++) {
118
- out.data[i] = upcast_s(val.data[i]);
119
- }
120
- return out;
121
- }
122
- }
123
-
124
- template <typename O>
125
- DINLINE O downcast(array_t<float, O::size> val) {
126
- if constexpr (std::is_same<typename O::type, float>::value) {
127
- return val;
128
- } else {
129
- O out;
130
- #pragma unroll
131
- for (int i = 0; i < O::size; i++) {
132
- out.data[i] = downcast_s<typename O::type>(val.data[i]);
133
- }
134
- return out;
135
- }
136
- }
137
-
138
- // compute flag at compile time
139
- __host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
140
- auto m = std::numeric_limits<uint64_t>::max();
141
- return m >> ((8 - ngpus) * 8);
142
- }
143
-
144
- template <int ngpus>
145
- DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
146
- int rank) {
147
- constexpr auto FLAG = compute_flag(ngpus);
148
- if (blockIdx.x == 0) {
149
- if (threadIdx.x < ngpus)
150
- // simultaneously write to the corresponding byte to all other ranks.
151
- // Latency = 1 p2p write
152
- sg.signals[threadIdx.x]->start.data[rank] = 255;
153
- else if (threadIdx.x == 32)
154
- // reset
155
- meta->sg.end.flag = 0;
156
- }
157
- if (threadIdx.x == 0) {
158
- while (meta->sg.start.flag != FLAG)
159
- ;
160
- }
161
- __syncthreads();
162
- }
163
-
164
- template <int ngpus, bool final_sync = false>
165
- DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
166
- int rank) {
167
- constexpr auto FLAG = compute_flag(ngpus);
168
- __syncthreads();
169
- __shared__ int num;
170
- if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
171
- __syncthreads();
172
-
173
- // Only the last completing block can perform the end synchronization
174
- // This can ensures when the final busy wait ends, all ranks must have
175
- // finished reading each other's buffer.
176
- if (num == gridDim.x - 1) {
177
- if (threadIdx.x == 32) {
178
- // reset in a different warp
179
- meta->counter = 0;
180
- meta->sg.start.flag = 0;
181
- } else if (threadIdx.x < ngpus) {
182
- // simultaneously write to the corresponding byte to all other ranks.
183
- // Latency = 1 p2p write
184
- sg.signals[threadIdx.x]->end.data[rank] = 255;
185
- }
186
- // if this is the final sync, only one block needs it
187
- // because kernel exit can serve as sync
188
- if constexpr (final_sync) {
189
- if (threadIdx.x == 0) {
190
- while (meta->sg.end.flag != FLAG)
191
- ;
192
- }
193
- }
194
- }
195
- if constexpr (!final_sync) {
196
- if (threadIdx.x == 0) {
197
- while (meta->sg.end.flag != FLAG)
198
- ;
199
- }
200
- __syncthreads();
201
- }
202
- }
203
-
204
- template <typename P, int ngpus, typename A>
205
- DINLINE P packed_reduce(const P *ptrs[], int idx) {
206
- A tmp = upcast(ptrs[0][idx]);
207
- #pragma unroll
208
- for (int i = 1; i < ngpus; i++) {
209
- packed_assign_add(tmp, upcast(ptrs[i][idx]));
210
- }
211
- return downcast<P>(tmp);
212
- }
213
-
214
- template <typename T, int ngpus>
215
- __global__ void __launch_bounds__(512, 1)
216
- cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
217
- volatile Metadata *meta, T *__restrict__ result,
218
- int rank, int size) {
219
- using P = typename packed_t<T>::P;
220
- using A = typename packed_t<T>::A;
221
- // note: we don't reorder the address so the accumulation order is the same
222
- // for all ranks, ensuring bitwise identical results
223
- auto dp = *_dp;
224
- start_sync<ngpus>(sg, meta, rank);
225
- // do the actual reduction
226
- for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
227
- idx += gridDim.x * blockDim.x) {
228
- ((P *)result)[idx] =
229
- packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
230
- }
231
- end_sync<ngpus, true>(sg, meta, rank);
232
- }
233
-
234
- template <typename P>
235
- DINLINE P *get_tmp_buf(volatile Signal *sg) {
236
- return (P *)(((Metadata *)sg) + 1);
237
- }
238
-
239
- template <typename T, int ngpus>
240
- __global__ void __launch_bounds__(512, 1)
241
- cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
242
- volatile Metadata *meta, T *__restrict__ result,
243
- int rank, int size) {
244
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
245
- int stride = gridDim.x * blockDim.x;
246
- using P = typename packed_t<T>::P;
247
- using A = typename packed_t<T>::A;
248
- int part = size / ngpus;
249
- int start = rank * part;
250
- int end = rank == ngpus - 1 ? size : start + part;
251
- const P *ptrs[ngpus];
252
- P *tmps[ngpus];
253
- #pragma unroll
254
- for (int i = 0; i < ngpus; i++) {
255
- int target = (rank + i) % ngpus;
256
- ptrs[i] = (const P *)_dp->ptrs[target];
257
- tmps[i] = get_tmp_buf<P>(sg.signals[target]);
258
- }
259
- auto tmp_out = tmps[0];
260
- start_sync<ngpus>(sg, meta, rank);
261
- // stage 1: reduce scatter
262
- for (int idx = start + tid; idx < end; idx += stride) {
263
- tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
264
- }
265
- // Maybe TODO: replace this with per-block release-acquire
266
- // can save about 1-2us (not a lot though)
267
- end_sync<ngpus>(sg, meta, rank);
268
-
269
- // stage 2: allgather
270
- for (int idx = tid; idx < part; idx += stride) {
271
- #pragma unroll
272
- for (int i = 0; i < ngpus; i++) {
273
- int dst_idx = ((rank + i) % ngpus) * part + idx;
274
- ((P *)result)[dst_idx] = tmps[i][idx];
275
- }
276
- }
277
- // process the last larger partition
278
- int remaining = size - part * ngpus;
279
- if (tid < remaining) {
280
- int dst_idx = tid + part * ngpus;
281
- ((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
282
- }
283
-
284
- // faster than this
285
- // for (int idx = tid; idx < size; idx += stride) {
286
- // int target_rank = idx / part;
287
- // if (target_rank == ngpus) target_rank -= 1;
288
- // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
289
- // }
290
- }
291
-
292
- template <typename T, int ngpus>
293
- __global__ void __launch_bounds__(512, 1)
294
- cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
295
- volatile Metadata *meta,
296
- T *__restrict__ result, int rank,
297
- int size) {
298
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
299
- int stride = gridDim.x * blockDim.x;
300
- using P = typename packed_t<T>::P;
301
- using A = typename packed_t<T>::A;
302
- auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
303
- constexpr int hg = ngpus / 2;
304
- // Actually not quite half butterfly.
305
- // This is an all-to-all within each group containing half of the ranks
306
- // followed by cross-group add. Equivalent to half butterfly when there
307
- // are 4 GPUs, a common case for PCIe cards like T4 and A10.
308
- const P *ptrs[hg];
309
- {
310
- int start = rank - rank % hg;
311
- #pragma unroll
312
- for (int i = 0; i < hg; i++) {
313
- ptrs[i] = (const P *)_dp->ptrs[i + start];
314
- }
315
- }
316
- start_sync<ngpus>(sg, meta, rank);
317
- for (int idx = tid; idx < size; idx += stride) {
318
- tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
319
- }
320
- end_sync<ngpus>(sg, meta, rank);
321
-
322
- auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
323
- // do the cross group reduction
324
- for (int idx = tid; idx < size; idx += stride) {
325
- auto tmp = tmp_out[idx];
326
- packed_assign_add(tmp, src[idx]);
327
- ((P *)result)[idx] = tmp;
328
- }
329
- }
330
-
331
- using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
332
- static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
333
- static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
334
-
335
- class CustomAllreduce {
336
- public:
337
- int rank_;
338
- int world_size_;
339
- bool full_nvlink_;
340
-
341
- // below are device pointers
342
- RankSignals sg_;
343
- std::unordered_map<void *, RankData *> buffers_;
344
- Metadata *meta_;
345
-
346
- // stores the registered device pointers from all ranks
347
- RankData *d_rank_data_base_, *d_rank_data_end_;
348
- std::vector<void *> graph_unreg_buffers_;
349
- // a map from IPC handles to opened IPC pointers
350
- std::map<IPC_KEY, char *> ipc_handles_;
351
-
352
- /**
353
- * meta is a pointer to device metadata and temporary buffer for allreduce.
354
- *
355
- * There's a total of sizeof(Metadata) of prefix before the actual data,
356
- * so meta + 1 points to actual temporary buffer.
357
- *
358
- * note: this class does not own any device memory. Any required buffers
359
- * are passed in from the constructor
360
- */
361
- CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
362
- const cudaIpcMemHandle_t *handles,
363
- const std::vector<int64_t> &offsets, int rank,
364
- bool full_nvlink = true)
365
- : rank_(rank),
366
- world_size_(offsets.size()),
367
- full_nvlink_(full_nvlink),
368
- meta_(meta),
369
- d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
370
- d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
371
- for (int i = 0; i < world_size_; i++) {
372
- Metadata *rank_meta;
373
- if (i != rank_) {
374
- char *handle = open_ipc_handle(&handles[i]);
375
- handle += offsets[i];
376
- rank_meta = (Metadata *)handle;
377
- } else {
378
- rank_meta = meta_;
379
- }
380
- sg_.signals[i] = &rank_meta->sg;
381
- }
382
- }
383
-
384
- char *open_ipc_handle(const void *ipc_handle) {
385
- auto [it, new_handle] =
386
- ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
387
- if (new_handle) {
388
- char *ipc_ptr;
389
- CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
390
- *((const cudaIpcMemHandle_t *)ipc_handle),
391
- cudaIpcMemLazyEnablePeerAccess));
392
- it->second = ipc_ptr;
393
- }
394
- return it->second;
395
- }
396
-
397
- std::pair<std::vector<uint8_t>, std::vector<int64_t>>
398
- get_graph_buffer_ipc_meta() {
399
- auto num_buffers = graph_unreg_buffers_.size();
400
- auto handle_sz = sizeof(cudaIpcMemHandle_t);
401
- std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
402
- std::vector<int64_t> offsets(num_buffers);
403
- for (int i = 0; i < num_buffers; i++) {
404
- auto ptr = graph_unreg_buffers_[i];
405
- void *base_ptr;
406
- // note: must share the base address of each allocation, or we get wrong
407
- // address
408
- if (cuPointerGetAttribute(&base_ptr,
409
- CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
410
- (CUdeviceptr)ptr) != CUDA_SUCCESS)
411
- throw std::runtime_error("failed to get pointer attr");
412
- CUDACHECK(cudaIpcGetMemHandle(
413
- (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
414
- offsets[i] = ((char *)ptr) - ((char *)base_ptr);
415
- }
416
- return std::make_pair(handles, offsets);
417
- }
418
-
419
- void check_rank_data_capacity(size_t num = 1) {
420
- if (d_rank_data_base_ + num > d_rank_data_end_)
421
- throw std::runtime_error(
422
- "Rank data buffer is overflowed by " +
423
- std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
424
- }
425
-
426
- void register_buffer(const std::vector<std::string> &handles,
427
- const std::vector<int64_t> &offsets, void *self) {
428
- check_rank_data_capacity();
429
- RankData data;
430
- for (int i = 0; i < world_size_; i++) {
431
- if (i != rank_) {
432
- char *handle = open_ipc_handle(handles[i].data());
433
- handle += offsets[i];
434
- data.ptrs[i] = handle;
435
- } else {
436
- data.ptrs[i] = self;
437
- }
438
- }
439
- auto d_data = d_rank_data_base_++;
440
- CUDACHECK(
441
- cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
442
- buffers_[self] = d_data;
443
- }
444
-
445
- // note: when registering graph buffers, we intentionally choose to not
446
- // deduplicate the addresses. That means if the allocator reuses some
447
- // addresses, they will be registered again. This is to account for the remote
448
- // possibility of different allocation patterns between ranks. For example,
449
- // rank 1 may get the same input address for the second allreduce, but rank 2
450
- // got a different address. IPC handles have internal reference counting
451
- // mechanism so overhead should be small.
452
- void register_graph_buffers(
453
- const std::vector<std::string> &handles,
454
- const std::vector<std::vector<int64_t>> &offsets) {
455
- auto num_buffers = graph_unreg_buffers_.size();
456
- check_rank_data_capacity(num_buffers);
457
- std::vector<RankData> rank_data(num_buffers);
458
- for (int i = 0; i < num_buffers; i++) {
459
- auto self_ptr = graph_unreg_buffers_[i];
460
- auto &rd = rank_data[i];
461
- for (int j = 0; j < world_size_; j++) {
462
- if (j != rank_) {
463
- char *handle =
464
- open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
465
- handle += offsets[j][i];
466
- rd.ptrs[j] = handle;
467
- } else {
468
- rd.ptrs[j] = self_ptr;
469
- }
470
- }
471
- }
472
- CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
473
- sizeof(RankData) * num_buffers,
474
- cudaMemcpyHostToDevice));
475
- d_rank_data_base_ += num_buffers;
476
- graph_unreg_buffers_.clear();
477
- }
478
-
479
- /**
480
- * This is the result after careful grid search. Using 36 blocks give the best
481
- * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
482
- * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
483
- * Not quite sure the underlying reason, but my guess is that too many SMs
484
- * will cause contention on NVLink bus.
485
- */
486
- template <typename T>
487
- void allreduce(cudaStream_t stream, T *input, T *output, int size,
488
- int threads = 512, int block_limit = 36) {
489
- auto d = packed_t<T>::P::size;
490
- if (size % d != 0)
491
- throw std::runtime_error(
492
- "custom allreduce currently requires input length to be multiple "
493
- "of " +
494
- std::to_string(d));
495
-
496
- RankData *ptrs;
497
- cudaStreamCaptureStatus status;
498
- CUDACHECK(cudaStreamIsCapturing(stream, &status));
499
- if (status == cudaStreamCaptureStatusActive) {
500
- ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
501
- graph_unreg_buffers_.push_back(input);
502
- } else {
503
- auto it = buffers_.find(input);
504
- if (it == buffers_.end())
505
- throw std::runtime_error(
506
- "buffer address " +
507
- std::to_string(reinterpret_cast<uint64_t>(input)) +
508
- " is not registered!");
509
- ptrs = it->second;
510
- }
511
-
512
- size /= d;
513
- auto bytes = size * sizeof(typename packed_t<T>::P);
514
- int blocks = std::min(block_limit, (size + threads - 1) / threads);
515
- #define KL(ngpus, name) \
516
- name<T, ngpus> \
517
- <<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
518
- #define REDUCE_CASE(ngpus) \
519
- case ngpus: { \
520
- if (world_size_ == 2) { \
521
- KL(ngpus, cross_device_reduce_1stage); \
522
- } else if (full_nvlink_) { \
523
- if ((world_size_ <= 4 && bytes < 512 * 1024) || \
524
- (world_size_ <= 8 && bytes < 256 * 1024)) { \
525
- KL(ngpus, cross_device_reduce_1stage); \
526
- } else { \
527
- KL(ngpus, cross_device_reduce_2stage); \
528
- } \
529
- } else { \
530
- KL(ngpus, cross_device_reduce_half_butterfly); \
531
- } \
532
- break; \
533
- }
534
-
535
- switch (world_size_) {
536
- REDUCE_CASE(2)
537
- REDUCE_CASE(4)
538
- REDUCE_CASE(6)
539
- REDUCE_CASE(8)
540
- default:
541
- throw std::runtime_error(
542
- "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
543
- "gpus = " +
544
- std::to_string(world_size_));
545
- }
546
- #undef REDUCE_CASE
547
- #undef KL
548
- }
549
-
550
- ~CustomAllreduce() {
551
- for (auto [_, ptr] : ipc_handles_) {
552
- CUDACHECK(cudaIpcCloseMemHandle(ptr));
553
- }
554
- }
555
- };
556
- /**
557
- * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
558
- a template instantiation:
559
- * template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
560
- int, int, int);
561
- */
562
- } // namespace vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/custom_all_reduce_test.cu DELETED
@@ -1,284 +0,0 @@
1
- /**
2
- * This is a standalone test for custom allreduce.
3
- * To compile, make sure you have MPI and NCCL installed in your system.
4
- * export MPI_HOME=XXX
5
- * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
6
- * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
7
- *
8
- * Warning: this C++ test is not designed to be very readable and was used
9
- * during the rapid prototyping process.
10
- *
11
- * To run:
12
- * mpirun -np 8 ./custom_all_reduce_test
13
- */
14
- #include <cuda.h>
15
- #include <curand_kernel.h>
16
- #include <stdio.h>
17
- #include <stdlib.h>
18
-
19
- #include <limits>
20
- #include <vector>
21
-
22
- #include "cuda_profiler_api.h"
23
- #include "custom_all_reduce.cuh"
24
- #include "mpi.h"
25
- #include "nccl.h"
26
-
27
- #define MPICHECK(cmd) \
28
- do { \
29
- int e = cmd; \
30
- if (e != MPI_SUCCESS) { \
31
- printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
32
- exit(EXIT_FAILURE); \
33
- } \
34
- } while (0)
35
-
36
- #define NCCLCHECK(cmd) \
37
- do { \
38
- ncclResult_t r = cmd; \
39
- if (r != ncclSuccess) { \
40
- printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
41
- ncclGetErrorString(r)); \
42
- exit(EXIT_FAILURE); \
43
- } \
44
- } while (0)
45
-
46
- __global__ void dummy_kernel() {
47
- for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
48
- }
49
-
50
- template <typename T>
51
- __global__ void set_data(T *data, int size, int myRank) {
52
- for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
53
- idx += gridDim.x * blockDim.x) {
54
- data[idx] = myRank * 0.11f;
55
- }
56
- }
57
-
58
- template <typename T>
59
- __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
60
- double *fdata2, int size) {
61
- for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
62
- idx += gridDim.x * blockDim.x) {
63
- fdata1[idx] = data1[idx];
64
- fdata2[idx] = data2[idx];
65
- }
66
- }
67
-
68
- __global__ void init_rand(curandState_t *state, int size, int nRanks) {
69
- for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
70
- idx += gridDim.x * blockDim.x) {
71
- for (int i = 0; i < nRanks; i++) {
72
- curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
73
- }
74
- }
75
- }
76
-
77
- template <typename T>
78
- __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
79
- int myRank, int nRanks, int size) {
80
- for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
81
- idx += gridDim.x * blockDim.x) {
82
- double sum = 0.0;
83
- for (int i = 0; i < nRanks; i++) {
84
- double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
85
- T hval = val; // downcast first
86
- sum += static_cast<double>(hval);
87
- if (i == myRank) data[idx] = hval;
88
- }
89
- ground_truth[idx] = sum;
90
- }
91
- }
92
-
93
- template <typename T>
94
- void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
95
- int data_size) {
96
- T *result;
97
- cudaStream_t stream;
98
- CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
99
- CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
100
- CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
101
-
102
- cudaIpcMemHandle_t self_data_handle;
103
- cudaIpcMemHandle_t data_handles[8];
104
- vllm::Metadata *buffer;
105
- T *self_data_copy;
106
- /**
107
- * Allocate IPC buffer
108
- *
109
- * The first section is a temporary buffer for storing intermediate allreduce
110
- * results, if a particular algorithm requires it. The second section is for
111
- * the input to the allreduce. The actual API takes the input pointer as an
112
- * argument (that is, they can and usually should be allocated separately).
113
- * But since the input pointers and the temporary buffer all require IPC
114
- * registration, they are allocated and registered together in the test for
115
- * convenience.
116
- */
117
- CUDACHECK(
118
- cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
119
- CUDACHECK(cudaMemset(buffer, 0,
120
- 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
121
- CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
122
- CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
123
-
124
- MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
125
- MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
126
- MPI_BYTE, MPI_COMM_WORLD));
127
-
128
- void *rank_data;
129
- size_t rank_data_sz = 16 * 1024 * 1024;
130
- CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
131
- std::vector<int64_t> offsets(nRanks, 0);
132
- vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
133
- offsets, myRank);
134
- auto *self_data =
135
- reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
136
- sizeof(vllm::Metadata) + data_size * sizeof(T));
137
- // hack buffer registration
138
- {
139
- std::vector<std::string> handles;
140
- handles.reserve(nRanks);
141
- for (int i = 0; i < nRanks; i++) {
142
- char *begin = (char *)&data_handles[i];
143
- char *end = (char *)&data_handles[i + 1];
144
- handles.emplace_back(begin, end);
145
- }
146
- std::vector<int64_t> offsets(
147
- nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
148
- fa.register_buffer(handles, offsets, self_data);
149
- }
150
-
151
- double *ground_truth;
152
- CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
153
- curandState_t *states;
154
- CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
155
- init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
156
- gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
157
- nRanks, data_size);
158
- CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
159
- cudaMemcpyDeviceToDevice, stream));
160
- cudaEvent_t start, stop;
161
- CUDACHECK(cudaEventCreate(&start));
162
- CUDACHECK(cudaEventCreate(&stop));
163
-
164
- ncclDataType_t ncclDtype;
165
- if (std::is_same<T, half>::value) {
166
- ncclDtype = ncclFloat16;
167
- } else if (std::is_same<T, nv_bfloat16>::value) {
168
- ncclDtype = ncclBfloat16;
169
- } else {
170
- ncclDtype = ncclFloat;
171
- }
172
-
173
- dummy_kernel<<<1, 1, 0, stream>>>();
174
- constexpr int warmup_iters = 5;
175
- constexpr int num_iters = 25;
176
- // warmup
177
- for (int i = 0; i < warmup_iters; i++) {
178
- NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
179
- stream));
180
- }
181
- CUDACHECK(cudaEventRecord(start, stream));
182
- for (int i = 0; i < num_iters; i++) {
183
- NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
184
- stream));
185
- }
186
- CUDACHECK(cudaEventRecord(stop, stream));
187
- CUDACHECK(cudaStreamSynchronize(stream));
188
- float allreduce_ms = 0;
189
- cudaEventElapsedTime(&allreduce_ms, start, stop);
190
-
191
- // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
192
- // set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
193
-
194
- dummy_kernel<<<1, 1, 0, stream>>>();
195
- // warm up
196
- for (int i = 0; i < warmup_iters; i++) {
197
- fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
198
- }
199
- CUDACHECK(cudaEventRecord(start, stream));
200
- for (int i = 0; i < num_iters; i++) {
201
- fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
202
- }
203
- CUDACHECK(cudaEventRecord(stop, stream));
204
- CUDACHECK(cudaStreamSynchronize(stream));
205
-
206
- float duration_ms = 0;
207
- cudaEventElapsedTime(&duration_ms, start, stop);
208
- if (myRank == 0)
209
- printf(
210
- "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
211
- "time:%.2fus\n",
212
- myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
213
- duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
214
-
215
- // And wait for all the queued up work to complete
216
- CUDACHECK(cudaStreamSynchronize(stream));
217
-
218
- NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
219
- ncclSum, comm, stream));
220
-
221
- double *nccl_result, *my_result;
222
- CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
223
- CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
224
-
225
- convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
226
- my_result, data_size);
227
- CUDACHECK(cudaStreamSynchronize(stream));
228
-
229
- for (unsigned long j = 0; j < data_size; j++) {
230
- auto diff = abs(nccl_result[j] - my_result[j]);
231
- if (diff >= 1e-2) {
232
- printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
233
- myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
234
- break;
235
- }
236
- }
237
-
238
- long double nccl_diffs = 0.0;
239
- long double my_diffs = 0.0;
240
- for (int j = 0; j < data_size; j++) {
241
- nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
242
- my_diffs += abs(my_result[j] - ground_truth[j]);
243
- }
244
- if (myRank == 0)
245
- std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
246
- << " me: " << my_diffs / data_size << std::endl;
247
-
248
- CUDACHECK(cudaFree(result));
249
- CUDACHECK(cudaFree(self_data_copy));
250
- CUDACHECK(cudaFree(rank_data));
251
- CUDACHECK(cudaFree(buffer));
252
- CUDACHECK(cudaFree(states));
253
- CUDACHECK(cudaFreeHost(ground_truth));
254
- CUDACHECK(cudaFreeHost(nccl_result));
255
- CUDACHECK(cudaFreeHost(my_result));
256
- CUDACHECK(cudaStreamDestroy(stream));
257
- }
258
-
259
- int main(int argc, char **argv) {
260
- int nRanks, myRank;
261
- MPICHECK(MPI_Init(&argc, &argv));
262
- MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
263
- MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
264
- CUDACHECK(cudaSetDevice(myRank));
265
- ncclUniqueId id;
266
- ncclComm_t comm;
267
- if (myRank == 0) ncclGetUniqueId(&id);
268
- MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
269
- MPI_COMM_WORLD));
270
- NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
271
-
272
- cudaProfilerStart();
273
- // for (int threads : {256, 512}) {
274
- // for (int block_limit = 16; block_limit < 112; block_limit += 4) {
275
- // run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
276
- // }
277
- // }
278
- for (int sz = 512; sz <= (32 << 20); sz *= 2) {
279
- run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
280
- }
281
-
282
- cudaProfilerStop();
283
- return EXIT_SUCCESS;
284
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/dispatch_utils.h DELETED
@@ -1,37 +0,0 @@
1
- /*
2
- * Adapted from
3
- * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
- */
5
- #pragma once
6
-
7
- #include <torch/extension.h>
8
-
9
- #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
-
14
- #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
- AT_DISPATCH_SWITCH( \
16
- TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
17
-
18
- #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
19
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
20
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
21
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
22
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
23
-
24
- #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
25
- AT_DISPATCH_SWITCH( \
26
- TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
27
-
28
- #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
29
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
30
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
31
- AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
32
- AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
33
- AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
34
-
35
- #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
36
- AT_DISPATCH_SWITCH( \
37
- TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/layernorm_kernels.cu DELETED
@@ -1,120 +0,0 @@
1
- #include <torch/extension.h>
2
- #include <ATen/cuda/CUDAContext.h>
3
- #include <c10/cuda/CUDAGuard.h>
4
-
5
- #include "dispatch_utils.h"
6
- #include "reduction_utils.cuh"
7
-
8
- namespace vllm {
9
-
10
- // TODO(woosuk): Further optimize this kernel.
11
- template<typename scalar_t>
12
- __global__ void rms_norm_kernel(
13
- scalar_t* __restrict__ out, // [..., hidden_size]
14
- const scalar_t* __restrict__ input, // [..., hidden_size]
15
- const scalar_t* __restrict__ weight, // [hidden_size]
16
- const float epsilon,
17
- const int num_tokens,
18
- const int hidden_size) {
19
- __shared__ float s_variance;
20
- float variance = 0.0f;
21
-
22
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
23
- const float x = (float) input[blockIdx.x * hidden_size + idx];
24
- variance += x * x;
25
- }
26
- variance = blockReduceSum<float>(variance);
27
- if (threadIdx.x == 0) {
28
- s_variance = rsqrtf(variance / hidden_size + epsilon);
29
- }
30
- __syncthreads();
31
-
32
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
33
- float x = (float) input[blockIdx.x * hidden_size + idx];
34
- out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
35
- }
36
- }
37
-
38
- // TODO: Further optimize this kernel.
39
- template<typename scalar_t>
40
- __global__ void fused_add_rms_norm_kernel(
41
- scalar_t* __restrict__ input, // [..., hidden_size]
42
- scalar_t* __restrict__ residual, // [..., hidden_size]
43
- const scalar_t* __restrict__ weight, // [hidden_size]
44
- const float epsilon,
45
- const int num_tokens,
46
- const int hidden_size) {
47
- __shared__ float s_variance;
48
- float variance = 0.0f;
49
-
50
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
51
- float x = (float) input[blockIdx.x * hidden_size + idx];
52
- x += (float) residual[blockIdx.x * hidden_size + idx];
53
- variance += x * x;
54
- residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
55
- }
56
- variance = blockReduceSum<float>(variance);
57
- if (threadIdx.x == 0) {
58
- s_variance = rsqrtf(variance / hidden_size + epsilon);
59
- }
60
- __syncthreads();
61
-
62
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
63
- float x = (float) residual[blockIdx.x * hidden_size + idx];
64
- input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
65
- }
66
- }
67
-
68
- } // namespace vllm
69
-
70
- void rms_norm(
71
- torch::Tensor& out, // [..., hidden_size]
72
- torch::Tensor& input, // [..., hidden_size]
73
- torch::Tensor& weight, // [hidden_size]
74
- float epsilon) {
75
- int hidden_size = input.size(-1);
76
- int num_tokens = input.numel() / hidden_size;
77
-
78
- dim3 grid(num_tokens);
79
- dim3 block(std::min(hidden_size, 1024));
80
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
81
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
82
- VLLM_DISPATCH_FLOATING_TYPES(
83
- input.scalar_type(),
84
- "rms_norm_kernel",
85
- [&] {
86
- vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
87
- out.data_ptr<scalar_t>(),
88
- input.data_ptr<scalar_t>(),
89
- weight.data_ptr<scalar_t>(),
90
- epsilon,
91
- num_tokens,
92
- hidden_size);
93
- });
94
- }
95
-
96
- void fused_add_rms_norm(
97
- torch::Tensor& input, // [..., hidden_size]
98
- torch::Tensor& residual, // [..., hidden_size]
99
- torch::Tensor& weight, // [hidden_size]
100
- float epsilon) {
101
- int hidden_size = input.size(-1);
102
- int num_tokens = input.numel() / hidden_size;
103
-
104
- dim3 grid(num_tokens);
105
- dim3 block(std::min(hidden_size, 1024));
106
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
107
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
108
- VLLM_DISPATCH_FLOATING_TYPES(
109
- input.scalar_type(),
110
- "fused_add_rms_norm_kernel",
111
- [&] {
112
- vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
113
- input.data_ptr<scalar_t>(),
114
- residual.data_ptr<scalar_t>(),
115
- weight.data_ptr<scalar_t>(),
116
- epsilon,
117
- num_tokens,
118
- hidden_size);
119
- });
120
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/moe_align_block_size_kernels.cu DELETED
@@ -1,108 +0,0 @@
1
- #include <torch/extension.h>
2
- #include <ATen/cuda/CUDAContext.h>
3
-
4
- #include <ATen/ATen.h>
5
- #include <THC/THCAtomics.cuh>
6
-
7
- #include "cuda_compat.h"
8
- #include "dispatch_utils.h"
9
-
10
- const static size_t NUM_MAX_EXPERTS = 64;
11
- #define CEILDIV(x,y) (((x) + (y) - 1) / (y))
12
-
13
- namespace vllm {
14
- template <typename scalar_t>
15
- __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
16
- int32_t *sorted_token_ids,
17
- int32_t *expert_ids,
18
- int32_t *total_tokens_post_pad,
19
- int32_t num_experts,
20
- int32_t block_size,
21
- size_t numel) {
22
- const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
23
- const size_t start_idx = threadIdx.x * tokens_per_thread;
24
- __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
25
- __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
26
- for (int i = 0; i < num_experts; ++i) {
27
- tokens_cnts[threadIdx.x + 1][i] = 0;
28
- }
29
-
30
- /**
31
- * In the first step we compute token_cnts[thread_index + 1][expert_index],
32
- * which counts how many tokens in the token shard of thread_index are assigned
33
- * to expert expert_index.
34
- */
35
- for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
36
- ++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
37
- }
38
-
39
- __syncthreads();
40
-
41
- // For each expert we accumulate the token counts from the different threads.
42
- tokens_cnts[0][threadIdx.x] = 0;
43
- for (int i = 1; i <= blockDim.x; ++i) {
44
- tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
45
- }
46
-
47
- __syncthreads();
48
-
49
- // We accumulate the token counts of all experts in thread 0.
50
- if (threadIdx.x == 0) {
51
- cumsum[0] = 0;
52
- for (int i = 1; i <= num_experts; ++i) {
53
- cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
54
- }
55
- *total_tokens_post_pad = cumsum[num_experts];
56
- }
57
-
58
- __syncthreads();
59
-
60
- /**
61
- * For each expert, each thread processes the tokens of the corresponding blocks
62
- * and stores the corresponding expert_id for each block.
63
- */
64
- for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
65
- expert_ids[i / block_size] = threadIdx.x;
66
- }
67
-
68
- /**
69
- * Each thread processes a token shard, calculating the index of each token after
70
- * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
71
- * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
72
- * where * represents a padding value(preset in python).
73
- */
74
- for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
75
- int32_t expert_id = topk_ids[i];
76
- /** The cumsum[expert_id] stores the starting index of the tokens that the
77
- * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
78
- * stores the indices of the tokens processed by the expert with expert_id within
79
- * the current thread's token shard.
80
- */
81
- int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
82
- sorted_token_ids[rank_post_pad] = i;
83
- ++tokens_cnts[threadIdx.x][expert_id];
84
- }
85
- }
86
- }
87
-
88
- void moe_align_block_size(
89
- torch::Tensor topk_ids,
90
- int num_experts,
91
- int block_size,
92
- torch::Tensor sorted_token_ids,
93
- torch::Tensor experts_ids,
94
- torch::Tensor num_tokens_post_pad) {
95
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
96
- assert(num_experts <= NUM_MAX_EXPERTS);
97
- VLLM_DISPATCH_INTEGRAL_TYPES(
98
- topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
99
- vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
100
- topk_ids.data_ptr<scalar_t>(),
101
- sorted_token_ids.data_ptr<int32_t>(),
102
- experts_ids.data_ptr<int32_t>(),
103
- num_tokens_post_pad.data_ptr<int32_t>(),
104
- num_experts,
105
- block_size,
106
- topk_ids.numel());
107
- });
108
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/ops.h DELETED
@@ -1,130 +0,0 @@
1
- #pragma once
2
-
3
- #include <torch/extension.h>
4
-
5
- void paged_attention_v1(
6
- torch::Tensor& out,
7
- torch::Tensor& query,
8
- torch::Tensor& key_cache,
9
- torch::Tensor& value_cache,
10
- int num_kv_heads,
11
- float scale,
12
- torch::Tensor& block_tables,
13
- torch::Tensor& context_lens,
14
- int block_size,
15
- int max_context_len,
16
- const c10::optional<torch::Tensor>& alibi_slopes,
17
- const std::string& kv_cache_dtype);
18
-
19
- void paged_attention_v2(
20
- torch::Tensor& out,
21
- torch::Tensor& exp_sums,
22
- torch::Tensor& max_logits,
23
- torch::Tensor& tmp_out,
24
- torch::Tensor& query,
25
- torch::Tensor& key_cache,
26
- torch::Tensor& value_cache,
27
- int num_kv_heads,
28
- float scale,
29
- torch::Tensor& block_tables,
30
- torch::Tensor& context_lens,
31
- int block_size,
32
- int max_context_len,
33
- const c10::optional<torch::Tensor>& alibi_slopes,
34
- const std::string& kv_cache_dtype);
35
-
36
- void rms_norm(
37
- torch::Tensor& out,
38
- torch::Tensor& input,
39
- torch::Tensor& weight,
40
- float epsilon);
41
-
42
- void fused_add_rms_norm(
43
- torch::Tensor& input,
44
- torch::Tensor& residual,
45
- torch::Tensor& weight,
46
- float epsilon);
47
-
48
- void rotary_embedding(
49
- torch::Tensor& positions,
50
- torch::Tensor& query,
51
- torch::Tensor& key,
52
- int head_size,
53
- torch::Tensor& cos_sin_cache,
54
- bool is_neox);
55
-
56
- void silu_and_mul(
57
- torch::Tensor& out,
58
- torch::Tensor& input);
59
-
60
- void gelu_new(
61
- torch::Tensor& out,
62
- torch::Tensor& input);
63
-
64
- void gelu_fast(
65
- torch::Tensor& out,
66
- torch::Tensor& input);
67
-
68
- #ifndef USE_ROCM
69
- torch::Tensor awq_gemm(
70
- torch::Tensor _in_feats,
71
- torch::Tensor _kernel,
72
- torch::Tensor _scaling_factors,
73
- torch::Tensor _zeros,
74
- int split_k_iters);
75
-
76
- torch::Tensor awq_dequantize(
77
- torch::Tensor _kernel,
78
- torch::Tensor _scaling_factors,
79
- torch::Tensor _zeros,
80
- int split_k_iters,
81
- int thx,
82
- int thy);
83
- #endif
84
-
85
- void squeezellm_gemm(
86
- torch::Tensor vec,
87
- torch::Tensor mat,
88
- torch::Tensor mul,
89
- torch::Tensor lookup_table);
90
-
91
- torch::Tensor gptq_gemm(
92
- torch::Tensor a,
93
- torch::Tensor b_q_weight,
94
- torch::Tensor b_gptq_qzeros,
95
- torch::Tensor b_gptq_scales,
96
- torch::Tensor b_g_idx,
97
- bool use_exllama);
98
-
99
- void gptq_shuffle(
100
- torch::Tensor q_weight,
101
- torch::Tensor q_perm);
102
-
103
- void moe_align_block_size(
104
- torch::Tensor topk_ids,
105
- int num_experts,
106
- int block_size,
107
- torch::Tensor sorted_token_ids,
108
- torch::Tensor experts_ids,
109
- torch::Tensor num_tokens_post_pad);
110
-
111
- #ifndef USE_ROCM
112
- using fptr_t = uint64_t;
113
- fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
114
- const std::vector<std::string> &handles,
115
- const std::vector<int64_t> &offsets, int rank,
116
- bool full_nvlink);
117
- bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
118
- bool full_nvlink);
119
- void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
120
- void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
121
- torch::Tensor &out);
122
- void dispose(fptr_t _fa);
123
- int meta_size();
124
- void register_buffer(fptr_t _fa, torch::Tensor &t,
125
- const std::vector<std::string> &handles,
126
- const std::vector<int64_t> &offsets);
127
- std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
128
- void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
129
- const std::vector<std::vector<int64_t>> &offsets);
130
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/pos_encoding_kernels.cu DELETED
@@ -1,130 +0,0 @@
1
- #include <torch/extension.h>
2
- #include <ATen/cuda/CUDAContext.h>
3
- #include <c10/cuda/CUDAGuard.h>
4
-
5
- #include "cuda_compat.h"
6
- #include "dispatch_utils.h"
7
-
8
- namespace vllm {
9
-
10
- template<typename scalar_t, bool IS_NEOX>
11
- inline __device__ void apply_rotary_embedding(
12
- scalar_t* __restrict__ arr,
13
- const scalar_t* __restrict__ cos_ptr,
14
- const scalar_t* __restrict__ sin_ptr,
15
- int rot_offset,
16
- int embed_dim)
17
- {
18
- int x_index, y_index;
19
- scalar_t cos, sin;
20
- if (IS_NEOX) {
21
- // GPT-NeoX style rotary embedding.
22
- x_index = rot_offset;
23
- y_index = embed_dim + rot_offset;
24
- cos = VLLM_LDG(cos_ptr + x_index);
25
- sin = VLLM_LDG(sin_ptr + x_index);
26
- } else {
27
- // GPT-J style rotary embedding.
28
- x_index = 2 * rot_offset;
29
- y_index = 2 * rot_offset + 1;
30
- cos = VLLM_LDG(cos_ptr + x_index / 2);
31
- sin = VLLM_LDG(sin_ptr + x_index / 2);
32
- }
33
-
34
- const scalar_t x = arr[x_index];
35
- const scalar_t y = arr[y_index];
36
- arr[x_index] = x * cos - y * sin;
37
- arr[y_index] = y * cos + x * sin;
38
- }
39
-
40
- template<typename scalar_t, bool IS_NEOX>
41
- __global__ void rotary_embedding_kernel(
42
- const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
43
- scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
44
- scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
45
- const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
46
- const int rot_dim,
47
- const int64_t query_stride,
48
- const int64_t key_stride,
49
- const int num_heads,
50
- const int num_kv_heads,
51
- const int head_size) {
52
- // Each thread block is responsible for one token.
53
- const int token_idx = blockIdx.x;
54
- int64_t pos = positions[token_idx];
55
- const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
56
-
57
- const int embed_dim = rot_dim / 2;
58
- const scalar_t* cos_ptr = cache_ptr;
59
- const scalar_t* sin_ptr = cache_ptr + embed_dim;
60
-
61
- const int nq = num_heads * embed_dim;
62
- for (int i = threadIdx.x; i < nq; i += blockDim.x) {
63
- const int head_idx = i / embed_dim;
64
- const int64_t token_head = token_idx * query_stride + head_idx * head_size;
65
- const int rot_offset = i % embed_dim;
66
- apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
67
- sin_ptr, rot_offset, embed_dim);
68
- }
69
-
70
- const int nk = num_kv_heads * embed_dim;
71
- for (int i = threadIdx.x; i < nk; i += blockDim.x) {
72
- const int head_idx = i / embed_dim;
73
- const int64_t token_head = token_idx * key_stride + head_idx * head_size;
74
- const int rot_offset = i % embed_dim;
75
- apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
76
- sin_ptr, rot_offset, embed_dim);
77
- }
78
- }
79
-
80
- } // namespace vllm
81
-
82
- void rotary_embedding(
83
- torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
84
- torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
85
- torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
86
- int head_size,
87
- torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
88
- bool is_neox) {
89
- int64_t num_tokens = query.numel() / query.size(-1);
90
- int rot_dim = cos_sin_cache.size(1);
91
- int num_heads = query.size(-1) / head_size;
92
- int num_kv_heads = key.size(-1) / head_size;
93
- int64_t query_stride = query.stride(-2);
94
- int64_t key_stride = key.stride(-2);
95
-
96
- dim3 grid(num_tokens);
97
- dim3 block(std::min(num_heads * rot_dim / 2, 512));
98
- const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
99
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
100
- VLLM_DISPATCH_FLOATING_TYPES(
101
- query.scalar_type(),
102
- "rotary_embedding",
103
- [&] {
104
- if (is_neox) {
105
- vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
106
- positions.data_ptr<int64_t>(),
107
- query.data_ptr<scalar_t>(),
108
- key.data_ptr<scalar_t>(),
109
- cos_sin_cache.data_ptr<scalar_t>(),
110
- rot_dim,
111
- query_stride,
112
- key_stride,
113
- num_heads,
114
- num_kv_heads,
115
- head_size);
116
- } else {
117
- vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
118
- positions.data_ptr<int64_t>(),
119
- query.data_ptr<scalar_t>(),
120
- key.data_ptr<scalar_t>(),
121
- cos_sin_cache.data_ptr<scalar_t>(),
122
- rot_dim,
123
- query_stride,
124
- key_stride,
125
- num_heads,
126
- num_kv_heads,
127
- head_size);
128
- }
129
- });
130
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/punica/LICENSE DELETED
@@ -1,217 +0,0 @@
1
- Contains code from https://github.com/punica-ai/punica
2
-
3
- Apache License
4
- Version 2.0, January 2004
5
- http://www.apache.org/licenses/
6
-
7
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
-
9
- 1. Definitions.
10
-
11
- "License" shall mean the terms and conditions for use, reproduction,
12
- and distribution as defined by Sections 1 through 9 of this document.
13
-
14
- "Licensor" shall mean the copyright owner or entity authorized by
15
- the copyright owner that is granting the License.
16
-
17
- "Legal Entity" shall mean the union of the acting entity and all
18
- other entities that control, are controlled by, or are under common
19
- control with that entity. For the purposes of this definition,
20
- "control" means (i) the power, direct or indirect, to cause the
21
- direction or management of such entity, whether by contract or
22
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
- outstanding shares, or (iii) beneficial ownership of such entity.
24
-
25
- "You" (or "Your") shall mean an individual or Legal Entity
26
- exercising permissions granted by this License.
27
-
28
- "Source" form shall mean the preferred form for making modifications,
29
- including but not limited to software source code, documentation
30
- source, and configuration files.
31
-
32
- "Object" form shall mean any form resulting from mechanical
33
- transformation or translation of a Source form, including but
34
- not limited to compiled object code, generated documentation,
35
- and conversions to other media types.
36
-
37
- "Work" shall mean the work of authorship, whether in Source or
38
- Object form, made available under the License, as indicated by a
39
- copyright notice that is included in or attached to the work
40
- (an example is provided in the Appendix below).
41
-
42
- "Derivative Works" shall mean any work, whether in Source or Object
43
- form, that is based on (or derived from) the Work and for which the
44
- editorial revisions, annotations, elaborations, or other modifications
45
- represent, as a whole, an original work of authorship. For the purposes
46
- of this License, Derivative Works shall not include works that remain
47
- separable from, or merely link (or bind by name) to the interfaces of,
48
- the Work and Derivative Works thereof.
49
-
50
- "Contribution" shall mean any work of authorship, including
51
- the original version of the Work and any modifications or additions
52
- to that Work or Derivative Works thereof, that is intentionally
53
- submitted to Licensor for inclusion in the Work by the copyright owner
54
- or by an individual or Legal Entity authorized to submit on behalf of
55
- the copyright owner. For the purposes of this definition, "submitted"
56
- means any form of electronic, verbal, or written communication sent
57
- to the Licensor or its representatives, including but not limited to
58
- communication on electronic mailing lists, source code control systems,
59
- and issue tracking systems that are managed by, or on behalf of, the
60
- Licensor for the purpose of discussing and improving the Work, but
61
- excluding communication that is conspicuously marked or otherwise
62
- designated in writing by the copyright owner as "Not a Contribution."
63
-
64
- "Contributor" shall mean Licensor and any individual or Legal Entity
65
- on behalf of whom a Contribution has been received by Licensor and
66
- subsequently incorporated within the Work.
67
-
68
- 2. Grant of Copyright License. Subject to the terms and conditions of
69
- this License, each Contributor hereby grants to You a perpetual,
70
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
- copyright license to reproduce, prepare Derivative Works of,
72
- publicly display, publicly perform, sublicense, and distribute the
73
- Work and such Derivative Works in Source or Object form.
74
-
75
- 3. Grant of Patent License. Subject to the terms and conditions of
76
- this License, each Contributor hereby grants to You a perpetual,
77
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
- (except as stated in this section) patent license to make, have made,
79
- use, offer to sell, sell, import, and otherwise transfer the Work,
80
- where such license applies only to those patent claims licensable
81
- by such Contributor that are necessarily infringed by their
82
- Contribution(s) alone or by combination of their Contribution(s)
83
- with the Work to which such Contribution(s) was submitted. If You
84
- institute patent litigation against any entity (including a
85
- cross-claim or counterclaim in a lawsuit) alleging that the Work
86
- or a Contribution incorporated within the Work constitutes direct
87
- or contributory patent infringement, then any patent licenses
88
- granted to You under this License for that Work shall terminate
89
- as of the date such litigation is filed.
90
-
91
- 4. Redistribution. You may reproduce and distribute copies of the
92
- Work or Derivative Works thereof in any medium, with or without
93
- modifications, and in Source or Object form, provided that You
94
- meet the following conditions:
95
-
96
- (a) You must give any other recipients of the Work or
97
- Derivative Works a copy of this License; and
98
-
99
- (b) You must cause any modified files to carry prominent notices
100
- stating that You changed the files; and
101
-
102
- (c) You must retain, in the Source form of any Derivative Works
103
- that You distribute, all copyright, patent, trademark, and
104
- attribution notices from the Source form of the Work,
105
- excluding those notices that do not pertain to any part of
106
- the Derivative Works; and
107
-
108
- (d) If the Work includes a "NOTICE" text file as part of its
109
- distribution, then any Derivative Works that You distribute must
110
- include a readable copy of the attribution notices contained
111
- within such NOTICE file, excluding those notices that do not
112
- pertain to any part of the Derivative Works, in at least one
113
- of the following places: within a NOTICE text file distributed
114
- as part of the Derivative Works; within the Source form or
115
- documentation, if provided along with the Derivative Works; or,
116
- within a display generated by the Derivative Works, if and
117
- wherever such third-party notices normally appear. The contents
118
- of the NOTICE file are for informational purposes only and
119
- do not modify the License. You may add Your own attribution
120
- notices within Derivative Works that You distribute, alongside
121
- or as an addendum to the NOTICE text from the Work, provided
122
- that such additional attribution notices cannot be construed
123
- as modifying the License.
124
-
125
- You may add Your own copyright statement to Your modifications and
126
- may provide additional or different license terms and conditions
127
- for use, reproduction, or distribution of Your modifications, or
128
- for any such Derivative Works as a whole, provided Your use,
129
- reproduction, and distribution of the Work otherwise complies with
130
- the conditions stated in this License.
131
-
132
- 5. Submission of Contributions. Unless You explicitly state otherwise,
133
- any Contribution intentionally submitted for inclusion in the Work
134
- by You to the Licensor shall be under the terms and conditions of
135
- this License, without any additional terms or conditions.
136
- Notwithstanding the above, nothing herein shall supersede or modify
137
- the terms of any separate license agreement you may have executed
138
- with Licensor regarding such Contributions.
139
-
140
- 6. Trademarks. This License does not grant permission to use the trade
141
- names, trademarks, service marks, or product names of the Licensor,
142
- except as required for reasonable and customary use in describing the
143
- origin of the Work and reproducing the content of the NOTICE file.
144
-
145
- 7. Disclaimer of Warranty. Unless required by applicable law or
146
- agreed to in writing, Licensor provides the Work (and each
147
- Contributor provides its Contributions) on an "AS IS" BASIS,
148
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
- implied, including, without limitation, any warranties or conditions
150
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
- PARTICULAR PURPOSE. You are solely responsible for determining the
152
- appropriateness of using or redistributing the Work and assume any
153
- risks associated with Your exercise of permissions under this License.
154
-
155
- 8. Limitation of Liability. In no event and under no legal theory,
156
- whether in tort (including negligence), contract, or otherwise,
157
- unless required by applicable law (such as deliberate and grossly
158
- negligent acts) or agreed to in writing, shall any Contributor be
159
- liable to You for damages, including any direct, indirect, special,
160
- incidental, or consequential damages of any character arising as a
161
- result of this License or out of the use or inability to use the
162
- Work (including but not limited to damages for loss of goodwill,
163
- work stoppage, computer failure or malfunction, or any and all
164
- other commercial damages or losses), even if such Contributor
165
- has been advised of the possibility of such damages.
166
-
167
- 9. Accepting Warranty or Additional Liability. While redistributing
168
- the Work or Derivative Works thereof, You may choose to offer,
169
- and charge a fee for, acceptance of support, warranty, indemnity,
170
- or other liability obligations and/or rights consistent with this
171
- License. However, in accepting such obligations, You may act only
172
- on Your own behalf and on Your sole responsibility, not on behalf
173
- of any other Contributor, and only if You agree to indemnify,
174
- defend, and hold each Contributor harmless for any liability
175
- incurred by, or claims asserted against, such Contributor by reason
176
- of your accepting any such warranty or additional liability.
177
-
178
- END OF TERMS AND CONDITIONS
179
-
180
- APPENDIX: How to apply the Apache License to your work.
181
-
182
- To apply the Apache License to your work, attach the following
183
- boilerplate notice, with the fields enclosed by brackets "{}"
184
- replaced with your own identifying information. (Don't include
185
- the brackets!) The text should be enclosed in the appropriate
186
- comment syntax for the file format. We also recommend that a
187
- file or class name and description of purpose be included on the
188
- same "printed page" as the copyright notice for easier
189
- identification within third-party archives.
190
-
191
- Copyright {yyyy} {name of copyright owner}
192
-
193
- Licensed under the Apache License, Version 2.0 (the "License");
194
- you may not use this file except in compliance with the License.
195
- You may obtain a copy of the License at
196
-
197
- http://www.apache.org/licenses/LICENSE-2.0
198
-
199
- Unless required by applicable law or agreed to in writing, software
200
- distributed under the License is distributed on an "AS IS" BASIS,
201
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
- See the License for the specific language governing permissions and
203
- limitations under the License.
204
-
205
- ------------------------------------------------------------------------------------
206
-
207
- This product bundles various third-party components under other open source licenses.
208
- This section summarizes those components and their licenses. See licenses/
209
- for text of these licenses.
210
-
211
-
212
- Apache-2.0
213
- * third_party/nvbench (with LLVM exception)
214
- * third_party/flashinfer
215
-
216
- BSD-3-Clause:
217
- * third_party/cutlass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
 
 
 
 
 
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
 
 
 
 
 
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
 
 
 
 
 
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
 
 
 
 
 
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
 
 
 
 
 
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
 
 
 
 
 
csrc/punica/bgmv/bgmv_config.h DELETED
@@ -1,59 +0,0 @@
1
- #pragma once
2
-
3
- template <int feat_in, int feat_out, typename in_T, typename out_T,
4
- typename W_T>
5
- void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
6
- const W_T *__restrict__ W,
7
- const int64_t *__restrict__ indicies, int64_t y_offset,
8
- int64_t full_y_size, int64_t batch_size, int64_t num_layers,
9
- int64_t layer_idx, float scale);
10
-
11
- // clang-format off
12
-
13
- #define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
14
- f(in_T, out_T, W_T, narrow, 128) \
15
- f(in_T, out_T, W_T, narrow, 256) \
16
- f(in_T, out_T, W_T, narrow, 512) \
17
- f(in_T, out_T, W_T, narrow, 1024) \
18
- f(in_T, out_T, W_T, narrow, 1280) \
19
- f(in_T, out_T, W_T, narrow, 1728) \
20
- f(in_T, out_T, W_T, narrow, 1792) \
21
- f(in_T, out_T, W_T, narrow, 2048) \
22
- f(in_T, out_T, W_T, narrow, 2560) \
23
- f(in_T, out_T, W_T, narrow, 2752) \
24
- f(in_T, out_T, W_T, narrow, 3072) \
25
- f(in_T, out_T, W_T, narrow, 3456) \
26
- f(in_T, out_T, W_T, narrow, 3584) \
27
- f(in_T, out_T, W_T, narrow, 4096) \
28
- f(in_T, out_T, W_T, narrow, 5120) \
29
- f(in_T, out_T, W_T, narrow, 5504) \
30
- f(in_T, out_T, W_T, narrow, 5632) \
31
- f(in_T, out_T, W_T, narrow, 6912) \
32
- f(in_T, out_T, W_T, narrow, 7168) \
33
- f(in_T, out_T, W_T, narrow, 8192) \
34
- f(in_T, out_T, W_T, narrow, 9216) \
35
- f(in_T, out_T, W_T, narrow, 10240) \
36
- f(in_T, out_T, W_T, narrow, 11008) \
37
- f(in_T, out_T, W_T, narrow, 12288) \
38
- f(in_T, out_T, W_T, narrow, 13824) \
39
- f(in_T, out_T, W_T, narrow, 14336) \
40
- f(in_T, out_T, W_T, narrow, 16384) \
41
- f(in_T, out_T, W_T, narrow, 20480) \
42
- f(in_T, out_T, W_T, narrow, 28672) \
43
- f(in_T, out_T, W_T, narrow, 32000) \
44
- f(in_T, out_T, W_T, narrow, 32256) \
45
- f(in_T, out_T, W_T, narrow, 32512) \
46
- f(in_T, out_T, W_T, narrow, 32768) \
47
- f(in_T, out_T, W_T, narrow, 33024) \
48
- f(in_T, out_T, W_T, narrow, 36864) \
49
- f(in_T, out_T, W_T, narrow, 49152) \
50
- // Keep above in sync with vllm/lora/layers::SamplerWithLoRA
51
-
52
- // Keep this in sync with vllm/config::LoRAConfig
53
- #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
54
- FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
55
- FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
56
- FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
57
- FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
58
-
59
- // clang-format on
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
 
 
 
 
 
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu DELETED
@@ -1,4 +0,0 @@
1
- #include "bgmv_config.h"
2
- #include "bgmv_impl.cuh"
3
-
4
- FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)