bsmit1659 commited on
Commit
ca1ecab
·
1 Parent(s): dc22273

Adding vllm package

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .buildkite/run-benchmarks.sh +63 -0
  2. .buildkite/test-pipeline.yaml +51 -0
  3. .buildkite/test-template.j2 +54 -0
  4. .dockerignore +1 -0
  5. .github/workflows/publish.yml +102 -0
  6. .github/workflows/ruff.yml +31 -0
  7. .github/workflows/scripts/build.sh +20 -0
  8. .github/workflows/scripts/create_release.js +20 -0
  9. .github/workflows/scripts/cuda-install.sh +23 -0
  10. .github/workflows/scripts/env.sh +56 -0
  11. .github/workflows/scripts/pytorch-install.sh +15 -0
  12. .github/workflows/yapf.yml +31 -0
  13. .gitignore +186 -0
  14. .readthedocs.yaml +21 -0
  15. CONTRIBUTING.md +77 -0
  16. Dockerfile +1 -1
  17. Dockerfile.rocm +88 -0
  18. LICENSE +201 -0
  19. MANIFEST.in +4 -0
  20. README.md +110 -7
  21. benchmarks/README.md +8 -0
  22. benchmarks/benchmark_latency.py +139 -0
  23. benchmarks/benchmark_serving.py +249 -0
  24. benchmarks/benchmark_throughput.py +328 -0
  25. benchmarks/kernels/benchmark_paged_attention.py +196 -0
  26. benchmarks/launch_tgi_server.sh +16 -0
  27. csrc/activation_kernels.cu +118 -0
  28. csrc/attention/attention_dtypes.h +7 -0
  29. csrc/attention/attention_generic.cuh +64 -0
  30. csrc/attention/attention_kernels.cu +951 -0
  31. csrc/attention/attention_utils.cuh +56 -0
  32. csrc/attention/dtype_bfloat16.cuh +451 -0
  33. csrc/attention/dtype_float16.cuh +502 -0
  34. csrc/attention/dtype_float32.cuh +273 -0
  35. csrc/attention/dtype_fp8_e5m2.cuh +35 -0
  36. csrc/cache.h +36 -0
  37. csrc/cache_kernels.cu +474 -0
  38. csrc/cuda_compat.h +28 -0
  39. csrc/cuda_utils.h +10 -0
  40. csrc/cuda_utils_kernels.cu +35 -0
  41. csrc/custom_all_reduce.cu +148 -0
  42. csrc/custom_all_reduce.cuh +562 -0
  43. csrc/custom_all_reduce_test.cu +284 -0
  44. csrc/dispatch_utils.h +37 -0
  45. csrc/layernorm_kernels.cu +120 -0
  46. csrc/moe_align_block_size_kernels.cu +108 -0
  47. csrc/ops.h +130 -0
  48. csrc/pos_encoding_kernels.cu +130 -0
  49. csrc/punica/LICENSE +217 -0
  50. csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +4 -0
.buildkite/run-benchmarks.sh ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is run by buildkite to run the benchmarks and upload the results to buildkite
2
+
3
+ set -ex
4
+ set -o pipefail
5
+
6
+ # cd into parent directory of this file
7
+ cd "$(dirname "${BASH_SOURCE[0]}")/.."
8
+
9
+ (wget && curl) || (apt-get update && apt-get install -y wget curl)
10
+
11
+ # run benchmarks and upload the result to buildkite
12
+ python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
13
+ bench_latency_exit_code=$?
14
+
15
+ python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
16
+ bench_throughput_exit_code=$?
17
+
18
+ python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
19
+ server_pid=$!
20
+ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
21
+
22
+ # wait for server to start, timeout after 600 seconds
23
+ timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
24
+ python3 benchmarks/benchmark_serving.py \
25
+ --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
26
+ --model meta-llama/Llama-2-7b-chat-hf \
27
+ --num-prompts 20 \
28
+ --endpoint /v1/completions \
29
+ --tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt
30
+ bench_serving_exit_code=$?
31
+ kill $server_pid
32
+
33
+ # write the results into a markdown file
34
+ echo "### Latency Benchmarks" >> benchmark_results.md
35
+ sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
36
+ echo "" >> benchmark_results.md
37
+ sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
38
+
39
+ echo "### Throughput Benchmarks" >> benchmark_results.md
40
+ sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
41
+ echo "" >> benchmark_results.md
42
+ sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
43
+
44
+ echo "### Serving Benchmarks" >> benchmark_results.md
45
+ sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
46
+ echo "" >> benchmark_results.md
47
+ tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines
48
+
49
+ # upload the results to buildkite
50
+ /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
51
+
52
+ # exit with the exit code of the benchmarks
53
+ if [ $bench_latency_exit_code -ne 0 ]; then
54
+ exit $bench_latency_exit_code
55
+ fi
56
+
57
+ if [ $bench_throughput_exit_code -ne 0 ]; then
58
+ exit $bench_throughput_exit_code
59
+ fi
60
+
61
+ if [ $bench_serving_exit_code -ne 0 ]; then
62
+ exit $bench_serving_exit_code
63
+ fi
.buildkite/test-pipeline.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # In this file, you can add more tests to run either by adding a new step or
2
+ # adding a new command to an existing step. See different options here for examples.
3
+ # This script will be feed into Jinja template in `test-template.j2` to generate
4
+ # the final pipeline yaml file.
5
+
6
+ steps:
7
+ - label: Regression Test
8
+ command: pytest -v -s test_regression.py
9
+ working_dir: "/vllm-workspace/tests" # optional
10
+
11
+ - label: AsyncEngine Test
12
+ command: pytest -v -s async_engine
13
+
14
+ - label: Distributed Test
15
+ command: pytest -v -s test_comm_ops.py
16
+ working_dir: "/vllm-workspace/tests/distributed"
17
+ num_gpus: 2 # only support 1 or 2 for now.
18
+
19
+ - label: Engine Test
20
+ command: pytest -v -s engine
21
+
22
+ - label: Entrypoints Test
23
+ command: pytest -v -s entrypoints
24
+
25
+ - label: Kernels Test
26
+ command: pytest -v -s kernels
27
+ soft_fail: true
28
+
29
+ - label: Models Test
30
+ commands:
31
+ - pytest -v -s models --forked
32
+ soft_fail: true
33
+
34
+ - label: Prefix Caching Test
35
+ commands:
36
+ - pytest -v -s prefix_caching
37
+
38
+ - label: Samplers Test
39
+ command: pytest -v -s samplers --forked
40
+
41
+ - label: Worker Test
42
+ command: pytest -v -s worker
43
+
44
+ - label: LoRA Test
45
+ command: pytest -v -s lora
46
+
47
+ - label: Benchmarks
48
+ working_dir: "/vllm-workspace/.buildkite"
49
+ commands:
50
+ - pip install aiohttp
51
+ - bash run-benchmarks.sh
.buildkite/test-template.j2 ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
2
+ {% set default_num_gpu = 1 %}
3
+ {% set default_working_dir = "/vllm-workspace/tests" %}
4
+
5
+ steps:
6
+ - label: ":docker: build image"
7
+ commands:
8
+ - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
9
+ - "docker push {{ docker_image }}"
10
+ env:
11
+ DOCKER_BUILDKIT: "1"
12
+ retry:
13
+ automatic:
14
+ - exit_status: -1 # Agent was lost
15
+ limit: 5
16
+ - wait
17
+
18
+ {% for step in steps %}
19
+ - label: "{{ step.label }}"
20
+ agents:
21
+ queue: kubernetes
22
+ soft_fail: {{ step.soft_fail or false }}
23
+ retry:
24
+ automatic:
25
+ - exit_status: -1 # Agent was lost
26
+ limit: 5
27
+ plugins:
28
+ - kubernetes:
29
+ podSpec:
30
+ volumes:
31
+ - name: dshm
32
+ emptyDir:
33
+ medium: Memory
34
+ containers:
35
+ - image: "{{ docker_image }}"
36
+ command: ["bash"]
37
+ args:
38
+ - "-c"
39
+ - "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
40
+ resources:
41
+ requests:
42
+ nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
43
+ limits:
44
+ nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
45
+ env:
46
+ - name: HF_TOKEN
47
+ valueFrom:
48
+ secretKeyRef:
49
+ name: hf-token-secret
50
+ key: token
51
+ volumeMounts:
52
+ - mountPath: /dev/shm
53
+ name: dshm
54
+ {% endfor %}
.dockerignore ADDED
@@ -0,0 +1 @@
 
 
1
+ vllm/*.so
.github/workflows/publish.yml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will upload a Python Package to Release asset
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
3
+
4
+ name: Create Release
5
+
6
+ on:
7
+ push:
8
+ tags:
9
+ - v*
10
+
11
+ # Needed to create release and upload assets
12
+ permissions:
13
+ contents: write
14
+
15
+ jobs:
16
+ release:
17
+ # Retrieve tag and create release
18
+ name: Create Release
19
+ runs-on: ubuntu-latest
20
+ outputs:
21
+ upload_url: ${{ steps.create_release.outputs.upload_url }}
22
+ steps:
23
+ - name: Checkout
24
+ uses: actions/checkout@v3
25
+
26
+ - name: Extract branch info
27
+ shell: bash
28
+ run: |
29
+ echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
30
+
31
+ - name: Create Release
32
+ id: create_release
33
+ uses: "actions/github-script@v6"
34
+ env:
35
+ RELEASE_TAG: ${{ env.release_tag }}
36
+ with:
37
+ github-token: "${{ secrets.GITHUB_TOKEN }}"
38
+ script: |
39
+ const script = require('.github/workflows/scripts/create_release.js')
40
+ await script(github, context, core)
41
+
42
+ wheel:
43
+ name: Build Wheel
44
+ runs-on: ${{ matrix.os }}
45
+ needs: release
46
+
47
+ strategy:
48
+ fail-fast: false
49
+ matrix:
50
+ os: ['ubuntu-20.04']
51
+ python-version: ['3.8', '3.9', '3.10', '3.11']
52
+ pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
53
+ cuda-version: ['11.8', '12.1']
54
+
55
+ steps:
56
+ - name: Checkout
57
+ uses: actions/checkout@v3
58
+
59
+ - name: Set up Linux Env
60
+ if: ${{ runner.os == 'Linux' }}
61
+ run: |
62
+ bash -x .github/workflows/scripts/env.sh
63
+
64
+ - name: Set up Python
65
+ uses: actions/setup-python@v4
66
+ with:
67
+ python-version: ${{ matrix.python-version }}
68
+
69
+ - name: Install CUDA ${{ matrix.cuda-version }}
70
+ run: |
71
+ bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
72
+
73
+ - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
74
+ run: |
75
+ bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
76
+
77
+ - name: Build wheel
78
+ shell: bash
79
+ run: |
80
+ bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
81
+ wheel_name=$(ls dist/*whl | xargs -n 1 basename)
82
+ asset_name=${wheel_name//"linux"/"manylinux1"}
83
+ echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
84
+ echo "asset_name=${asset_name}" >> $GITHUB_ENV
85
+
86
+ - name: Upload Release Asset
87
+ uses: actions/upload-release-asset@v1
88
+ env:
89
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
90
+ with:
91
+ upload_url: ${{ needs.release.outputs.upload_url }}
92
+ asset_path: ./dist/${{ env.wheel_name }}
93
+ asset_name: ${{ env.asset_name }}
94
+ asset_content_type: application/*
95
+
96
+ # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
97
+ # - name: Publish package
98
+ # uses: pypa/gh-action-pypi-publish@release/v1.8
99
+ # with:
100
+ # repository-url: https://test.pypi.org/legacy/
101
+ # password: ${{ secrets.PYPI_API_TOKEN }}
102
+ # skip-existing: true
.github/workflows/ruff.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ruff
2
+
3
+ on:
4
+ # Trigger the workflow on push or pull request,
5
+ # but only for the main branch
6
+ push:
7
+ branches:
8
+ - main
9
+ pull_request:
10
+ branches:
11
+ - main
12
+
13
+ jobs:
14
+ ruff:
15
+ runs-on: ubuntu-latest
16
+ strategy:
17
+ matrix:
18
+ python-version: ["3.10"]
19
+ steps:
20
+ - uses: actions/checkout@v2
21
+ - name: Set up Python ${{ matrix.python-version }}
22
+ uses: actions/setup-python@v2
23
+ with:
24
+ python-version: ${{ matrix.python-version }}
25
+ - name: Install dependencies
26
+ run: |
27
+ python -m pip install --upgrade pip
28
+ pip install ruff==0.1.5
29
+ - name: Analysing the code with ruff
30
+ run: |
31
+ ruff vllm tests
.github/workflows/scripts/build.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python_executable=python$1
4
+ cuda_home=/usr/local/cuda-$2
5
+
6
+ # Update paths
7
+ PATH=${cuda_home}/bin:$PATH
8
+ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
9
+
10
+ # Install requirements
11
+ $python_executable -m pip install wheel packaging
12
+ $python_executable -m pip install -r requirements.txt
13
+
14
+ # Limit the number of parallel jobs to avoid OOM
15
+ export MAX_JOBS=1
16
+ # Make sure punica is built for the release (for LoRA)
17
+ export VLLM_INSTALL_PUNICA_KERNELS=1
18
+
19
+ # Build
20
+ $python_executable setup.py bdist_wheel --dist-dir=dist
.github/workflows/scripts/create_release.js ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Uses Github's API to create the release and wait for result.
2
+ // We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
3
+
4
+ module.exports = async (github, context, core) => {
5
+ try {
6
+ const response = await github.rest.repos.createRelease({
7
+ draft: false,
8
+ generate_release_notes: true,
9
+ name: process.env.RELEASE_TAG,
10
+ owner: context.repo.owner,
11
+ prerelease: false,
12
+ repo: context.repo.repo,
13
+ tag_name: process.env.RELEASE_TAG,
14
+ });
15
+
16
+ core.setOutput('upload_url', response.data.upload_url);
17
+ } catch (error) {
18
+ core.setFailed(error.message);
19
+ }
20
+ }
.github/workflows/scripts/cuda-install.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Replace '.' with '-' ex: 11.8 -> 11-8
4
+ cuda_version=$(echo $1 | tr "." "-")
5
+ # Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
6
+ OS=$(echo $2 | tr -d ".\-")
7
+
8
+ # Installs CUDA
9
+ wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
10
+ sudo dpkg -i cuda-keyring_1.1-1_all.deb
11
+ rm cuda-keyring_1.1-1_all.deb
12
+ sudo apt -qq update
13
+ sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
14
+ sudo apt clean
15
+
16
+ # Test nvcc
17
+ PATH=/usr/local/cuda-$1/bin:${PATH}
18
+ nvcc --version
19
+
20
+ # Log gcc, g++, c++ versions
21
+ gcc --version
22
+ g++ --version
23
+ c++ --version
.github/workflows/scripts/env.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This file installs common linux environment tools
4
+
5
+ export LANG C.UTF-8
6
+
7
+ # python_version=$1
8
+
9
+ sudo apt-get update && \
10
+ sudo apt-get install -y --no-install-recommends \
11
+ software-properties-common \
12
+
13
+ sudo apt-get install -y --no-install-recommends \
14
+ build-essential \
15
+ apt-utils \
16
+ ca-certificates \
17
+ wget \
18
+ git \
19
+ vim \
20
+ libssl-dev \
21
+ curl \
22
+ unzip \
23
+ unrar \
24
+ cmake \
25
+ net-tools \
26
+ sudo \
27
+ autotools-dev \
28
+ rsync \
29
+ jq \
30
+ openssh-server \
31
+ tmux \
32
+ screen \
33
+ htop \
34
+ pdsh \
35
+ openssh-client \
36
+ lshw \
37
+ dmidecode \
38
+ util-linux \
39
+ automake \
40
+ autoconf \
41
+ libtool \
42
+ net-tools \
43
+ pciutils \
44
+ libpci-dev \
45
+ libaio-dev \
46
+ libcap2 \
47
+ libtinfo5 \
48
+ fakeroot \
49
+ devscripts \
50
+ debhelper \
51
+ nfs-common
52
+
53
+ # Remove github bloat files to free up disk space
54
+ sudo rm -rf "/usr/local/share/boost"
55
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
56
+ sudo rm -rf "/usr/share/dotnet"
.github/workflows/scripts/pytorch-install.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python_executable=python$1
4
+ pytorch_version=$2
5
+ cuda_version=$3
6
+
7
+ # Install torch
8
+ $python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
9
+ $python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
10
+
11
+ # Print version information
12
+ $python_executable --version
13
+ $python_executable -c "import torch; print('PyTorch:', torch.__version__)"
14
+ $python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
15
+ $python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
.github/workflows/yapf.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: yapf
2
+
3
+ on:
4
+ # Trigger the workflow on push or pull request,
5
+ # but only for the main branch
6
+ push:
7
+ branches:
8
+ - main
9
+ pull_request:
10
+ branches:
11
+ - main
12
+ jobs:
13
+ yapf:
14
+ runs-on: ubuntu-latest
15
+ strategy:
16
+ matrix:
17
+ python-version: ["3.10"]
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+ - name: Set up Python ${{ matrix.python-version }}
21
+ uses: actions/setup-python@v2
22
+ with:
23
+ python-version: ${{ matrix.python-version }}
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip
27
+ pip install yapf==0.32.0
28
+ pip install toml==0.10.2
29
+ - name: Running yapf
30
+ run: |
31
+ yapf --diff --recursive .
.gitignore ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ # VSCode
163
+ .vscode/
164
+
165
+ # DS Store
166
+ .DS_Store
167
+
168
+ # Results
169
+ *.csv
170
+
171
+ # Python pickle files
172
+ *.pkl
173
+
174
+ # Sphinx documentation
175
+ _build/
176
+
177
+ # vim swap files
178
+ *.swo
179
+ *.swp
180
+
181
+ # hip files generated by PyTorch
182
+ *.hip
183
+ *_hip*
184
+
185
+ # Benchmark dataset
186
+ *.json
.readthedocs.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the Docs configuration file
2
+ # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3
+
4
+ version: 2
5
+
6
+ build:
7
+ os: ubuntu-22.04
8
+ tools:
9
+ python: "3.8"
10
+
11
+ sphinx:
12
+ configuration: docs/source/conf.py
13
+
14
+ # If using Sphinx, optionally build your docs in additional formats such as PDF
15
+ formats:
16
+ - pdf
17
+
18
+ # Optionally declare the Python requirements required to build your docs
19
+ python:
20
+ install:
21
+ - requirements: docs/requirements-docs.txt
CONTRIBUTING.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to vLLM
2
+
3
+ Thank you for your interest in contributing to vLLM!
4
+ Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
5
+ There are several ways you can contribute to the project:
6
+
7
+ - Identify and report any issues or bugs.
8
+ - Request or add a new model.
9
+ - Suggest or implement new features.
10
+
11
+ However, remember that contributions aren't just about code.
12
+ We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
13
+
14
+ Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
15
+ Talk about it in your blog posts, highlighting how it's driving your incredible projects.
16
+ Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
17
+
18
+
19
+ ## Setup for development
20
+
21
+ ### Build from source
22
+
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ pip install -e . # This may take several minutes.
26
+ ```
27
+
28
+ ### Testing
29
+
30
+ ```bash
31
+ pip install -r requirements-dev.txt
32
+
33
+ # Static type checking
34
+ mypy
35
+ # Unit tests
36
+ pytest tests/
37
+ ```
38
+ **Note:** Currently, the repository does not pass the mypy tests.
39
+
40
+
41
+ ## Contributing Guidelines
42
+
43
+ ### Issue Reporting
44
+
45
+ If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
46
+ If not, please file a new issue, providing as much relevant information as possible.
47
+
48
+ ### Coding Style Guide
49
+
50
+ In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
51
+
52
+ We include a formatting script [`format.sh`](./format.sh) to format the code.
53
+
54
+ ### Pull Requests
55
+
56
+ When submitting a pull request:
57
+
58
+ 1. Make sure your code has been rebased on top of the latest commit on the main branch.
59
+ 2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
60
+ 3. Include a detailed description of the changes in the pull request.
61
+ Explain why you made the changes you did.
62
+ If your pull request fixes an open issue, please include a reference to it in the description.
63
+
64
+ ### Code Reviews
65
+
66
+ All submissions, including submissions by project members, require a code review.
67
+ To make the review process as smooth as possible, please:
68
+
69
+ 1. Keep your changes as concise as possible.
70
+ If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
71
+ 2. Respond to all comments within a reasonable time frame.
72
+ If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
73
+
74
+ ### Thank You
75
+
76
+ Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
77
+ Your contributions make vLLM a great tool for everyone!
Dockerfile CHANGED
@@ -94,4 +94,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/
94
  COPY vllm vllm
95
 
96
  ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
97
- #################### OPENAI API SERVER ####################
 
94
  COPY vllm vllm
95
 
96
  ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
97
+ #################### OPENAI API SERVER ####################
Dockerfile.rocm ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # default base image
2
+ ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
3
+
4
+ FROM $BASE_IMAGE
5
+
6
+ ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
7
+
8
+ RUN echo "Base image is $BASE_IMAGE"
9
+
10
+ # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
11
+ # BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
12
+
13
+ # this does not always work for all rocm versions
14
+ RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
15
+ echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
16
+
17
+ ARG FA_GFX_ARCHS="gfx90a;gfx942"
18
+ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
19
+
20
+ ARG FA_BRANCH="3d2b6f5"
21
+ RUN echo "FA_BRANCH is $FA_BRANCH"
22
+
23
+ # Install some basic utilities
24
+ RUN apt-get update && apt-get install python3 python3-pip -y
25
+
26
+ # Install some basic utilities
27
+ RUN apt-get update && apt-get install -y \
28
+ curl \
29
+ ca-certificates \
30
+ sudo \
31
+ git \
32
+ bzip2 \
33
+ libx11-6 \
34
+ build-essential \
35
+ wget \
36
+ unzip \
37
+ nvidia-cuda-toolkit \
38
+ tmux \
39
+ && rm -rf /var/lib/apt/lists/*
40
+
41
+ ### Mount Point ###
42
+ # When launching the container, mount the code directory to /app
43
+ ARG APP_MOUNT=/app
44
+ VOLUME [ ${APP_MOUNT} ]
45
+ WORKDIR ${APP_MOUNT}
46
+
47
+ RUN python3 -m pip install --upgrade pip
48
+ RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
49
+
50
+ ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
51
+ ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
52
+ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
53
+ ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
54
+
55
+ # Install ROCm flash-attention
56
+ RUN mkdir libs \
57
+ && cd libs \
58
+ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
59
+ && cd flash-attention \
60
+ && git checkout ${FA_BRANCH} \
61
+ && git submodule update --init \
62
+ && export GPU_ARCHS=${FA_GFX_ARCHS} \
63
+ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
64
+ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
65
+ && python3 setup.py install \
66
+ && cd ..
67
+
68
+ COPY ./ /app/vllm
69
+
70
+ RUN python3 -m pip install --upgrade pip
71
+ RUN python3 -m pip install xformers==0.0.23 --no-deps
72
+
73
+ # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
74
+ # Manually removed it so that later steps of numpy upgrade can continue
75
+ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
76
+ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
77
+
78
+ RUN cd /app \
79
+ && cd vllm \
80
+ && pip install -U -r requirements-rocm.txt \
81
+ && bash patch_xformers.rocm.sh \
82
+ && python3 setup.py install \
83
+ && cd ..
84
+
85
+ RUN python3 -m pip install --upgrade pip
86
+ RUN python3 -m pip install --no-cache-dir ray[all]
87
+
88
+ CMD ["/bin/bash"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MANIFEST.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ include LICENSE
2
+ include requirements.txt
3
+
4
+ recursive-include csrc *
README.md CHANGED
@@ -1,10 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Certifaier
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
 
 
 
 
 
 
8
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <picture>
3
+ <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
4
+ <img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-light.png" width=55%>
5
+ </picture>
6
+ </p>
7
+
8
+ <h3 align="center">
9
+ Easy, fast, and cheap LLM serving for everyone
10
+ </h3>
11
+
12
+ <p align="center">
13
+ | <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
14
+
15
+ </p>
16
+
17
+ ---
18
+
19
+ **The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)**
20
+
21
+ We are thrilled to announce our second vLLM Meetup!
22
+ The vLLM team will share recent updates and roadmap.
23
+ We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations.
24
+ Please register [here](https://lu.ma/ygxbpzhl) and join us!
25
+
26
  ---
27
+
28
+ *Latest News* 🔥
29
+ - [2024/01] Added ROCm 6.0 support to vLLM.
30
+ - [2023/12] Added ROCm 5.7 support to vLLM.
31
+ - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
32
+ - [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
33
+ - [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
34
+ - [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
35
+ - [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
36
+ - [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
37
+ - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
38
+
39
  ---
40
+ ## About
41
+ vLLM is a fast and easy-to-use library for LLM inference and serving.
42
+
43
+ vLLM is fast with:
44
+
45
+ - State-of-the-art serving throughput
46
+ - Efficient management of attention key and value memory with **PagedAttention**
47
+ - Continuous batching of incoming requests
48
+ - Fast model execution with CUDA/HIP graph
49
+ - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629)
50
+ - Optimized CUDA kernels
51
+
52
+ vLLM is flexible and easy to use with:
53
+
54
+ - Seamless integration with popular Hugging Face models
55
+ - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
56
+ - Tensor parallelism support for distributed inference
57
+ - Streaming outputs
58
+ - OpenAI-compatible API server
59
+ - Support NVIDIA GPUs and AMD GPUs
60
+
61
+ vLLM seamlessly supports many Hugging Face models, including the following architectures:
62
+
63
+ - Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
64
+ - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
65
+ - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
66
+ - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
67
+ - DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
68
+ - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
69
+ - GPT-2 (`gpt2`, `gpt2-xl`, etc.)
70
+ - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
71
+ - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
72
+ - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
73
+ - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
74
+ - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
75
+ - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
76
+ - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
77
+ - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
78
+ - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
79
+ - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
80
+ - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
81
+ - Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
82
+ - StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
83
+ - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
84
+
85
+ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
86
+
87
+ ```bash
88
+ pip install vllm
89
+ ```
90
+
91
+ ## Getting Started
92
+
93
+ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
94
+ - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
95
+ - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
96
+ - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
97
+
98
+ ## Contributing
99
+
100
+ We welcome and value any contributions and collaborations.
101
+ Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
102
+
103
+ ## Citation
104
 
105
+ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
106
+ ```bibtex
107
+ @inproceedings{kwon2023efficient,
108
+ title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
109
+ author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
110
+ booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
111
+ year={2023}
112
+ }
113
+ ```
benchmarks/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Benchmarking vLLM
2
+
3
+ ## Downloading the ShareGPT dataset
4
+
5
+ You can download the dataset by running:
6
+ ```bash
7
+ wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
8
+ ```
benchmarks/benchmark_latency.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark the latency of processing a single batch of requests."""
2
+ import argparse
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from vllm import LLM, SamplingParams
12
+
13
+
14
+ def main(args: argparse.Namespace):
15
+ print(args)
16
+
17
+ # NOTE(woosuk): If the request cannot be processed in a single batch,
18
+ # the engine will automatically process the request in multiple batches.
19
+ llm = LLM(
20
+ model=args.model,
21
+ tokenizer=args.tokenizer,
22
+ quantization=args.quantization,
23
+ tensor_parallel_size=args.tensor_parallel_size,
24
+ trust_remote_code=args.trust_remote_code,
25
+ dtype=args.dtype,
26
+ enforce_eager=args.enforce_eager,
27
+ kv_cache_dtype=args.kv_cache_dtype,
28
+ )
29
+
30
+ sampling_params = SamplingParams(
31
+ n=args.n,
32
+ temperature=0.0 if args.use_beam_search else 1.0,
33
+ top_p=1.0,
34
+ use_beam_search=args.use_beam_search,
35
+ ignore_eos=True,
36
+ max_tokens=args.output_len,
37
+ )
38
+ print(sampling_params)
39
+ dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
40
+
41
+ def run_to_completion(profile_dir: Optional[str] = None):
42
+ if profile_dir:
43
+ with torch.profiler.profile(
44
+ activities=[
45
+ torch.profiler.ProfilerActivity.CPU,
46
+ torch.profiler.ProfilerActivity.CUDA,
47
+ ],
48
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
49
+ str(profile_dir))) as p:
50
+ llm.generate(prompt_token_ids=dummy_prompt_token_ids,
51
+ sampling_params=sampling_params,
52
+ use_tqdm=False)
53
+ print(p.key_averages())
54
+ else:
55
+ start_time = time.perf_counter()
56
+ llm.generate(prompt_token_ids=dummy_prompt_token_ids,
57
+ sampling_params=sampling_params,
58
+ use_tqdm=False)
59
+ end_time = time.perf_counter()
60
+ latency = end_time - start_time
61
+ return latency
62
+
63
+ print("Warming up...")
64
+ run_to_completion(profile_dir=None)
65
+
66
+ if args.profile:
67
+ profile_dir = args.profile_result_dir
68
+ if not profile_dir:
69
+ profile_dir = Path(
70
+ "."
71
+ ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
72
+ print(f"Profiling (results will be saved to '{profile_dir}')...")
73
+ run_to_completion(profile_dir=args.profile_result_dir)
74
+ return
75
+
76
+ # Benchmark.
77
+ latencies = []
78
+ for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
79
+ latencies.append(run_to_completion(profile_dir=None))
80
+ print(f'Avg latency: {np.mean(latencies)} seconds')
81
+
82
+
83
+ if __name__ == '__main__':
84
+ parser = argparse.ArgumentParser(
85
+ description='Benchmark the latency of processing a single batch of '
86
+ 'requests till completion.')
87
+ parser.add_argument('--model', type=str, default='facebook/opt-125m')
88
+ parser.add_argument('--tokenizer', type=str, default=None)
89
+ parser.add_argument('--quantization',
90
+ '-q',
91
+ choices=['awq', 'gptq', 'squeezellm', None],
92
+ default=None)
93
+ parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
94
+ parser.add_argument('--input-len', type=int, default=32)
95
+ parser.add_argument('--output-len', type=int, default=128)
96
+ parser.add_argument('--batch-size', type=int, default=8)
97
+ parser.add_argument('--n',
98
+ type=int,
99
+ default=1,
100
+ help='Number of generated sequences per prompt.')
101
+ parser.add_argument('--use-beam-search', action='store_true')
102
+ parser.add_argument('--num-iters',
103
+ type=int,
104
+ default=3,
105
+ help='Number of iterations to run.')
106
+ parser.add_argument('--trust-remote-code',
107
+ action='store_true',
108
+ help='trust remote code from huggingface')
109
+ parser.add_argument(
110
+ '--dtype',
111
+ type=str,
112
+ default='auto',
113
+ choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
114
+ help='data type for model weights and activations. '
115
+ 'The "auto" option will use FP16 precision '
116
+ 'for FP32 and FP16 models, and BF16 precision '
117
+ 'for BF16 models.')
118
+ parser.add_argument('--enforce-eager',
119
+ action='store_true',
120
+ help='enforce eager mode and disable CUDA graph')
121
+ parser.add_argument(
122
+ "--kv-cache-dtype",
123
+ type=str,
124
+ choices=['auto', 'fp8_e5m2'],
125
+ default='auto',
126
+ help=
127
+ 'Data type for kv cache storage. If "auto", will use model data type.')
128
+ parser.add_argument(
129
+ '--profile',
130
+ action='store_true',
131
+ help='profile the generation process of a single batch')
132
+ parser.add_argument(
133
+ '--profile-result-dir',
134
+ type=str,
135
+ default=None,
136
+ help=('path to save the pytorch profiler output. Can be visualized '
137
+ 'with ui.perfetto.dev or Tensorboard.'))
138
+ args = parser.parse_args()
139
+ main(args)
benchmarks/benchmark_serving.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark online serving throughput.
2
+
3
+ On the server side, run one of the following commands:
4
+ (vLLM backend)
5
+ python -m vllm.entrypoints.api_server \
6
+ --model <your_model> --swap-space 16 \
7
+ --disable-log-requests
8
+
9
+ (TGI backend)
10
+ ./launch_hf_server.sh <your_model>
11
+
12
+ On the client side, run:
13
+ python benchmarks/benchmark_serving.py \
14
+ --backend <backend> \
15
+ --tokenizer <your_model> --dataset <target_dataset> \
16
+ --request-rate <request_rate>
17
+ """
18
+ import argparse
19
+ import asyncio
20
+ import json
21
+ import random
22
+ import time
23
+ from typing import AsyncGenerator, List, Tuple
24
+
25
+ import aiohttp
26
+ import numpy as np
27
+ from tqdm.asyncio import tqdm
28
+ from transformers import PreTrainedTokenizerBase
29
+ from vllm.transformers_utils.tokenizer import get_tokenizer
30
+
31
+ # (prompt len, output len, latency)
32
+ REQUEST_LATENCY: List[Tuple[int, int, float]] = []
33
+
34
+
35
+ def sample_requests(
36
+ dataset_path: str,
37
+ num_requests: int,
38
+ tokenizer: PreTrainedTokenizerBase,
39
+ ) -> List[Tuple[str, int, int]]:
40
+ # Load the dataset.
41
+ with open(dataset_path) as f:
42
+ dataset = json.load(f)
43
+ # Filter out the conversations with less than 2 turns.
44
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
45
+ # Only keep the first two turns of each conversation.
46
+ dataset = [(data["conversations"][0]["value"],
47
+ data["conversations"][1]["value"]) for data in dataset]
48
+
49
+ # Tokenize the prompts and completions.
50
+ prompts = [prompt for prompt, _ in dataset]
51
+ prompt_token_ids = tokenizer(prompts).input_ids
52
+ completions = [completion for _, completion in dataset]
53
+ completion_token_ids = tokenizer(completions).input_ids
54
+ tokenized_dataset = []
55
+ for i in range(len(dataset)):
56
+ output_len = len(completion_token_ids[i])
57
+ tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
58
+
59
+ # Filter out too long sequences.
60
+ filtered_dataset: List[Tuple[str, int, int]] = []
61
+ for prompt, prompt_token_ids, output_len in tokenized_dataset:
62
+ prompt_len = len(prompt_token_ids)
63
+ if prompt_len < 4 or output_len < 4:
64
+ # Prune too short sequences.
65
+ # This is because TGI causes errors when the input or output length
66
+ # is too short.
67
+ continue
68
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
69
+ # Prune too long sequences.
70
+ continue
71
+ filtered_dataset.append((prompt, prompt_len, output_len))
72
+
73
+ # Sample the requests.
74
+ sampled_requests = random.sample(filtered_dataset, num_requests)
75
+ return sampled_requests
76
+
77
+
78
+ async def get_request(
79
+ input_requests: List[Tuple[str, int, int]],
80
+ request_rate: float,
81
+ ) -> AsyncGenerator[Tuple[str, int, int], None]:
82
+ input_requests = iter(input_requests)
83
+ for request in input_requests:
84
+ yield request
85
+
86
+ if request_rate == float("inf"):
87
+ # If the request rate is infinity, then we don't need to wait.
88
+ continue
89
+ # Sample the request interval from the exponential distribution.
90
+ interval = np.random.exponential(1.0 / request_rate)
91
+ # The next request will be sent after the interval.
92
+ await asyncio.sleep(interval)
93
+
94
+
95
+ async def send_request(backend: str, model: str, api_url: str, prompt: str,
96
+ prompt_len: int, output_len: int, best_of: int,
97
+ use_beam_search: bool, pbar: tqdm) -> None:
98
+ request_start_time = time.perf_counter()
99
+
100
+ headers = {"User-Agent": "Benchmark Client"}
101
+ if backend == "vllm":
102
+ pload = {
103
+ "prompt": prompt,
104
+ "n": 1,
105
+ "best_of": best_of,
106
+ "use_beam_search": use_beam_search,
107
+ "temperature": 0.0 if use_beam_search else 1.0,
108
+ "top_p": 1.0,
109
+ "max_tokens": output_len,
110
+ "ignore_eos": True,
111
+ "stream": False,
112
+ }
113
+ if model is not None:
114
+ pload["model"] = model
115
+ elif backend == "tgi":
116
+ assert not use_beam_search
117
+ params = {
118
+ "best_of": best_of,
119
+ "max_new_tokens": output_len,
120
+ "do_sample": True,
121
+ }
122
+ pload = {
123
+ "inputs": prompt,
124
+ "parameters": params,
125
+ }
126
+ else:
127
+ raise ValueError(f"Unknown backend: {backend}")
128
+
129
+ timeout = aiohttp.ClientTimeout(total=3 * 3600)
130
+ async with aiohttp.ClientSession(timeout=timeout) as session:
131
+ while True:
132
+ async with session.post(api_url, headers=headers,
133
+ json=pload) as response:
134
+ chunks = []
135
+ async for chunk, _ in response.content.iter_chunks():
136
+ chunks.append(chunk)
137
+ output = b"".join(chunks).decode("utf-8")
138
+ output = json.loads(output)
139
+
140
+ # Re-send the request if it failed.
141
+ if "error" not in output:
142
+ break
143
+
144
+ request_end_time = time.perf_counter()
145
+ request_latency = request_end_time - request_start_time
146
+ REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
147
+ pbar.update(1)
148
+
149
+
150
+ async def benchmark(
151
+ backend: str,
152
+ model: str,
153
+ api_url: str,
154
+ input_requests: List[Tuple[str, int, int]],
155
+ best_of: int,
156
+ use_beam_search: bool,
157
+ request_rate: float,
158
+ ) -> None:
159
+ tasks: List[asyncio.Task] = []
160
+ pbar = tqdm(total=len(input_requests))
161
+ async for request in get_request(input_requests, request_rate):
162
+ prompt, prompt_len, output_len = request
163
+ task = asyncio.create_task(
164
+ send_request(backend, model, api_url, prompt, prompt_len,
165
+ output_len, best_of, use_beam_search, pbar))
166
+ tasks.append(task)
167
+ await asyncio.gather(*tasks)
168
+ pbar.close()
169
+
170
+
171
+ def main(args: argparse.Namespace):
172
+ print(args)
173
+ random.seed(args.seed)
174
+ np.random.seed(args.seed)
175
+
176
+ api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}"
177
+ tokenizer = get_tokenizer(args.tokenizer,
178
+ trust_remote_code=args.trust_remote_code)
179
+ input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
180
+
181
+ benchmark_start_time = time.perf_counter()
182
+ asyncio.run(
183
+ benchmark(args.backend, args.model, api_url, input_requests,
184
+ args.best_of, args.use_beam_search, args.request_rate))
185
+ benchmark_end_time = time.perf_counter()
186
+ benchmark_time = benchmark_end_time - benchmark_start_time
187
+ print(f"Total time: {benchmark_time:.2f} s")
188
+ print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
189
+
190
+ # Compute the latency statistics.
191
+ avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
192
+ print(f"Average latency: {avg_latency:.2f} s")
193
+ avg_per_token_latency = np.mean([
194
+ latency / (prompt_len + output_len)
195
+ for prompt_len, output_len, latency in REQUEST_LATENCY
196
+ ])
197
+ print(f"Average latency per token: {avg_per_token_latency:.2f} s")
198
+ avg_per_output_token_latency = np.mean(
199
+ [latency / output_len for _, output_len, latency in REQUEST_LATENCY])
200
+ print("Average latency per output token: "
201
+ f"{avg_per_output_token_latency:.2f} s")
202
+
203
+
204
+ if __name__ == "__main__":
205
+ parser = argparse.ArgumentParser(
206
+ description="Benchmark the online serving throughput.")
207
+ parser.add_argument("--backend",
208
+ type=str,
209
+ default="vllm",
210
+ choices=["vllm", "tgi"])
211
+ parser.add_argument("--protocol",
212
+ type=str,
213
+ default="http",
214
+ choices=["http", "https"])
215
+ parser.add_argument("--host", type=str, default="localhost")
216
+ parser.add_argument("--port", type=int, default=8000)
217
+ parser.add_argument("--endpoint", type=str, default="/generate")
218
+ parser.add_argument("--model", type=str, default=None)
219
+ parser.add_argument("--dataset",
220
+ type=str,
221
+ required=True,
222
+ help="Path to the dataset.")
223
+ parser.add_argument("--tokenizer",
224
+ type=str,
225
+ required=True,
226
+ help="Name or path of the tokenizer.")
227
+ parser.add_argument("--best-of",
228
+ type=int,
229
+ default=1,
230
+ help="Generates `best_of` sequences per prompt and "
231
+ "returns the best one.")
232
+ parser.add_argument("--use-beam-search", action="store_true")
233
+ parser.add_argument("--num-prompts",
234
+ type=int,
235
+ default=1000,
236
+ help="Number of prompts to process.")
237
+ parser.add_argument("--request-rate",
238
+ type=float,
239
+ default=float("inf"),
240
+ help="Number of requests per second. If this is inf, "
241
+ "then all the requests are sent at time 0. "
242
+ "Otherwise, we use Poisson process to synthesize "
243
+ "the request arrival times.")
244
+ parser.add_argument("--seed", type=int, default=0)
245
+ parser.add_argument('--trust-remote-code',
246
+ action='store_true',
247
+ help='trust remote code from huggingface')
248
+ args = parser.parse_args()
249
+ main(args)
benchmarks/benchmark_throughput.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark offline inference throughput."""
2
+ import argparse
3
+ import json
4
+ import random
5
+ import time
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
10
+ PreTrainedTokenizerBase)
11
+ from tqdm import tqdm
12
+
13
+
14
+ def sample_requests(
15
+ dataset_path: str,
16
+ num_requests: int,
17
+ tokenizer: PreTrainedTokenizerBase,
18
+ fixed_output_len: Optional[int],
19
+ ) -> List[Tuple[str, int, int]]:
20
+ if fixed_output_len is not None and fixed_output_len < 4:
21
+ raise ValueError("output_len too small")
22
+
23
+ # Load the dataset.
24
+ with open(dataset_path) as f:
25
+ dataset = json.load(f)
26
+ # Filter out the conversations with less than 2 turns.
27
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
28
+ # Only keep the first two turns of each conversation.
29
+ dataset = [(data["conversations"][0]["value"],
30
+ data["conversations"][1]["value"]) for data in dataset]
31
+
32
+ # Tokenize the prompts and completions.
33
+ prompts = [prompt for prompt, _ in dataset]
34
+ prompt_token_ids = tokenizer(prompts).input_ids
35
+ completions = [completion for _, completion in dataset]
36
+ completion_token_ids = tokenizer(completions).input_ids
37
+ tokenized_dataset = []
38
+ for i in range(len(dataset)):
39
+ output_len = len(completion_token_ids[i])
40
+ if fixed_output_len is not None:
41
+ output_len = fixed_output_len
42
+ tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
43
+
44
+ # Filter out too long sequences.
45
+ filtered_dataset: List[Tuple[str, int, int]] = []
46
+ for prompt, prompt_token_ids, output_len in tokenized_dataset:
47
+ prompt_len = len(prompt_token_ids)
48
+ if prompt_len < 4 or output_len < 4:
49
+ # Prune too short sequences.
50
+ continue
51
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
52
+ # Prune too long sequences.
53
+ continue
54
+ filtered_dataset.append((prompt, prompt_len, output_len))
55
+
56
+ # Sample the requests.
57
+ sampled_requests = random.sample(filtered_dataset, num_requests)
58
+ return sampled_requests
59
+
60
+
61
+ def run_vllm(
62
+ requests: List[Tuple[str, int, int]],
63
+ model: str,
64
+ tokenizer: str,
65
+ quantization: Optional[str],
66
+ tensor_parallel_size: int,
67
+ seed: int,
68
+ n: int,
69
+ use_beam_search: bool,
70
+ trust_remote_code: bool,
71
+ dtype: str,
72
+ max_model_len: Optional[int],
73
+ enforce_eager: bool,
74
+ kv_cache_dtype: str,
75
+ ) -> float:
76
+ from vllm import LLM, SamplingParams
77
+ llm = LLM(
78
+ model=model,
79
+ tokenizer=tokenizer,
80
+ quantization=quantization,
81
+ tensor_parallel_size=tensor_parallel_size,
82
+ seed=seed,
83
+ trust_remote_code=trust_remote_code,
84
+ dtype=dtype,
85
+ max_model_len=max_model_len,
86
+ enforce_eager=enforce_eager,
87
+ kv_cache_dtype=kv_cache_dtype,
88
+ )
89
+
90
+ # Add the requests to the engine.
91
+ for prompt, _, output_len in requests:
92
+ sampling_params = SamplingParams(
93
+ n=n,
94
+ temperature=0.0 if use_beam_search else 1.0,
95
+ top_p=1.0,
96
+ use_beam_search=use_beam_search,
97
+ ignore_eos=True,
98
+ max_tokens=output_len,
99
+ )
100
+ # FIXME(woosuk): Do not use internal method.
101
+ llm._add_request(
102
+ prompt=prompt,
103
+ prompt_token_ids=None,
104
+ sampling_params=sampling_params,
105
+ )
106
+
107
+ start = time.perf_counter()
108
+ # FIXME(woosuk): Do not use internal method.
109
+ llm._run_engine(use_tqdm=True)
110
+ end = time.perf_counter()
111
+ return end - start
112
+
113
+
114
+ def run_hf(
115
+ requests: List[Tuple[str, int, int]],
116
+ model: str,
117
+ tokenizer: PreTrainedTokenizerBase,
118
+ n: int,
119
+ use_beam_search: bool,
120
+ max_batch_size: int,
121
+ trust_remote_code: bool,
122
+ ) -> float:
123
+ assert not use_beam_search
124
+ llm = AutoModelForCausalLM.from_pretrained(
125
+ model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
126
+ if llm.config.model_type == "llama":
127
+ # To enable padding in the HF backend.
128
+ tokenizer.pad_token = tokenizer.eos_token
129
+ llm = llm.cuda()
130
+
131
+ pbar = tqdm(total=len(requests))
132
+ start = time.perf_counter()
133
+ batch: List[str] = []
134
+ max_prompt_len = 0
135
+ max_output_len = 0
136
+ for i in range(len(requests)):
137
+ prompt, prompt_len, output_len = requests[i]
138
+ # Add the prompt to the batch.
139
+ batch.append(prompt)
140
+ max_prompt_len = max(max_prompt_len, prompt_len)
141
+ max_output_len = max(max_output_len, output_len)
142
+ if len(batch) < max_batch_size and i != len(requests) - 1:
143
+ # Check if we can add more requests to the batch.
144
+ _, next_prompt_len, next_output_len = requests[i + 1]
145
+ if (max(max_prompt_len, next_prompt_len) +
146
+ max(max_output_len, next_output_len)) <= 2048:
147
+ # We can add more requests to the batch.
148
+ continue
149
+
150
+ # Generate the sequences.
151
+ input_ids = tokenizer(batch, return_tensors="pt",
152
+ padding=True).input_ids
153
+ llm_outputs = llm.generate(
154
+ input_ids=input_ids.cuda(),
155
+ do_sample=not use_beam_search,
156
+ num_return_sequences=n,
157
+ temperature=1.0,
158
+ top_p=1.0,
159
+ use_cache=True,
160
+ max_new_tokens=max_output_len,
161
+ )
162
+ # Include the decoding time.
163
+ tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
164
+ pbar.update(len(batch))
165
+
166
+ # Clear the batch.
167
+ batch = []
168
+ max_prompt_len = 0
169
+ max_output_len = 0
170
+ end = time.perf_counter()
171
+ return end - start
172
+
173
+
174
+ def run_mii(
175
+ requests: List[Tuple[str, int, int]],
176
+ model: str,
177
+ tensor_parallel_size: int,
178
+ output_len: int,
179
+ ) -> float:
180
+ from mii import pipeline
181
+ llm = pipeline(model, tensor_parallel=tensor_parallel_size)
182
+ prompts = [prompt for prompt, _, _ in requests]
183
+
184
+ start = time.perf_counter()
185
+ llm(prompts, max_new_tokens=output_len)
186
+ end = time.perf_counter()
187
+ return end - start
188
+
189
+
190
+ def main(args: argparse.Namespace):
191
+ print(args)
192
+ random.seed(args.seed)
193
+
194
+ # Sample the requests.
195
+ tokenizer = AutoTokenizer.from_pretrained(
196
+ args.tokenizer, trust_remote_code=args.trust_remote_code)
197
+ if args.dataset is None:
198
+ # Synthesize a prompt with the given input length.
199
+ prompt = "hi" * (args.input_len - 1)
200
+ requests = [(prompt, args.input_len, args.output_len)
201
+ for _ in range(args.num_prompts)]
202
+ else:
203
+ requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
204
+ args.output_len)
205
+
206
+ if args.backend == "vllm":
207
+ elapsed_time = run_vllm(requests, args.model, args.tokenizer,
208
+ args.quantization, args.tensor_parallel_size,
209
+ args.seed, args.n, args.use_beam_search,
210
+ args.trust_remote_code, args.dtype,
211
+ args.max_model_len, args.enforce_eager,
212
+ args.kv_cache_dtype)
213
+ elif args.backend == "hf":
214
+ assert args.tensor_parallel_size == 1
215
+ elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
216
+ args.use_beam_search, args.hf_max_batch_size,
217
+ args.trust_remote_code)
218
+ elif args.backend == "mii":
219
+ elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
220
+ args.output_len)
221
+ else:
222
+ raise ValueError(f"Unknown backend: {args.backend}")
223
+ total_num_tokens = sum(prompt_len + output_len
224
+ for _, prompt_len, output_len in requests)
225
+ print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
226
+ f"{total_num_tokens / elapsed_time:.2f} tokens/s")
227
+
228
+
229
+ if __name__ == "__main__":
230
+ parser = argparse.ArgumentParser(description="Benchmark the throughput.")
231
+ parser.add_argument("--backend",
232
+ type=str,
233
+ choices=["vllm", "hf", "mii"],
234
+ default="vllm")
235
+ parser.add_argument("--dataset",
236
+ type=str,
237
+ default=None,
238
+ help="Path to the dataset.")
239
+ parser.add_argument("--input-len",
240
+ type=int,
241
+ default=None,
242
+ help="Input prompt length for each request")
243
+ parser.add_argument("--output-len",
244
+ type=int,
245
+ default=None,
246
+ help="Output length for each request. Overrides the "
247
+ "output length from the dataset.")
248
+ parser.add_argument("--model", type=str, default="facebook/opt-125m")
249
+ parser.add_argument("--tokenizer", type=str, default=None)
250
+ parser.add_argument('--quantization',
251
+ '-q',
252
+ choices=['awq', 'gptq', 'squeezellm', None],
253
+ default=None)
254
+ parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
255
+ parser.add_argument("--n",
256
+ type=int,
257
+ default=1,
258
+ help="Number of generated sequences per prompt.")
259
+ parser.add_argument("--use-beam-search", action="store_true")
260
+ parser.add_argument("--num-prompts",
261
+ type=int,
262
+ default=1000,
263
+ help="Number of prompts to process.")
264
+ parser.add_argument("--seed", type=int, default=0)
265
+ parser.add_argument("--hf-max-batch-size",
266
+ type=int,
267
+ default=None,
268
+ help="Maximum batch size for HF backend.")
269
+ parser.add_argument('--trust-remote-code',
270
+ action='store_true',
271
+ help='trust remote code from huggingface')
272
+ parser.add_argument(
273
+ '--max-model-len',
274
+ type=int,
275
+ default=None,
276
+ help='Maximum length of a sequence (including prompt and output). '
277
+ 'If None, will be derived from the model.')
278
+ parser.add_argument(
279
+ '--dtype',
280
+ type=str,
281
+ default='auto',
282
+ choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
283
+ help='data type for model weights and activations. '
284
+ 'The "auto" option will use FP16 precision '
285
+ 'for FP32 and FP16 models, and BF16 precision '
286
+ 'for BF16 models.')
287
+ parser.add_argument("--enforce-eager",
288
+ action="store_true",
289
+ help="enforce eager execution")
290
+ parser.add_argument(
291
+ "--kv-cache-dtype",
292
+ type=str,
293
+ choices=["auto", "fp8_e5m2"],
294
+ default="auto",
295
+ help=
296
+ 'Data type for kv cache storage. If "auto", will use model data type.')
297
+ args = parser.parse_args()
298
+ if args.tokenizer is None:
299
+ args.tokenizer = args.model
300
+ if args.dataset is None:
301
+ assert args.input_len is not None
302
+ assert args.output_len is not None
303
+ else:
304
+ assert args.input_len is None
305
+
306
+ if args.backend == "vllm":
307
+ if args.hf_max_batch_size is not None:
308
+ raise ValueError("HF max batch size is only for HF backend.")
309
+ elif args.backend == "hf":
310
+ if args.hf_max_batch_size is None:
311
+ raise ValueError("HF max batch size is required for HF backend.")
312
+ if args.quantization is not None:
313
+ raise ValueError("Quantization is only for vLLM backend.")
314
+ elif args.backend == "mii":
315
+ if args.dtype != "auto":
316
+ raise ValueError("dtype must be auto for MII backend.")
317
+ if args.n != 1:
318
+ raise ValueError("n must be 1 for MII backend.")
319
+ if args.use_beam_search:
320
+ raise ValueError("Beam search is not supported for MII backend.")
321
+ if args.quantization is not None:
322
+ raise ValueError("Quantization is only for vLLM backend.")
323
+ if args.hf_max_batch_size is not None:
324
+ raise ValueError("HF max batch size is only for HF backend.")
325
+ if args.tokenizer != args.model:
326
+ raise ValueError("Tokenizer must be the same as the model for MII "
327
+ "backend.")
328
+ main(args)
benchmarks/kernels/benchmark_paged_attention.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import argparse
3
+ import random
4
+ import time
5
+
6
+ import torch
7
+
8
+ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
9
+ from vllm._C import ops
10
+
11
+ NUM_BLOCKS = 1024
12
+ PARTITION_SIZE = 512
13
+
14
+
15
+ @torch.inference_mode()
16
+ def main(
17
+ version: str,
18
+ num_seqs: int,
19
+ context_len: int,
20
+ num_query_heads: int,
21
+ num_kv_heads: int,
22
+ head_size: int,
23
+ use_alibi: bool,
24
+ block_size: int,
25
+ dtype: torch.dtype,
26
+ seed: int,
27
+ do_profile: bool,
28
+ kv_cache_dtype: Optional[str] = None,
29
+ ) -> None:
30
+ random.seed(seed)
31
+ torch.random.manual_seed(seed)
32
+ torch.cuda.manual_seed(seed)
33
+
34
+ scale = float(1.0 / (head_size**0.5))
35
+ query = torch.empty(num_seqs,
36
+ num_query_heads,
37
+ head_size,
38
+ dtype=dtype,
39
+ device="cuda")
40
+ query.uniform_(-scale, scale)
41
+
42
+ assert num_query_heads % num_kv_heads == 0
43
+ alibi_slopes = None
44
+ if use_alibi:
45
+ alibi_slopes = torch.randn(num_query_heads,
46
+ dtype=torch.float,
47
+ device="cuda")
48
+
49
+ context_lens = [context_len for _ in range(num_seqs)]
50
+ max_context_len = max(context_lens)
51
+ context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
52
+
53
+ # Create the block tables.
54
+ max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
55
+ block_tables = []
56
+ for _ in range(num_seqs):
57
+ block_table = [
58
+ random.randint(0, NUM_BLOCKS - 1)
59
+ for _ in range(max_num_blocks_per_seq)
60
+ ]
61
+ block_tables.append(block_table)
62
+ block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
63
+
64
+ # Create the KV cache.
65
+ key_caches, value_caches = create_kv_caches_with_random(
66
+ NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
67
+ dtype)
68
+ key_cache, value_cache = key_caches[0], value_caches[0]
69
+
70
+ # Prepare for the paged attention kernel.
71
+ output = torch.empty_like(query)
72
+ if version == "v2":
73
+ num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
74
+ PARTITION_SIZE)
75
+ tmp_output = torch.empty(
76
+ size=(num_seqs, num_query_heads, num_partitions, head_size),
77
+ dtype=output.dtype,
78
+ device=output.device,
79
+ )
80
+ exp_sums = torch.empty(
81
+ size=(num_seqs, num_query_heads, num_partitions),
82
+ dtype=torch.float32,
83
+ device=output.device,
84
+ )
85
+ max_logits = torch.empty_like(exp_sums)
86
+
87
+ def run_benchmark(num_iters: int, profile: bool = False) -> float:
88
+ torch.cuda.synchronize()
89
+ if profile:
90
+ torch.cuda.cudart().cudaProfilerStart()
91
+ start_time = time.perf_counter()
92
+
93
+ for _ in range(num_iters):
94
+ if version == "v1":
95
+ ops.paged_attention_v1(
96
+ output,
97
+ query,
98
+ key_cache,
99
+ value_cache,
100
+ num_kv_heads,
101
+ scale,
102
+ block_tables,
103
+ context_lens,
104
+ block_size,
105
+ max_context_len,
106
+ alibi_slopes,
107
+ kv_cache_dtype,
108
+ )
109
+ elif version == "v2":
110
+ ops.paged_attention_v2(
111
+ output,
112
+ exp_sums,
113
+ max_logits,
114
+ tmp_output,
115
+ query,
116
+ key_cache,
117
+ value_cache,
118
+ num_kv_heads,
119
+ scale,
120
+ block_tables,
121
+ context_lens,
122
+ block_size,
123
+ max_context_len,
124
+ alibi_slopes,
125
+ kv_cache_dtype,
126
+ )
127
+ else:
128
+ raise ValueError(f"Invalid version: {version}")
129
+ torch.cuda.synchronize()
130
+
131
+ end_time = time.perf_counter()
132
+ if profile:
133
+ torch.cuda.cudart().cudaProfilerStart()
134
+ return (end_time - start_time) / num_iters
135
+
136
+ # Warmup.
137
+ print("Warming up...")
138
+ run_benchmark(num_iters=3, profile=False)
139
+
140
+ # Benchmark.
141
+ if do_profile:
142
+ latency = run_benchmark(num_iters=1, profile=True)
143
+ else:
144
+ latency = run_benchmark(num_iters=100, profile=False)
145
+ print(f"Kernel running time: {latency * 1000000:.3f} us")
146
+
147
+
148
+ if __name__ == '__main__':
149
+ parser = argparse.ArgumentParser(
150
+ description="Benchmark the paged attention kernel.")
151
+ parser.add_argument("--version",
152
+ type=str,
153
+ choices=["v1", "v2"],
154
+ default="v2")
155
+ parser.add_argument("--batch-size", type=int, default=8)
156
+ parser.add_argument("--context-len", type=int, default=4096)
157
+ parser.add_argument("--num-query-heads", type=int, default=64)
158
+ parser.add_argument("--num-kv-heads", type=int, default=8)
159
+ parser.add_argument("--head-size",
160
+ type=int,
161
+ choices=[64, 80, 96, 112, 128, 256],
162
+ default=128)
163
+ parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
164
+ parser.add_argument("--use-alibi", action="store_true")
165
+ parser.add_argument("--dtype",
166
+ type=str,
167
+ choices=["half", "bfloat16", "float"],
168
+ default="half")
169
+ parser.add_argument("--seed", type=int, default=0)
170
+ parser.add_argument("--profile", action="store_true")
171
+ parser.add_argument(
172
+ "--kv-cache-dtype",
173
+ type=str,
174
+ choices=["auto", "fp8_e5m2"],
175
+ default="auto",
176
+ help=
177
+ 'Data type for kv cache storage. If "auto", will use model data type.')
178
+ args = parser.parse_args()
179
+ print(args)
180
+
181
+ if args.num_query_heads % args.num_kv_heads != 0:
182
+ raise ValueError("num_query_heads must be divisible by num_kv_heads")
183
+ main(
184
+ version=args.version,
185
+ num_seqs=args.batch_size,
186
+ context_len=args.context_len,
187
+ num_query_heads=args.num_query_heads,
188
+ num_kv_heads=args.num_kv_heads,
189
+ head_size=args.head_size,
190
+ block_size=args.block_size,
191
+ use_alibi=args.use_alibi,
192
+ dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
193
+ seed=args.seed,
194
+ do_profile=args.profile,
195
+ kv_cache_dtype=args.kv_cache_dtype,
196
+ )
benchmarks/launch_tgi_server.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ PORT=8000
4
+ MODEL=$1
5
+ TOKENS=$2
6
+
7
+ docker run --gpus all --shm-size 1g -p $PORT:80 \
8
+ -v $PWD/data:/data \
9
+ ghcr.io/huggingface/text-generation-inference:0.8 \
10
+ --model-id $MODEL \
11
+ --sharded false \
12
+ --max-input-length 1024 \
13
+ --max-total-tokens 2048 \
14
+ --max-best-of 5 \
15
+ --max-concurrent-requests 5000 \
16
+ --max-batch-total-tokens $TOKENS
csrc/activation_kernels.cu ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <torch/extension.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include "cuda_compat.h"
6
+ #include "dispatch_utils.h"
7
+
8
+ namespace vllm {
9
+
10
+ template<typename T>
11
+ __device__ __forceinline__ T silu(const T& x) {
12
+ // x * sigmoid(x)
13
+ return (T) (((float) x) / (1.0f + expf((float) -x)));
14
+ }
15
+
16
+ template<typename scalar_t>
17
+ __global__ void silu_and_mul_kernel(
18
+ scalar_t* __restrict__ out, // [..., d]
19
+ const scalar_t* __restrict__ input, // [..., 2, d]
20
+ const int d) {
21
+ const int64_t token_idx = blockIdx.x;
22
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
23
+ const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
24
+ const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
25
+ out[token_idx * d + idx] = silu(x) * y;
26
+ }
27
+ }
28
+
29
+ } // namespace vllm
30
+
31
+ void silu_and_mul(
32
+ torch::Tensor& out, // [..., d]
33
+ torch::Tensor& input) // [..., 2 * d]
34
+ {
35
+ int64_t num_tokens = input.numel() / input.size(-1);
36
+ int d = input.size(-1) / 2;
37
+
38
+ dim3 grid(num_tokens);
39
+ dim3 block(std::min(d, 1024));
40
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
41
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
42
+ VLLM_DISPATCH_FLOATING_TYPES(
43
+ input.scalar_type(),
44
+ "silu_and_mul_kernel",
45
+ [&] {
46
+ vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
47
+ out.data_ptr<scalar_t>(),
48
+ input.data_ptr<scalar_t>(),
49
+ d);
50
+ });
51
+ }
52
+
53
+ namespace vllm {
54
+
55
+ // Element-wise activation kernel template.
56
+ template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
57
+ __global__ void activation_kernel(
58
+ scalar_t* __restrict__ out, // [..., d]
59
+ const scalar_t* __restrict__ input, // [..., d]
60
+ const int d) {
61
+ const int64_t token_idx = blockIdx.x;
62
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
63
+ const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
64
+ out[token_idx * d + idx] = ACT_FN(x);
65
+ }
66
+ }
67
+
68
+ } // namespace vllm
69
+
70
+ // Launch element-wise activation kernel.
71
+ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
72
+ int d = input.size(-1); \
73
+ int64_t num_tokens = input.numel() / d; \
74
+ dim3 grid(num_tokens); \
75
+ dim3 block(std::min(d, 1024)); \
76
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
77
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
78
+ VLLM_DISPATCH_FLOATING_TYPES( \
79
+ input.scalar_type(), \
80
+ "activation_kernel", \
81
+ [&] { \
82
+ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
83
+ out.data_ptr<scalar_t>(), \
84
+ input.data_ptr<scalar_t>(), \
85
+ d); \
86
+ });
87
+
88
+ namespace vllm {
89
+
90
+ template<typename T>
91
+ __device__ __forceinline__ T gelu_new_kernel(const T& x) {
92
+ const float x3 = (float) (x * x * x);
93
+ const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
94
+ return ((T) 0.5) * x * (((T) 1.0) + t);
95
+ }
96
+
97
+ template<typename T>
98
+ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
99
+ const float f = (float) x;
100
+ const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
101
+ return ((T) 0.5) * x * (((T) 1.0) + t);
102
+ }
103
+
104
+ } // namespace vllm
105
+
106
+ void gelu_new(
107
+ torch::Tensor& out, // [..., d]
108
+ torch::Tensor& input) // [..., d]
109
+ {
110
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
111
+ }
112
+
113
+ void gelu_fast(
114
+ torch::Tensor& out, // [..., d]
115
+ torch::Tensor& input) // [..., d]
116
+ {
117
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
118
+ }
csrc/attention/attention_dtypes.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "attention_generic.cuh"
4
+ #include "dtype_float16.cuh"
5
+ #include "dtype_float32.cuh"
6
+ #include "dtype_bfloat16.cuh"
7
+ #include "dtype_fp8_e5m2.cuh"
csrc/attention/attention_generic.cuh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
3
+ * Copyright (c) 2023, The vLLM team.
4
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+ #pragma once
19
+
20
+ #include <stdint.h>
21
+
22
+ namespace vllm {
23
+
24
+ // A vector type to store Q, K, V elements.
25
+ template<typename T, int VEC_SIZE>
26
+ struct Vec {};
27
+
28
+ // A vector type to store FP32 accumulators.
29
+ template<typename T>
30
+ struct FloatVec {};
31
+
32
+ // Template vector operations.
33
+ template<typename Acc, typename A, typename B>
34
+ inline __device__ Acc mul(A a, B b);
35
+
36
+ template<typename T>
37
+ inline __device__ float sum(T v);
38
+
39
+ template<typename T>
40
+ inline __device__ float dot(T a, T b) {
41
+ return sum(mul<T, T, T>(a, b));
42
+ }
43
+
44
+ template<typename A, typename T>
45
+ inline __device__ float dot(T a, T b) {
46
+ return sum(mul<A, T, T>(a, b));
47
+ }
48
+
49
+ template<typename T>
50
+ inline __device__ void zero(T& dst) {
51
+ constexpr int WORDS = sizeof(T) / 4;
52
+ union {
53
+ T raw;
54
+ uint32_t words[WORDS];
55
+ } tmp;
56
+
57
+ #pragma unroll
58
+ for (int ii = 0; ii < WORDS; ++ii) {
59
+ tmp.words[ii] = 0u;
60
+ }
61
+ dst = tmp.raw;
62
+ }
63
+
64
+ } // namespace vllm
csrc/attention/attention_kernels.cu ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
+ * Copyright (c) 2023, The vLLM team.
4
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+ #ifdef USE_ROCM
19
+ #include <hip/hip_runtime.h>
20
+ #endif
21
+
22
+ #include <torch/extension.h>
23
+ #include <ATen/cuda/CUDAContext.h>
24
+ #include <c10/cuda/CUDAGuard.h>
25
+
26
+ #include "attention_dtypes.h"
27
+ #include "attention_utils.cuh"
28
+ #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
29
+
30
+ #include <algorithm>
31
+
32
+ #ifndef USE_ROCM
33
+ #define WARP_SIZE 32
34
+ #else
35
+ #define WARP_SIZE warpSize
36
+ #endif
37
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
38
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
39
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
40
+
41
+ namespace vllm {
42
+
43
+ // Utility function for attention softmax.
44
+ template<int NUM_WARPS>
45
+ inline __device__ float block_sum(float* red_smem, float sum) {
46
+ // Decompose the thread index into warp / lane.
47
+ int warp = threadIdx.x / WARP_SIZE;
48
+ int lane = threadIdx.x % WARP_SIZE;
49
+
50
+ // Compute the sum per warp.
51
+ #pragma unroll
52
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
53
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
54
+ }
55
+
56
+ // Warp leaders store the data to shared memory.
57
+ if (lane == 0) {
58
+ red_smem[warp] = sum;
59
+ }
60
+
61
+ // Make sure the data is in shared memory.
62
+ __syncthreads();
63
+
64
+ // The warps compute the final sums.
65
+ if (lane < NUM_WARPS) {
66
+ sum = red_smem[lane];
67
+ }
68
+
69
+ // Parallel reduction inside the warp.
70
+ #pragma unroll
71
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
72
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
73
+ }
74
+
75
+ // Broadcast to other threads.
76
+ return VLLM_SHFL_SYNC(sum, 0);
77
+ }
78
+
79
+ // TODO(woosuk): Merge the last two dimensions of the grid.
80
+ // Grid: (num_heads, num_seqs, max_num_partitions).
81
+ template<
82
+ typename scalar_t,
83
+ typename cache_t,
84
+ int HEAD_SIZE,
85
+ int BLOCK_SIZE,
86
+ int NUM_THREADS,
87
+ bool IS_FP8_E5M2_KV_CACHE,
88
+ int PARTITION_SIZE = 0> // Zero means no partitioning.
89
+ __device__ void paged_attention_kernel(
90
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
91
+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
92
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
93
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
94
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
95
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
96
+ const int num_kv_heads, // [num_heads]
97
+ const float scale,
98
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
99
+ const int* __restrict__ context_lens, // [num_seqs]
100
+ const int max_num_blocks_per_seq,
101
+ const float* __restrict__ alibi_slopes, // [num_heads]
102
+ const int q_stride,
103
+ const int kv_block_stride,
104
+ const int kv_head_stride) {
105
+ const int seq_idx = blockIdx.y;
106
+ const int partition_idx = blockIdx.z;
107
+ const int max_num_partitions = gridDim.z;
108
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
109
+ const int context_len = context_lens[seq_idx];
110
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
111
+ // No work to do. Terminate the thread block.
112
+ return;
113
+ }
114
+
115
+ const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
116
+ const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
117
+
118
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
119
+ const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
120
+ const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
121
+ const int num_blocks = end_block_idx - start_block_idx;
122
+
123
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
124
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
125
+ const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
126
+ const int num_tokens = end_token_idx - start_token_idx;
127
+
128
+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
129
+ constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
130
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
131
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
132
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
133
+ const int thread_idx = threadIdx.x;
134
+ const int warp_idx = thread_idx / WARP_SIZE;
135
+ const int lane = thread_idx % WARP_SIZE;
136
+
137
+ const int head_idx = blockIdx.x;
138
+ const int num_heads = gridDim.x;
139
+ const int num_queries_per_kv = num_heads / num_kv_heads;
140
+ const int kv_head_idx = head_idx / num_queries_per_kv;
141
+ const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
142
+
143
+ // A vector type to store a part of a key or a query.
144
+ // The vector size is configured in such a way that the threads in a thread group
145
+ // fetch or compute 16 bytes at a time.
146
+ // For example, if the size of a thread group is 4 and the data type is half,
147
+ // then the vector size is 16 / (4 * sizeof(half)) == 2.
148
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
149
+ using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
150
+ using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
151
+ #ifdef ENABLE_FP8_E5M2
152
+ using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
153
+ #endif
154
+
155
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
156
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
157
+
158
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
159
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
160
+
161
+ // Load the query to registers.
162
+ // Each thread in a thread group has a different part of the query.
163
+ // For example, if the the thread group size is 4, then the first thread in the group
164
+ // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
165
+ // th vectors of the query, and so on.
166
+ // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
167
+ const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
168
+ __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
169
+ #pragma unroll
170
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
171
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
172
+ q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
173
+ }
174
+ __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
175
+
176
+ // Memory planning.
177
+ extern __shared__ char shared_mem[];
178
+ // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
179
+ float* logits = reinterpret_cast<float*>(shared_mem);
180
+ // Workspace for reduction.
181
+ __shared__ float red_smem[2 * NUM_WARPS];
182
+
183
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
184
+ // Each thread group fetches x elements from the key at a time.
185
+ constexpr int x = 16 / sizeof(cache_t);
186
+ float qk_max = -FLT_MAX;
187
+
188
+ // Iterate over the key blocks.
189
+ // Each warp fetches a block of keys for each iteration.
190
+ // Each thread group in a warp fetches a key from the block, and computes
191
+ // dot product with the query.
192
+ const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
193
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
194
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
195
+ // because int32 can lead to overflow when this variable is multiplied by large numbers
196
+ // (e.g., kv_block_stride).
197
+ const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
198
+
199
+ // Load a key to registers.
200
+ // Each thread in a thread group has a different part of the key.
201
+ // For example, if the the thread group size is 4, then the first thread in the group
202
+ // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
203
+ // vectors of the key, and so on.
204
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
205
+ const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
206
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
207
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
208
+
209
+ #pragma unroll
210
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
211
+ const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
212
+ + kv_head_idx * kv_head_stride
213
+ + physical_block_offset * x;
214
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
215
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
216
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
217
+ if constexpr (IS_FP8_E5M2_KV_CACHE) {
218
+ #ifdef ENABLE_FP8_E5M2
219
+ Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
220
+ // Vector conversion from Quant_vec to K_vec.
221
+ k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
222
+ #else
223
+ assert(false);
224
+ #endif
225
+ } else {
226
+ k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
227
+ }
228
+ }
229
+
230
+ // Compute dot product.
231
+ // This includes a reduction across the threads in the same thread group.
232
+ float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
233
+ // Add the ALiBi bias if slopes are given.
234
+ qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
235
+
236
+ if (thread_group_offset == 0) {
237
+ // Store the partial reductions to shared memory.
238
+ // NOTE(woosuk): It is required to zero out the masked logits.
239
+ const bool mask = token_idx >= context_len;
240
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
241
+ // Update the max value.
242
+ qk_max = mask ? qk_max : fmaxf(qk_max, qk);
243
+ }
244
+ }
245
+ }
246
+
247
+ // Perform reduction across the threads in the same warp to get the
248
+ // max qk value for each "warp" (not across the thread block yet).
249
+ // The 0-th thread of each thread group already has its max qk value.
250
+ #pragma unroll
251
+ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
252
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
253
+ }
254
+ if (lane == 0) {
255
+ red_smem[warp_idx] = qk_max;
256
+ }
257
+ __syncthreads();
258
+
259
+ // TODO(woosuk): Refactor this part.
260
+ // Get the max qk value for the sequence.
261
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
262
+ #pragma unroll
263
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
264
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
265
+ }
266
+ // Broadcast the max qk value to all threads.
267
+ qk_max = VLLM_SHFL_SYNC(qk_max, 0);
268
+
269
+ // Get the sum of the exp values.
270
+ float exp_sum = 0.f;
271
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
272
+ float val = __expf(logits[i] - qk_max);
273
+ logits[i] = val;
274
+ exp_sum += val;
275
+ }
276
+ exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
277
+
278
+ // Compute softmax.
279
+ const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
280
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
281
+ logits[i] *= inv_sum;
282
+ }
283
+ __syncthreads();
284
+
285
+ // If partitioning is enabled, store the max logit and exp_sum.
286
+ if (USE_PARTITIONING && thread_idx == 0) {
287
+ float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
288
+ + head_idx * max_num_partitions
289
+ + partition_idx;
290
+ *max_logits_ptr = qk_max;
291
+ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
292
+ + head_idx * max_num_partitions
293
+ + partition_idx;
294
+ *exp_sums_ptr = exp_sum;
295
+ }
296
+
297
+ // Each thread will fetch 16 bytes from the value cache at a time.
298
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
299
+ using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
300
+ using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
301
+ #ifdef ENABLE_FP8_E5M2
302
+ using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
303
+ #endif
304
+ using Float_L_vec = typename FloatVec<L_vec>::Type;
305
+
306
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
307
+ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
308
+ constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
309
+
310
+ // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
311
+ float accs[NUM_ROWS_PER_THREAD];
312
+ #pragma unroll
313
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
314
+ accs[i] = 0.f;
315
+ }
316
+
317
+ scalar_t zero_value;
318
+ zero(zero_value);
319
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
320
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
321
+ // because int32 can lead to overflow when this variable is multiplied by large numbers
322
+ // (e.g., kv_block_stride).
323
+ const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
324
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
325
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
326
+ L_vec logits_vec;
327
+ from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
328
+
329
+ const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
330
+ + kv_head_idx * kv_head_stride;
331
+ #pragma unroll
332
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
333
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
334
+ if (row_idx < HEAD_SIZE) {
335
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
336
+ V_vec v_vec;
337
+ if constexpr (IS_FP8_E5M2_KV_CACHE) {
338
+ #ifdef ENABLE_FP8_E5M2
339
+ V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
340
+ // Vector conversion from V_quant_vec to V_vec.
341
+ v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
342
+ #else
343
+ assert(false);
344
+ #endif
345
+ } else {
346
+ v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
347
+ }
348
+ if (block_idx == num_context_blocks - 1) {
349
+ // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
350
+ // we should explicitly zero out the values since they may contain NaNs.
351
+ // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
352
+ scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
353
+ #pragma unroll
354
+ for (int j = 0; j < V_VEC_SIZE; j++) {
355
+ v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
356
+ }
357
+ }
358
+ accs[i] += dot(logits_vec, v_vec);
359
+ }
360
+ }
361
+ }
362
+
363
+ // Perform reduction within each warp.
364
+ #pragma unroll
365
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
366
+ float acc = accs[i];
367
+ #pragma unroll
368
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
369
+ acc += VLLM_SHFL_XOR_SYNC(acc, mask);
370
+ }
371
+ accs[i] = acc;
372
+ }
373
+
374
+ // NOTE(woosuk): A barrier is required because the shared memory space for logits
375
+ // is reused for the output.
376
+ __syncthreads();
377
+
378
+ // Perform reduction across warps.
379
+ float* out_smem = reinterpret_cast<float*>(shared_mem);
380
+ #pragma unroll
381
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
382
+ int mid = i / 2;
383
+ // Upper warps write to shared memory.
384
+ if (warp_idx >= mid && warp_idx < i) {
385
+ float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
386
+ #pragma unroll
387
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
388
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
389
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
390
+ dst[row_idx] = accs[i];
391
+ }
392
+ }
393
+ }
394
+ __syncthreads();
395
+
396
+ // Lower warps update the output.
397
+ if (warp_idx < mid) {
398
+ const float* src = &out_smem[warp_idx * HEAD_SIZE];
399
+ #pragma unroll
400
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
401
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
402
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
403
+ accs[i] += src[row_idx];
404
+ }
405
+ }
406
+ }
407
+ __syncthreads();
408
+ }
409
+
410
+ // Write the final output.
411
+ if (warp_idx == 0) {
412
+ scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
413
+ + head_idx * max_num_partitions * HEAD_SIZE
414
+ + partition_idx * HEAD_SIZE;
415
+ #pragma unroll
416
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
417
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
418
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
419
+ from_float(*(out_ptr + row_idx), accs[i]);
420
+ }
421
+ }
422
+ }
423
+ }
424
+
425
+ // Grid: (num_heads, num_seqs, 1).
426
+ template<
427
+ typename scalar_t,
428
+ typename cache_t,
429
+ int HEAD_SIZE,
430
+ int BLOCK_SIZE,
431
+ int NUM_THREADS,
432
+ bool IS_FP8_E5M2_KV_CACHE>
433
+ __global__ void paged_attention_v1_kernel(
434
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
435
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
436
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
437
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
438
+ const int num_kv_heads, // [num_heads]
439
+ const float scale,
440
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
441
+ const int* __restrict__ context_lens, // [num_seqs]
442
+ const int max_num_blocks_per_seq,
443
+ const float* __restrict__ alibi_slopes, // [num_heads]
444
+ const int q_stride,
445
+ const int kv_block_stride,
446
+ const int kv_head_stride) {
447
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
448
+ /* exp_sums */ nullptr, /* max_logits */ nullptr,
449
+ out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
450
+ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
451
+ }
452
+
453
+ // Grid: (num_heads, num_seqs, max_num_partitions).
454
+ template<
455
+ typename scalar_t,
456
+ typename cache_t,
457
+ int HEAD_SIZE,
458
+ int BLOCK_SIZE,
459
+ int NUM_THREADS,
460
+ bool IS_FP8_E5M2_KV_CACHE,
461
+ int PARTITION_SIZE>
462
+ __global__ void paged_attention_v2_kernel(
463
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
464
+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
465
+ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
466
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
467
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
468
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
469
+ const int num_kv_heads, // [num_heads]
470
+ const float scale,
471
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
472
+ const int* __restrict__ context_lens, // [num_seqs]
473
+ const int max_num_blocks_per_seq,
474
+ const float* __restrict__ alibi_slopes, // [num_heads]
475
+ const int q_stride,
476
+ const int kv_block_stride,
477
+ const int kv_head_stride) {
478
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
479
+ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
480
+ block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
481
+ q_stride, kv_block_stride, kv_head_stride);
482
+ }
483
+
484
+ // Grid: (num_heads, num_seqs).
485
+ template<
486
+ typename scalar_t,
487
+ int HEAD_SIZE,
488
+ int NUM_THREADS,
489
+ int PARTITION_SIZE>
490
+ __global__ void paged_attention_v2_reduce_kernel(
491
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
492
+ const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
493
+ const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
494
+ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
495
+ const int* __restrict__ context_lens, // [num_seqs]
496
+ const int max_num_partitions) {
497
+ const int num_heads = gridDim.x;
498
+ const int head_idx = blockIdx.x;
499
+ const int seq_idx = blockIdx.y;
500
+ const int context_len = context_lens[seq_idx];
501
+ const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
502
+ if (num_partitions == 1) {
503
+ // No need to reduce. Only copy tmp_out to out.
504
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
505
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
506
+ + head_idx * max_num_partitions * HEAD_SIZE;
507
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
508
+ out_ptr[i] = tmp_out_ptr[i];
509
+ }
510
+ // Terminate the thread block.
511
+ return;
512
+ }
513
+
514
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
515
+ const int warp_idx = threadIdx.x / WARP_SIZE;
516
+ const int lane = threadIdx.x % WARP_SIZE;
517
+
518
+ // Size: 2 * num_partitions.
519
+ extern __shared__ char shared_mem[];
520
+ // Workspace for reduction.
521
+ __shared__ float red_smem[2 * NUM_WARPS];
522
+
523
+ // Load max logits to shared memory.
524
+ float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
525
+ const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
526
+ + head_idx * max_num_partitions;
527
+ float max_logit = -FLT_MAX;
528
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
529
+ const float l = max_logits_ptr[i];
530
+ shared_max_logits[i] = l;
531
+ max_logit = fmaxf(max_logit, l);
532
+ }
533
+ __syncthreads();
534
+
535
+ // Get the global max logit.
536
+ // Reduce within the warp.
537
+ #pragma unroll
538
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
539
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
540
+ }
541
+ if (lane == 0) {
542
+ red_smem[warp_idx] = max_logit;
543
+ }
544
+ __syncthreads();
545
+ // Reduce across warps.
546
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
547
+ #pragma unroll
548
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
549
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
550
+ }
551
+ // Broadcast the max value to all threads.
552
+ max_logit = VLLM_SHFL_SYNC(max_logit, 0);
553
+
554
+ // Load rescaled exp sums to shared memory.
555
+ float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
556
+ const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
557
+ + head_idx * max_num_partitions;
558
+ float global_exp_sum = 0.0f;
559
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
560
+ float l = shared_max_logits[i];
561
+ float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
562
+ global_exp_sum += rescaled_exp_sum;
563
+ shared_exp_sums[i] = rescaled_exp_sum;
564
+ }
565
+ __syncthreads();
566
+ global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
567
+ const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
568
+
569
+ // Aggregate tmp_out to out.
570
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
571
+ + head_idx * max_num_partitions * HEAD_SIZE;
572
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
573
+ #pragma unroll
574
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
575
+ float acc = 0.0f;
576
+ for (int j = 0; j < num_partitions; ++j) {
577
+ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
578
+ }
579
+ from_float(out_ptr[i], acc);
580
+ }
581
+ }
582
+
583
+ } // namespace vllm
584
+
585
+ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
586
+ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
587
+ ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
588
+ IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
589
+ vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
590
+ IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
591
+ out_ptr, \
592
+ query_ptr, \
593
+ key_cache_ptr, \
594
+ value_cache_ptr, \
595
+ num_kv_heads, \
596
+ scale, \
597
+ block_tables_ptr, \
598
+ context_lens_ptr, \
599
+ max_num_blocks_per_seq, \
600
+ alibi_slopes_ptr, \
601
+ q_stride, \
602
+ kv_block_stride, \
603
+ kv_head_stride);
604
+
605
+ // TODO(woosuk): Tune NUM_THREADS.
606
+ template<
607
+ typename T,
608
+ typename CACHE_T,
609
+ int BLOCK_SIZE,
610
+ bool IS_FP8_E5M2_KV_CACHE,
611
+ int NUM_THREADS = 128>
612
+ void paged_attention_v1_launcher(
613
+ torch::Tensor& out,
614
+ torch::Tensor& query,
615
+ torch::Tensor& key_cache,
616
+ torch::Tensor& value_cache,
617
+ int num_kv_heads,
618
+ float scale,
619
+ torch::Tensor& block_tables,
620
+ torch::Tensor& context_lens,
621
+ int max_context_len,
622
+ const c10::optional<torch::Tensor>& alibi_slopes) {
623
+ int num_seqs = query.size(0);
624
+ int num_heads = query.size(1);
625
+ int head_size = query.size(2);
626
+ int max_num_blocks_per_seq = block_tables.size(1);
627
+ int q_stride = query.stride(0);
628
+ int kv_block_stride = key_cache.stride(0);
629
+ int kv_head_stride = key_cache.stride(1);
630
+
631
+ int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
632
+ assert(head_size % thread_group_size == 0);
633
+
634
+ // NOTE: alibi_slopes is optional.
635
+ const float* alibi_slopes_ptr = alibi_slopes ?
636
+ reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
637
+ : nullptr;
638
+
639
+ T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
640
+ T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
641
+ CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
642
+ CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
643
+ int* block_tables_ptr = block_tables.data_ptr<int>();
644
+ int* context_lens_ptr = context_lens.data_ptr<int>();
645
+
646
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
647
+ int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
648
+ int logits_size = padded_max_context_len * sizeof(float);
649
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
650
+ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
651
+ // Keep that in sync with the logic here!
652
+ int shared_mem_size = std::max(logits_size, outputs_size);
653
+
654
+ dim3 grid(num_heads, num_seqs, 1);
655
+ dim3 block(NUM_THREADS);
656
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
657
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
658
+ switch (head_size) {
659
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
660
+ // head sizes that we use in the model. However, we can easily extend this
661
+ // to support any head size which is a multiple of 16.
662
+ case 64:
663
+ LAUNCH_PAGED_ATTENTION_V1(64);
664
+ break;
665
+ case 80:
666
+ LAUNCH_PAGED_ATTENTION_V1(80);
667
+ break;
668
+ case 96:
669
+ LAUNCH_PAGED_ATTENTION_V1(96);
670
+ break;
671
+ case 112:
672
+ LAUNCH_PAGED_ATTENTION_V1(112);
673
+ break;
674
+ case 128:
675
+ LAUNCH_PAGED_ATTENTION_V1(128);
676
+ break;
677
+ case 256:
678
+ LAUNCH_PAGED_ATTENTION_V1(256);
679
+ break;
680
+ default:
681
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
682
+ break;
683
+ }
684
+ }
685
+
686
+ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
687
+ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
688
+ out, \
689
+ query, \
690
+ key_cache, \
691
+ value_cache, \
692
+ num_kv_heads, \
693
+ scale, \
694
+ block_tables, \
695
+ context_lens, \
696
+ max_context_len, \
697
+ alibi_slopes);
698
+
699
+ // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
700
+ // 1, 2, 4, 64, 128, 256.
701
+ #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
702
+ switch (block_size) { \
703
+ case 8: \
704
+ CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
705
+ break; \
706
+ case 16: \
707
+ CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
708
+ break; \
709
+ case 32: \
710
+ CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
711
+ break; \
712
+ default: \
713
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
714
+ break; \
715
+ }
716
+
717
+ void paged_attention_v1(
718
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
719
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
720
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
721
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
722
+ int num_kv_heads, // [num_heads]
723
+ float scale,
724
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
725
+ torch::Tensor& context_lens, // [num_seqs]
726
+ int block_size,
727
+ int max_context_len,
728
+ const c10::optional<torch::Tensor>& alibi_slopes,
729
+ const std::string& kv_cache_dtype) {
730
+ if (kv_cache_dtype == "auto") {
731
+ if (query.dtype() == at::ScalarType::Float) {
732
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
733
+ } else if (query.dtype() == at::ScalarType::Half) {
734
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
735
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
736
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
737
+ } else {
738
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
739
+ }
740
+ } else if (kv_cache_dtype == "fp8_e5m2") {
741
+ if (query.dtype() == at::ScalarType::Float) {
742
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
743
+ } else if (query.dtype() == at::ScalarType::Half) {
744
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
745
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
746
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
747
+ } else {
748
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
749
+ }
750
+ } else {
751
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
752
+ }
753
+ }
754
+
755
+ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
756
+ vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
757
+ IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
758
+ <<<grid, block, shared_mem_size, stream>>>( \
759
+ exp_sums_ptr, \
760
+ max_logits_ptr, \
761
+ tmp_out_ptr, \
762
+ query_ptr, \
763
+ key_cache_ptr, \
764
+ value_cache_ptr, \
765
+ num_kv_heads, \
766
+ scale, \
767
+ block_tables_ptr, \
768
+ context_lens_ptr, \
769
+ max_num_blocks_per_seq, \
770
+ alibi_slopes_ptr, \
771
+ q_stride, \
772
+ kv_block_stride, \
773
+ kv_head_stride); \
774
+ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
775
+ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
776
+ out_ptr, \
777
+ exp_sums_ptr, \
778
+ max_logits_ptr, \
779
+ tmp_out_ptr, \
780
+ context_lens_ptr, \
781
+ max_num_partitions);
782
+
783
+ template<
784
+ typename T,
785
+ typename CACHE_T,
786
+ int BLOCK_SIZE,
787
+ bool IS_FP8_E5M2_KV_CACHE,
788
+ int NUM_THREADS = 128,
789
+ int PARTITION_SIZE = 512>
790
+ void paged_attention_v2_launcher(
791
+ torch::Tensor& out,
792
+ torch::Tensor& exp_sums,
793
+ torch::Tensor& max_logits,
794
+ torch::Tensor& tmp_out,
795
+ torch::Tensor& query,
796
+ torch::Tensor& key_cache,
797
+ torch::Tensor& value_cache,
798
+ int num_kv_heads,
799
+ float scale,
800
+ torch::Tensor& block_tables,
801
+ torch::Tensor& context_lens,
802
+ int max_context_len,
803
+ const c10::optional<torch::Tensor>& alibi_slopes) {
804
+ int num_seqs = query.size(0);
805
+ int num_heads = query.size(1);
806
+ int head_size = query.size(2);
807
+ int max_num_blocks_per_seq = block_tables.size(1);
808
+ int q_stride = query.stride(0);
809
+ int kv_block_stride = key_cache.stride(0);
810
+ int kv_head_stride = key_cache.stride(1);
811
+
812
+ int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
813
+ assert(head_size % thread_group_size == 0);
814
+
815
+ // NOTE: alibi_slopes is optional.
816
+ const float* alibi_slopes_ptr = alibi_slopes ?
817
+ reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
818
+ : nullptr;
819
+
820
+ T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
821
+ float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
822
+ float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
823
+ T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
824
+ T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
825
+ CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
826
+ CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
827
+ int* block_tables_ptr = block_tables.data_ptr<int>();
828
+ int* context_lens_ptr = context_lens.data_ptr<int>();
829
+
830
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
831
+ int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
832
+ int logits_size = PARTITION_SIZE * sizeof(float);
833
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
834
+
835
+ // For paged attention v2 kernel.
836
+ dim3 grid(num_heads, num_seqs, max_num_partitions);
837
+ int shared_mem_size = std::max(logits_size, outputs_size);
838
+ // For paged attention v2 reduce kernel.
839
+ dim3 reduce_grid(num_heads, num_seqs);
840
+ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
841
+
842
+ dim3 block(NUM_THREADS);
843
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
844
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
845
+ switch (head_size) {
846
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
847
+ // head sizes that we use in the model. However, we can easily extend this
848
+ // to support any head size which is a multiple of 16.
849
+ case 64:
850
+ LAUNCH_PAGED_ATTENTION_V2(64);
851
+ break;
852
+ case 80:
853
+ LAUNCH_PAGED_ATTENTION_V2(80);
854
+ break;
855
+ case 96:
856
+ LAUNCH_PAGED_ATTENTION_V2(96);
857
+ break;
858
+ case 112:
859
+ LAUNCH_PAGED_ATTENTION_V2(112);
860
+ break;
861
+ case 128:
862
+ LAUNCH_PAGED_ATTENTION_V2(128);
863
+ break;
864
+ case 256:
865
+ LAUNCH_PAGED_ATTENTION_V2(256);
866
+ break;
867
+ default:
868
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
869
+ break;
870
+ }
871
+ }
872
+
873
+ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
874
+ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
875
+ out, \
876
+ exp_sums, \
877
+ max_logits, \
878
+ tmp_out, \
879
+ query, \
880
+ key_cache, \
881
+ value_cache, \
882
+ num_kv_heads, \
883
+ scale, \
884
+ block_tables, \
885
+ context_lens, \
886
+ max_context_len, \
887
+ alibi_slopes);
888
+
889
+ // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
890
+ // 1, 2, 4, 64, 128, 256.
891
+ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
892
+ switch (block_size) { \
893
+ case 8: \
894
+ CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
895
+ break; \
896
+ case 16: \
897
+ CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
898
+ break; \
899
+ case 32: \
900
+ CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
901
+ break; \
902
+ default: \
903
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
904
+ break; \
905
+ }
906
+
907
+ void paged_attention_v2(
908
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
909
+ torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
910
+ torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
911
+ torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
912
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
913
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
914
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
915
+ int num_kv_heads, // [num_heads]
916
+ float scale,
917
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
918
+ torch::Tensor& context_lens, // [num_seqs]
919
+ int block_size,
920
+ int max_context_len,
921
+ const c10::optional<torch::Tensor>& alibi_slopes,
922
+ const std::string& kv_cache_dtype) {
923
+ if (kv_cache_dtype == "auto") {
924
+ if (query.dtype() == at::ScalarType::Float) {
925
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
926
+ } else if (query.dtype() == at::ScalarType::Half) {
927
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
928
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
929
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
930
+ } else {
931
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
932
+ }
933
+ } else if (kv_cache_dtype == "fp8_e5m2") {
934
+ if (query.dtype() == at::ScalarType::Float) {
935
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
936
+ } else if (query.dtype() == at::ScalarType::Half) {
937
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
938
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
939
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
940
+ } else {
941
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
942
+ }
943
+ } else {
944
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
945
+ }
946
+ }
947
+
948
+ #undef WARP_SIZE
949
+ #undef MAX
950
+ #undef MIN
951
+ #undef DIVIDE_ROUND_UP
csrc/attention/attention_utils.cuh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
+ * Copyright (c) 2023, The vLLM team.
4
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+ #pragma once
19
+
20
+ #include "../cuda_compat.h"
21
+ #include "attention_dtypes.h"
22
+
23
+ #include <float.h>
24
+ #include <type_traits>
25
+
26
+ namespace vllm {
27
+
28
+ // Q*K^T operation.
29
+ template<int THREAD_GROUP_SIZE, typename Vec, int N>
30
+ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
31
+ using A_vec = typename FloatVec<Vec>::Type;
32
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
33
+ A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
34
+ #pragma unroll
35
+ for (int ii = 1; ii < N; ++ii) {
36
+ qk_vec = fma(q[ii], k[ii], qk_vec);
37
+ }
38
+
39
+ // Finalize the reduction across lanes.
40
+ float qk = sum(qk_vec);
41
+ #pragma unroll
42
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
43
+ qk += VLLM_SHFL_XOR_SYNC(qk, mask);
44
+ }
45
+ return qk;
46
+ }
47
+
48
+ template<typename T, int THREAD_GROUP_SIZE>
49
+ struct Qk_dot {
50
+ template<typename Vec, int N>
51
+ static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
52
+ return qk_dot_<THREAD_GROUP_SIZE>(q, k);
53
+ }
54
+ };
55
+
56
+ } // namespace vllm
csrc/attention/dtype_bfloat16.cuh ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #pragma once
20
+
21
+ #include "attention_generic.cuh"
22
+ #include "dtype_float32.cuh"
23
+
24
+ #ifndef USE_ROCM
25
+ #include <cuda_bf16.h>
26
+ #include <cuda_fp16.h>
27
+ #else
28
+ #include <hip/hip_bf16.h>
29
+ #include <hip/hip_fp16.h>
30
+
31
+ typedef __hip_bfloat162 __nv_bfloat162;
32
+ typedef __hip_bfloat16 __nv_bfloat16;
33
+ #endif
34
+
35
+ #include <stdint.h>
36
+
37
+ namespace vllm {
38
+
39
+ // Define custom BF16 vector data types.
40
+ struct bf16_4_t {
41
+ __nv_bfloat162 x;
42
+ __nv_bfloat162 y;
43
+ };
44
+
45
+ struct bf16_8_t {
46
+ __nv_bfloat162 x;
47
+ __nv_bfloat162 y;
48
+ __nv_bfloat162 z;
49
+ __nv_bfloat162 w;
50
+ };
51
+
52
+ // BF16 vector types for Q, K, V.
53
+ template<>
54
+ struct Vec<__nv_bfloat16, 1> {
55
+ using Type = __nv_bfloat16;
56
+ };
57
+ template<>
58
+ struct Vec<__nv_bfloat16, 2> {
59
+ using Type = __nv_bfloat162;
60
+ };
61
+ template<>
62
+ struct Vec<__nv_bfloat16, 4> {
63
+ using Type = bf16_4_t;
64
+ };
65
+ template<>
66
+ struct Vec<__nv_bfloat16, 8> {
67
+ using Type = bf16_8_t;
68
+ };
69
+
70
+ // FP32 accumulator vector types corresponding to Vec.
71
+ template<>
72
+ struct FloatVec<__nv_bfloat16> {
73
+ using Type = float;
74
+ };
75
+ template<>
76
+ struct FloatVec<__nv_bfloat162> {
77
+ using Type = float2;
78
+ };
79
+ template<>
80
+ struct FloatVec<bf16_4_t> {
81
+ using Type = Float4_;
82
+ };
83
+ template<>
84
+ struct FloatVec<bf16_8_t> {
85
+ using Type = Float8_;
86
+ };
87
+
88
+ // Utility functions for type conversions.
89
+ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
90
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
91
+ assert(false);
92
+ #else
93
+ return __bfloat1622float2(val);
94
+ #endif
95
+ }
96
+
97
+ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
98
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
99
+ assert(false);
100
+ #else
101
+ return __bfloat162bfloat162(val);
102
+ #endif
103
+ }
104
+
105
+ // Vector addition.
106
+ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
107
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
108
+ assert(false);
109
+ #else
110
+ #ifndef USE_ROCM
111
+ return a + b;
112
+ #else
113
+ return __hadd(a, b);
114
+ #endif
115
+ #endif
116
+ }
117
+
118
+ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
119
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
120
+ assert(false);
121
+ #else
122
+ return __hadd2(a, b);
123
+ #endif
124
+ }
125
+
126
+ inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
127
+ bf16_4_t c;
128
+ c.x = add(a.x, b.x);
129
+ c.y = add(a.y, b.y);
130
+ return c;
131
+ }
132
+
133
+ inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
134
+ bf16_8_t c;
135
+ c.x = add(a.x, b.x);
136
+ c.y = add(a.y, b.y);
137
+ c.z = add(a.z, b.z);
138
+ c.w = add(a.w, b.w);
139
+ return c;
140
+ }
141
+
142
+ inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
143
+ float2 fa = bf1622float2(a);
144
+ return add(fa, fb);
145
+ }
146
+
147
+ inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
148
+ Float4_ fc;
149
+ fc.x = add(a.x, fb.x);
150
+ fc.y = add(a.y, fb.y);
151
+ return fc;
152
+ }
153
+
154
+ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
155
+ Float8_ fc;
156
+ fc.x = add(a.x, fb.x);
157
+ fc.y = add(a.y, fb.y);
158
+ fc.z = add(a.z, fb.z);
159
+ fc.w = add(a.w, fb.w);
160
+ return fc;
161
+ }
162
+
163
+ // Vector multiplication.
164
+ template<>
165
+ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
166
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
167
+ assert(false);
168
+ #else
169
+ return __hmul(a, b);
170
+ #endif
171
+ }
172
+
173
+ template<>
174
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
175
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
176
+ assert(false);
177
+ #else
178
+ return __hmul2(a, b);
179
+ #endif
180
+ }
181
+
182
+ template<>
183
+ inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
184
+ return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
185
+ }
186
+
187
+ template<>
188
+ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
189
+ bf16_4_t c;
190
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
191
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
192
+ return c;
193
+ }
194
+
195
+ template<>
196
+ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
197
+ __nv_bfloat162 s = bf162bf162(a);
198
+ bf16_4_t c;
199
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
200
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
201
+ return c;
202
+ }
203
+
204
+ template<>
205
+ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
206
+ bf16_8_t c;
207
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
208
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
209
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
210
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
211
+ return c;
212
+ }
213
+
214
+ template<>
215
+ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
216
+ __nv_bfloat162 s = bf162bf162(a);
217
+ bf16_8_t c;
218
+ c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
219
+ c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
220
+ c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
221
+ c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
222
+ return c;
223
+ }
224
+
225
+ template<>
226
+ inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
227
+ float fa = __bfloat162float(a);
228
+ float fb = __bfloat162float(b);
229
+ return fa * fb;
230
+ }
231
+
232
+ template<>
233
+ inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
234
+ float2 fa = bf1622float2(a);
235
+ float2 fb = bf1622float2(b);
236
+ return mul<float2, float2, float2>(fa, fb);
237
+ }
238
+
239
+ template<>
240
+ inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
241
+ return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
242
+ }
243
+
244
+ template<>
245
+ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
246
+ Float4_ fc;
247
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
248
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
249
+ return fc;
250
+ }
251
+
252
+ template<>
253
+ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
254
+ __nv_bfloat162 s = bf162bf162(a);
255
+ Float4_ fc;
256
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
257
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
258
+ return fc;
259
+ }
260
+
261
+ template<>
262
+ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
263
+ Float8_ fc;
264
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
265
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
266
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
267
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
268
+ return fc;
269
+ }
270
+
271
+ template<>
272
+ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
273
+ __nv_bfloat162 s = bf162bf162(a);
274
+ Float8_ fc;
275
+ fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
276
+ fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
277
+ fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
278
+ fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
279
+ return fc;
280
+ }
281
+
282
+ // Vector fused multiply-add.
283
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
284
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
285
+ assert(false);
286
+ #else
287
+ return __hfma2(a, b, c);
288
+ #endif
289
+ }
290
+
291
+ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
292
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
293
+ assert(false);
294
+ #else
295
+ return __hfma2(bf162bf162(a), b, c);
296
+ #endif
297
+ }
298
+
299
+ inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
300
+ bf16_4_t d;
301
+ d.x = fma(a.x, b.x, c.x);
302
+ d.y = fma(a.y, b.y, c.y);
303
+ return d;
304
+ }
305
+
306
+ inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
307
+ __nv_bfloat162 s = bf162bf162(a);
308
+ bf16_4_t d;
309
+ d.x = fma(s, b.x, c.x);
310
+ d.y = fma(s, b.y, c.y);
311
+ return d;
312
+ }
313
+
314
+ inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
315
+ bf16_8_t d;
316
+ d.x = fma(a.x, b.x, c.x);
317
+ d.y = fma(a.y, b.y, c.y);
318
+ d.z = fma(a.z, b.z, c.z);
319
+ d.w = fma(a.w, b.w, c.w);
320
+ return d;
321
+ }
322
+
323
+ inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
324
+ __nv_bfloat162 s = bf162bf162(a);
325
+ bf16_8_t d;
326
+ d.x = fma(s, b.x, c.x);
327
+ d.y = fma(s, b.y, c.y);
328
+ d.z = fma(s, b.z, c.z);
329
+ d.w = fma(s, b.w, c.w);
330
+ return d;
331
+ }
332
+
333
+ inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
334
+ return __bfloat162float(a) * __bfloat162float(b) + fc;
335
+ }
336
+
337
+ inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
338
+ float2 fa = bf1622float2(a);
339
+ float2 fb = bf1622float2(b);
340
+ return fma(fa, fb, fc);
341
+ }
342
+
343
+ inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
344
+ return fma(bf162bf162(a), b, fc);
345
+ }
346
+
347
+ inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
348
+ Float4_ fd;
349
+ fd.x = fma(a.x, b.x, fc.x);
350
+ fd.y = fma(a.y, b.y, fc.y);
351
+ return fd;
352
+ }
353
+
354
+ inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
355
+ __nv_bfloat162 s = bf162bf162(a);
356
+ Float4_ fd;
357
+ fd.x = fma(s, b.x, fc.x);
358
+ fd.y = fma(s, b.y, fc.y);
359
+ return fd;
360
+ }
361
+
362
+ inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
363
+ Float8_ fd;
364
+ fd.x = fma(a.x, b.x, fc.x);
365
+ fd.y = fma(a.y, b.y, fc.y);
366
+ fd.z = fma(a.z, b.z, fc.z);
367
+ fd.w = fma(a.w, b.w, fc.w);
368
+ return fd;
369
+ }
370
+
371
+ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
372
+ __nv_bfloat162 s = bf162bf162(a);
373
+ Float8_ fd;
374
+ fd.x = fma(s, b.x, fc.x);
375
+ fd.y = fma(s, b.y, fc.y);
376
+ fd.z = fma(s, b.z, fc.z);
377
+ fd.w = fma(s, b.w, fc.w);
378
+ return fd;
379
+ }
380
+
381
+ // Vector sum.
382
+ template<>
383
+ inline __device__ float sum(__nv_bfloat16 v) {
384
+ return __bfloat162float(v);
385
+ }
386
+
387
+ template<>
388
+ inline __device__ float sum(__nv_bfloat162 v) {
389
+ float2 vf = bf1622float2(v);
390
+ return vf.x + vf.y;
391
+ }
392
+
393
+ template<>
394
+ inline __device__ float sum(bf16_4_t v) {
395
+ return sum(v.x) + sum(v.y);
396
+ }
397
+
398
+ template<>
399
+ inline __device__ float sum(bf16_8_t v) {
400
+ return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
401
+ }
402
+
403
+ // From float32 to bfloat16.
404
+ inline __device__ void from_float(__nv_bfloat16& dst, float src) {
405
+ dst = __float2bfloat16(src);
406
+ }
407
+
408
+ inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
409
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
410
+ assert(false);
411
+ #else
412
+ dst = __float22bfloat162_rn(src);
413
+ #endif
414
+ }
415
+
416
+ inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
417
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
418
+ assert(false);
419
+ #else
420
+ dst.x = __float22bfloat162_rn(src.x);
421
+ dst.y = __float22bfloat162_rn(src.y);
422
+ #endif
423
+ }
424
+
425
+ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
426
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
427
+ assert(false);
428
+ #else
429
+ dst.x = __float22bfloat162_rn(src.x);
430
+ dst.y = __float22bfloat162_rn(src.y);
431
+ dst.z = __float22bfloat162_rn(src.z);
432
+ dst.w = __float22bfloat162_rn(src.w);
433
+ #endif
434
+ }
435
+
436
+ // From bfloat16 to float32.
437
+ inline __device__ float to_float(__nv_bfloat16 u) {
438
+ return __bfloat162float(u);
439
+ }
440
+
441
+ // Zero-out a variable.
442
+ inline __device__ void zero(__nv_bfloat16& dst) {
443
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
444
+ assert(false);
445
+ #else
446
+ // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
447
+ dst = __ushort_as_bfloat16((unsigned short)0x0000U);
448
+ #endif
449
+ }
450
+
451
+ } // namespace vllm
csrc/attention/dtype_float16.cuh ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #pragma once
20
+
21
+ #include "attention_generic.cuh"
22
+ #include "dtype_float32.cuh"
23
+
24
+ #ifdef USE_ROCM
25
+ #include <hip/hip_fp16.h>
26
+ #endif
27
+
28
+ #include <stdint.h>
29
+
30
+ namespace vllm {
31
+
32
+ // FP16 vector types for Q, K, V.
33
+ template<>
34
+ struct Vec<uint16_t, 1> {
35
+ using Type = uint16_t;
36
+ };
37
+ template<>
38
+ struct Vec<uint16_t, 2> {
39
+ using Type = uint32_t;
40
+ };
41
+ template<>
42
+ struct Vec<uint16_t, 4> {
43
+ using Type = uint2;
44
+ };
45
+ template<>
46
+ struct Vec<uint16_t, 8> {
47
+ using Type = uint4;
48
+ };
49
+
50
+ // FP32 accumulator vector types corresponding to Vec.
51
+ template<>
52
+ struct FloatVec<uint16_t> {
53
+ using Type = float;
54
+ };
55
+ template<>
56
+ struct FloatVec<uint32_t> {
57
+ using Type = float2;
58
+ };
59
+ template<>
60
+ struct FloatVec<uint2> {
61
+ using Type = Float4_;
62
+ };
63
+ template<>
64
+ struct FloatVec<uint4> {
65
+ using Type = Float8_;
66
+ };
67
+
68
+ // Utility functions for type conversions.
69
+ inline __device__ uint32_t h0_h0(uint16_t a) {
70
+ #ifndef USE_ROCM
71
+ uint32_t b;
72
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
73
+ return b;
74
+ #else
75
+ union {
76
+ uint32_t u32;
77
+ uint16_t u16[2];
78
+ } tmp;
79
+ tmp.u16[0] = a;
80
+ tmp.u16[1] = a;
81
+ return tmp.u32;
82
+ #endif
83
+ }
84
+
85
+ inline __device__ float half_to_float(uint16_t h) {
86
+ float f;
87
+ #ifndef USE_ROCM
88
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
89
+ #else
90
+ asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
91
+ #endif
92
+ return f;
93
+ }
94
+
95
+ inline __device__ float2 half2_to_float2(uint32_t v) {
96
+ #ifndef USE_ROCM
97
+ uint16_t lo, hi;
98
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
99
+ return make_float2(half_to_float(lo), half_to_float(hi));
100
+ #else
101
+ union {
102
+ uint32_t u32;
103
+ uint16_t u16[2];
104
+ } tmp;
105
+ tmp.u32 = v;
106
+ float2 ret;
107
+ ret.x = half_to_float(tmp.u16[0]);
108
+ ret.y = half_to_float(tmp.u16[1]);
109
+ return ret;
110
+ #endif
111
+ }
112
+
113
+ inline __device__ uint16_t float_to_half(float f) {
114
+ union {
115
+ uint32_t u32;
116
+ uint16_t u16[2];
117
+ } tmp;
118
+ #ifndef USE_ROCM
119
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
120
+ #else
121
+ asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
122
+ #endif
123
+ return tmp.u16[0];
124
+ }
125
+
126
+ inline __device__ uint32_t float2_to_half2(float2 f) {
127
+ union {
128
+ uint32_t u32;
129
+ uint16_t u16[2];
130
+ } tmp;
131
+ #ifndef USE_ROCM
132
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
133
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
134
+ #else
135
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
136
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
137
+ #endif
138
+ #else
139
+ tmp.u16[0] = float_to_half(f.x);
140
+ tmp.u16[1] = float_to_half(f.y);
141
+ #endif
142
+ return tmp.u32;
143
+ }
144
+
145
+ // Vector addition.
146
+ inline __device__ uint16_t add(uint16_t a, uint16_t b) {
147
+ uint16_t c;
148
+ #ifndef USE_ROCM
149
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
150
+ #else
151
+ asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
152
+ #endif
153
+ return c;
154
+ }
155
+
156
+ inline __device__ uint32_t add(uint32_t a, uint32_t b) {
157
+ uint32_t c;
158
+ #ifndef USE_ROCM
159
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
160
+ #else
161
+ asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
162
+ #endif
163
+ return c;
164
+ }
165
+
166
+ inline __device__ uint2 add(uint2 a, uint2 b) {
167
+ uint2 c;
168
+ c.x = add(a.x, b.x);
169
+ c.y = add(a.y, b.y);
170
+ return c;
171
+ }
172
+
173
+ inline __device__ uint4 add(uint4 a, uint4 b) {
174
+ uint4 c;
175
+ c.x = add(a.x, b.x);
176
+ c.y = add(a.y, b.y);
177
+ c.z = add(a.z, b.z);
178
+ c.w = add(a.w, b.w);
179
+ return c;
180
+ }
181
+
182
+ inline __device__ float2 add(uint32_t a, float2 fb) {
183
+ float2 fa = half2_to_float2(a);
184
+ return add(fa, fb);
185
+ }
186
+
187
+ inline __device__ Float4_ add(uint2 a, Float4_ fb) {
188
+ Float4_ fc;
189
+ fc.x = add(a.x, fb.x);
190
+ fc.y = add(a.y, fb.y);
191
+ return fc;
192
+ }
193
+
194
+ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
195
+ Float8_ fc;
196
+ fc.x = add(a.x, fb.x);
197
+ fc.y = add(a.y, fb.y);
198
+ fc.z = add(a.z, fb.z);
199
+ fc.w = add(a.w, fb.w);
200
+ return fc;
201
+ }
202
+
203
+ // Vector multiplication.
204
+ template<>
205
+ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
206
+ uint16_t c;
207
+ #ifndef USE_ROCM
208
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
209
+ #else
210
+ asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
211
+ #endif
212
+ return c;
213
+ }
214
+
215
+ template<>
216
+ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
217
+ uint32_t c;
218
+ #ifndef USE_ROCM
219
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
220
+ #else
221
+ asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
222
+ #endif
223
+ return c;
224
+ }
225
+
226
+ template<>
227
+ inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
228
+ return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
229
+ }
230
+
231
+ template<>
232
+ inline __device__ uint2 mul(uint2 a, uint2 b) {
233
+ uint2 c;
234
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
235
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
236
+ return c;
237
+ }
238
+
239
+ template<>
240
+ inline __device__ uint2 mul(uint16_t a, uint2 b) {
241
+ uint32_t s = h0_h0(a);
242
+ uint2 c;
243
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
244
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
245
+ return c;
246
+ }
247
+
248
+ template<>
249
+ inline __device__ uint4 mul(uint4 a, uint4 b) {
250
+ uint4 c;
251
+ c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
252
+ c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
253
+ c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
254
+ c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
255
+ return c;
256
+ }
257
+
258
+ template<>
259
+ inline __device__ uint4 mul(uint16_t a, uint4 b) {
260
+ uint32_t s = h0_h0(a);
261
+ uint4 c;
262
+ c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
263
+ c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
264
+ c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
265
+ c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
266
+ return c;
267
+ }
268
+
269
+ template<>
270
+ inline __device__ float mul(uint16_t a, uint16_t b) {
271
+ float fa = half_to_float(a);
272
+ float fb = half_to_float(b);
273
+ return fa * fb;
274
+ }
275
+
276
+ template<>
277
+ inline __device__ float2 mul(uint32_t a, uint32_t b) {
278
+ float2 fa = half2_to_float2(a);
279
+ float2 fb = half2_to_float2(b);
280
+ return mul<float2, float2, float2>(fa, fb);
281
+ }
282
+
283
+ template<>
284
+ inline __device__ float2 mul(uint16_t a, uint32_t b) {
285
+ return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
286
+ }
287
+
288
+ template<>
289
+ inline __device__ Float4_ mul(uint2 a, uint2 b) {
290
+ Float4_ fc;
291
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
292
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
293
+ return fc;
294
+ }
295
+
296
+ template<>
297
+ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
298
+ uint32_t s = h0_h0(a);
299
+ Float4_ fc;
300
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
301
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
302
+ return fc;
303
+ }
304
+
305
+ template<>
306
+ inline __device__ Float8_ mul(uint4 a, uint4 b) {
307
+ Float8_ fc;
308
+ fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
309
+ fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
310
+ fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
311
+ fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
312
+ return fc;
313
+ }
314
+
315
+ template<>
316
+ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
317
+ uint32_t s = h0_h0(a);
318
+ Float8_ fc;
319
+ fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
320
+ fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
321
+ fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
322
+ fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
323
+ return fc;
324
+ }
325
+
326
+ // Vector fused multiply-add.
327
+ inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
328
+ uint32_t d;
329
+ #ifndef USE_ROCM
330
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
331
+ #else
332
+ asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
333
+ #endif
334
+ return d;
335
+ }
336
+
337
+ inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
338
+ return fma(h0_h0(a), b, c);
339
+ }
340
+
341
+ inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
342
+ uint2 d;
343
+ d.x = fma(a.x, b.x, c.x);
344
+ d.y = fma(a.y, b.y, c.y);
345
+ return d;
346
+ }
347
+
348
+ inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
349
+ uint32_t s = h0_h0(a);
350
+ uint2 d;
351
+ d.x = fma(s, b.x, c.x);
352
+ d.y = fma(s, b.y, c.y);
353
+ return d;
354
+ }
355
+
356
+ inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
357
+ uint4 d;
358
+ d.x = fma(a.x, b.x, c.x);
359
+ d.y = fma(a.y, b.y, c.y);
360
+ d.z = fma(a.z, b.z, c.z);
361
+ d.w = fma(a.w, b.w, c.w);
362
+ return d;
363
+ }
364
+
365
+ inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
366
+ uint32_t s = h0_h0(a);
367
+ uint4 d;
368
+ d.x = fma(s, b.x, c.x);
369
+ d.y = fma(s, b.y, c.y);
370
+ d.z = fma(s, b.z, c.z);
371
+ d.w = fma(s, b.w, c.w);
372
+ return d;
373
+ }
374
+
375
+ inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
376
+ float fa = half_to_float(a);
377
+ float fb = half_to_float(b);
378
+ return fa * fb + fc;
379
+ }
380
+
381
+ inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
382
+ float2 fa = half2_to_float2(a);
383
+ float2 fb = half2_to_float2(b);
384
+ return fma(fa, fb, fc);
385
+ }
386
+
387
+ inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
388
+ return fma(h0_h0(a), b, fc);
389
+ }
390
+
391
+ inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
392
+ Float4_ fd;
393
+ fd.x = fma(a.x, b.x, fc.x);
394
+ fd.y = fma(a.y, b.y, fc.y);
395
+ return fd;
396
+ }
397
+
398
+ inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
399
+ uint32_t s = h0_h0(a);
400
+ Float4_ fd;
401
+ fd.x = fma(s, b.x, fc.x);
402
+ fd.y = fma(s, b.y, fc.y);
403
+ return fd;
404
+ }
405
+
406
+ inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
407
+ Float8_ fd;
408
+ fd.x = fma(a.x, b.x, fc.x);
409
+ fd.y = fma(a.y, b.y, fc.y);
410
+ fd.z = fma(a.z, b.z, fc.z);
411
+ fd.w = fma(a.w, b.w, fc.w);
412
+ return fd;
413
+ }
414
+
415
+ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
416
+ uint32_t s = h0_h0(a);
417
+ Float8_ fd;
418
+ fd.x = fma(s, b.x, fc.x);
419
+ fd.y = fma(s, b.y, fc.y);
420
+ fd.z = fma(s, b.z, fc.z);
421
+ fd.w = fma(s, b.w, fc.w);
422
+ return fd;
423
+ }
424
+
425
+ // Vector sum.
426
+ template<>
427
+ inline __device__ float sum(uint16_t v) {
428
+ return half_to_float(v);
429
+ }
430
+
431
+ template<>
432
+ inline __device__ float sum(uint32_t v) {
433
+ float2 tmp = half2_to_float2(v);
434
+ return tmp.x + tmp.y;
435
+ }
436
+
437
+ template<>
438
+ inline __device__ float sum(uint2 v) {
439
+ uint32_t c = add(v.x, v.y);
440
+ return sum(c);
441
+ }
442
+
443
+ template<>
444
+ inline __device__ float sum(uint4 v) {
445
+ uint32_t c = add(v.x, v.y);
446
+ c = add(c, v.z);
447
+ c = add(c, v.w);
448
+ return sum(c);
449
+ }
450
+
451
+ // From float32 to float16.
452
+ inline __device__ void from_float(uint16_t& dst, float src) {
453
+ dst = float_to_half(src);
454
+ }
455
+
456
+ inline __device__ void from_float(uint32_t& dst, float2 src) {
457
+ dst = float2_to_half2(src);
458
+ }
459
+
460
+ inline __device__ void from_float(uint2& dst, Float4_ src) {
461
+ dst.x = float2_to_half2(src.x);
462
+ dst.y = float2_to_half2(src.y);
463
+ }
464
+
465
+ inline __device__ void from_float(uint4& dst, Float8_ src) {
466
+ dst.x = float2_to_half2(src.x);
467
+ dst.y = float2_to_half2(src.y);
468
+ dst.z = float2_to_half2(src.z);
469
+ dst.w = float2_to_half2(src.w);
470
+ }
471
+
472
+ // From float16 to float32.
473
+ inline __device__ float to_float(uint16_t u) {
474
+ return half_to_float(u);
475
+ }
476
+
477
+ inline __device__ float2 to_float(uint32_t u) {
478
+ return half2_to_float2(u);
479
+ }
480
+
481
+ inline __device__ Float4_ to_float(uint2 u) {
482
+ Float4_ tmp;
483
+ tmp.x = half2_to_float2(u.x);
484
+ tmp.y = half2_to_float2(u.y);
485
+ return tmp;
486
+ }
487
+
488
+ inline __device__ Float8_ to_float(uint4 u) {
489
+ Float8_ tmp;
490
+ tmp.x = half2_to_float2(u.x);
491
+ tmp.y = half2_to_float2(u.y);
492
+ tmp.z = half2_to_float2(u.z);
493
+ tmp.w = half2_to_float2(u.w);
494
+ return tmp;
495
+ }
496
+
497
+ // Zero-out a variable.
498
+ inline __device__ void zero(uint16_t& dst) {
499
+ dst = uint16_t(0);
500
+ }
501
+
502
+ } // namespace vllm
csrc/attention/dtype_float32.cuh ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
4
+ * Copyright (c) 2023, The vLLM team.
5
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
6
+ *
7
+ * Licensed under the Apache License, Version 2.0 (the "License");
8
+ * you may not use this file except in compliance with the License.
9
+ * You may obtain a copy of the License at
10
+ *
11
+ * http://www.apache.org/licenses/LICENSE-2.0
12
+ *
13
+ * Unless required by applicable law or agreed to in writing, software
14
+ * distributed under the License is distributed on an "AS IS" BASIS,
15
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ * See the License for the specific language governing permissions and
17
+ * limitations under the License.
18
+ */
19
+ #pragma once
20
+
21
+ #include "attention_generic.cuh"
22
+
23
+ #include <stdint.h>
24
+
25
+ namespace vllm {
26
+
27
+ // Define custom FP32 vector data types.
28
+ struct Float4_ {
29
+ float2 x;
30
+ float2 y;
31
+ };
32
+
33
+ struct Float8_ {
34
+ float2 x;
35
+ float2 y;
36
+ float2 z;
37
+ float2 w;
38
+ };
39
+
40
+ // FP32 vector types for Q, K, V.
41
+ template<>
42
+ struct Vec<float, 1> {
43
+ using Type = float;
44
+ };
45
+ template<>
46
+ struct Vec<float, 2> {
47
+ using Type = float2;
48
+ };
49
+ template<>
50
+ struct Vec<float, 4> {
51
+ using Type = float4;
52
+ };
53
+
54
+ // FP32 accumulator vector types corresponding to Vec.
55
+ template<>
56
+ struct FloatVec<float> {
57
+ using Type = float;
58
+ };
59
+ template<>
60
+ struct FloatVec<float2> {
61
+ using Type = float2;
62
+ };
63
+ template<>
64
+ struct FloatVec<float4> {
65
+ using Type = float4;
66
+ };
67
+
68
+ // Vector addition.
69
+ inline __device__ float add(float a, float b) {
70
+ return a + b;
71
+ }
72
+
73
+ inline __device__ float2 add(float2 a, float2 b) {
74
+ float2 c;
75
+ c.x = add(a.x, b.x);
76
+ c.y = add(a.y, b.y);
77
+ return c;
78
+ }
79
+
80
+ inline __device__ float4 add(float4 a, float4 b) {
81
+ float4 c;
82
+ c.x = add(a.x, b.x);
83
+ c.y = add(a.y, b.y);
84
+ c.z = add(a.z, b.z);
85
+ c.w = add(a.w, b.w);
86
+ return c;
87
+ }
88
+
89
+ // Vector multiplication.
90
+ template<>
91
+ inline __device__ float mul<float, float>(float a, float b) {
92
+ return a * b;
93
+ }
94
+
95
+ template<>
96
+ inline __device__ float2 mul(float2 a, float2 b) {
97
+ float2 c;
98
+ c.x = a.x * b.x;
99
+ c.y = a.y * b.y;
100
+ return c;
101
+ }
102
+
103
+ template<>
104
+ inline __device__ float2 mul(float a, float2 b) {
105
+ float2 c;
106
+ c.x = a * b.x;
107
+ c.y = a * b.y;
108
+ return c;
109
+ }
110
+
111
+ template<>
112
+ inline __device__ float4 mul(float4 a, float4 b) {
113
+ float4 c;
114
+ c.x = a.x * b.x;
115
+ c.y = a.y * b.y;
116
+ c.z = a.z * b.z;
117
+ c.w = a.w * b.w;
118
+ return c;
119
+ }
120
+
121
+ template<>
122
+ inline __device__ float4 mul(float a, float4 b) {
123
+ float4 c;
124
+ c.x = a * b.x;
125
+ c.y = a * b.y;
126
+ c.z = a * b.z;
127
+ c.w = a * b.w;
128
+ return c;
129
+ }
130
+
131
+ // Vector fused multiply-add.
132
+ inline __device__ float fma(float a, float b, float c) {
133
+ return a * b + c;
134
+ }
135
+
136
+ inline __device__ float2 fma(float2 a, float2 b, float2 c) {
137
+ float2 d;
138
+ d.x = fma(a.x, b.x, c.x);
139
+ d.y = fma(a.y, b.y, c.y);
140
+ return d;
141
+ }
142
+
143
+ inline __device__ float2 fma(float a, float2 b, float2 c) {
144
+ float2 d;
145
+ d.x = fma(a, b.x, c.x);
146
+ d.y = fma(a, b.y, c.y);
147
+ return d;
148
+ }
149
+
150
+ inline __device__ float4 fma(float4 a, float4 b, float4 c) {
151
+ float4 d;
152
+ d.x = fma(a.x, b.x, c.x);
153
+ d.y = fma(a.y, b.y, c.y);
154
+ d.z = fma(a.z, b.z, c.z);
155
+ d.w = fma(a.w, b.w, c.w);
156
+ return d;
157
+ }
158
+
159
+ inline __device__ float4 fma(float a, float4 b, float4 c) {
160
+ float4 d;
161
+ d.x = fma(a, b.x, c.x);
162
+ d.y = fma(a, b.y, c.y);
163
+ d.z = fma(a, b.z, c.z);
164
+ d.w = fma(a, b.w, c.w);
165
+ return d;
166
+ }
167
+
168
+ inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
169
+ Float4_ d;
170
+ d.x = fma(a, b.x, c.x);
171
+ d.y = fma(a, b.y, c.y);
172
+ return d;
173
+ }
174
+
175
+ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
176
+ Float8_ d;
177
+ d.x = fma(a, b.x, c.x);
178
+ d.y = fma(a, b.y, c.y);
179
+ d.z = fma(a, b.z, c.z);
180
+ d.w = fma(a, b.w, c.w);
181
+ return d;
182
+ }
183
+
184
+ // Vector sum.
185
+ template<>
186
+ inline __device__ float sum(float v) {
187
+ return v;
188
+ }
189
+
190
+ template<>
191
+ inline __device__ float sum(float2 v) {
192
+ return v.x + v.y;
193
+ }
194
+
195
+ template<>
196
+ inline __device__ float sum(float4 v) {
197
+ return v.x + v.y + v.z + v.w;
198
+ }
199
+
200
+ template<>
201
+ inline __device__ float sum(Float4_ v) {
202
+ return v.x.x + v.x.y + v.y.x + v.y.y;
203
+ }
204
+
205
+ template<>
206
+ inline __device__ float sum(Float8_ v) {
207
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
208
+ }
209
+
210
+ // Vector dot product.
211
+ inline __device__ float dot(float a, float b) {
212
+ return a * b;
213
+ }
214
+
215
+ inline __device__ float dot(float2 a, float2 b) {
216
+ float2 c = mul<float2, float2, float2>(a, b);
217
+ return c.x + c.y;
218
+ }
219
+
220
+ inline __device__ float dot(Float4_ a, Float4_ b) {
221
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
222
+ acc = fma(a.y, b.y, acc);
223
+ return acc.x + acc.y;
224
+ }
225
+
226
+ inline __device__ float dot(Float8_ a, Float8_ b) {
227
+ float2 acc = mul<float2, float2, float2>(a.x, b.x);
228
+ acc = fma(a.y, b.y, acc);
229
+ acc = fma(a.z, b.z, acc);
230
+ acc = fma(a.w, b.w, acc);
231
+ return acc.x + acc.y;
232
+ }
233
+
234
+ // From float to float.
235
+ inline __device__ void from_float(float& dst, float src) {
236
+ dst = src;
237
+ }
238
+
239
+ inline __device__ void from_float(float2& dst, float2 src) {
240
+ dst = src;
241
+ }
242
+
243
+ inline __device__ void from_float(float4& dst, float4 src) {
244
+ dst = src;
245
+ }
246
+
247
+ // From float to float.
248
+ inline __device__ float to_float(float u) {
249
+ return u;
250
+ }
251
+
252
+ inline __device__ float2 to_float(float2 u) {
253
+ return u;
254
+ }
255
+
256
+ inline __device__ float4 to_float(float4 u) {
257
+ return u;
258
+ }
259
+
260
+ inline __device__ Float4_ to_float(Float4_ u) {
261
+ return u;
262
+ }
263
+
264
+ inline __device__ Float8_ to_float(Float8_ u) {
265
+ return u;
266
+ }
267
+
268
+ // Zero-out a variable.
269
+ inline __device__ void zero(float& dst) {
270
+ dst = 0.f;
271
+ }
272
+
273
+ } // namespace vllm
csrc/attention/dtype_fp8_e5m2.cuh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "attention_generic.cuh"
4
+
5
+ #include <stdint.h>
6
+ #ifdef ENABLE_FP8_E5M2
7
+ #include <cuda_fp8.h>
8
+ #endif
9
+
10
+ namespace vllm {
11
+ #ifdef ENABLE_FP8_E5M2
12
+ // fp8 vector types for quantization of kv cache
13
+
14
+ template<>
15
+ struct Vec<uint8_t, 1> {
16
+ using Type = uint8_t;
17
+ };
18
+
19
+ template<>
20
+ struct Vec<uint8_t, 2> {
21
+ using Type = uint16_t;
22
+ };
23
+
24
+ template<>
25
+ struct Vec<uint8_t, 4> {
26
+ using Type = uint32_t;
27
+ };
28
+
29
+ template<>
30
+ struct Vec<uint8_t, 8> {
31
+ using Type = uint2;
32
+ };
33
+ #endif // ENABLE_FP8_E5M2
34
+
35
+ } // namespace vllm
csrc/cache.h ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <map>
6
+ #include <vector>
7
+
8
+ void swap_blocks(
9
+ torch::Tensor& src,
10
+ torch::Tensor& dst,
11
+ const std::map<int64_t, int64_t>& block_mapping);
12
+
13
+ void copy_blocks(
14
+ std::vector<torch::Tensor>& key_caches,
15
+ std::vector<torch::Tensor>& value_caches,
16
+ const std::map<int64_t, std::vector<int64_t>>& block_mapping);
17
+
18
+ void reshape_and_cache(
19
+ torch::Tensor& key,
20
+ torch::Tensor& value,
21
+ torch::Tensor& key_cache,
22
+ torch::Tensor& value_cache,
23
+ torch::Tensor& slot_mapping,
24
+ const std::string& kv_cache_dtype);
25
+
26
+ void gather_cached_kv(
27
+ torch::Tensor& key,
28
+ torch::Tensor& value,
29
+ torch::Tensor& key_cache,
30
+ torch::Tensor& value_cache,
31
+ torch::Tensor& slot_mapping);
32
+
33
+ // Just for unittest
34
+ void convert_fp8_e5m2(
35
+ torch::Tensor& src_cache,
36
+ torch::Tensor& dst_cache);
csrc/cache_kernels.cu ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include "cuda_compat.h"
6
+ #include "dispatch_utils.h"
7
+ #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
8
+
9
+ #include <algorithm>
10
+ #include <cassert>
11
+ #include <map>
12
+ #include <vector>
13
+
14
+ void swap_blocks(
15
+ torch::Tensor& src,
16
+ torch::Tensor& dst,
17
+ const std::map<int64_t, int64_t>& block_mapping) {
18
+ torch::Device src_device = src.device();
19
+ torch::Device dst_device = dst.device();
20
+ cudaMemcpyKind memcpy_type;
21
+ if (src_device.is_cuda() && dst_device.is_cuda()) {
22
+ TORCH_CHECK(
23
+ src_device.index() == dst_device.index(),
24
+ "src and dst must be on the same GPU");
25
+ memcpy_type = cudaMemcpyDeviceToDevice;
26
+ } else if (src_device.is_cuda() && dst_device.is_cpu()) {
27
+ memcpy_type = cudaMemcpyDeviceToHost;
28
+ } else if (src_device.is_cpu() && dst_device.is_cuda()) {
29
+ memcpy_type = cudaMemcpyHostToDevice;
30
+ } else {
31
+ TORCH_CHECK(false, "Invalid device combination");
32
+ }
33
+
34
+ char *src_ptr = static_cast<char*>(src.data_ptr());
35
+ char *dst_ptr = static_cast<char*>(dst.data_ptr());
36
+
37
+ const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
38
+ const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
39
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
40
+ // NOTE(woosuk): This can be slow if the number of blocks is large.
41
+ for (const auto& pair : block_mapping) {
42
+ int64_t src_block_number = pair.first;
43
+ int64_t dst_block_number = pair.second;
44
+ int64_t src_offset = src_block_number * block_size_in_bytes;
45
+ int64_t dst_offset = dst_block_number * block_size_in_bytes;
46
+ cudaMemcpyAsync(
47
+ dst_ptr + dst_offset,
48
+ src_ptr + src_offset,
49
+ block_size_in_bytes,
50
+ memcpy_type,
51
+ stream);
52
+ }
53
+ }
54
+
55
+ namespace vllm {
56
+
57
+ // Grid: (num_layers, num_pairs)
58
+ template<typename scalar_t>
59
+ __global__ void copy_blocks_kernel(
60
+ int64_t* key_cache_ptrs,
61
+ int64_t* value_cache_ptrs,
62
+ const int64_t* __restrict__ block_mapping,
63
+ const int numel_per_block) {
64
+ const int layer_idx = blockIdx.x;
65
+ const int pair_idx = blockIdx.y;
66
+
67
+ scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
68
+ scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
69
+ int64_t src_block_number = block_mapping[2 * pair_idx];
70
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
71
+
72
+ const int64_t src_block_offset = src_block_number * numel_per_block;
73
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
74
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
75
+ int64_t src_offset = src_block_offset + i;
76
+ int64_t dst_offset = dst_block_offset + i;
77
+ key_cache[dst_offset] = key_cache[src_offset];
78
+ }
79
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
80
+ int64_t src_offset = src_block_offset + i;
81
+ int64_t dst_offset = dst_block_offset + i;
82
+ value_cache[dst_offset] = value_cache[src_offset];
83
+ }
84
+ }
85
+
86
+ } // namespace vllm
87
+
88
+ void copy_blocks(
89
+ std::vector<torch::Tensor>& key_caches,
90
+ std::vector<torch::Tensor>& value_caches,
91
+ const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
92
+ int num_layers = key_caches.size();
93
+ TORCH_CHECK(num_layers == value_caches.size());
94
+ if (num_layers == 0) {
95
+ return;
96
+ }
97
+ torch::Device cache_device = key_caches[0].device();
98
+ TORCH_CHECK(cache_device.is_cuda());
99
+
100
+ // Create data structures for the kernel.
101
+ // Create an array of pointers to the key and value caches.
102
+ int64_t key_cache_ptrs[num_layers];
103
+ int64_t value_cache_ptrs[num_layers];
104
+ for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
105
+ key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
106
+ value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
107
+ }
108
+ // Create block mapping array.
109
+ std::vector<int64_t> block_mapping_vec;
110
+ for (const auto& pair : block_mapping) {
111
+ int64_t src_block_number = pair.first;
112
+ for (int64_t dst_block_number : pair.second) {
113
+ block_mapping_vec.push_back(src_block_number);
114
+ block_mapping_vec.push_back(dst_block_number);
115
+ }
116
+ }
117
+ int64_t* block_mapping_array = block_mapping_vec.data();
118
+ int num_pairs = block_mapping_vec.size() / 2;
119
+
120
+ // Move the data structures to the GPU.
121
+ // NOTE: This synchronizes the CPU and GPU.
122
+ torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
123
+ key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
124
+ torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
125
+ value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
126
+ torch::Tensor block_mapping_tensor = torch::from_blob(
127
+ block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
128
+
129
+ // Launch the kernel.
130
+ const int numel_per_block = key_caches[0][0].numel();
131
+ dim3 grid(num_layers, num_pairs);
132
+ dim3 block(std::min(1024, numel_per_block));
133
+ const at::cuda::OptionalCUDAGuard device_guard(cache_device);
134
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
135
+ VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
136
+ key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
137
+ vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
138
+ key_cache_ptrs_tensor.data_ptr<int64_t>(),
139
+ value_cache_ptrs_tensor.data_ptr<int64_t>(),
140
+ block_mapping_tensor.data_ptr<int64_t>(),
141
+ numel_per_block);
142
+ }));
143
+ }
144
+
145
+ namespace vllm {
146
+
147
+ template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
148
+ __global__ void reshape_and_cache_kernel(
149
+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
150
+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
151
+ cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
152
+ cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
153
+ const int64_t* __restrict__ slot_mapping, // [num_tokens]
154
+ const int key_stride,
155
+ const int value_stride,
156
+ const int num_heads,
157
+ const int head_size,
158
+ const int block_size,
159
+ const int x) {
160
+ const int64_t token_idx = blockIdx.x;
161
+ const int64_t slot_idx = slot_mapping[token_idx];
162
+ if (slot_idx < 0) {
163
+ // Padding token that should be ignored.
164
+ return;
165
+ }
166
+
167
+ const int64_t block_idx = slot_idx / block_size;
168
+ const int64_t block_offset = slot_idx % block_size;
169
+
170
+ const int n = num_heads * head_size;
171
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
172
+ const int64_t src_key_idx = token_idx * key_stride + i;
173
+ const int64_t src_value_idx = token_idx * value_stride + i;
174
+
175
+ const int head_idx = i / head_size;
176
+ const int head_offset = i % head_size;
177
+ const int x_idx = head_offset / x;
178
+ const int x_offset = head_offset % x;
179
+
180
+ const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
181
+ + head_idx * (head_size / x) * block_size * x
182
+ + x_idx * block_size * x
183
+ + block_offset * x
184
+ + x_offset;
185
+ const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
186
+ + head_idx * head_size * block_size
187
+ + head_offset * block_size
188
+ + block_offset;
189
+ scalar_t tgt_key = key[src_key_idx];
190
+ scalar_t tgt_value = value[src_value_idx];
191
+ if constexpr (is_fp8_e5m2_kv_cache) {
192
+ #ifdef ENABLE_FP8_E5M2
193
+ key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
194
+ value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
195
+ #else
196
+ assert(false);
197
+ #endif
198
+ } else {
199
+ key_cache[tgt_key_idx] = tgt_key;
200
+ value_cache[tgt_value_idx] = tgt_value;
201
+ }
202
+ }
203
+ }
204
+
205
+ } // namespace vllm
206
+
207
+ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
208
+ vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
209
+ reinterpret_cast<KV_T*>(key.data_ptr()), \
210
+ reinterpret_cast<KV_T*>(value.data_ptr()), \
211
+ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
212
+ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
213
+ slot_mapping.data_ptr<int64_t>(), \
214
+ key_stride, \
215
+ value_stride, \
216
+ num_heads, \
217
+ head_size, \
218
+ block_size, \
219
+ x);
220
+
221
+ void reshape_and_cache(
222
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
223
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
224
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
225
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
226
+ torch::Tensor& slot_mapping, // [num_tokens]
227
+ const std::string& kv_cache_dtype)
228
+ {
229
+ int num_tokens = key.size(0);
230
+ int num_heads = key.size(1);
231
+ int head_size = key.size(2);
232
+ int block_size = key_cache.size(3);
233
+ int x = key_cache.size(4);
234
+
235
+ int key_stride = key.stride(0);
236
+ int value_stride = value.stride(0);
237
+
238
+ dim3 grid(num_tokens);
239
+ dim3 block(std::min(num_heads * head_size, 512));
240
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
241
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
242
+ if (kv_cache_dtype == "auto") {
243
+ if (key.dtype() == at::ScalarType::Float) {
244
+ CALL_RESHAPE_AND_CACHE(float, float, false);
245
+ } else if (key.dtype() == at::ScalarType::Half) {
246
+ CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
247
+ } else if (key.dtype() == at::ScalarType::BFloat16) {
248
+ CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
249
+ }
250
+ } else if (kv_cache_dtype == "fp8_e5m2") {
251
+ if (key.dtype() == at::ScalarType::Float) {
252
+ CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
253
+ } else if (key.dtype() == at::ScalarType::Half) {
254
+ CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
255
+ } else if (key.dtype() == at::ScalarType::BFloat16) {
256
+ CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
257
+ }
258
+ } else {
259
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
260
+ }
261
+ }
262
+
263
+ namespace vllm {
264
+
265
+ // Grid: (num_blocks, block_size).
266
+ template<typename scalar_t>
267
+ __global__ void gather_cached_kv_kernel(
268
+ scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
269
+ scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
270
+ const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
271
+ const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
272
+ const int* __restrict__ slot_mapping, // [num_tokens]
273
+ const int key_stride,
274
+ const int value_stride,
275
+ const int num_heads,
276
+ const int head_size,
277
+ const int block_size,
278
+ const int x) {
279
+ const int token_idx = blockIdx.x;
280
+ const int slot_idx = slot_mapping[token_idx];
281
+ const int block_idx = slot_idx / block_size;
282
+ const int block_offset = slot_idx % block_size;
283
+
284
+ const int num_tokens = num_heads * head_size;
285
+ for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
286
+ const int tgt_key_idx = token_idx * key_stride + i;
287
+ const int tgt_value_idx = token_idx * value_stride + i;
288
+
289
+ const int head_idx = i / head_size;
290
+ const int head_offset = i % head_size;
291
+ const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
292
+ const int x_offset = head_offset % x;
293
+
294
+ const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
295
+ + head_idx * (head_size / x) * block_size * x
296
+ + x_idx * block_size * x
297
+ + block_offset * x
298
+ + x_offset;
299
+ const int src_value_idx = block_idx * num_heads * head_size * block_size
300
+ + head_idx * head_size * block_size
301
+ + head_offset * block_size
302
+ + block_offset;
303
+
304
+ key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
305
+ value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
306
+ }
307
+ }
308
+
309
+ template <typename scalar_t>
310
+ __global__ void gather_cached_kv_kernel_optimized(
311
+ scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
312
+ scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
313
+ const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
314
+ const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
315
+ const int *__restrict__ slot_mapping, // [num_tokens]
316
+ const int key_stride,
317
+ const int value_stride,
318
+ const int num_heads,
319
+ const int head_size,
320
+ const int block_size,
321
+ const int x)
322
+ {
323
+ const int token_idx = blockIdx.x;
324
+ const int slot_idx = slot_mapping[token_idx];
325
+ const int block_idx = slot_idx / block_size;
326
+ const int block_offset = slot_idx % block_size;
327
+
328
+ const int dim = num_heads * head_size;
329
+ assert(dim % 4 == 0); // this is true for known use cases
330
+ const int unroll_factor = 4;
331
+ const int unrolled_dim = dim / unroll_factor;
332
+
333
+ for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
334
+ {
335
+ int tgt_key_indices[unroll_factor];
336
+ int tgt_value_indices[unroll_factor];
337
+ int src_key_indices[unroll_factor];
338
+ int src_value_indices[unroll_factor];
339
+ scalar_t keys_to_store[unroll_factor];
340
+ scalar_t values_to_store[unroll_factor];
341
+
342
+ #pragma unroll
343
+ for (int j = 0; j < unroll_factor; ++j)
344
+ {
345
+ int index = i + j * unrolled_dim;
346
+
347
+ const int tgt_key_idx = token_idx * key_stride + index;
348
+ const int tgt_value_idx = token_idx * value_stride + index;
349
+
350
+ const int head_idx = index / head_size;
351
+ const int head_offset = index % head_size;
352
+ const int x_idx = head_offset / x;
353
+ const int x_offset = head_offset % x;
354
+
355
+ const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
356
+ + head_idx * (head_size / x) * block_size * x
357
+ + x_idx * block_size * x
358
+ + block_offset * x
359
+ + x_offset;
360
+ const int src_value_idx = block_idx * num_heads * head_size * block_size
361
+ + head_idx * head_size * block_size
362
+ + head_offset * block_size
363
+ + block_offset;
364
+
365
+ tgt_key_indices[j] = tgt_key_idx;
366
+ tgt_value_indices[j] = tgt_value_idx;
367
+ src_key_indices[j] = src_key_idx;
368
+ src_value_indices[j] = src_value_idx;
369
+
370
+ keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
371
+ values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
372
+ }
373
+
374
+ #pragma unroll
375
+ for (int j = 0; j < unroll_factor; ++j)
376
+ {
377
+ key[tgt_key_indices[j]] = keys_to_store[j];
378
+ value[tgt_value_indices[j]] = values_to_store[j];
379
+ }
380
+ }
381
+ }
382
+
383
+ } // namespace vllm
384
+
385
+ void gather_cached_kv(
386
+ torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
387
+ torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
388
+ torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
389
+ torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
390
+ torch::Tensor& slot_mapping) // [in] [num_tokens]
391
+ {
392
+ int num_tokens = key.size(0);
393
+ int num_heads = key.size(1);
394
+ int head_size = key.size(2);
395
+ int block_size = key_cache.size(3);
396
+ int x = key_cache.size(4);
397
+
398
+ int key_stride = key.stride(0);
399
+ int value_stride = value.stride(0);
400
+
401
+ dim3 grid(num_tokens);
402
+ dim3 block(std::min(num_heads * head_size, 512));
403
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
404
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
405
+ VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
406
+ key.scalar_type(),
407
+ "gather_cached_kv_kernel_optimized",
408
+ [&] {
409
+ vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
410
+ key.data_ptr<scalar_t>(),
411
+ value.data_ptr<scalar_t>(),
412
+ key_cache.data_ptr<scalar_t>(),
413
+ value_cache.data_ptr<scalar_t>(),
414
+ slot_mapping.data_ptr<int>(),
415
+ key_stride,
416
+ value_stride,
417
+ num_heads,
418
+ head_size,
419
+ block_size,
420
+ x);
421
+ });
422
+ }
423
+
424
+ namespace vllm {
425
+
426
+ template<typename Tout, typename Tin>
427
+ __global__ void convert_fp8_e5m2_kernel(
428
+ const Tin* __restrict__ src_cache,
429
+ Tout* __restrict__ dst_cache,
430
+ const int64_t block_stride) {
431
+ const int64_t block_idx = blockIdx.x;
432
+ for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
433
+ int64_t idx = block_idx * block_stride + i;
434
+ #ifdef ENABLE_FP8_E5M2
435
+ dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
436
+ #else
437
+ assert(false);
438
+ #endif
439
+ }
440
+ }
441
+
442
+ } // namespace vllm
443
+
444
+ #define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
445
+ vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
446
+ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
447
+ reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
448
+ block_stride);
449
+
450
+ void convert_fp8_e5m2(
451
+ torch::Tensor& src_cache,
452
+ torch::Tensor& dst_cache)
453
+ {
454
+ int64_t num_blocks = src_cache.size(0);
455
+ int64_t block_stride = src_cache.stride(0);
456
+
457
+ dim3 grid(num_blocks);
458
+ dim3 block(std::min(block_stride, int64_t(512)));
459
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
460
+
461
+ if (src_cache.dtype() == at::ScalarType::Float) {
462
+ CALL_CONVERT_FP8_E5M2(uint8_t, float);
463
+ } else if (src_cache.dtype() == at::ScalarType::Half) {
464
+ CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
465
+ } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
466
+ CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
467
+ } else if (dst_cache.dtype() == at::ScalarType::Float) {
468
+ CALL_CONVERT_FP8_E5M2(float, uint8_t);
469
+ } else if (dst_cache.dtype() == at::ScalarType::Half) {
470
+ CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
471
+ } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
472
+ CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
473
+ }
474
+ }
csrc/cuda_compat.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifndef USE_ROCM
4
+ #define VLLM_LDG(arg) __ldg(arg)
5
+ #else
6
+ #define VLLM_LDG(arg) *(arg)
7
+ #endif
8
+
9
+ #ifndef USE_ROCM
10
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
11
+ #else
12
+ #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
13
+ #endif
14
+
15
+ #ifndef USE_ROCM
16
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
17
+ #else
18
+ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
19
+ #endif
20
+
21
+ #ifndef USE_ROCM
22
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
23
+ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
24
+ #else
25
+ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
26
+ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
27
+ #endif
28
+
csrc/cuda_utils.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/extension.h>
4
+
5
+ int get_device_attribute(
6
+ int attribute,
7
+ int device_id);
8
+
9
+ int get_max_shared_memory_per_block_device_attribute(
10
+ int device_id);
csrc/cuda_utils_kernels.cu ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef USE_ROCM
2
+ #include <hip/hip_runtime.h>
3
+ #include <hip/hip_runtime_api.h>
4
+ #endif
5
+ int get_device_attribute(
6
+ int attribute,
7
+ int device_id)
8
+ {
9
+ int device, value;
10
+ if (device_id < 0) {
11
+ cudaGetDevice(&device);
12
+ }
13
+ else {
14
+ device = device_id;
15
+ }
16
+ cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
17
+ return value;
18
+ }
19
+
20
+
21
+ int get_max_shared_memory_per_block_device_attribute(
22
+ int device_id)
23
+ {
24
+ int attribute;
25
+ // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
26
+ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
27
+
28
+ #ifdef USE_ROCM
29
+ attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
30
+ #else
31
+ attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
32
+ #endif
33
+
34
+ return get_device_attribute(attribute, device_id);
35
+ }
csrc/custom_all_reduce.cu ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/Exceptions.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include <c10/cuda/CUDAStream.h>
4
+ #include <torch/extension.h>
5
+
6
+ #include "custom_all_reduce.cuh"
7
+
8
+ // fake pointer type
9
+ using fptr_t = uint64_t;
10
+ static_assert(sizeof(void *) == sizeof(fptr_t));
11
+
12
+ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
13
+ const std::vector<std::string> &handles,
14
+ const std::vector<int64_t> &offsets, int rank,
15
+ bool full_nvlink) {
16
+ int world_size = offsets.size();
17
+ if (world_size > 8)
18
+ throw std::invalid_argument("world size > 8 is not supported");
19
+ if (world_size % 2 != 0)
20
+ throw std::invalid_argument("Odd num gpus is not supported for now");
21
+ if (world_size != handles.size())
22
+ throw std::invalid_argument(
23
+ "handles length should equal to offsets length");
24
+ if (rank < 0 || rank >= world_size)
25
+ throw std::invalid_argument("invalid rank passed in");
26
+
27
+ cudaIpcMemHandle_t ipc_handles[8];
28
+ for (int i = 0; i < world_size; i++) {
29
+ std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
30
+ }
31
+ return (fptr_t) new vllm::CustomAllreduce(
32
+ reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
33
+ rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
34
+ }
35
+
36
+ /**
37
+ * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
38
+ * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
39
+ * because it allows transpose of contiguous slice (i.e. slicing the first
40
+ * dimension). Currently, we require this because stride information is not
41
+ * passed into the kernels and we treat input tensors as flat.
42
+ *
43
+ * Examples
44
+ * A = torch.zeros(3, 3, 3)
45
+ * 1. A: OK
46
+ * 2. A[1:]: OK
47
+ * 3. A.permute(2, 0, 1): OK
48
+ * 4. A[1:].permute(2, 0, 1): OK
49
+ * 5. A[None].expand(2, -1, -1, -1): Not OK
50
+ * 6. A[:, 1:, 1:]: Not OK
51
+ */
52
+ bool _is_weak_contiguous(torch::Tensor &t) {
53
+ return t.is_contiguous() ||
54
+ (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
55
+ t.numel() * t.element_size());
56
+ }
57
+
58
+ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
59
+ bool full_nvlink) {
60
+ auto inp_size = inp.numel() * inp.element_size();
61
+ // custom allreduce requires input byte size to be multiples of 16
62
+ if (inp_size % 16 != 0) return false;
63
+ if (!_is_weak_contiguous(inp)) return false;
64
+ if (world_size == 2 || full_nvlink) return inp_size <= max_size;
65
+ // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
66
+ // <= 512k
67
+ return world_size <= 4 && inp_size <= 512 * 1024;
68
+ }
69
+
70
+ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
71
+ cudaStream_t stream) {
72
+ auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
73
+ TORCH_CHECK(_is_weak_contiguous(out));
74
+ switch (out.scalar_type()) {
75
+ case at::ScalarType::Float: {
76
+ fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
77
+ reinterpret_cast<float *>(out.data_ptr()),
78
+ out.numel());
79
+ break;
80
+ }
81
+ case at::ScalarType::Half: {
82
+ fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
83
+ reinterpret_cast<half *>(out.data_ptr()),
84
+ out.numel());
85
+ break;
86
+ }
87
+ #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
88
+ case at::ScalarType::BFloat16: {
89
+ fa->allreduce<nv_bfloat16>(
90
+ stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
91
+ reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
92
+ break;
93
+ }
94
+ #endif
95
+ default:
96
+ throw std::runtime_error(
97
+ "custom allreduce only supports float32, float16 and bfloat16");
98
+ }
99
+ }
100
+
101
+ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
102
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
103
+ auto stream = c10::cuda::getCurrentCUDAStream().stream();
104
+ TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
105
+ TORCH_CHECK_EQ(inp.numel(), out.numel());
106
+ _all_reduce(_fa, inp, out, stream);
107
+ }
108
+
109
+ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
110
+ torch::Tensor &out) {
111
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
112
+ auto stream = c10::cuda::getCurrentCUDAStream().stream();
113
+
114
+ auto input_size = inp.numel() * inp.element_size();
115
+ TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
116
+ TORCH_CHECK_EQ(inp.numel(), out.numel());
117
+ TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
118
+ "registered buffer is too small to contain the input");
119
+ AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
120
+ input_size, cudaMemcpyDeviceToDevice, stream));
121
+ _all_reduce(_fa, reg_buffer, out, stream);
122
+ }
123
+
124
+ void dispose(fptr_t _fa) {
125
+ auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
126
+ delete fa;
127
+ }
128
+
129
+ int meta_size() { return sizeof(vllm::Metadata); }
130
+
131
+ void register_buffer(fptr_t _fa, torch::Tensor &t,
132
+ const std::vector<std::string> &handles,
133
+ const std::vector<int64_t> &offsets) {
134
+ auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
135
+ fa->register_buffer(handles, offsets, t.data_ptr());
136
+ }
137
+
138
+ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
139
+ fptr_t _fa) {
140
+ auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
141
+ return fa->get_graph_buffer_ipc_meta();
142
+ }
143
+
144
+ void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
145
+ const std::vector<std::vector<int64_t>> &offsets) {
146
+ auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
147
+ fa->register_graph_buffers(handles, offsets);
148
+ }
csrc/custom_all_reduce.cuh ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_bf16.h>
5
+ #include <cuda_fp16.h>
6
+ #include <cuda_runtime.h>
7
+
8
+ #include <iostream>
9
+ #include <limits>
10
+ #include <map>
11
+ #include <unordered_map>
12
+ #include <vector>
13
+
14
+ #define CUDACHECK(cmd) \
15
+ do { \
16
+ cudaError_t e = cmd; \
17
+ if (e != cudaSuccess) { \
18
+ printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
19
+ cudaGetErrorString(e)); \
20
+ exit(EXIT_FAILURE); \
21
+ } \
22
+ } while (0)
23
+
24
+ namespace vllm {
25
+
26
+ struct Signal {
27
+ alignas(64) union {
28
+ uint64_t flag;
29
+ unsigned char data[8];
30
+ } start;
31
+ alignas(64) union {
32
+ uint64_t flag;
33
+ unsigned char data[8];
34
+ } end;
35
+ };
36
+
37
+ struct Metadata {
38
+ alignas(128) Signal sg;
39
+ alignas(128) int counter;
40
+ };
41
+ static_assert(offsetof(Metadata, counter) == 128);
42
+ static_assert(sizeof(Metadata) == 256);
43
+
44
+ struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
45
+
46
+ struct RankSignals {
47
+ volatile Signal *signals[8];
48
+ };
49
+
50
+ // like std::array, but aligned
51
+ template <typename T, int sz>
52
+ struct __align__(alignof(T) * sz) array_t {
53
+ T data[sz];
54
+ using type = T;
55
+ static constexpr int size = sz;
56
+ };
57
+
58
+ // use packed type to maximize memory efficiency
59
+ // goal: generate ld.128 and st.128 instructions
60
+ template <typename T>
61
+ struct packed_t {
62
+ // the (P)acked type for load/store
63
+ using P = array_t<T, 16 / sizeof(T)>;
64
+ // the (A)ccumulator type for reduction
65
+ using A = array_t<float, 16 / sizeof(T)>;
66
+ };
67
+
68
+ #define DINLINE __device__ __forceinline__
69
+
70
+ // scalar cast functions
71
+ DINLINE float upcast_s(half val) { return __half2float(val); }
72
+
73
+ template <typename T>
74
+ DINLINE T downcast_s(float val);
75
+ template <>
76
+ DINLINE half downcast_s(float val) {
77
+ return __float2half(val);
78
+ }
79
+
80
+ // scalar add functions
81
+ // for some reason when compiling with Pytorch, the + operator for half and
82
+ // bfloat is disabled so we call the intrinsics directly
83
+ DINLINE half &assign_add(half &a, half b) {
84
+ a = __hadd(a, b);
85
+ return a;
86
+ }
87
+ DINLINE float &assign_add(float &a, float b) { return a += b; }
88
+
89
+ #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
90
+ DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
91
+ template <>
92
+ DINLINE nv_bfloat16 downcast_s(float val) {
93
+ return __float2bfloat16(val);
94
+ }
95
+ DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
96
+ a = __hadd(a, b);
97
+ return a;
98
+ }
99
+ #endif
100
+
101
+ template <typename T, int N>
102
+ DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
103
+ #pragma unroll
104
+ for (int i = 0; i < N; i++) {
105
+ assign_add(a.data[i], b.data[i]);
106
+ }
107
+ return a;
108
+ }
109
+
110
+ template <typename T, int N>
111
+ DINLINE array_t<float, N> upcast(array_t<T, N> val) {
112
+ if constexpr (std::is_same<T, float>::value) {
113
+ return val;
114
+ } else {
115
+ array_t<float, N> out;
116
+ #pragma unroll
117
+ for (int i = 0; i < N; i++) {
118
+ out.data[i] = upcast_s(val.data[i]);
119
+ }
120
+ return out;
121
+ }
122
+ }
123
+
124
+ template <typename O>
125
+ DINLINE O downcast(array_t<float, O::size> val) {
126
+ if constexpr (std::is_same<typename O::type, float>::value) {
127
+ return val;
128
+ } else {
129
+ O out;
130
+ #pragma unroll
131
+ for (int i = 0; i < O::size; i++) {
132
+ out.data[i] = downcast_s<typename O::type>(val.data[i]);
133
+ }
134
+ return out;
135
+ }
136
+ }
137
+
138
+ // compute flag at compile time
139
+ __host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
140
+ auto m = std::numeric_limits<uint64_t>::max();
141
+ return m >> ((8 - ngpus) * 8);
142
+ }
143
+
144
+ template <int ngpus>
145
+ DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
146
+ int rank) {
147
+ constexpr auto FLAG = compute_flag(ngpus);
148
+ if (blockIdx.x == 0) {
149
+ if (threadIdx.x < ngpus)
150
+ // simultaneously write to the corresponding byte to all other ranks.
151
+ // Latency = 1 p2p write
152
+ sg.signals[threadIdx.x]->start.data[rank] = 255;
153
+ else if (threadIdx.x == 32)
154
+ // reset
155
+ meta->sg.end.flag = 0;
156
+ }
157
+ if (threadIdx.x == 0) {
158
+ while (meta->sg.start.flag != FLAG)
159
+ ;
160
+ }
161
+ __syncthreads();
162
+ }
163
+
164
+ template <int ngpus, bool final_sync = false>
165
+ DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
166
+ int rank) {
167
+ constexpr auto FLAG = compute_flag(ngpus);
168
+ __syncthreads();
169
+ __shared__ int num;
170
+ if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
171
+ __syncthreads();
172
+
173
+ // Only the last completing block can perform the end synchronization
174
+ // This can ensures when the final busy wait ends, all ranks must have
175
+ // finished reading each other's buffer.
176
+ if (num == gridDim.x - 1) {
177
+ if (threadIdx.x == 32) {
178
+ // reset in a different warp
179
+ meta->counter = 0;
180
+ meta->sg.start.flag = 0;
181
+ } else if (threadIdx.x < ngpus) {
182
+ // simultaneously write to the corresponding byte to all other ranks.
183
+ // Latency = 1 p2p write
184
+ sg.signals[threadIdx.x]->end.data[rank] = 255;
185
+ }
186
+ // if this is the final sync, only one block needs it
187
+ // because kernel exit can serve as sync
188
+ if constexpr (final_sync) {
189
+ if (threadIdx.x == 0) {
190
+ while (meta->sg.end.flag != FLAG)
191
+ ;
192
+ }
193
+ }
194
+ }
195
+ if constexpr (!final_sync) {
196
+ if (threadIdx.x == 0) {
197
+ while (meta->sg.end.flag != FLAG)
198
+ ;
199
+ }
200
+ __syncthreads();
201
+ }
202
+ }
203
+
204
+ template <typename P, int ngpus, typename A>
205
+ DINLINE P packed_reduce(const P *ptrs[], int idx) {
206
+ A tmp = upcast(ptrs[0][idx]);
207
+ #pragma unroll
208
+ for (int i = 1; i < ngpus; i++) {
209
+ packed_assign_add(tmp, upcast(ptrs[i][idx]));
210
+ }
211
+ return downcast<P>(tmp);
212
+ }
213
+
214
+ template <typename T, int ngpus>
215
+ __global__ void __launch_bounds__(512, 1)
216
+ cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
217
+ volatile Metadata *meta, T *__restrict__ result,
218
+ int rank, int size) {
219
+ using P = typename packed_t<T>::P;
220
+ using A = typename packed_t<T>::A;
221
+ // note: we don't reorder the address so the accumulation order is the same
222
+ // for all ranks, ensuring bitwise identical results
223
+ auto dp = *_dp;
224
+ start_sync<ngpus>(sg, meta, rank);
225
+ // do the actual reduction
226
+ for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
227
+ idx += gridDim.x * blockDim.x) {
228
+ ((P *)result)[idx] =
229
+ packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
230
+ }
231
+ end_sync<ngpus, true>(sg, meta, rank);
232
+ }
233
+
234
+ template <typename P>
235
+ DINLINE P *get_tmp_buf(volatile Signal *sg) {
236
+ return (P *)(((Metadata *)sg) + 1);
237
+ }
238
+
239
+ template <typename T, int ngpus>
240
+ __global__ void __launch_bounds__(512, 1)
241
+ cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
242
+ volatile Metadata *meta, T *__restrict__ result,
243
+ int rank, int size) {
244
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
245
+ int stride = gridDim.x * blockDim.x;
246
+ using P = typename packed_t<T>::P;
247
+ using A = typename packed_t<T>::A;
248
+ int part = size / ngpus;
249
+ int start = rank * part;
250
+ int end = rank == ngpus - 1 ? size : start + part;
251
+ const P *ptrs[ngpus];
252
+ P *tmps[ngpus];
253
+ #pragma unroll
254
+ for (int i = 0; i < ngpus; i++) {
255
+ int target = (rank + i) % ngpus;
256
+ ptrs[i] = (const P *)_dp->ptrs[target];
257
+ tmps[i] = get_tmp_buf<P>(sg.signals[target]);
258
+ }
259
+ auto tmp_out = tmps[0];
260
+ start_sync<ngpus>(sg, meta, rank);
261
+ // stage 1: reduce scatter
262
+ for (int idx = start + tid; idx < end; idx += stride) {
263
+ tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
264
+ }
265
+ // Maybe TODO: replace this with per-block release-acquire
266
+ // can save about 1-2us (not a lot though)
267
+ end_sync<ngpus>(sg, meta, rank);
268
+
269
+ // stage 2: allgather
270
+ for (int idx = tid; idx < part; idx += stride) {
271
+ #pragma unroll
272
+ for (int i = 0; i < ngpus; i++) {
273
+ int dst_idx = ((rank + i) % ngpus) * part + idx;
274
+ ((P *)result)[dst_idx] = tmps[i][idx];
275
+ }
276
+ }
277
+ // process the last larger partition
278
+ int remaining = size - part * ngpus;
279
+ if (tid < remaining) {
280
+ int dst_idx = tid + part * ngpus;
281
+ ((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
282
+ }
283
+
284
+ // faster than this
285
+ // for (int idx = tid; idx < size; idx += stride) {
286
+ // int target_rank = idx / part;
287
+ // if (target_rank == ngpus) target_rank -= 1;
288
+ // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
289
+ // }
290
+ }
291
+
292
+ template <typename T, int ngpus>
293
+ __global__ void __launch_bounds__(512, 1)
294
+ cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
295
+ volatile Metadata *meta,
296
+ T *__restrict__ result, int rank,
297
+ int size) {
298
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
299
+ int stride = gridDim.x * blockDim.x;
300
+ using P = typename packed_t<T>::P;
301
+ using A = typename packed_t<T>::A;
302
+ auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
303
+ constexpr int hg = ngpus / 2;
304
+ // Actually not quite half butterfly.
305
+ // This is an all-to-all within each group containing half of the ranks
306
+ // followed by cross-group add. Equivalent to half butterfly when there
307
+ // are 4 GPUs, a common case for PCIe cards like T4 and A10.
308
+ const P *ptrs[hg];
309
+ {
310
+ int start = rank - rank % hg;
311
+ #pragma unroll
312
+ for (int i = 0; i < hg; i++) {
313
+ ptrs[i] = (const P *)_dp->ptrs[i + start];
314
+ }
315
+ }
316
+ start_sync<ngpus>(sg, meta, rank);
317
+ for (int idx = tid; idx < size; idx += stride) {
318
+ tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
319
+ }
320
+ end_sync<ngpus>(sg, meta, rank);
321
+
322
+ auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
323
+ // do the cross group reduction
324
+ for (int idx = tid; idx < size; idx += stride) {
325
+ auto tmp = tmp_out[idx];
326
+ packed_assign_add(tmp, src[idx]);
327
+ ((P *)result)[idx] = tmp;
328
+ }
329
+ }
330
+
331
+ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
332
+ static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
333
+ static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
334
+
335
+ class CustomAllreduce {
336
+ public:
337
+ int rank_;
338
+ int world_size_;
339
+ bool full_nvlink_;
340
+
341
+ // below are device pointers
342
+ RankSignals sg_;
343
+ std::unordered_map<void *, RankData *> buffers_;
344
+ Metadata *meta_;
345
+
346
+ // stores the registered device pointers from all ranks
347
+ RankData *d_rank_data_base_, *d_rank_data_end_;
348
+ std::vector<void *> graph_unreg_buffers_;
349
+ // a map from IPC handles to opened IPC pointers
350
+ std::map<IPC_KEY, char *> ipc_handles_;
351
+
352
+ /**
353
+ * meta is a pointer to device metadata and temporary buffer for allreduce.
354
+ *
355
+ * There's a total of sizeof(Metadata) of prefix before the actual data,
356
+ * so meta + 1 points to actual temporary buffer.
357
+ *
358
+ * note: this class does not own any device memory. Any required buffers
359
+ * are passed in from the constructor
360
+ */
361
+ CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
362
+ const cudaIpcMemHandle_t *handles,
363
+ const std::vector<int64_t> &offsets, int rank,
364
+ bool full_nvlink = true)
365
+ : rank_(rank),
366
+ world_size_(offsets.size()),
367
+ full_nvlink_(full_nvlink),
368
+ meta_(meta),
369
+ d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
370
+ d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
371
+ for (int i = 0; i < world_size_; i++) {
372
+ Metadata *rank_meta;
373
+ if (i != rank_) {
374
+ char *handle = open_ipc_handle(&handles[i]);
375
+ handle += offsets[i];
376
+ rank_meta = (Metadata *)handle;
377
+ } else {
378
+ rank_meta = meta_;
379
+ }
380
+ sg_.signals[i] = &rank_meta->sg;
381
+ }
382
+ }
383
+
384
+ char *open_ipc_handle(const void *ipc_handle) {
385
+ auto [it, new_handle] =
386
+ ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
387
+ if (new_handle) {
388
+ char *ipc_ptr;
389
+ CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
390
+ *((const cudaIpcMemHandle_t *)ipc_handle),
391
+ cudaIpcMemLazyEnablePeerAccess));
392
+ it->second = ipc_ptr;
393
+ }
394
+ return it->second;
395
+ }
396
+
397
+ std::pair<std::vector<uint8_t>, std::vector<int64_t>>
398
+ get_graph_buffer_ipc_meta() {
399
+ auto num_buffers = graph_unreg_buffers_.size();
400
+ auto handle_sz = sizeof(cudaIpcMemHandle_t);
401
+ std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
402
+ std::vector<int64_t> offsets(num_buffers);
403
+ for (int i = 0; i < num_buffers; i++) {
404
+ auto ptr = graph_unreg_buffers_[i];
405
+ void *base_ptr;
406
+ // note: must share the base address of each allocation, or we get wrong
407
+ // address
408
+ if (cuPointerGetAttribute(&base_ptr,
409
+ CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
410
+ (CUdeviceptr)ptr) != CUDA_SUCCESS)
411
+ throw std::runtime_error("failed to get pointer attr");
412
+ CUDACHECK(cudaIpcGetMemHandle(
413
+ (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
414
+ offsets[i] = ((char *)ptr) - ((char *)base_ptr);
415
+ }
416
+ return std::make_pair(handles, offsets);
417
+ }
418
+
419
+ void check_rank_data_capacity(size_t num = 1) {
420
+ if (d_rank_data_base_ + num > d_rank_data_end_)
421
+ throw std::runtime_error(
422
+ "Rank data buffer is overflowed by " +
423
+ std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
424
+ }
425
+
426
+ void register_buffer(const std::vector<std::string> &handles,
427
+ const std::vector<int64_t> &offsets, void *self) {
428
+ check_rank_data_capacity();
429
+ RankData data;
430
+ for (int i = 0; i < world_size_; i++) {
431
+ if (i != rank_) {
432
+ char *handle = open_ipc_handle(handles[i].data());
433
+ handle += offsets[i];
434
+ data.ptrs[i] = handle;
435
+ } else {
436
+ data.ptrs[i] = self;
437
+ }
438
+ }
439
+ auto d_data = d_rank_data_base_++;
440
+ CUDACHECK(
441
+ cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
442
+ buffers_[self] = d_data;
443
+ }
444
+
445
+ // note: when registering graph buffers, we intentionally choose to not
446
+ // deduplicate the addresses. That means if the allocator reuses some
447
+ // addresses, they will be registered again. This is to account for the remote
448
+ // possibility of different allocation patterns between ranks. For example,
449
+ // rank 1 may get the same input address for the second allreduce, but rank 2
450
+ // got a different address. IPC handles have internal reference counting
451
+ // mechanism so overhead should be small.
452
+ void register_graph_buffers(
453
+ const std::vector<std::string> &handles,
454
+ const std::vector<std::vector<int64_t>> &offsets) {
455
+ auto num_buffers = graph_unreg_buffers_.size();
456
+ check_rank_data_capacity(num_buffers);
457
+ std::vector<RankData> rank_data(num_buffers);
458
+ for (int i = 0; i < num_buffers; i++) {
459
+ auto self_ptr = graph_unreg_buffers_[i];
460
+ auto &rd = rank_data[i];
461
+ for (int j = 0; j < world_size_; j++) {
462
+ if (j != rank_) {
463
+ char *handle =
464
+ open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
465
+ handle += offsets[j][i];
466
+ rd.ptrs[j] = handle;
467
+ } else {
468
+ rd.ptrs[j] = self_ptr;
469
+ }
470
+ }
471
+ }
472
+ CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
473
+ sizeof(RankData) * num_buffers,
474
+ cudaMemcpyHostToDevice));
475
+ d_rank_data_base_ += num_buffers;
476
+ graph_unreg_buffers_.clear();
477
+ }
478
+
479
+ /**
480
+ * This is the result after careful grid search. Using 36 blocks give the best
481
+ * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
482
+ * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
483
+ * Not quite sure the underlying reason, but my guess is that too many SMs
484
+ * will cause contention on NVLink bus.
485
+ */
486
+ template <typename T>
487
+ void allreduce(cudaStream_t stream, T *input, T *output, int size,
488
+ int threads = 512, int block_limit = 36) {
489
+ auto d = packed_t<T>::P::size;
490
+ if (size % d != 0)
491
+ throw std::runtime_error(
492
+ "custom allreduce currently requires input length to be multiple "
493
+ "of " +
494
+ std::to_string(d));
495
+
496
+ RankData *ptrs;
497
+ cudaStreamCaptureStatus status;
498
+ CUDACHECK(cudaStreamIsCapturing(stream, &status));
499
+ if (status == cudaStreamCaptureStatusActive) {
500
+ ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
501
+ graph_unreg_buffers_.push_back(input);
502
+ } else {
503
+ auto it = buffers_.find(input);
504
+ if (it == buffers_.end())
505
+ throw std::runtime_error(
506
+ "buffer address " +
507
+ std::to_string(reinterpret_cast<uint64_t>(input)) +
508
+ " is not registered!");
509
+ ptrs = it->second;
510
+ }
511
+
512
+ size /= d;
513
+ auto bytes = size * sizeof(typename packed_t<T>::P);
514
+ int blocks = std::min(block_limit, (size + threads - 1) / threads);
515
+ #define KL(ngpus, name) \
516
+ name<T, ngpus> \
517
+ <<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
518
+ #define REDUCE_CASE(ngpus) \
519
+ case ngpus: { \
520
+ if (world_size_ == 2) { \
521
+ KL(ngpus, cross_device_reduce_1stage); \
522
+ } else if (full_nvlink_) { \
523
+ if ((world_size_ <= 4 && bytes < 512 * 1024) || \
524
+ (world_size_ <= 8 && bytes < 256 * 1024)) { \
525
+ KL(ngpus, cross_device_reduce_1stage); \
526
+ } else { \
527
+ KL(ngpus, cross_device_reduce_2stage); \
528
+ } \
529
+ } else { \
530
+ KL(ngpus, cross_device_reduce_half_butterfly); \
531
+ } \
532
+ break; \
533
+ }
534
+
535
+ switch (world_size_) {
536
+ REDUCE_CASE(2)
537
+ REDUCE_CASE(4)
538
+ REDUCE_CASE(6)
539
+ REDUCE_CASE(8)
540
+ default:
541
+ throw std::runtime_error(
542
+ "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
543
+ "gpus = " +
544
+ std::to_string(world_size_));
545
+ }
546
+ #undef REDUCE_CASE
547
+ #undef KL
548
+ }
549
+
550
+ ~CustomAllreduce() {
551
+ for (auto [_, ptr] : ipc_handles_) {
552
+ CUDACHECK(cudaIpcCloseMemHandle(ptr));
553
+ }
554
+ }
555
+ };
556
+ /**
557
+ * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
558
+ a template instantiation:
559
+ * template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
560
+ int, int, int);
561
+ */
562
+ } // namespace vllm
csrc/custom_all_reduce_test.cu ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * This is a standalone test for custom allreduce.
3
+ * To compile, make sure you have MPI and NCCL installed in your system.
4
+ * export MPI_HOME=XXX
5
+ * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
6
+ * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
7
+ *
8
+ * Warning: this C++ test is not designed to be very readable and was used
9
+ * during the rapid prototyping process.
10
+ *
11
+ * To run:
12
+ * mpirun -np 8 ./custom_all_reduce_test
13
+ */
14
+ #include <cuda.h>
15
+ #include <curand_kernel.h>
16
+ #include <stdio.h>
17
+ #include <stdlib.h>
18
+
19
+ #include <limits>
20
+ #include <vector>
21
+
22
+ #include "cuda_profiler_api.h"
23
+ #include "custom_all_reduce.cuh"
24
+ #include "mpi.h"
25
+ #include "nccl.h"
26
+
27
+ #define MPICHECK(cmd) \
28
+ do { \
29
+ int e = cmd; \
30
+ if (e != MPI_SUCCESS) { \
31
+ printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
32
+ exit(EXIT_FAILURE); \
33
+ } \
34
+ } while (0)
35
+
36
+ #define NCCLCHECK(cmd) \
37
+ do { \
38
+ ncclResult_t r = cmd; \
39
+ if (r != ncclSuccess) { \
40
+ printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
41
+ ncclGetErrorString(r)); \
42
+ exit(EXIT_FAILURE); \
43
+ } \
44
+ } while (0)
45
+
46
+ __global__ void dummy_kernel() {
47
+ for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
48
+ }
49
+
50
+ template <typename T>
51
+ __global__ void set_data(T *data, int size, int myRank) {
52
+ for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
53
+ idx += gridDim.x * blockDim.x) {
54
+ data[idx] = myRank * 0.11f;
55
+ }
56
+ }
57
+
58
+ template <typename T>
59
+ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
60
+ double *fdata2, int size) {
61
+ for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
62
+ idx += gridDim.x * blockDim.x) {
63
+ fdata1[idx] = data1[idx];
64
+ fdata2[idx] = data2[idx];
65
+ }
66
+ }
67
+
68
+ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
69
+ for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
70
+ idx += gridDim.x * blockDim.x) {
71
+ for (int i = 0; i < nRanks; i++) {
72
+ curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
73
+ }
74
+ }
75
+ }
76
+
77
+ template <typename T>
78
+ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
79
+ int myRank, int nRanks, int size) {
80
+ for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
81
+ idx += gridDim.x * blockDim.x) {
82
+ double sum = 0.0;
83
+ for (int i = 0; i < nRanks; i++) {
84
+ double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
85
+ T hval = val; // downcast first
86
+ sum += static_cast<double>(hval);
87
+ if (i == myRank) data[idx] = hval;
88
+ }
89
+ ground_truth[idx] = sum;
90
+ }
91
+ }
92
+
93
+ template <typename T>
94
+ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
95
+ int data_size) {
96
+ T *result;
97
+ cudaStream_t stream;
98
+ CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
99
+ CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
100
+ CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
101
+
102
+ cudaIpcMemHandle_t self_data_handle;
103
+ cudaIpcMemHandle_t data_handles[8];
104
+ vllm::Metadata *buffer;
105
+ T *self_data_copy;
106
+ /**
107
+ * Allocate IPC buffer
108
+ *
109
+ * The first section is a temporary buffer for storing intermediate allreduce
110
+ * results, if a particular algorithm requires it. The second section is for
111
+ * the input to the allreduce. The actual API takes the input pointer as an
112
+ * argument (that is, they can and usually should be allocated separately).
113
+ * But since the input pointers and the temporary buffer all require IPC
114
+ * registration, they are allocated and registered together in the test for
115
+ * convenience.
116
+ */
117
+ CUDACHECK(
118
+ cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
119
+ CUDACHECK(cudaMemset(buffer, 0,
120
+ 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
121
+ CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
122
+ CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
123
+
124
+ MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
125
+ MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
126
+ MPI_BYTE, MPI_COMM_WORLD));
127
+
128
+ void *rank_data;
129
+ size_t rank_data_sz = 16 * 1024 * 1024;
130
+ CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
131
+ std::vector<int64_t> offsets(nRanks, 0);
132
+ vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
133
+ offsets, myRank);
134
+ auto *self_data =
135
+ reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
136
+ sizeof(vllm::Metadata) + data_size * sizeof(T));
137
+ // hack buffer registration
138
+ {
139
+ std::vector<std::string> handles;
140
+ handles.reserve(nRanks);
141
+ for (int i = 0; i < nRanks; i++) {
142
+ char *begin = (char *)&data_handles[i];
143
+ char *end = (char *)&data_handles[i + 1];
144
+ handles.emplace_back(begin, end);
145
+ }
146
+ std::vector<int64_t> offsets(
147
+ nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
148
+ fa.register_buffer(handles, offsets, self_data);
149
+ }
150
+
151
+ double *ground_truth;
152
+ CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
153
+ curandState_t *states;
154
+ CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
155
+ init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
156
+ gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
157
+ nRanks, data_size);
158
+ CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
159
+ cudaMemcpyDeviceToDevice, stream));
160
+ cudaEvent_t start, stop;
161
+ CUDACHECK(cudaEventCreate(&start));
162
+ CUDACHECK(cudaEventCreate(&stop));
163
+
164
+ ncclDataType_t ncclDtype;
165
+ if (std::is_same<T, half>::value) {
166
+ ncclDtype = ncclFloat16;
167
+ } else if (std::is_same<T, nv_bfloat16>::value) {
168
+ ncclDtype = ncclBfloat16;
169
+ } else {
170
+ ncclDtype = ncclFloat;
171
+ }
172
+
173
+ dummy_kernel<<<1, 1, 0, stream>>>();
174
+ constexpr int warmup_iters = 5;
175
+ constexpr int num_iters = 25;
176
+ // warmup
177
+ for (int i = 0; i < warmup_iters; i++) {
178
+ NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
179
+ stream));
180
+ }
181
+ CUDACHECK(cudaEventRecord(start, stream));
182
+ for (int i = 0; i < num_iters; i++) {
183
+ NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
184
+ stream));
185
+ }
186
+ CUDACHECK(cudaEventRecord(stop, stream));
187
+ CUDACHECK(cudaStreamSynchronize(stream));
188
+ float allreduce_ms = 0;
189
+ cudaEventElapsedTime(&allreduce_ms, start, stop);
190
+
191
+ // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
192
+ // set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
193
+
194
+ dummy_kernel<<<1, 1, 0, stream>>>();
195
+ // warm up
196
+ for (int i = 0; i < warmup_iters; i++) {
197
+ fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
198
+ }
199
+ CUDACHECK(cudaEventRecord(start, stream));
200
+ for (int i = 0; i < num_iters; i++) {
201
+ fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
202
+ }
203
+ CUDACHECK(cudaEventRecord(stop, stream));
204
+ CUDACHECK(cudaStreamSynchronize(stream));
205
+
206
+ float duration_ms = 0;
207
+ cudaEventElapsedTime(&duration_ms, start, stop);
208
+ if (myRank == 0)
209
+ printf(
210
+ "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
211
+ "time:%.2fus\n",
212
+ myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
213
+ duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
214
+
215
+ // And wait for all the queued up work to complete
216
+ CUDACHECK(cudaStreamSynchronize(stream));
217
+
218
+ NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
219
+ ncclSum, comm, stream));
220
+
221
+ double *nccl_result, *my_result;
222
+ CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
223
+ CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
224
+
225
+ convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
226
+ my_result, data_size);
227
+ CUDACHECK(cudaStreamSynchronize(stream));
228
+
229
+ for (unsigned long j = 0; j < data_size; j++) {
230
+ auto diff = abs(nccl_result[j] - my_result[j]);
231
+ if (diff >= 1e-2) {
232
+ printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
233
+ myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
234
+ break;
235
+ }
236
+ }
237
+
238
+ long double nccl_diffs = 0.0;
239
+ long double my_diffs = 0.0;
240
+ for (int j = 0; j < data_size; j++) {
241
+ nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
242
+ my_diffs += abs(my_result[j] - ground_truth[j]);
243
+ }
244
+ if (myRank == 0)
245
+ std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
246
+ << " me: " << my_diffs / data_size << std::endl;
247
+
248
+ CUDACHECK(cudaFree(result));
249
+ CUDACHECK(cudaFree(self_data_copy));
250
+ CUDACHECK(cudaFree(rank_data));
251
+ CUDACHECK(cudaFree(buffer));
252
+ CUDACHECK(cudaFree(states));
253
+ CUDACHECK(cudaFreeHost(ground_truth));
254
+ CUDACHECK(cudaFreeHost(nccl_result));
255
+ CUDACHECK(cudaFreeHost(my_result));
256
+ CUDACHECK(cudaStreamDestroy(stream));
257
+ }
258
+
259
+ int main(int argc, char **argv) {
260
+ int nRanks, myRank;
261
+ MPICHECK(MPI_Init(&argc, &argv));
262
+ MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
263
+ MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
264
+ CUDACHECK(cudaSetDevice(myRank));
265
+ ncclUniqueId id;
266
+ ncclComm_t comm;
267
+ if (myRank == 0) ncclGetUniqueId(&id);
268
+ MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
269
+ MPI_COMM_WORLD));
270
+ NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
271
+
272
+ cudaProfilerStart();
273
+ // for (int threads : {256, 512}) {
274
+ // for (int block_limit = 16; block_limit < 112; block_limit += 4) {
275
+ // run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
276
+ // }
277
+ // }
278
+ for (int sz = 512; sz <= (32 << 20); sz *= 2) {
279
+ run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
280
+ }
281
+
282
+ cudaProfilerStop();
283
+ return EXIT_SUCCESS;
284
+ }
csrc/dispatch_utils.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
+ */
5
+ #pragma once
6
+
7
+ #include <torch/extension.h>
8
+
9
+ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
+
14
+ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH( \
16
+ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
17
+
18
+ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
19
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
20
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
21
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
22
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
23
+
24
+ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
25
+ AT_DISPATCH_SWITCH( \
26
+ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
27
+
28
+ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
29
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
30
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
31
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
32
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
33
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
34
+
35
+ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
36
+ AT_DISPATCH_SWITCH( \
37
+ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
csrc/layernorm_kernels.cu ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include "dispatch_utils.h"
6
+ #include "reduction_utils.cuh"
7
+
8
+ namespace vllm {
9
+
10
+ // TODO(woosuk): Further optimize this kernel.
11
+ template<typename scalar_t>
12
+ __global__ void rms_norm_kernel(
13
+ scalar_t* __restrict__ out, // [..., hidden_size]
14
+ const scalar_t* __restrict__ input, // [..., hidden_size]
15
+ const scalar_t* __restrict__ weight, // [hidden_size]
16
+ const float epsilon,
17
+ const int num_tokens,
18
+ const int hidden_size) {
19
+ __shared__ float s_variance;
20
+ float variance = 0.0f;
21
+
22
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
23
+ const float x = (float) input[blockIdx.x * hidden_size + idx];
24
+ variance += x * x;
25
+ }
26
+ variance = blockReduceSum<float>(variance);
27
+ if (threadIdx.x == 0) {
28
+ s_variance = rsqrtf(variance / hidden_size + epsilon);
29
+ }
30
+ __syncthreads();
31
+
32
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
33
+ float x = (float) input[blockIdx.x * hidden_size + idx];
34
+ out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
35
+ }
36
+ }
37
+
38
+ // TODO: Further optimize this kernel.
39
+ template<typename scalar_t>
40
+ __global__ void fused_add_rms_norm_kernel(
41
+ scalar_t* __restrict__ input, // [..., hidden_size]
42
+ scalar_t* __restrict__ residual, // [..., hidden_size]
43
+ const scalar_t* __restrict__ weight, // [hidden_size]
44
+ const float epsilon,
45
+ const int num_tokens,
46
+ const int hidden_size) {
47
+ __shared__ float s_variance;
48
+ float variance = 0.0f;
49
+
50
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
51
+ float x = (float) input[blockIdx.x * hidden_size + idx];
52
+ x += (float) residual[blockIdx.x * hidden_size + idx];
53
+ variance += x * x;
54
+ residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
55
+ }
56
+ variance = blockReduceSum<float>(variance);
57
+ if (threadIdx.x == 0) {
58
+ s_variance = rsqrtf(variance / hidden_size + epsilon);
59
+ }
60
+ __syncthreads();
61
+
62
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
63
+ float x = (float) residual[blockIdx.x * hidden_size + idx];
64
+ input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
65
+ }
66
+ }
67
+
68
+ } // namespace vllm
69
+
70
+ void rms_norm(
71
+ torch::Tensor& out, // [..., hidden_size]
72
+ torch::Tensor& input, // [..., hidden_size]
73
+ torch::Tensor& weight, // [hidden_size]
74
+ float epsilon) {
75
+ int hidden_size = input.size(-1);
76
+ int num_tokens = input.numel() / hidden_size;
77
+
78
+ dim3 grid(num_tokens);
79
+ dim3 block(std::min(hidden_size, 1024));
80
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
81
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
82
+ VLLM_DISPATCH_FLOATING_TYPES(
83
+ input.scalar_type(),
84
+ "rms_norm_kernel",
85
+ [&] {
86
+ vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
87
+ out.data_ptr<scalar_t>(),
88
+ input.data_ptr<scalar_t>(),
89
+ weight.data_ptr<scalar_t>(),
90
+ epsilon,
91
+ num_tokens,
92
+ hidden_size);
93
+ });
94
+ }
95
+
96
+ void fused_add_rms_norm(
97
+ torch::Tensor& input, // [..., hidden_size]
98
+ torch::Tensor& residual, // [..., hidden_size]
99
+ torch::Tensor& weight, // [hidden_size]
100
+ float epsilon) {
101
+ int hidden_size = input.size(-1);
102
+ int num_tokens = input.numel() / hidden_size;
103
+
104
+ dim3 grid(num_tokens);
105
+ dim3 block(std::min(hidden_size, 1024));
106
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
107
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
108
+ VLLM_DISPATCH_FLOATING_TYPES(
109
+ input.scalar_type(),
110
+ "fused_add_rms_norm_kernel",
111
+ [&] {
112
+ vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
113
+ input.data_ptr<scalar_t>(),
114
+ residual.data_ptr<scalar_t>(),
115
+ weight.data_ptr<scalar_t>(),
116
+ epsilon,
117
+ num_tokens,
118
+ hidden_size);
119
+ });
120
+ }
csrc/moe_align_block_size_kernels.cu ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+
4
+ #include <ATen/ATen.h>
5
+ #include <THC/THCAtomics.cuh>
6
+
7
+ #include "cuda_compat.h"
8
+ #include "dispatch_utils.h"
9
+
10
+ const static size_t NUM_MAX_EXPERTS = 64;
11
+ #define CEILDIV(x,y) (((x) + (y) - 1) / (y))
12
+
13
+ namespace vllm {
14
+ template <typename scalar_t>
15
+ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
16
+ int32_t *sorted_token_ids,
17
+ int32_t *expert_ids,
18
+ int32_t *total_tokens_post_pad,
19
+ int32_t num_experts,
20
+ int32_t block_size,
21
+ size_t numel) {
22
+ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
23
+ const size_t start_idx = threadIdx.x * tokens_per_thread;
24
+ __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
25
+ __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
26
+ for (int i = 0; i < num_experts; ++i) {
27
+ tokens_cnts[threadIdx.x + 1][i] = 0;
28
+ }
29
+
30
+ /**
31
+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
32
+ * which counts how many tokens in the token shard of thread_index are assigned
33
+ * to expert expert_index.
34
+ */
35
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
36
+ ++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
37
+ }
38
+
39
+ __syncthreads();
40
+
41
+ // For each expert we accumulate the token counts from the different threads.
42
+ tokens_cnts[0][threadIdx.x] = 0;
43
+ for (int i = 1; i <= blockDim.x; ++i) {
44
+ tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
45
+ }
46
+
47
+ __syncthreads();
48
+
49
+ // We accumulate the token counts of all experts in thread 0.
50
+ if (threadIdx.x == 0) {
51
+ cumsum[0] = 0;
52
+ for (int i = 1; i <= num_experts; ++i) {
53
+ cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
54
+ }
55
+ *total_tokens_post_pad = cumsum[num_experts];
56
+ }
57
+
58
+ __syncthreads();
59
+
60
+ /**
61
+ * For each expert, each thread processes the tokens of the corresponding blocks
62
+ * and stores the corresponding expert_id for each block.
63
+ */
64
+ for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
65
+ expert_ids[i / block_size] = threadIdx.x;
66
+ }
67
+
68
+ /**
69
+ * Each thread processes a token shard, calculating the index of each token after
70
+ * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
71
+ * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
72
+ * where * represents a padding value(preset in python).
73
+ */
74
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
75
+ int32_t expert_id = topk_ids[i];
76
+ /** The cumsum[expert_id] stores the starting index of the tokens that the
77
+ * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
78
+ * stores the indices of the tokens processed by the expert with expert_id within
79
+ * the current thread's token shard.
80
+ */
81
+ int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
82
+ sorted_token_ids[rank_post_pad] = i;
83
+ ++tokens_cnts[threadIdx.x][expert_id];
84
+ }
85
+ }
86
+ }
87
+
88
+ void moe_align_block_size(
89
+ torch::Tensor topk_ids,
90
+ int num_experts,
91
+ int block_size,
92
+ torch::Tensor sorted_token_ids,
93
+ torch::Tensor experts_ids,
94
+ torch::Tensor num_tokens_post_pad) {
95
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
96
+ assert(num_experts <= NUM_MAX_EXPERTS);
97
+ VLLM_DISPATCH_INTEGRAL_TYPES(
98
+ topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
99
+ vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
100
+ topk_ids.data_ptr<scalar_t>(),
101
+ sorted_token_ids.data_ptr<int32_t>(),
102
+ experts_ids.data_ptr<int32_t>(),
103
+ num_tokens_post_pad.data_ptr<int32_t>(),
104
+ num_experts,
105
+ block_size,
106
+ topk_ids.numel());
107
+ });
108
+ }
csrc/ops.h ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/extension.h>
4
+
5
+ void paged_attention_v1(
6
+ torch::Tensor& out,
7
+ torch::Tensor& query,
8
+ torch::Tensor& key_cache,
9
+ torch::Tensor& value_cache,
10
+ int num_kv_heads,
11
+ float scale,
12
+ torch::Tensor& block_tables,
13
+ torch::Tensor& context_lens,
14
+ int block_size,
15
+ int max_context_len,
16
+ const c10::optional<torch::Tensor>& alibi_slopes,
17
+ const std::string& kv_cache_dtype);
18
+
19
+ void paged_attention_v2(
20
+ torch::Tensor& out,
21
+ torch::Tensor& exp_sums,
22
+ torch::Tensor& max_logits,
23
+ torch::Tensor& tmp_out,
24
+ torch::Tensor& query,
25
+ torch::Tensor& key_cache,
26
+ torch::Tensor& value_cache,
27
+ int num_kv_heads,
28
+ float scale,
29
+ torch::Tensor& block_tables,
30
+ torch::Tensor& context_lens,
31
+ int block_size,
32
+ int max_context_len,
33
+ const c10::optional<torch::Tensor>& alibi_slopes,
34
+ const std::string& kv_cache_dtype);
35
+
36
+ void rms_norm(
37
+ torch::Tensor& out,
38
+ torch::Tensor& input,
39
+ torch::Tensor& weight,
40
+ float epsilon);
41
+
42
+ void fused_add_rms_norm(
43
+ torch::Tensor& input,
44
+ torch::Tensor& residual,
45
+ torch::Tensor& weight,
46
+ float epsilon);
47
+
48
+ void rotary_embedding(
49
+ torch::Tensor& positions,
50
+ torch::Tensor& query,
51
+ torch::Tensor& key,
52
+ int head_size,
53
+ torch::Tensor& cos_sin_cache,
54
+ bool is_neox);
55
+
56
+ void silu_and_mul(
57
+ torch::Tensor& out,
58
+ torch::Tensor& input);
59
+
60
+ void gelu_new(
61
+ torch::Tensor& out,
62
+ torch::Tensor& input);
63
+
64
+ void gelu_fast(
65
+ torch::Tensor& out,
66
+ torch::Tensor& input);
67
+
68
+ #ifndef USE_ROCM
69
+ torch::Tensor awq_gemm(
70
+ torch::Tensor _in_feats,
71
+ torch::Tensor _kernel,
72
+ torch::Tensor _scaling_factors,
73
+ torch::Tensor _zeros,
74
+ int split_k_iters);
75
+
76
+ torch::Tensor awq_dequantize(
77
+ torch::Tensor _kernel,
78
+ torch::Tensor _scaling_factors,
79
+ torch::Tensor _zeros,
80
+ int split_k_iters,
81
+ int thx,
82
+ int thy);
83
+ #endif
84
+
85
+ void squeezellm_gemm(
86
+ torch::Tensor vec,
87
+ torch::Tensor mat,
88
+ torch::Tensor mul,
89
+ torch::Tensor lookup_table);
90
+
91
+ torch::Tensor gptq_gemm(
92
+ torch::Tensor a,
93
+ torch::Tensor b_q_weight,
94
+ torch::Tensor b_gptq_qzeros,
95
+ torch::Tensor b_gptq_scales,
96
+ torch::Tensor b_g_idx,
97
+ bool use_exllama);
98
+
99
+ void gptq_shuffle(
100
+ torch::Tensor q_weight,
101
+ torch::Tensor q_perm);
102
+
103
+ void moe_align_block_size(
104
+ torch::Tensor topk_ids,
105
+ int num_experts,
106
+ int block_size,
107
+ torch::Tensor sorted_token_ids,
108
+ torch::Tensor experts_ids,
109
+ torch::Tensor num_tokens_post_pad);
110
+
111
+ #ifndef USE_ROCM
112
+ using fptr_t = uint64_t;
113
+ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
114
+ const std::vector<std::string> &handles,
115
+ const std::vector<int64_t> &offsets, int rank,
116
+ bool full_nvlink);
117
+ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
118
+ bool full_nvlink);
119
+ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
120
+ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
121
+ torch::Tensor &out);
122
+ void dispose(fptr_t _fa);
123
+ int meta_size();
124
+ void register_buffer(fptr_t _fa, torch::Tensor &t,
125
+ const std::vector<std::string> &handles,
126
+ const std::vector<int64_t> &offsets);
127
+ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
128
+ void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
129
+ const std::vector<std::vector<int64_t>> &offsets);
130
+ #endif
csrc/pos_encoding_kernels.cu ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include "cuda_compat.h"
6
+ #include "dispatch_utils.h"
7
+
8
+ namespace vllm {
9
+
10
+ template<typename scalar_t, bool IS_NEOX>
11
+ inline __device__ void apply_rotary_embedding(
12
+ scalar_t* __restrict__ arr,
13
+ const scalar_t* __restrict__ cos_ptr,
14
+ const scalar_t* __restrict__ sin_ptr,
15
+ int rot_offset,
16
+ int embed_dim)
17
+ {
18
+ int x_index, y_index;
19
+ scalar_t cos, sin;
20
+ if (IS_NEOX) {
21
+ // GPT-NeoX style rotary embedding.
22
+ x_index = rot_offset;
23
+ y_index = embed_dim + rot_offset;
24
+ cos = VLLM_LDG(cos_ptr + x_index);
25
+ sin = VLLM_LDG(sin_ptr + x_index);
26
+ } else {
27
+ // GPT-J style rotary embedding.
28
+ x_index = 2 * rot_offset;
29
+ y_index = 2 * rot_offset + 1;
30
+ cos = VLLM_LDG(cos_ptr + x_index / 2);
31
+ sin = VLLM_LDG(sin_ptr + x_index / 2);
32
+ }
33
+
34
+ const scalar_t x = arr[x_index];
35
+ const scalar_t y = arr[y_index];
36
+ arr[x_index] = x * cos - y * sin;
37
+ arr[y_index] = y * cos + x * sin;
38
+ }
39
+
40
+ template<typename scalar_t, bool IS_NEOX>
41
+ __global__ void rotary_embedding_kernel(
42
+ const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
43
+ scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
44
+ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
45
+ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
46
+ const int rot_dim,
47
+ const int64_t query_stride,
48
+ const int64_t key_stride,
49
+ const int num_heads,
50
+ const int num_kv_heads,
51
+ const int head_size) {
52
+ // Each thread block is responsible for one token.
53
+ const int token_idx = blockIdx.x;
54
+ int64_t pos = positions[token_idx];
55
+ const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
56
+
57
+ const int embed_dim = rot_dim / 2;
58
+ const scalar_t* cos_ptr = cache_ptr;
59
+ const scalar_t* sin_ptr = cache_ptr + embed_dim;
60
+
61
+ const int nq = num_heads * embed_dim;
62
+ for (int i = threadIdx.x; i < nq; i += blockDim.x) {
63
+ const int head_idx = i / embed_dim;
64
+ const int64_t token_head = token_idx * query_stride + head_idx * head_size;
65
+ const int rot_offset = i % embed_dim;
66
+ apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
67
+ sin_ptr, rot_offset, embed_dim);
68
+ }
69
+
70
+ const int nk = num_kv_heads * embed_dim;
71
+ for (int i = threadIdx.x; i < nk; i += blockDim.x) {
72
+ const int head_idx = i / embed_dim;
73
+ const int64_t token_head = token_idx * key_stride + head_idx * head_size;
74
+ const int rot_offset = i % embed_dim;
75
+ apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
76
+ sin_ptr, rot_offset, embed_dim);
77
+ }
78
+ }
79
+
80
+ } // namespace vllm
81
+
82
+ void rotary_embedding(
83
+ torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
84
+ torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
85
+ torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
86
+ int head_size,
87
+ torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
88
+ bool is_neox) {
89
+ int64_t num_tokens = query.numel() / query.size(-1);
90
+ int rot_dim = cos_sin_cache.size(1);
91
+ int num_heads = query.size(-1) / head_size;
92
+ int num_kv_heads = key.size(-1) / head_size;
93
+ int64_t query_stride = query.stride(-2);
94
+ int64_t key_stride = key.stride(-2);
95
+
96
+ dim3 grid(num_tokens);
97
+ dim3 block(std::min(num_heads * rot_dim / 2, 512));
98
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
99
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
100
+ VLLM_DISPATCH_FLOATING_TYPES(
101
+ query.scalar_type(),
102
+ "rotary_embedding",
103
+ [&] {
104
+ if (is_neox) {
105
+ vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
106
+ positions.data_ptr<int64_t>(),
107
+ query.data_ptr<scalar_t>(),
108
+ key.data_ptr<scalar_t>(),
109
+ cos_sin_cache.data_ptr<scalar_t>(),
110
+ rot_dim,
111
+ query_stride,
112
+ key_stride,
113
+ num_heads,
114
+ num_kv_heads,
115
+ head_size);
116
+ } else {
117
+ vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
118
+ positions.data_ptr<int64_t>(),
119
+ query.data_ptr<scalar_t>(),
120
+ key.data_ptr<scalar_t>(),
121
+ cos_sin_cache.data_ptr<scalar_t>(),
122
+ rot_dim,
123
+ query_stride,
124
+ key_stride,
125
+ num_heads,
126
+ num_kv_heads,
127
+ head_size);
128
+ }
129
+ });
130
+ }
csrc/punica/LICENSE ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Contains code from https://github.com/punica-ai/punica
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "{}"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright {yyyy} {name of copyright owner}
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
204
+
205
+ ------------------------------------------------------------------------------------
206
+
207
+ This product bundles various third-party components under other open source licenses.
208
+ This section summarizes those components and their licenses. See licenses/
209
+ for text of these licenses.
210
+
211
+
212
+ Apache-2.0
213
+ * third_party/nvbench (with LLVM exception)
214
+ * third_party/flashinfer
215
+
216
+ BSD-3-Clause:
217
+ * third_party/cutlass
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #include "bgmv_config.h"
2
+ #include "bgmv_impl.cuh"
3
+
4
+ FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)